natanielruiz
2017-07-07 8736b9753604a2e88843ba87c7e0e688dce072e6
Next
1个文件已修改
3 ■■■■■ 已修改文件
code/test_resnet_bins.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/test_resnet_bins.py
@@ -46,6 +46,7 @@
    # ResNet50 with 3 outputs.
    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
    # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
    print 'Loading snapshot.'
    # Load snapshot
@@ -107,6 +108,8 @@
        pitch_error += abs(pitch_predicted - label_pitch[0]) * 3
        roll_error += abs(roll_predicted - label_roll[0]) * 3
        # print yaw_predicted * 3, label_yaw[0] * 3, abs(yaw_predicted - label_yaw[0]) * 3
        # for er in xrange(0,n_margins):
        #     yaw_correct[er] += (label_yaw[0] in range(yaw_predicted[0,0] - er, yaw_predicted[0,0] + er + 1))
        #     pitch_correct[er] += (label_pitch[0] in range(pitch_predicted[0,0] - er, pitch_predicted[0,0] + er + 1))