From 2f6778c2db9ce1a887f04fdc85ad0d5db4ba84b8 Mon Sep 17 00:00:00 2001
From: natanielruiz <nruiz9@gatech.edu>
Date: 星期一, 30 十月 2017 06:15:30 +0800
Subject: [PATCH] Cleaned up a bit

---
 code/train_alexnet.py |   70 +++++++++++------------------------
 1 files changed, 22 insertions(+), 48 deletions(-)

diff --git a/code/train_alexnet.py b/code/train_alexnet.py
index 5f60211..9254ee7 100644
--- a/code/train_alexnet.py
+++ b/code/train_alexnet.py
@@ -1,4 +1,9 @@
+import sys, os, argparse, time
+
 import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+
 import torch
 import torch.nn as nn
 from torch.autograd import Variable
@@ -8,17 +13,8 @@
 import torch.backends.cudnn as cudnn
 import torch.nn.functional as F
 
-import cv2
-import matplotlib.pyplot as plt
-import sys
-import os
-import argparse
-
-import datasets
-import hopenet
+import datasets, hopenet
 import torch.utils.model_zoo as model_zoo
-
-import time
 
 model_urls = {
     'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
@@ -43,16 +39,12 @@
     parser.add_argument('--alpha', dest='alpha', help='Regression loss coefficient.',
           default=0.001, type=float)
     parser.add_argument('--dataset', dest='dataset', help='Dataset type.', default='Pose_300W_LP', type=str)
-
     args = parser.parse_args()
     return args
 
 def get_ignored_params(model):
     # Generator function that yields ignored params.
-    b = []
-    b.append(model.features[0])
-    b.append(model.features[1])
-    b.append(model.features[2])
+    b = [model.features[0], model.features[1], model.features[2]]
     for i in range(len(b)):
         for module_name, module in b[i].named_modules():
             if 'bn' in module_name:
@@ -75,10 +67,7 @@
                 yield param
 
 def get_fc_params(model):
-    b = []
-    b.append(model.fc_yaw)
-    b.append(model.fc_pitch)
-    b.append(model.fc_roll)
+    b = [model.fc_yaw, model.fc_pitch, model.fc_roll]
     for i in range(len(b)):
         for module_name, module in b[i].named_modules():
             for name, param in module.named_parameters():
@@ -87,11 +76,8 @@
 def load_filtered_state_dict(model, snapshot):
     # By user apaszke from discuss.pytorch.org
     model_dict = model.state_dict()
-    # 1. filter out unnecessary keys
     snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
-    # 2. overwrite entries in the existing state dict
     model_dict.update(snapshot)
-    # 3. load the new state dict
     model.load_state_dict(model_dict)
 
 if __name__ == '__main__':
@@ -116,6 +102,8 @@
 
     if args.dataset == 'Pose_300W_LP':
         pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'Pose_300W_LP_random_ds':
+        pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW2000':
         pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'BIWI':
@@ -141,48 +129,38 @@
     # Regression loss coefficient
     alpha = args.alpha
 
-    idx_tensor = [idx for idx in xrange(66)]
-    idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
-
     optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0},
                                   {'params': get_non_ignored_params(model), 'lr': args.lr},
                                   {'params': get_fc_params(model), 'lr': args.lr * 5}],
                                    lr = args.lr)
 
     print 'Ready to train network.'
-    print 'First phase of training.'
     for epoch in range(num_epochs):
-        # start = time.time()
         for i, (images, labels, cont_labels, name) in enumerate(train_loader):
-            # print i
-            # print 'start: ', time.time() - start
             images = Variable(images).cuda(gpu)
+
+            # Binned labels
             label_yaw = Variable(labels[:,0]).cuda(gpu)
             label_pitch = Variable(labels[:,1]).cuda(gpu)
             label_roll = Variable(labels[:,2]).cuda(gpu)
 
-            label_angles = Variable(cont_labels[:,:3]).cuda(gpu)
+            # Continuous labels
             label_yaw_cont = Variable(cont_labels[:,0]).cuda(gpu)
             label_pitch_cont = Variable(cont_labels[:,1]).cuda(gpu)
             label_roll_cont = Variable(cont_labels[:,2]).cuda(gpu)
 
-            optimizer.zero_grad()
-            model.zero_grad()
+            # Forward pass
+            yaw, pitch, roll, angles = model(images)
 
-            pre_yaw, pre_pitch, pre_roll = model(images)
             # Cross entropy loss
-            loss_yaw = criterion(pre_yaw, label_yaw)
-            loss_pitch = criterion(pre_pitch, label_pitch)
-            loss_roll = criterion(pre_roll, label_roll)
+            loss_yaw = criterion(yaw, label_yaw)
+            loss_pitch = criterion(pitch, label_pitch)
+            loss_roll = criterion(roll, label_roll)
 
             # MSE loss
-            yaw_predicted = softmax(pre_yaw)
-            pitch_predicted = softmax(pre_pitch)
-            roll_predicted = softmax(pre_roll)
-
-            yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 99
-            pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 99
-            roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) * 3 - 99
+            yaw_predicted = angles[:,0]
+            pitch_predicted = angles[:,1]
+            roll_predicted = angles[:,2]
 
             loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw_cont)
             loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch_cont)
@@ -195,17 +173,13 @@
 
             loss_seq = [loss_yaw, loss_pitch, loss_roll]
             grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))]
+            optimizer.zero_grad()
             torch.autograd.backward(loss_seq, grad_seq)
             optimizer.step()
-
-            # print 'end: ', time.time() - start
 
             if (i+1) % 100 == 0:
                 print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
                        %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0]))
-                # if epoch == 0:
-                #     torch.save(model.state_dict(),
-                #     'output/snapshots/' + args.output_string + '_iter_'+ str(i+1) + '.pkl')
 
         # Save models at numbered epochs.
         if epoch % 1 == 0 and epoch < num_epochs:

--
Gitblit v1.8.0