From 2eb13d63b15a8ac908d6fa324c7f3d19141ca570 Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期六, 12 八月 2017 08:57:15 +0800 Subject: [PATCH] Temperature softmax and 10 shape PCA regression. --- code/test_resnet_bins.py | 16 ++++++---------- 1 files changed, 6 insertions(+), 10 deletions(-) diff --git a/code/test_resnet_bins.py b/code/test_resnet_bins.py index 699c9c9..4b1a655 100644 --- a/code/test_resnet_bins.py +++ b/code/test_resnet_bins.py @@ -103,18 +103,14 @@ _, pitch_bpred = torch.max(pitch.data, 1) _, roll_bpred = torch.max(roll.data, 1) - yaw_predicted = F.softmax(yaw) - pitch_predicted = F.softmax(pitch) - roll_predicted = F.softmax(roll) - # Continuous predictions - yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1) - pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1) - roll_predicted = torch.sum(roll_predicted.data * idx_tensor, 1) + yaw_predicted = utils.softmax_temperature(yaw.data, 1) + pitch_predicted = utils.softmax_temperature(pitch.data, 1) + roll_predicted = utils.softmax_temperature(roll.data, 1) - yaw_predicted = yaw_predicted.cpu() - pitch_predicted = pitch_predicted.cpu() - roll_predicted = roll_predicted.cpu() + yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu() + pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu() + roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu() # Mean absolute error yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw) * 3) -- Gitblit v1.8.0