From af51d0ecb51ad4d6c8ed086855bd3c411ebc4aa0 Mon Sep 17 00:00:00 2001
From: natanielruiz <nruiz9@gatech.edu>
Date: 星期一, 30 十月 2017 06:29:51 +0800
Subject: [PATCH] Fixed stuff

---
 code/test_alexnet.py             |   25 ++++++++----
 code/test_resnet50_regression.py |   46 +++++++++++++----------
 code/train_alexnet.py            |   22 +++++++----
 code/test_preangles.py           |   11 +++++
 4 files changed, 67 insertions(+), 37 deletions(-)

diff --git a/code/test_alexnet.py b/code/test_alexnet.py
index d9cc0a3..7a3989a 100644
--- a/code/test_alexnet.py
+++ b/code/test_alexnet.py
@@ -39,6 +39,13 @@
 
     return args
 
+def load_filtered_state_dict(model, snapshot):
+    # By user apaszke from discuss.pytorch.org
+    model_dict = model.state_dict()
+    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
+    model_dict.update(snapshot)
+    model.load_state_dict(model_dict)
+
 if __name__ == '__main__':
     args = parse_args()
 
@@ -51,7 +58,7 @@
     print 'Loading snapshot.'
     # Load snapshot
     saved_state_dict = torch.load(snapshot_path)
-    model.load_state_dict(saved_state_dict)
+    load_filtered_state_dict(model, saved_state_dict)
 
     print 'Loading data.'
 
@@ -59,18 +66,20 @@
     transforms.CenterCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
-    if args.dataset == 'AFLW2000':
-        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
-                                transformations)
+    if args.dataset == 'Pose_300W_LP':
+        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'Pose_300W_LP_random_ds':
+        pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW2000':
+        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW2000_ds':
-        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list,
-                                transformations)
+        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'BIWI':
         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW':
         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
-    elif args.dataset == 'Pose_300W_LP':
-        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW_aug':
+        pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFW':
         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
     else:
diff --git a/code/test_preangles.py b/code/test_preangles.py
index be9bfda..05f621a 100644
--- a/code/test_preangles.py
+++ b/code/test_preangles.py
@@ -36,6 +36,13 @@
 
     return args
 
+def load_filtered_state_dict(model, snapshot):
+    # By user apaszke from discuss.pytorch.org
+    model_dict = model.state_dict()
+    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
+    model_dict.update(snapshot)
+    model.load_state_dict(model_dict)
+
 if __name__ == '__main__':
     args = parse_args()
 
@@ -49,7 +56,7 @@
     print 'Loading snapshot.'
     # Load snapshot
     saved_state_dict = torch.load(snapshot_path)
-    model.load_state_dict(saved_state_dict)
+    load_filtered_state_dict(model, saved_state_dict)
 
     print 'Loading data.'
 
@@ -63,6 +70,8 @@
         pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW2000':
         pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW2000_ds':
+        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'BIWI':
         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW':
diff --git a/code/test_resnet50_regression.py b/code/test_resnet50_regression.py
index 85207f8..6945269 100644
--- a/code/test_resnet50_regression.py
+++ b/code/test_resnet50_regression.py
@@ -1,4 +1,9 @@
+import sys, os, argparse
+
 import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+
 import torch
 import torch.nn as nn
 from torch.autograd import Variable
@@ -8,15 +13,7 @@
 import torchvision
 import torch.nn.functional as F
 
-import cv2
-import matplotlib.pyplot as plt
-import sys
-import os
-import argparse
-
-import datasets
-import hopenet
-import utils
+import datasets, hopenet, utils
 
 def parse_args():
     """Parse input arguments."""
@@ -39,6 +36,13 @@
 
     return args
 
+def load_filtered_state_dict(model, snapshot):
+    # By user apaszke from discuss.pytorch.org
+    model_dict = model.state_dict()
+    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
+    model_dict.update(snapshot)
+    model.load_state_dict(model_dict)
+
 if __name__ == '__main__':
     args = parse_args()
 
@@ -51,7 +55,7 @@
     print 'Loading snapshot.'
     # Load snapshot
     saved_state_dict = torch.load(snapshot_path)
-    model.load_state_dict(saved_state_dict)
+    load_filtered_state_dict(model, saved_state_dict)
 
     print 'Loading data.'
 
@@ -59,18 +63,20 @@
     transforms.CenterCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
-    if args.dataset == 'AFLW2000':
-        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
-                                transformations)
+    if args.dataset == 'Pose_300W_LP':
+        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'Pose_300W_LP_random_ds':
+        pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW2000':
+        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW2000_ds':
-        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list,
-                                transformations)
+        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'BIWI':
         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW':
         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
-    elif args.dataset == 'Pose_300W_LP':
-        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW_aug':
+        pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFW':
         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
     else:
@@ -111,8 +117,7 @@
         pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch))
         roll_error += torch.sum(torch.abs(roll_predicted - label_roll))
 
-        # Save images with pose cube.
-        # TODO: fix for larger batch size
+        # Save first image in batch with pose cube or axis.
         if args.save_viz:
             name = name[0]
             if args.dataset == 'BIWI':
@@ -122,7 +127,8 @@
             if args.batch_size == 1:
                 error_string = 'y %.2f, p %.2f, r %.2f' % (torch.sum(torch.abs(yaw_predicted - label_yaw)), torch.sum(torch.abs(pitch_predicted - label_pitch)), torch.sum(torch.abs(roll_predicted - label_roll)))
                 cv2.putText(cv2_img, error_string, (30, cv2_img.shape[0]- 30), fontFace=1, fontScale=1, color=(0,0,255), thickness=1)
-            utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0])
+            # utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], size=100)
+            utils.draw_axis(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], tdx = 200, tdy= 200, size=100)
             cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
 
     print('Test error in degrees of the model on the ' + str(total) +
diff --git a/code/train_alexnet.py b/code/train_alexnet.py
index 9254ee7..51cf43b 100644
--- a/code/train_alexnet.py
+++ b/code/train_alexnet.py
@@ -129,6 +129,9 @@
     # Regression loss coefficient
     alpha = args.alpha
 
+    idx_tensor = [idx for idx in xrange(66)]
+    idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
+
     optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0},
                                   {'params': get_non_ignored_params(model), 'lr': args.lr},
                                   {'params': get_fc_params(model), 'lr': args.lr * 5}],
@@ -150,17 +153,21 @@
             label_roll_cont = Variable(cont_labels[:,2]).cuda(gpu)
 
             # Forward pass
-            yaw, pitch, roll, angles = model(images)
+            pre_yaw, pre_pitch, pre_roll = model(images)
 
             # Cross entropy loss
-            loss_yaw = criterion(yaw, label_yaw)
-            loss_pitch = criterion(pitch, label_pitch)
-            loss_roll = criterion(roll, label_roll)
+            loss_yaw = criterion(pre_yaw, label_yaw)
+            loss_pitch = criterion(pre_pitch, label_pitch)
+            loss_roll = criterion(pre_roll, label_roll)
 
             # MSE loss
-            yaw_predicted = angles[:,0]
-            pitch_predicted = angles[:,1]
-            roll_predicted = angles[:,2]
+            yaw_predicted = softmax(pre_yaw)
+            pitch_predicted = softmax(pre_pitch)
+            roll_predicted = softmax(pre_roll)
+
+            yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 99
+            pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 99
+            roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) * 3 - 99
 
             loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw_cont)
             loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch_cont)
@@ -173,7 +180,6 @@
 
             loss_seq = [loss_yaw, loss_pitch, loss_roll]
             grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))]
-            optimizer.zero_grad()
             torch.autograd.backward(loss_seq, grad_seq)
             optimizer.step()
 

--
Gitblit v1.8.0