From 134330d8a9c41d498078a78b87d06086d964d273 Mon Sep 17 00:00:00 2001 From: chenshijun <csj_sky@126.com> Date: 星期三, 05 六月 2019 13:26:15 +0800 Subject: [PATCH] multiplayer display --- code/test_alexnet.py | 38 ++++++++++++++------------------------ 1 files changed, 14 insertions(+), 24 deletions(-) diff --git a/code/test_alexnet.py b/code/test_alexnet.py index 7a3989a..45ad25f 100644 --- a/code/test_alexnet.py +++ b/code/test_alexnet.py @@ -1,4 +1,9 @@ +import sys, os, argparse + import numpy as np +import cv2 +import matplotlib.pyplot as plt + import torch import torch.nn as nn from torch.autograd import Variable @@ -8,15 +13,7 @@ import torchvision import torch.nn.functional as F -import cv2 -import matplotlib.pyplot as plt -import sys -import os -import argparse - -import datasets -import hopenet -import utils +import datasets, hopenet, utils def parse_args(): """Parse input arguments.""" @@ -39,13 +36,6 @@ return args -def load_filtered_state_dict(model, snapshot): - # By user apaszke from discuss.pytorch.org - model_dict = model.state_dict() - snapshot = {k: v for k, v in snapshot.items() if k in model_dict} - model_dict.update(snapshot) - model.load_state_dict(model_dict) - if __name__ == '__main__': args = parse_args() @@ -55,12 +45,12 @@ model = hopenet.AlexNet(66) - print 'Loading snapshot.' + print('Loading snapshot.') # Load snapshot saved_state_dict = torch.load(snapshot_path) - load_filtered_state_dict(model, saved_state_dict) + model.load_state_dict(saved_state_dict) - print 'Loading data.' + print('Loading data.') transformations = transforms.Compose([transforms.Scale(224), transforms.CenterCrop(224), transforms.ToTensor(), @@ -83,7 +73,7 @@ elif args.dataset == 'AFW': pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations) else: - print 'Error: not a valid dataset name' + print('Error: not a valid dataset name') sys.exit() test_loader = torch.utils.data.DataLoader(dataset=pose_dataset, batch_size=args.batch_size, @@ -91,7 +81,7 @@ model.cuda(gpu) - print 'Ready to test network.' + print ('Ready to test network.') # Test the Model model.eval() # Change model to 'eval' mode (BN uses moving mean/var). @@ -134,8 +124,7 @@ pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch)) roll_error += torch.sum(torch.abs(roll_predicted - label_roll)) - # Save images with pose cube. - # TODO: fix for larger batch size + # Save first image in batch with pose cube or axis. if args.save_viz: name = name[0] if args.dataset == 'BIWI': @@ -145,7 +134,8 @@ if args.batch_size == 1: error_string = 'y %.2f, p %.2f, r %.2f' % (torch.sum(torch.abs(yaw_predicted - label_yaw)), torch.sum(torch.abs(pitch_predicted - label_pitch)), torch.sum(torch.abs(roll_predicted - label_roll))) cv2.putText(cv2_img, error_string, (30, cv2_img.shape[0]- 30), fontFace=1, fontScale=1, color=(0,0,255), thickness=1) - utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0]) + # utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], size=100) + utils.draw_axis(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], tdx = 200, tdy= 200, size=100) cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img) print('Test error in degrees of the model on the ' + str(total) + -- Gitblit v1.8.0