From 2eb13d63b15a8ac908d6fa324c7f3d19141ca570 Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期六, 12 八月 2017 08:57:15 +0800
Subject: [PATCH] Temperature softmax and 10 shape PCA regression.

---
 code/test_resnet_bins.py |   16 ++++++----------
 1 files changed, 6 insertions(+), 10 deletions(-)

diff --git a/code/test_resnet_bins.py b/code/test_resnet_bins.py
index 699c9c9..4b1a655 100644
--- a/code/test_resnet_bins.py
+++ b/code/test_resnet_bins.py
@@ -103,18 +103,14 @@
         _, pitch_bpred = torch.max(pitch.data, 1)
         _, roll_bpred = torch.max(roll.data, 1)
 
-        yaw_predicted = F.softmax(yaw)
-        pitch_predicted = F.softmax(pitch)
-        roll_predicted = F.softmax(roll)
-
         # Continuous predictions
-        yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1)
-        pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1)
-        roll_predicted = torch.sum(roll_predicted.data * idx_tensor, 1)
+        yaw_predicted = utils.softmax_temperature(yaw.data, 1)
+        pitch_predicted = utils.softmax_temperature(pitch.data, 1)
+        roll_predicted = utils.softmax_temperature(roll.data, 1)
 
-        yaw_predicted = yaw_predicted.cpu()
-        pitch_predicted = pitch_predicted.cpu()
-        roll_predicted = roll_predicted.cpu()
+        yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu()
+        pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu()
+        roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu()
 
         # Mean absolute error
         yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw) * 3)

--
Gitblit v1.8.0