| | |
| | | default=1, type=int) |
| | | parser.add_argument('--save_viz', dest='save_viz', help='Save images with pose cube.', |
| | | default=False, type=bool) |
| | | parser.add_argument('--iter_ref', dest='iter_ref', default=1, type=int) |
| | | |
| | | args = parser.parse_args() |
| | | |
| | |
| | | # ResNet101 with 3 outputs. |
| | | # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66) |
| | | # ResNet50 |
| | | model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66) |
| | | model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66, args.iter_ref) |
| | | # ResNet18 |
| | | # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66) |
| | | |
| | |
| | | label_roll = labels[:,2].float() |
| | | |
| | | pre_yaw, pre_pitch, pre_roll, angles = model(images) |
| | | yaw = angles[0][:,0].cpu().data |
| | | pitch = angles[0][:,1].cpu().data |
| | | roll = angles[0][:,2].cpu().data |
| | | yaw = angles[args.iter_ref-1][:,0].cpu().data |
| | | pitch = angles[args.iter_ref-1][:,1].cpu().data |
| | | roll = angles[args.iter_ref-1][:,2].cpu().data |
| | | |
| | | # Mean absolute error |
| | | yaw_error += torch.sum(torch.abs(yaw - label_yaw) * 3) |