From 93a4f337f2fd0280634024d2ff15790831813bed Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期五, 07 七月 2017 14:33:47 +0800 Subject: [PATCH] Resnet50, and changed test error --- code/test_resnet_bins.py | 67 +++++++++++++++++++++++---------- 1 files changed, 47 insertions(+), 20 deletions(-) diff --git a/code/test_resnet_bins.py b/code/test_resnet_bins.py index 0a093ee..f5be4f8 100644 --- a/code/test_resnet_bins.py +++ b/code/test_resnet_bins.py @@ -6,6 +6,7 @@ from torchvision import transforms import torch.backends.cudnn as cudnn import torchvision +import torch.nn.functional as F import cv2 import matplotlib.pyplot as plt @@ -43,10 +44,8 @@ gpu = args.gpu_id snapshot_path = os.path.join('output/snapshots', args.snapshot + '.pkl') - model = torchvision.models.resnet18() - # Parameters of newly constructed modules have requires_grad=True by default - num_ftrs = model.fc.in_features - model.fc = nn.Linear(num_ftrs, 3) + # ResNet50 with 3 outputs. + model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66) print 'Loading snapshot.' # Load snapshot @@ -70,25 +69,53 @@ # Test the Model model.eval() # Change model to 'eval' mode (BN uses moving mean/var). - yaw_correct = 0 - pitch_correct = 0 - roll_correct = 0 total = 0 + n_margins = 20 + yaw_correct = np.zeros(n_margins) + pitch_correct = np.zeros(n_margins) + roll_correct = np.zeros(n_margins) + + idx_tensor = [idx for idx in xrange(66)] + idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu) + + yaw_error = .0 + pitch_error = .0 + roll_error = .0 + for i, (images, labels, name) in enumerate(test_loader): images = Variable(images).cuda(gpu) - labels = Variable(labels).cuda(gpu) - outputs = model(images) - _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) - # TODO: There are more efficient ways. - yaw_correct += (outputs[:][0] == labels[:][0]) - pitch_correct += (outputs[:][]) - for idx in xrange(len(outputs)): - yaw_correct += (outputs[idx].data[0] == labels[idx].data[0]) - pitch_correct += (outputs[idx].data[1] == labels[idx].data[1]) - roll_correct += (outputs[idx].data[2] == labels[idx].data[2]) + label_yaw = labels[:,0] + label_pitch = labels[:,1] + label_roll = labels[:,2] + yaw, pitch, roll = model(images) + # _, yaw_predicted = torch.max(yaw.data, 1) + # _, pitch_predicted = torch.max(pitch.data, 1) + # _, roll_predicted = torch.max(roll.data, 1) - print('Test accuracies of the model on the ' + str(total) + - ' test images. Yaw: %.4f %%, Pitch: %.4f %%, Roll: %.4f %%' % (yaw_correct / total, - pitch_correct / total, roll_correct / total)) + yaw_predicted = F.softmax(yaw) + pitch_predicted = F.softmax(pitch) + roll_predicted = F.softmax(roll) + + yaw_predicted = torch.sum(yaw_predicted.data[0] * idx_tensor) + pitch_predicted = torch.sum(pitch_predicted.data[0] * idx_tensor) + roll_predicted = torch.sum(roll_predicted.data[0] * idx_tensor) + + yaw_error += abs(yaw_predicted - label_yaw[0]) * 3 + pitch_error += abs(pitch_predicted - label_pitch[0]) * 3 + roll_error += abs(roll_predicted - label_roll[0]) * 3 + + # for er in xrange(0,n_margins): + # yaw_correct[er] += (label_yaw[0] in range(yaw_predicted[0,0] - er, yaw_predicted[0,0] + er + 1)) + # pitch_correct[er] += (label_pitch[0] in range(pitch_predicted[0,0] - er, pitch_predicted[0,0] + er + 1)) + # roll_correct[er] += (label_roll[0] in range(roll_predicted[0,0] - er, roll_predicted[0,0] + er + 1)) + + # print label_yaw[0], yaw_predicted[0,0] + # 4 -> 15 + print('Test error in degrees of the model on the ' + str(total) + + ' test images. Yaw: %.4f, Pitch: %.4f, Roll: %.4f' % (yaw_error / total, + pitch_error / total, roll_error / total)) + # for idx in xrange(len(yaw_correct)): + # print yaw_correct[idx] / total, pitch_correct[idx] / total, roll_correct[idx] / total -- Gitblit v1.8.0