From 4b67b5c8ed5566ec3030d537536282e830d87e40 Mon Sep 17 00:00:00 2001
From: natanielruiz <nruiz9@gatech.edu>
Date: 星期一, 30 十月 2017 07:15:49 +0800
Subject: [PATCH] next

---
 code/test_alexnet.py             |    9 ----
 code/test_resnet50_regression.py |    9 ----
 code/hopenet.py                  |   18 +--------
 code/test_on_video.py            |   15 +++++--
 code/test_on_video_noconf.py     |   14 ++++--
 code/train_preangles.py          |   14 +++++-
 code/test_preangles.py           |   12 -----
 7 files changed, 35 insertions(+), 56 deletions(-)

diff --git a/code/hopenet.py b/code/hopenet.py
index 0a98a66..c9e0b74 100644
--- a/code/hopenet.py
+++ b/code/hopenet.py
@@ -24,11 +24,8 @@
         self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)
         self.fc_roll = nn.Linear(512 * block.expansion, num_bins)
 
+        # Vestigial layer from previous experiments
         self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
-
-        # Used to get the expected value of angle from bins
-        self.softmax = nn.Softmax()
-        self.idx_tensor = Variable(torch.FloatTensor(range(66))).cuda()
 
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
@@ -72,18 +69,7 @@
         pre_pitch = self.fc_pitch(x)
         pre_roll = self.fc_roll(x)
 
-        yaw = self.softmax(pre_yaw)
-        yaw = Variable(torch.sum(yaw.data * self.idx_tensor.data, 1), requires_grad=True)
-        pitch = self.softmax(pre_pitch)
-        pitch = Variable(torch.sum(pitch.data * self.idx_tensor.data, 1), requires_grad=True)
-        roll = self.softmax(pre_roll)
-        roll = Variable(torch.sum(roll.data * self.idx_tensor.data, 1), requires_grad=True)
-        yaw = yaw.view(yaw.size(0), 1)
-        pitch = pitch.view(pitch.size(0), 1)
-        roll = roll.view(roll.size(0), 1)
-        preangles = torch.cat([yaw, pitch, roll], 1)
-
-        return pre_yaw, pre_pitch, pre_roll, preangles
+        return pre_yaw, pre_pitch, pre_roll
 
 class ResNet(nn.Module):
     # ResNet for regression of 3 Euler angles.
diff --git a/code/test_alexnet.py b/code/test_alexnet.py
index 529d566..81a9148 100644
--- a/code/test_alexnet.py
+++ b/code/test_alexnet.py
@@ -36,13 +36,6 @@
 
     return args
 
-def load_filtered_state_dict(model, snapshot):
-    # By user apaszke from discuss.pytorch.org
-    model_dict = model.state_dict()
-    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
-    model_dict.update(snapshot)
-    model.load_state_dict(model_dict)
-
 if __name__ == '__main__':
     args = parse_args()
 
@@ -55,7 +48,7 @@
     print 'Loading snapshot.'
     # Load snapshot
     saved_state_dict = torch.load(snapshot_path)
-    load_filtered_state_dict(model, saved_state_dict)
+    model.load_state_dict(saved_state_dict)
 
     print 'Loading data.'
 
diff --git a/code/test_on_video.py b/code/test_on_video.py
index c4172da..bbafbd8 100644
--- a/code/test_on_video.py
+++ b/code/test_on_video.py
@@ -47,7 +47,7 @@
     if not os.path.exists(args.video_path):
         sys.exit('Video does not exist')
 
-    # ResNet50
+    # ResNet50 structure
     model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
 
     print 'Loading snapshot.'
@@ -154,11 +154,16 @@
                 img = img.view(1, img_shape[0], img_shape[1], img_shape[2])
                 img = Variable(img).cuda(gpu)
 
-                yaw, pitch, roll, angles = model(img)
+                yaw, pitch, roll = model(img)
 
-                yaw_predicted = angles[:,0].data[0].cpu()
-                pitch_predicted = angles[:,1].data[0].cpu()
-                roll_predicted = angles[:,2].data[0].cpu()
+                yaw_predicted = F.softmax(yaw)
+                pitch_predicted = F.softmax(pitch)
+                roll_predicted = F.softmax(roll)
+                # Get continuous predictions in degrees.
+                yaw_predicted = torch.sum(yaw_predicted.data[0] * idx_tensor) * 3 - 99
+                pitch_predicted = torch.sum(pitch_predicted.data[0] * idx_tensor) * 3 - 99
+                roll_predicted = torch.sum(roll_predicted.data[0] * idx_tensor) * 3 - 99
+
                 # Print new frame with cube and axis
                 txt_out.write(str(frame_num) + ' %f %f %f\n' % (yaw_predicted, pitch_predicted, roll_predicted))
                 # utils.plot_pose_cube(frame, yaw_predicted, pitch_predicted, roll_predicted, (x_min + x_max) / 2, (y_min + y_max) / 2, size = bbox_width)
diff --git a/code/test_on_video_noconf.py b/code/test_on_video_noconf.py
index b6a8d2c..a76a396 100644
--- a/code/test_on_video_noconf.py
+++ b/code/test_on_video_noconf.py
@@ -153,12 +153,16 @@
             img_shape = img.size()
             img = img.view(1, img_shape[0], img_shape[1], img_shape[2])
             img = Variable(img).cuda(gpu)
-            
-            yaw, pitch, roll, angles = model(img)
 
-            yaw_predicted = angles[:,0].data[0].cpu()
-            pitch_predicted = angles[:,1].data[0].cpu()
-            roll_predicted = angles[:,2].data[0].cpu()
+            yaw, pitch, roll = model(img)
+
+            yaw_predicted = F.softmax(yaw)
+            pitch_predicted = F.softmax(pitch)
+            roll_predicted = F.softmax(roll)
+            # Get continuous predictions in degrees.
+            yaw_predicted = torch.sum(yaw_predicted.data[0] * idx_tensor) * 3 - 99
+            pitch_predicted = torch.sum(pitch_predicted.data[0] * idx_tensor) * 3 - 99
+            roll_predicted = torch.sum(roll_predicted.data[0] * idx_tensor) * 3 - 99
 
             # Print new frame with cube and axis
             txt_out.write(str(frame_num) + ' %f %f %f\n' % (yaw_predicted, pitch_predicted, roll_predicted))
diff --git a/code/test_preangles.py b/code/test_preangles.py
index 9cdc8e3..d4a9f5f 100644
--- a/code/test_preangles.py
+++ b/code/test_preangles.py
@@ -36,13 +36,6 @@
 
     return args
 
-def load_filtered_state_dict(model, snapshot):
-    # By user apaszke from discuss.pytorch.org
-    model_dict = model.state_dict()
-    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
-    model_dict.update(snapshot)
-    model.load_state_dict(model_dict)
-
 if __name__ == '__main__':
     args = parse_args()
 
@@ -57,7 +50,6 @@
     # Load snapshot
     saved_state_dict = torch.load(snapshot_path)
     model.load_state_dict(saved_state_dict)
-    # load_filtered_state_dict(model, saved_state_dict)
 
     print 'Loading data.'
 
@@ -105,8 +97,6 @@
 
     l1loss = torch.nn.L1Loss(size_average=False)
 
-
-
     for i, (images, labels, cont_labels, name) in enumerate(test_loader):
         images = Variable(images).cuda(gpu)
         total += cont_labels.size(0)
@@ -115,7 +105,7 @@
         label_pitch = cont_labels[:,1].float()
         label_roll = cont_labels[:,2].float()
 
-        yaw, pitch, roll, angles = model(images)
+        yaw, pitch, roll = model(images)
 
         # Binned predictions
         _, yaw_bpred = torch.max(yaw.data, 1)
diff --git a/code/test_resnet50_regression.py b/code/test_resnet50_regression.py
index 6945269..67c63af 100644
--- a/code/test_resnet50_regression.py
+++ b/code/test_resnet50_regression.py
@@ -36,13 +36,6 @@
 
     return args
 
-def load_filtered_state_dict(model, snapshot):
-    # By user apaszke from discuss.pytorch.org
-    model_dict = model.state_dict()
-    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
-    model_dict.update(snapshot)
-    model.load_state_dict(model_dict)
-
 if __name__ == '__main__':
     args = parse_args()
 
@@ -55,7 +48,7 @@
     print 'Loading snapshot.'
     # Load snapshot
     saved_state_dict = torch.load(snapshot_path)
-    load_filtered_state_dict(model, saved_state_dict)
+    model.load_state_dict(saved_state_dict)
 
     print 'Loading data.'
 
diff --git a/code/train_preangles.py b/code/train_preangles.py
index 1fe626c..600a9ae 100644
--- a/code/train_preangles.py
+++ b/code/train_preangles.py
@@ -124,6 +124,10 @@
     # Regression loss coefficient
     alpha = args.alpha
 
+    softmax = nn.Softmax().cuda(gpu)
+    idx_tensor = [idx for idx in xrange(66)]
+    idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
+
     optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0},
                                   {'params': get_non_ignored_params(model), 'lr': args.lr},
                                   {'params': get_fc_params(model), 'lr': args.lr * 5}],
@@ -153,9 +157,13 @@
             loss_roll = criterion(roll, label_roll)
 
             # MSE loss
-            yaw_predicted = angles[:,0]
-            pitch_predicted = angles[:,1]
-            roll_predicted = angles[:,2]
+            yaw_predicted = softmax(yaw)
+            pitch_predicted = softmax(pitch)
+            roll_predicted = softmax(roll)
+
+            yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 99
+            pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 99
+            roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) * 3 - 99
 
             loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw_cont)
             loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch_cont)

--
Gitblit v1.8.0