natanielruiz
2017-07-07 6664c6d52fad58e396861946a3bed7d5afc4d44d
code/test_resnet_bins.py
@@ -13,7 +13,7 @@
import os
import argparse
from datasets import AFLW2000
import datasets
import hopenet
import utils
@@ -55,9 +55,10 @@
    print 'Loading data.'
    transformations = transforms.Compose([transforms.Scale(224),transforms.RandomCrop(224), transforms.ToTensor()])
    transformations = transforms.Compose([transforms.Scale(224),
    transforms.RandomCrop(224), transforms.ToTensor()])
    pose_dataset = AFLW2000(args.data_dir, args.filename_list,
    pose_dataset = datasets.AFLW2000_binned(args.data_dir, args.filename_list,
                                transformations)
    test_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
                                               batch_size=batch_size,
@@ -69,7 +70,9 @@
    # Test the Model
    model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
    error = .0
    yaw_correct = 0
    pitch_correct = 0
    roll_correct = 0
    total = 0
    for i, (images, labels, name) in enumerate(test_loader):
        images = Variable(images).cuda(gpu)
@@ -78,13 +81,14 @@
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        # TODO: There are more efficient ways.
        yaw_correct += (outputs[:][0] == labels[:][0])
        pitch_correct += (outputs[:][])
        for idx in xrange(len(outputs)):
            # if abs(outputs[idx].data[1] - labels[idx].data[1]) * 180 / np.pi > 30:
            print name
            print abs(outputs[idx].data - labels[idx].data) * 180 / np.pi, 180 * outputs[idx].data / np.pi, labels[idx].data * 180 / np.pi
            # error += utils.mse_loss(outputs[idx], labels[idx])
            error += abs(outputs[idx].data - labels[idx].data) * 180 / np.pi
            yaw_correct += (outputs[idx].data[0] == labels[idx].data[0])
            pitch_correct += (outputs[idx].data[1] == labels[idx].data[1])
            roll_correct += (outputs[idx].data[2] == labels[idx].data[2])
    print('Test MSE error of the model on the ' + str(total) +
    ' test images: %.4f' % (error / total))
    print('Test accuracies of the model on the ' + str(total) +
    ' test images. Yaw: %.4f %%, Pitch: %.4f %%, Roll: %.4f %%' % (yaw_correct / total,
    pitch_correct / total, roll_correct / total))