hyhmrright
2019-05-31 e65c915e5bdbcca56b37aa13bcff4911beffbe37
code/test_on_video_dockerface.py
@@ -51,12 +51,12 @@
    # ResNet50 structure
    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
    print 'Loading snapshot.'
    print('Loading snapshot.')
    # Load snapshot
    saved_state_dict = torch.load(snapshot_path)
    model.load_state_dict(saved_state_dict)
    print 'Loading data.'
    print('Loading data.')
    transformations = transforms.Compose([transforms.Scale(224),
    transforms.CenterCrop(224), transforms.ToTensor(),
@@ -64,7 +64,7 @@
    model.cuda(gpu)
    print 'Ready to test network.'
    print('Ready to test network.')
    # Test the Model
    model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
@@ -105,7 +105,7 @@
        line = line.split(' ')
        det_frame_num = int(line[0])
        print frame_num
        print(frame_num)
        # Stop at a certain frame number
        if frame_num > args.n_frames:
@@ -166,7 +166,7 @@
                pitch_predicted = torch.sum(pitch_predicted.data[0] * idx_tensor) * 3 - 99
                roll_predicted = torch.sum(roll_predicted.data[0] * idx_tensor) * 3 - 99
                # Print new frame with cube and axis
                # print(new frame with cube and axis
                txt_out.write(str(frame_num) + ' %f %f %f\n' % (yaw_predicted, pitch_predicted, roll_predicted))
                # utils.plot_pose_cube(frame, yaw_predicted, pitch_predicted, roll_predicted, (x_min + x_max) / 2, (y_min + y_max) / 2, size = bbox_width)
                utils.draw_axis(frame, yaw_predicted, pitch_predicted, roll_predicted, tdx = (x_min + x_max) / 2, tdy= (y_min + y_max) / 2, size = bbox_height/2)
@@ -175,7 +175,7 @@
            # Peek next frame detection
            next_frame_num = int(bbox_line_list[idx+1].strip('\n').split(' ')[0])
            # print 'next_frame_num ', next_frame_num
            # print('next_frame_num ', next_frame_num
            if next_frame_num == det_frame_num:
                idx += 1
                line = bbox_line_list[idx].strip('\n').split(' ')