From f415df3448622f30c3d1eb680596871672b38dac Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期四, 26 十月 2017 02:51:37 +0800 Subject: [PATCH] after fg --- code/test.py | 8 +++++--- 1 files changed, 5 insertions(+), 3 deletions(-) diff --git a/code/test.py b/code/test.py index b69c3a9..4983105 100644 --- a/code/test.py +++ b/code/test.py @@ -72,6 +72,8 @@ pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFLW': pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations) + elif args.dataset == 'Pose_300W_LP': + pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFW': pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations) else: @@ -107,9 +109,9 @@ roll = angles[0][:,2].cpu().data * 3 - 99 for idx in xrange(1,args.iter_ref+1): - yaw += angles[idx][:,0].cpu().data - pitch += angles[idx][:,1].cpu().data - roll += angles[idx][:,2].cpu().data + yaw += angles[idx][:,0].cpu().data * 3 - 99 + pitch += angles[idx][:,1].cpu().data * 3 - 99 + roll += angles[idx][:,2].cpu().data * 3 - 99 # Mean absolute error yaw_error += torch.sum(torch.abs(yaw - label_yaw)) -- Gitblit v1.8.0