| | |
| | | pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations) |
| | | elif args.dataset == 'AFLW': |
| | | pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations) |
| | | elif args.dataset == 'Pose_300W_LP': |
| | | pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations) |
| | | elif args.dataset == 'AFW': |
| | | pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations) |
| | | else: |
| | |
| | | roll = angles[0][:,2].cpu().data * 3 - 99 |
| | | |
| | | for idx in xrange(1,args.iter_ref+1): |
| | | yaw += angles[idx][:,0].cpu().data |
| | | pitch += angles[idx][:,1].cpu().data |
| | | roll += angles[idx][:,2].cpu().data |
| | | yaw += angles[idx][:,0].cpu().data * 3 - 99 |
| | | pitch += angles[idx][:,1].cpu().data * 3 - 99 |
| | | roll += angles[idx][:,2].cpu().data * 3 - 99 |
| | | |
| | | # Mean absolute error |
| | | yaw_error += torch.sum(torch.abs(yaw - label_yaw)) |