From 5483d8fec0814e9cc9f5e6fdbb69810f74c76ac9 Mon Sep 17 00:00:00 2001
From: natanielruiz <nruiz9@gatech.edu>
Date: 星期一, 30 十月 2017 07:09:27 +0800
Subject: [PATCH] next

---
 code/test_preangles.py |   18 ++++++++++++++++--
 1 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/code/test_preangles.py b/code/test_preangles.py
index 3d70bb0..9cdc8e3 100644
--- a/code/test_preangles.py
+++ b/code/test_preangles.py
@@ -36,6 +36,13 @@
 
     return args
 
+def load_filtered_state_dict(model, snapshot):
+    # By user apaszke from discuss.pytorch.org
+    model_dict = model.state_dict()
+    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
+    model_dict.update(snapshot)
+    model.load_state_dict(model_dict)
+
 if __name__ == '__main__':
     args = parse_args()
 
@@ -50,6 +57,7 @@
     # Load snapshot
     saved_state_dict = torch.load(snapshot_path)
     model.load_state_dict(saved_state_dict)
+    # load_filtered_state_dict(model, saved_state_dict)
 
     print 'Loading data.'
 
@@ -63,6 +71,8 @@
         pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW2000':
         pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW2000_ds':
+        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'BIWI':
         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW':
@@ -86,11 +96,16 @@
     model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
     total = 0
 
+    idx_tensor = [idx for idx in xrange(66)]
+    idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
+
     yaw_error = .0
     pitch_error = .0
     roll_error = .0
 
     l1loss = torch.nn.L1Loss(size_average=False)
+
+
 
     for i, (images, labels, cont_labels, name) in enumerate(test_loader):
         images = Variable(images).cuda(gpu)
@@ -121,8 +136,7 @@
         pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch))
         roll_error += torch.sum(torch.abs(roll_predicted - label_roll))
 
-        # Save images with pose cube.
-        # TODO: fix for larger batch size
+        # Save first image in batch with pose cube or axis.
         if args.save_viz:
             name = name[0]
             if args.dataset == 'BIWI':

--
Gitblit v1.8.0