Nataniel Ruiz
2017-12-01 ea0e6821e44dca377ba790dbe3eeede1703014ab
code/test_alexnet.py
@@ -1,4 +1,9 @@
import sys, os, argparse
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.autograd import Variable
@@ -8,15 +13,7 @@
import torchvision
import torch.nn.functional as F
import cv2
import matplotlib.pyplot as plt
import sys
import os
import argparse
import datasets
import hopenet
import utils
import datasets, hopenet, utils
def parse_args():
    """Parse input arguments."""
@@ -39,13 +36,6 @@
    return args
def load_filtered_state_dict(model, snapshot):
    # By user apaszke from discuss.pytorch.org
    model_dict = model.state_dict()
    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
    model_dict.update(snapshot)
    model.load_state_dict(model_dict)
if __name__ == '__main__':
    args = parse_args()
@@ -58,7 +48,7 @@
    print 'Loading snapshot.'
    # Load snapshot
    saved_state_dict = torch.load(snapshot_path)
    load_filtered_state_dict(model, saved_state_dict)
    model.load_state_dict(saved_state_dict)
    print 'Loading data.'
@@ -134,8 +124,7 @@
        pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch))
        roll_error += torch.sum(torch.abs(roll_predicted - label_roll))
        # Save images with pose cube.
        # TODO: fix for larger batch size
        # Save first image in batch with pose cube or axis.
        if args.save_viz:
            name = name[0]
            if args.dataset == 'BIWI':
@@ -145,7 +134,8 @@
            if args.batch_size == 1:
                error_string = 'y %.2f, p %.2f, r %.2f' % (torch.sum(torch.abs(yaw_predicted - label_yaw)), torch.sum(torch.abs(pitch_predicted - label_pitch)), torch.sum(torch.abs(roll_predicted - label_roll)))
                cv2.putText(cv2_img, error_string, (30, cv2_img.shape[0]- 30), fontFace=1, fontScale=1, color=(0,0,255), thickness=1)
            utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0])
            # utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], size=100)
            utils.draw_axis(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], tdx = 200, tdy= 200, size=100)
            cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
    print('Test error in degrees of the model on the ' + str(total) +