From ec44ac453f794a5368e702315addfedcea3a4299 Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期二, 19 九月 2017 06:01:47 +0800
Subject: [PATCH] Added continuous labels

---
 code/train.py |   43 +++++++++++++++++++++++++++----------------
 1 files changed, 27 insertions(+), 16 deletions(-)

diff --git a/code/train.py b/code/train.py
index a41edc0..2f0cce3 100644
--- a/code/train.py
+++ b/code/train.py
@@ -133,6 +133,8 @@
         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW':
         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW_aug':
+        pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFW':
         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
     else:
@@ -162,11 +164,16 @@
 
     print 'First phase of training.'
     for epoch in range(num_epochs):
-        for i, (images, labels, name) in enumerate(train_loader):
+        for i, (images, labels, cont_labels, name) in enumerate(train_loader):
             images = Variable(images.cuda(gpu))
             label_yaw = Variable(labels[:,0].cuda(gpu))
             label_pitch = Variable(labels[:,1].cuda(gpu))
             label_roll = Variable(labels[:,2].cuda(gpu))
+
+            label_angles = Variable(cont_labels[:,:3].cuda(gpu))
+            label_yaw_cont = Variable(cont_labels[:,0].cuda(gpu))
+            label_pitch_cont = Variable(cont_labels[:,1].cuda(gpu))
+            label_roll_cont = Variable(cont_labels[:,2].cuda(gpu))
 
             optimizer.zero_grad()
             model.zero_grad()
@@ -183,13 +190,13 @@
             pitch_predicted = softmax(pre_pitch)
             roll_predicted = softmax(pre_roll)
 
-            yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1)
-            pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1)
-            roll_predicted = torch.sum(roll_predicted * idx_tensor, 1)
+            yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 99
+            pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 99
+            roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) * 3 - 99
 
-            loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw.float())
-            loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch.float())
-            loss_reg_roll = reg_criterion(roll_predicted, label_roll.float())
+            loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw_cont)
+            loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch_cont)
+            loss_reg_roll = reg_criterion(roll_predicted, label_roll_cont)
 
             # Total loss
             loss_yaw += alpha * loss_reg_yaw
@@ -216,12 +223,16 @@
 
     print 'Second phase of training (finetuning layer).'
     for epoch in range(num_epochs_ft):
-        for i, (images, labels, name) in enumerate(train_loader):
+        for i, (images, labels, cont_labels, name) in enumerate(train_loader):
             images = Variable(images.cuda(gpu))
             label_yaw = Variable(labels[:,0].cuda(gpu))
             label_pitch = Variable(labels[:,1].cuda(gpu))
             label_roll = Variable(labels[:,2].cuda(gpu))
-            label_angles = Variable(labels[:,:3].cuda(gpu))
+
+            label_angles = Variable(cont_labels[:,:3].cuda(gpu))
+            label_yaw_cont = Variable(cont_labels[:,0].cuda(gpu))
+            label_pitch_cont = Variable(cont_labels[:,1].cuda(gpu))
+            label_roll_cont = Variable(cont_labels[:,2].cuda(gpu))
 
             optimizer.zero_grad()
             model.zero_grad()
@@ -238,13 +249,13 @@
             pitch_predicted = softmax(pre_pitch)
             roll_predicted = softmax(pre_roll)
 
-            yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1)
-            pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1)
-            roll_predicted = torch.sum(roll_predicted * idx_tensor, 1)
+            yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 99
+            pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 99
+            roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) * 3 - 99
 
-            loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw.float())
-            loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch.float())
-            loss_reg_roll = reg_criterion(roll_predicted, label_roll.float())
+            loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw_cont)
+            loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch_cont)
+            loss_reg_roll = reg_criterion(roll_predicted, label_roll_cont)
 
             # Total loss
             loss_yaw += alpha * loss_reg_yaw
@@ -254,7 +265,7 @@
             # Finetuning loss
             loss_seq = [loss_yaw, loss_pitch, loss_roll]
             for idx in xrange(1,len(angles)):
-                label_angles_residuals = label_angles.float() - angles[0]
+                label_angles_residuals = label_angles - angles[0] * 3 - 99
                 label_angles_residuals = label_angles_residuals.detach()
                 loss_angles = reg_criterion(angles[idx], label_angles_residuals)
                 loss_seq.append(loss_angles)

--
Gitblit v1.8.0