From 6664c6d52fad58e396861946a3bed7d5afc4d44d Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期五, 07 七月 2017 10:53:52 +0800 Subject: [PATCH] Training for hopenet works. --- code/test_resnet_bins.py | 26 +++++++++++++++----------- 1 files changed, 15 insertions(+), 11 deletions(-) diff --git a/code/test_resnet_bins.py b/code/test_resnet_bins.py index 34fc8f5..0a093ee 100644 --- a/code/test_resnet_bins.py +++ b/code/test_resnet_bins.py @@ -13,7 +13,7 @@ import os import argparse -from datasets import AFLW2000 +import datasets import hopenet import utils @@ -55,9 +55,10 @@ print 'Loading data.' - transformations = transforms.Compose([transforms.Scale(224),transforms.RandomCrop(224), transforms.ToTensor()]) + transformations = transforms.Compose([transforms.Scale(224), + transforms.RandomCrop(224), transforms.ToTensor()]) - pose_dataset = AFLW2000(args.data_dir, args.filename_list, + pose_dataset = datasets.AFLW2000_binned(args.data_dir, args.filename_list, transformations) test_loader = torch.utils.data.DataLoader(dataset=pose_dataset, batch_size=batch_size, @@ -69,7 +70,9 @@ # Test the Model model.eval() # Change model to 'eval' mode (BN uses moving mean/var). - error = .0 + yaw_correct = 0 + pitch_correct = 0 + roll_correct = 0 total = 0 for i, (images, labels, name) in enumerate(test_loader): images = Variable(images).cuda(gpu) @@ -78,13 +81,14 @@ _, predicted = torch.max(outputs.data, 1) total += labels.size(0) # TODO: There are more efficient ways. + yaw_correct += (outputs[:][0] == labels[:][0]) + pitch_correct += (outputs[:][]) for idx in xrange(len(outputs)): - # if abs(outputs[idx].data[1] - labels[idx].data[1]) * 180 / np.pi > 30: - print name - print abs(outputs[idx].data - labels[idx].data) * 180 / np.pi, 180 * outputs[idx].data / np.pi, labels[idx].data * 180 / np.pi - # error += utils.mse_loss(outputs[idx], labels[idx]) - error += abs(outputs[idx].data - labels[idx].data) * 180 / np.pi + yaw_correct += (outputs[idx].data[0] == labels[idx].data[0]) + pitch_correct += (outputs[idx].data[1] == labels[idx].data[1]) + roll_correct += (outputs[idx].data[2] == labels[idx].data[2]) - print('Test MSE error of the model on the ' + str(total) + - ' test images: %.4f' % (error / total)) + print('Test accuracies of the model on the ' + str(total) + + ' test images. Yaw: %.4f %%, Pitch: %.4f %%, Roll: %.4f %%' % (yaw_correct / total, + pitch_correct / total, roll_correct / total)) -- Gitblit v1.8.0