| | |
| | | import os |
| | | import argparse |
| | | |
| | | from datasets import AFLW2000 |
| | | import datasets |
| | | import hopenet |
| | | import utils |
| | | |
| | |
| | | |
| | | print 'Loading data.' |
| | | |
| | | transformations = transforms.Compose([transforms.Scale(224),transforms.RandomCrop(224), transforms.ToTensor()]) |
| | | transformations = transforms.Compose([transforms.Scale(224), |
| | | transforms.RandomCrop(224), transforms.ToTensor()]) |
| | | |
| | | pose_dataset = AFLW2000(args.data_dir, args.filename_list, |
| | | pose_dataset = datasets.AFLW2000_binned(args.data_dir, args.filename_list, |
| | | transformations) |
| | | test_loader = torch.utils.data.DataLoader(dataset=pose_dataset, |
| | | batch_size=batch_size, |
| | |
| | | |
| | | # Test the Model |
| | | model.eval() # Change model to 'eval' mode (BN uses moving mean/var). |
| | | error = .0 |
| | | yaw_correct = 0 |
| | | pitch_correct = 0 |
| | | roll_correct = 0 |
| | | total = 0 |
| | | for i, (images, labels, name) in enumerate(test_loader): |
| | | images = Variable(images).cuda(gpu) |
| | |
| | | _, predicted = torch.max(outputs.data, 1) |
| | | total += labels.size(0) |
| | | # TODO: There are more efficient ways. |
| | | yaw_correct += (outputs[:][0] == labels[:][0]) |
| | | pitch_correct += (outputs[:][]) |
| | | for idx in xrange(len(outputs)): |
| | | # if abs(outputs[idx].data[1] - labels[idx].data[1]) * 180 / np.pi > 30: |
| | | print name |
| | | print abs(outputs[idx].data - labels[idx].data) * 180 / np.pi, 180 * outputs[idx].data / np.pi, labels[idx].data * 180 / np.pi |
| | | # error += utils.mse_loss(outputs[idx], labels[idx]) |
| | | error += abs(outputs[idx].data - labels[idx].data) * 180 / np.pi |
| | | yaw_correct += (outputs[idx].data[0] == labels[idx].data[0]) |
| | | pitch_correct += (outputs[idx].data[1] == labels[idx].data[1]) |
| | | roll_correct += (outputs[idx].data[2] == labels[idx].data[2]) |
| | | |
| | | |
| | | print('Test MSE error of the model on the ' + str(total) + |
| | | ' test images: %.4f' % (error / total)) |
| | | print('Test accuracies of the model on the ' + str(total) + |
| | | ' test images. Yaw: %.4f %%, Pitch: %.4f %%, Roll: %.4f %%' % (yaw_correct / total, |
| | | pitch_correct / total, roll_correct / total)) |