From 2c764e41e2fde6244b87da58d12c40d09a14fcb4 Mon Sep 17 00:00:00 2001
From: natanielruiz <nruiz9@gatech.edu>
Date: 星期一, 30 十月 2017 06:49:01 +0800
Subject: [PATCH] Next

---
 code/test_on_video_noconf.py |   29 +++++++++++++----------------
 1 files changed, 13 insertions(+), 16 deletions(-)

diff --git a/code/test_on_video_noconf.py b/code/test_on_video_noconf.py
index e040de7..b6a8d2c 100644
--- a/code/test_on_video_noconf.py
+++ b/code/test_on_video_noconf.py
@@ -1,4 +1,9 @@
+import sys, os, argparse
+
 import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+
 import torch
 import torch.nn as nn
 from torch.autograd import Variable
@@ -8,10 +13,6 @@
 import torchvision
 import torch.nn.functional as F
 from PIL import Image
-
-import cv2
-import matplotlib.pyplot as plt
-import sys, os, argparse
 
 import datasets, hopenet, utils
 
@@ -47,10 +48,8 @@
     if not os.path.exists(args.video_path):
         sys.exit('Video does not exist')
 
-    # ResNet101 with 3 outputs.
-    # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
-    # ResNet50
-    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66, 0)
+    # ResNet50 structure
+    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
 
     print 'Loading snapshot.'
     # Load snapshot
@@ -145,7 +144,7 @@
             y_min = max(y_min, 0)
             x_max = min(frame.shape[1], x_max)
             y_max = min(frame.shape[0], y_max)
-            # Crop image
+            # Crop face loosely
             img = frame[y_min:y_max,x_min:x_max]
             img = Image.fromarray(img)
 
@@ -154,15 +153,13 @@
             img_shape = img.size()
             img = img.view(1, img_shape[0], img_shape[1], img_shape[2])
             img = Variable(img).cuda(gpu)
+            
             yaw, pitch, roll, angles = model(img)
 
-            yaw_predicted = F.softmax(yaw)
-            pitch_predicted = F.softmax(pitch)
-            roll_predicted = F.softmax(roll)
-            # Get continuous predictions in degrees.
-            yaw_predicted = torch.sum(yaw_predicted.data[0] * idx_tensor) * 3 - 99
-            pitch_predicted = torch.sum(pitch_predicted.data[0] * idx_tensor) * 3 - 99
-            roll_predicted = torch.sum(roll_predicted.data[0] * idx_tensor) * 3 - 99
+            yaw_predicted = angles[:,0].data[0].cpu()
+            pitch_predicted = angles[:,1].data[0].cpu()
+            roll_predicted = angles[:,2].data[0].cpu()
+
             # Print new frame with cube and axis
             txt_out.write(str(frame_num) + ' %f %f %f\n' % (yaw_predicted, pitch_predicted, roll_predicted))
             # utils.plot_pose_cube(frame, yaw_predicted, pitch_predicted, roll_predicted, (x_min + x_max) / 2, (y_min + y_max) / 2, size = bbox_width)

--
Gitblit v1.8.0