From 5483d8fec0814e9cc9f5e6fdbb69810f74c76ac9 Mon Sep 17 00:00:00 2001
From: natanielruiz <nruiz9@gatech.edu>
Date: 星期一, 30 十月 2017 07:09:27 +0800
Subject: [PATCH] next

---
 code/test_preangles.py |   18 ++++++++++++++----
 1 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/code/test_preangles.py b/code/test_preangles.py
index 05f621a..9cdc8e3 100644
--- a/code/test_preangles.py
+++ b/code/test_preangles.py
@@ -56,7 +56,8 @@
     print 'Loading snapshot.'
     # Load snapshot
     saved_state_dict = torch.load(snapshot_path)
-    load_filtered_state_dict(model, saved_state_dict)
+    model.load_state_dict(saved_state_dict)
+    # load_filtered_state_dict(model, saved_state_dict)
 
     print 'Loading data.'
 
@@ -95,11 +96,16 @@
     model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
     total = 0
 
+    idx_tensor = [idx for idx in xrange(66)]
+    idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
+
     yaw_error = .0
     pitch_error = .0
     roll_error = .0
 
     l1loss = torch.nn.L1Loss(size_average=False)
+
+
 
     for i, (images, labels, cont_labels, name) in enumerate(test_loader):
         images = Variable(images).cuda(gpu)
@@ -117,9 +123,13 @@
         _, roll_bpred = torch.max(roll.data, 1)
 
         # Continuous predictions
-        yaw_predicted = angles[:,0].data.cpu()
-        pitch_predicted = angles[:,1].data.cpu()
-        roll_predicted = angles[:,2].data.cpu()
+        yaw_predicted = utils.softmax_temperature(yaw.data, 1)
+        pitch_predicted = utils.softmax_temperature(pitch.data, 1)
+        roll_predicted = utils.softmax_temperature(roll.data, 1)
+
+        yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 3 - 99
+        pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 3 - 99
+        roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu() * 3 - 99
 
         # Mean absolute error
         yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw))

--
Gitblit v1.8.0