From 4b67b5c8ed5566ec3030d537536282e830d87e40 Mon Sep 17 00:00:00 2001
From: natanielruiz <nruiz9@gatech.edu>
Date: 星期一, 30 十月 2017 07:15:49 +0800
Subject: [PATCH] next

---
 code/test_preangles.py |   24 ++++++++++++------------
 1 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/code/test_preangles.py b/code/test_preangles.py
index 05f621a..d4a9f5f 100644
--- a/code/test_preangles.py
+++ b/code/test_preangles.py
@@ -36,13 +36,6 @@
 
     return args
 
-def load_filtered_state_dict(model, snapshot):
-    # By user apaszke from discuss.pytorch.org
-    model_dict = model.state_dict()
-    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
-    model_dict.update(snapshot)
-    model.load_state_dict(model_dict)
-
 if __name__ == '__main__':
     args = parse_args()
 
@@ -56,7 +49,7 @@
     print 'Loading snapshot.'
     # Load snapshot
     saved_state_dict = torch.load(snapshot_path)
-    load_filtered_state_dict(model, saved_state_dict)
+    model.load_state_dict(saved_state_dict)
 
     print 'Loading data.'
 
@@ -95,6 +88,9 @@
     model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
     total = 0
 
+    idx_tensor = [idx for idx in xrange(66)]
+    idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
+
     yaw_error = .0
     pitch_error = .0
     roll_error = .0
@@ -109,7 +105,7 @@
         label_pitch = cont_labels[:,1].float()
         label_roll = cont_labels[:,2].float()
 
-        yaw, pitch, roll, angles = model(images)
+        yaw, pitch, roll = model(images)
 
         # Binned predictions
         _, yaw_bpred = torch.max(yaw.data, 1)
@@ -117,9 +113,13 @@
         _, roll_bpred = torch.max(roll.data, 1)
 
         # Continuous predictions
-        yaw_predicted = angles[:,0].data.cpu()
-        pitch_predicted = angles[:,1].data.cpu()
-        roll_predicted = angles[:,2].data.cpu()
+        yaw_predicted = utils.softmax_temperature(yaw.data, 1)
+        pitch_predicted = utils.softmax_temperature(pitch.data, 1)
+        roll_predicted = utils.softmax_temperature(roll.data, 1)
+
+        yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 3 - 99
+        pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 3 - 99
+        roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu() * 3 - 99
 
         # Mean absolute error
         yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw))

--
Gitblit v1.8.0