natanielruiz
2017-09-27 43416c4717d2430c3e11f042294d12b781fee2e1
code/test.py
@@ -72,6 +72,8 @@
        pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
    elif args.dataset == 'AFLW':
        pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
    elif args.dataset == 'Pose_300W_LP':
        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
    elif args.dataset == 'AFW':
        pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
    else:
@@ -107,9 +109,9 @@
        roll = angles[0][:,2].cpu().data * 3 - 99
        for idx in xrange(1,args.iter_ref+1):
            yaw += angles[idx][:,0].cpu().data
            pitch += angles[idx][:,1].cpu().data
            roll += angles[idx][:,2].cpu().data
            yaw += angles[idx][:,0].cpu().data * 3 - 99
            pitch += angles[idx][:,1].cpu().data * 3 - 99
            roll += angles[idx][:,2].cpu().data * 3 - 99
        # Mean absolute error
        yaw_error += torch.sum(torch.abs(yaw - label_yaw))