natanielruiz
2017-08-07 6f71fb102f509d705d3abaa1f44638a19f57e92e
code/test_resnet_bins.py
@@ -47,8 +47,8 @@
    snapshot_path = os.path.join('output/snapshots', args.snapshot + '.pkl')
    # ResNet50 with 3 outputs.
    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
    # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
    # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
    model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
    print 'Loading snapshot.'
    # Load snapshot
@@ -87,7 +87,6 @@
    for i, (images, labels, name) in enumerate(test_loader):
        images = Variable(images).cuda(gpu)
        total += labels.size(0)
        label_yaw = labels[:,0]
        label_pitch = labels[:,1]
@@ -126,8 +125,7 @@
        if args.save_viz:
            name = name[0]
            cv2_img = cv2.imread(os.path.join(args.data_dir, name + '.jpg'))
            #cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
            #print name
            #print os.path.join('output/images', name + '.jpg')
            #print label_yaw[0] * 3 - 99, label_pitch[0] * 3 - 99, label_roll[0] * 3 - 99
            #print yaw_predicted * 3 - 99, pitch_predicted * 3 - 99, roll_predicted * 3 - 99
            utils.plot_pose_cube(cv2_img, yaw_predicted * 3 - 99, pitch_predicted * 3 - 99, roll_predicted * 3 - 99)