hyhmrright
2019-05-31 e65c915e5bdbcca56b37aa13bcff4911beffbe37
code/test_alexnet.py
@@ -1,4 +1,9 @@
import sys, os, argparse
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.autograd import Variable
@@ -8,15 +13,7 @@
import torchvision
import torch.nn.functional as F
import cv2
import matplotlib.pyplot as plt
import sys
import os
import argparse
import datasets
import hopenet
import utils
import datasets, hopenet, utils
def parse_args():
    """Parse input arguments."""
@@ -39,13 +36,6 @@
    return args
def load_filtered_state_dict(model, snapshot):
    # By user apaszke from discuss.pytorch.org
    model_dict = model.state_dict()
    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
    model_dict.update(snapshot)
    model.load_state_dict(model_dict)
if __name__ == '__main__':
    args = parse_args()
@@ -55,12 +45,12 @@
    model = hopenet.AlexNet(66)
    print 'Loading snapshot.'
    print('Loading snapshot.')
    # Load snapshot
    saved_state_dict = torch.load(snapshot_path)
    load_filtered_state_dict(model, saved_state_dict)
    model.load_state_dict(saved_state_dict)
    print 'Loading data.'
    print('Loading data.')
    transformations = transforms.Compose([transforms.Scale(224),
    transforms.CenterCrop(224), transforms.ToTensor(),
@@ -83,7 +73,7 @@
    elif args.dataset == 'AFW':
        pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
    else:
        print 'Error: not a valid dataset name'
        print('Error: not a valid dataset name')
        sys.exit()
    test_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
                                               batch_size=args.batch_size,
@@ -91,7 +81,7 @@
    model.cuda(gpu)
    print 'Ready to test network.'
    print ('Ready to test network.')
    # Test the Model
    model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
@@ -134,8 +124,7 @@
        pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch))
        roll_error += torch.sum(torch.abs(roll_predicted - label_roll))
        # Save images with pose cube.
        # TODO: fix for larger batch size
        # Save first image in batch with pose cube or axis.
        if args.save_viz:
            name = name[0]
            if args.dataset == 'BIWI':
@@ -145,7 +134,8 @@
            if args.batch_size == 1:
                error_string = 'y %.2f, p %.2f, r %.2f' % (torch.sum(torch.abs(yaw_predicted - label_yaw)), torch.sum(torch.abs(pitch_predicted - label_pitch)), torch.sum(torch.abs(roll_predicted - label_roll)))
                cv2.putText(cv2_img, error_string, (30, cv2_img.shape[0]- 30), fontFace=1, fontScale=1, color=(0,0,255), thickness=1)
            utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0])
            # utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], size=100)
            utils.draw_axis(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], tdx = 200, tdy= 200, size=100)
            cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
    print('Test error in degrees of the model on the ' + str(total) +