natanielruiz
2017-08-12 2eb13d63b15a8ac908d6fa324c7f3d19141ca570
code/test_resnet_bins.py
@@ -103,18 +103,14 @@
        _, pitch_bpred = torch.max(pitch.data, 1)
        _, roll_bpred = torch.max(roll.data, 1)
        yaw_predicted = F.softmax(yaw)
        pitch_predicted = F.softmax(pitch)
        roll_predicted = F.softmax(roll)
        # Continuous predictions
        yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1)
        pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1)
        roll_predicted = torch.sum(roll_predicted.data * idx_tensor, 1)
        yaw_predicted = utils.softmax_temperature(yaw.data, 1)
        pitch_predicted = utils.softmax_temperature(pitch.data, 1)
        roll_predicted = utils.softmax_temperature(roll.data, 1)
        yaw_predicted = yaw_predicted.cpu()
        pitch_predicted = pitch_predicted.cpu()
        roll_predicted = roll_predicted.cpu()
        yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu()
        pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu()
        roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu()
        # Mean absolute error
        yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw) * 3)