From e624d2ace8296e130a4fa4d2d307041798c538e0 Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期二, 19 九月 2017 12:22:56 +0800 Subject: [PATCH] next --- code/test_AFW_preangles.py | 18 +++++------------- 1 files changed, 5 insertions(+), 13 deletions(-) diff --git a/code/test_AFW_preangles.py b/code/test_AFW_preangles.py index daa1f24..3dae2e8 100644 --- a/code/test_AFW_preangles.py +++ b/code/test_AFW_preangles.py @@ -94,12 +94,12 @@ yaw_correct = .0 yaw_margin = args.margin - for i, (images, labels, name) in enumerate(test_loader): + for i, (images, labels, cont_labels, name) in enumerate(test_loader): images = Variable(images).cuda(gpu) total += labels.size(0) - label_yaw = labels[:,0].float() * 3 - 99 - label_pitch = labels[:,1].float() * 3 - 99 - label_roll = labels[:,2].float() * 3 - 99 + label_yaw = cont_labels[:,0] + label_pitch = cont_labels[:,1].float() + label_roll = cont_labels[:,2].float() yaw, pitch, roll, angles = model(images) @@ -109,7 +109,7 @@ _, roll_bpred = torch.max(roll.data, 1) # Continuous predictions - yaw_predicted = utils.softmax_temperature(yaw.data, 0.4) + yaw_predicted = utils.softmax_temperature(yaw.data, 0.85) pitch_predicted = utils.softmax_temperature(pitch.data, 0.8) roll_predicted = utils.softmax_temperature(roll.data, 0.8) @@ -129,14 +129,6 @@ if yaw_tensor_error[0] > yaw_margin: print name[0] + ' ' + str(yaw_predicted[0]) + ' ' + str(label_yaw[0]) + ' ' + str(yaw_tensor_error[0]) - - # Binned Accuracy - # for er in xrange(n_margins): - # yaw_bpred[er] += (label_yaw[0] in range(yaw_bpred[0,0] - er, yaw_bpred[0,0] + er + 1)) - # pitch_bpred[er] += (label_pitch[0] in range(pitch_bpred[0,0] - er, pitch_bpred[0,0] + er + 1)) - # roll_bpred[er] += (label_roll[0] in range(roll_bpred[0,0] - er, roll_bpred[0,0] + er + 1)) - - # print label_yaw[0], yaw_bpred[0,0] # Save images with pose cube. # TODO: fix for larger batch size -- Gitblit v1.8.0