From af51d0ecb51ad4d6c8ed086855bd3c411ebc4aa0 Mon Sep 17 00:00:00 2001
From: natanielruiz <nruiz9@gatech.edu>
Date: 星期一, 30 十月 2017 06:29:51 +0800
Subject: [PATCH] Fixed stuff

---
 code/test_alexnet.py |   25 +++++++++++++++++--------
 1 files changed, 17 insertions(+), 8 deletions(-)

diff --git a/code/test_alexnet.py b/code/test_alexnet.py
index d9cc0a3..7a3989a 100644
--- a/code/test_alexnet.py
+++ b/code/test_alexnet.py
@@ -39,6 +39,13 @@
 
     return args
 
+def load_filtered_state_dict(model, snapshot):
+    # By user apaszke from discuss.pytorch.org
+    model_dict = model.state_dict()
+    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
+    model_dict.update(snapshot)
+    model.load_state_dict(model_dict)
+
 if __name__ == '__main__':
     args = parse_args()
 
@@ -51,7 +58,7 @@
     print 'Loading snapshot.'
     # Load snapshot
     saved_state_dict = torch.load(snapshot_path)
-    model.load_state_dict(saved_state_dict)
+    load_filtered_state_dict(model, saved_state_dict)
 
     print 'Loading data.'
 
@@ -59,18 +66,20 @@
     transforms.CenterCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
-    if args.dataset == 'AFLW2000':
-        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
-                                transformations)
+    if args.dataset == 'Pose_300W_LP':
+        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'Pose_300W_LP_random_ds':
+        pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW2000':
+        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW2000_ds':
-        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list,
-                                transformations)
+        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'BIWI':
         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW':
         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
-    elif args.dataset == 'Pose_300W_LP':
-        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW_aug':
+        pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFW':
         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
     else:

--
Gitblit v1.8.0