From 43416c4717d2430c3e11f042294d12b781fee2e1 Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期三, 27 九月 2017 04:09:30 +0800 Subject: [PATCH] Failed lstm experiment --- code/test.py | 8 +++++--- 1 files changed, 5 insertions(+), 3 deletions(-) diff --git a/code/test.py b/code/test.py index b69c3a9..4983105 100644 --- a/code/test.py +++ b/code/test.py @@ -72,6 +72,8 @@ pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFLW': pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations) + elif args.dataset == 'Pose_300W_LP': + pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFW': pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations) else: @@ -107,9 +109,9 @@ roll = angles[0][:,2].cpu().data * 3 - 99 for idx in xrange(1,args.iter_ref+1): - yaw += angles[idx][:,0].cpu().data - pitch += angles[idx][:,1].cpu().data - roll += angles[idx][:,2].cpu().data + yaw += angles[idx][:,0].cpu().data * 3 - 99 + pitch += angles[idx][:,1].cpu().data * 3 - 99 + roll += angles[idx][:,2].cpu().data * 3 - 99 # Mean absolute error yaw_error += torch.sum(torch.abs(yaw - label_yaw)) -- Gitblit v1.8.0