From 0b8e19c1cc8ad03805d4ca68f32df6e4806a36e8 Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期五, 08 九月 2017 11:15:10 +0800
Subject: [PATCH] Finetune layer working
---
code/train.py | 95 ++++++++--
code/datasets.py | 16
code/hopenet.py | 36 +++
code/test_old.py | 149 ++++++++++++++++
code/test.py | 38 ---
code/test_preangles.py | 149 ++++++++++++++++
6 files changed, 420 insertions(+), 63 deletions(-)
diff --git a/code/datasets.py b/code/datasets.py
index f73c0a1..f24f063 100644
--- a/code/datasets.py
+++ b/code/datasets.py
@@ -60,14 +60,14 @@
img = img.transpose(Image.FLIP_LEFT_RIGHT)
# Rotate?
- rnd = np.random.random_sample()
- if rnd < 0.5:
- if roll >= 0:
- img = img.rotate(30)
- roll -= 30
- else:
- img = img.rotate(-30)
- roll += 30
+ # rnd = np.random.random_sample()
+ # if rnd < 0.5:
+ # if roll >= 0:
+ # img = img.rotate(30)
+ # roll -= 30
+ # else:
+ # img = img.rotate(-30)
+ # roll += 30
# Bin values
bins = np.array(range(-99, 102, 3))
diff --git a/code/hopenet.py b/code/hopenet.py
index 1b94fa1..274044f 100644
--- a/code/hopenet.py
+++ b/code/hopenet.py
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
-import torchvision.datasets as dsets
from torch.autograd import Variable
import math
+import torch.nn.functional as F
# CNN Model (2 conv layer)
class Simple_CNN(nn.Module):
@@ -58,6 +58,11 @@
self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)
self.fc_roll = nn.Linear(512 * block.expansion, num_bins)
+ self.softmax = nn.Softmax()
+ self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
+
+ self.idx_tensor = Variable(torch.FloatTensor(range(66))).cuda()
+
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
@@ -83,6 +88,12 @@
return nn.Sequential(*layers)
+ def get_expectation(angle):
+ angle_pred = F.softmax(angle)
+
+ angle_pred = torch.sum(angle_pred.data * self.idx_tensor, 1)
+ return angle_pred
+
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
@@ -96,11 +107,26 @@
x = self.avgpool(x)
x = x.view(x.size(0), -1)
- yaw = self.fc_yaw(x)
- pitch = self.fc_pitch(x)
- roll = self.fc_roll(x)
+ pre_yaw = self.fc_yaw(x)
+ pre_pitch = self.fc_pitch(x)
+ pre_roll = self.fc_roll(x)
- return yaw, pitch, roll
+ yaw = self.softmax(pre_yaw)
+ yaw = Variable(torch.sum(yaw.data * self.idx_tensor.data, 1), requires_grad=True)
+ pitch = self.softmax(pre_pitch)
+ pitch = Variable(torch.sum(pitch.data * self.idx_tensor.data, 1), requires_grad=True)
+ roll = self.softmax(pre_roll)
+ roll = Variable(torch.sum(roll.data * self.idx_tensor.data, 1), requires_grad=True)
+ yaw = yaw.view(yaw.size(0), 1)
+ pitch = pitch.view(pitch.size(0), 1)
+ roll = roll.view(roll.size(0), 1)
+ angles = []
+ angles.append(torch.cat([yaw, pitch, roll], 1))
+
+ for idx in xrange(1):
+ angles.append(self.fc_finetune(torch.cat((angles[-1], x), 1)))
+
+ return pre_yaw, pre_pitch, pre_roll, angles
class Hopenet_shape(nn.Module):
# This is just Hopenet with 3 output layers for yaw, pitch and roll.
diff --git a/code/test.py b/code/test.py
index b9be11e..8e8fe50 100644
--- a/code/test.py
+++ b/code/test.py
@@ -100,44 +100,22 @@
label_pitch = labels[:,1].float()
label_roll = labels[:,2].float()
- yaw, pitch, roll = model(images)
-
- # Binned predictions
- _, yaw_bpred = torch.max(yaw.data, 1)
- _, pitch_bpred = torch.max(pitch.data, 1)
- _, roll_bpred = torch.max(roll.data, 1)
-
- # Continuous predictions
- yaw_predicted = utils.softmax_temperature(yaw.data, 1)
- pitch_predicted = utils.softmax_temperature(pitch.data, 1)
- roll_predicted = utils.softmax_temperature(roll.data, 1)
-
- yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu()
- pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu()
- roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu()
+ pre_yaw, pre_pitch, pre_roll, angles = model(images)
+ yaw = angles[:,0].cpu().data
+ pitch = angles[:,1].cpu().data
+ roll = angles[:,2].cpu().data
# Mean absolute error
- yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw) * 3)
- pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch) * 3)
- roll_error += torch.sum(torch.abs(roll_predicted - label_roll) * 3)
-
- # Binned Accuracy
- # for er in xrange(n_margins):
- # yaw_bpred[er] += (label_yaw[0] in range(yaw_bpred[0,0] - er, yaw_bpred[0,0] + er + 1))
- # pitch_bpred[er] += (label_pitch[0] in range(pitch_bpred[0,0] - er, pitch_bpred[0,0] + er + 1))
- # roll_bpred[er] += (label_roll[0] in range(roll_bpred[0,0] - er, roll_bpred[0,0] + er + 1))
-
- # print label_yaw[0], yaw_bpred[0,0]
+ yaw_error += torch.sum(torch.abs(yaw - label_yaw) * 3)
+ pitch_error += torch.sum(torch.abs(pitch - label_pitch) * 3)
+ roll_error += torch.sum(torch.abs(roll - label_roll) * 3)
# Save images with pose cube.
# TODO: fix for larger batch size
if args.save_viz:
name = name[0]
cv2_img = cv2.imread(os.path.join(args.data_dir, name + '.jpg'))
- #print os.path.join('output/images', name + '.jpg')
- #print label_yaw[0] * 3 - 99, label_pitch[0] * 3 - 99, label_roll[0] * 3 - 99
- #print yaw_predicted * 3 - 99, pitch_predicted * 3 - 99, roll_predicted * 3 - 99
- utils.plot_pose_cube(cv2_img, yaw_predicted[0] * 3 - 99, pitch_predicted[0] * 3 - 99, roll_predicted[0] * 3 - 99)
+ utils.plot_pose_cube(cv2_img, yaw[0] * 3 - 99, pitch[0] * 3 - 99, roll[0] * 3 - 99)
cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
print('Test error in degrees of the model on the ' + str(total) +
diff --git a/code/test_old.py b/code/test_old.py
new file mode 100644
index 0000000..b9be11e
--- /dev/null
+++ b/code/test_old.py
@@ -0,0 +1,149 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.utils.data import DataLoader
+from torchvision import transforms
+import torch.backends.cudnn as cudnn
+import torchvision
+import torch.nn.functional as F
+
+import cv2
+import matplotlib.pyplot as plt
+import sys
+import os
+import argparse
+
+import datasets
+import hopenet
+import utils
+
+def parse_args():
+ """Parse input arguments."""
+ parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
+ parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
+ default=0, type=int)
+ parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.',
+ default='', type=str)
+ parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
+ default='', type=str)
+ parser.add_argument('--snapshot', dest='snapshot', help='Name of model snapshot.',
+ default='', type=str)
+ parser.add_argument('--batch_size', dest='batch_size', help='Batch size.',
+ default=1, type=int)
+ parser.add_argument('--save_viz', dest='save_viz', help='Save images with pose cube.',
+ default=False, type=bool)
+
+ args = parser.parse_args()
+
+ return args
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ cudnn.enabled = True
+ gpu = args.gpu_id
+ snapshot_path = os.path.join('output/snapshots', args.snapshot + '.pkl')
+
+ # ResNet101 with 3 outputs.
+ # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
+ # ResNet50
+ model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
+ # ResNet18
+ # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
+
+ print 'Loading snapshot.'
+ # Load snapshot
+ saved_state_dict = torch.load(snapshot_path)
+ model.load_state_dict(saved_state_dict)
+
+ print 'Loading data.'
+
+ # transformations = transforms.Compose([transforms.Scale(224),
+ # transforms.RandomCrop(224), transforms.ToTensor()])
+
+ transformations = transforms.Compose([transforms.Scale(224),
+ transforms.RandomCrop(224), transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
+
+ pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
+ transformations)
+ test_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
+ batch_size=args.batch_size,
+ num_workers=2)
+
+ model.cuda(gpu)
+
+ print 'Ready to test network.'
+
+ # Test the Model
+ model.eval() # Change model to 'eval' mode (BN uses moving mean/var).
+ total = 0
+ n_margins = 20
+ yaw_correct = np.zeros(n_margins)
+ pitch_correct = np.zeros(n_margins)
+ roll_correct = np.zeros(n_margins)
+
+ idx_tensor = [idx for idx in xrange(66)]
+ idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
+
+ yaw_error = .0
+ pitch_error = .0
+ roll_error = .0
+
+ l1loss = torch.nn.L1Loss(size_average=False)
+
+ for i, (images, labels, name) in enumerate(test_loader):
+ images = Variable(images).cuda(gpu)
+ total += labels.size(0)
+ label_yaw = labels[:,0].float()
+ label_pitch = labels[:,1].float()
+ label_roll = labels[:,2].float()
+
+ yaw, pitch, roll = model(images)
+
+ # Binned predictions
+ _, yaw_bpred = torch.max(yaw.data, 1)
+ _, pitch_bpred = torch.max(pitch.data, 1)
+ _, roll_bpred = torch.max(roll.data, 1)
+
+ # Continuous predictions
+ yaw_predicted = utils.softmax_temperature(yaw.data, 1)
+ pitch_predicted = utils.softmax_temperature(pitch.data, 1)
+ roll_predicted = utils.softmax_temperature(roll.data, 1)
+
+ yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu()
+ pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu()
+ roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu()
+
+ # Mean absolute error
+ yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw) * 3)
+ pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch) * 3)
+ roll_error += torch.sum(torch.abs(roll_predicted - label_roll) * 3)
+
+ # Binned Accuracy
+ # for er in xrange(n_margins):
+ # yaw_bpred[er] += (label_yaw[0] in range(yaw_bpred[0,0] - er, yaw_bpred[0,0] + er + 1))
+ # pitch_bpred[er] += (label_pitch[0] in range(pitch_bpred[0,0] - er, pitch_bpred[0,0] + er + 1))
+ # roll_bpred[er] += (label_roll[0] in range(roll_bpred[0,0] - er, roll_bpred[0,0] + er + 1))
+
+ # print label_yaw[0], yaw_bpred[0,0]
+
+ # Save images with pose cube.
+ # TODO: fix for larger batch size
+ if args.save_viz:
+ name = name[0]
+ cv2_img = cv2.imread(os.path.join(args.data_dir, name + '.jpg'))
+ #print os.path.join('output/images', name + '.jpg')
+ #print label_yaw[0] * 3 - 99, label_pitch[0] * 3 - 99, label_roll[0] * 3 - 99
+ #print yaw_predicted * 3 - 99, pitch_predicted * 3 - 99, roll_predicted * 3 - 99
+ utils.plot_pose_cube(cv2_img, yaw_predicted[0] * 3 - 99, pitch_predicted[0] * 3 - 99, roll_predicted[0] * 3 - 99)
+ cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
+
+ print('Test error in degrees of the model on the ' + str(total) +
+ ' test images. Yaw: %.4f, Pitch: %.4f, Roll: %.4f' % (yaw_error / total,
+ pitch_error / total, roll_error / total))
+
+ # Binned accuracy
+ # for idx in xrange(len(yaw_correct)):
+ # print yaw_correct[idx] / total, pitch_correct[idx] / total, roll_correct[idx] / total
diff --git a/code/test_preangles.py b/code/test_preangles.py
new file mode 100644
index 0000000..67e4744
--- /dev/null
+++ b/code/test_preangles.py
@@ -0,0 +1,149 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.utils.data import DataLoader
+from torchvision import transforms
+import torch.backends.cudnn as cudnn
+import torchvision
+import torch.nn.functional as F
+
+import cv2
+import matplotlib.pyplot as plt
+import sys
+import os
+import argparse
+
+import datasets
+import hopenet
+import utils
+
+def parse_args():
+ """Parse input arguments."""
+ parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
+ parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
+ default=0, type=int)
+ parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.',
+ default='', type=str)
+ parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
+ default='', type=str)
+ parser.add_argument('--snapshot', dest='snapshot', help='Name of model snapshot.',
+ default='', type=str)
+ parser.add_argument('--batch_size', dest='batch_size', help='Batch size.',
+ default=1, type=int)
+ parser.add_argument('--save_viz', dest='save_viz', help='Save images with pose cube.',
+ default=False, type=bool)
+
+ args = parser.parse_args()
+
+ return args
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ cudnn.enabled = True
+ gpu = args.gpu_id
+ snapshot_path = os.path.join('output/snapshots', args.snapshot + '.pkl')
+
+ # ResNet101 with 3 outputs.
+ # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
+ # ResNet50
+ model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
+ # ResNet18
+ # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
+
+ print 'Loading snapshot.'
+ # Load snapshot
+ saved_state_dict = torch.load(snapshot_path)
+ model.load_state_dict(saved_state_dict)
+
+ print 'Loading data.'
+
+ # transformations = transforms.Compose([transforms.Scale(224),
+ # transforms.RandomCrop(224), transforms.ToTensor()])
+
+ transformations = transforms.Compose([transforms.Scale(224),
+ transforms.RandomCrop(224), transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
+
+ pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
+ transformations)
+ test_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
+ batch_size=args.batch_size,
+ num_workers=2)
+
+ model.cuda(gpu)
+
+ print 'Ready to test network.'
+
+ # Test the Model
+ model.eval() # Change model to 'eval' mode (BN uses moving mean/var).
+ total = 0
+ n_margins = 20
+ yaw_correct = np.zeros(n_margins)
+ pitch_correct = np.zeros(n_margins)
+ roll_correct = np.zeros(n_margins)
+
+ idx_tensor = [idx for idx in xrange(66)]
+ idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
+
+ yaw_error = .0
+ pitch_error = .0
+ roll_error = .0
+
+ l1loss = torch.nn.L1Loss(size_average=False)
+
+ for i, (images, labels, name) in enumerate(test_loader):
+ images = Variable(images).cuda(gpu)
+ total += labels.size(0)
+ label_yaw = labels[:,0].float()
+ label_pitch = labels[:,1].float()
+ label_roll = labels[:,2].float()
+
+ yaw, pitch, roll, angles = model(images)
+
+ # Binned predictions
+ _, yaw_bpred = torch.max(yaw.data, 1)
+ _, pitch_bpred = torch.max(pitch.data, 1)
+ _, roll_bpred = torch.max(roll.data, 1)
+
+ # Continuous predictions
+ yaw_predicted = utils.softmax_temperature(yaw.data, 1)
+ pitch_predicted = utils.softmax_temperature(pitch.data, 1)
+ roll_predicted = utils.softmax_temperature(roll.data, 1)
+
+ yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu()
+ pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu()
+ roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu()
+
+ # Mean absolute error
+ yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw) * 3)
+ pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch) * 3)
+ roll_error += torch.sum(torch.abs(roll_predicted - label_roll) * 3)
+
+ # Binned Accuracy
+ # for er in xrange(n_margins):
+ # yaw_bpred[er] += (label_yaw[0] in range(yaw_bpred[0,0] - er, yaw_bpred[0,0] + er + 1))
+ # pitch_bpred[er] += (label_pitch[0] in range(pitch_bpred[0,0] - er, pitch_bpred[0,0] + er + 1))
+ # roll_bpred[er] += (label_roll[0] in range(roll_bpred[0,0] - er, roll_bpred[0,0] + er + 1))
+
+ # print label_yaw[0], yaw_bpred[0,0]
+
+ # Save images with pose cube.
+ # TODO: fix for larger batch size
+ if args.save_viz:
+ name = name[0]
+ cv2_img = cv2.imread(os.path.join(args.data_dir, name + '.jpg'))
+ #print os.path.join('output/images', name + '.jpg')
+ #print label_yaw[0] * 3 - 99, label_pitch[0] * 3 - 99, label_roll[0] * 3 - 99
+ #print yaw_predicted * 3 - 99, pitch_predicted * 3 - 99, roll_predicted * 3 - 99
+ utils.plot_pose_cube(cv2_img, yaw_predicted[0] * 3 - 99, pitch_predicted[0] * 3 - 99, roll_predicted[0] * 3 - 99)
+ cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
+
+ print('Test error in degrees of the model on the ' + str(total) +
+ ' test images. Yaw: %.4f, Pitch: %.4f, Roll: %.4f' % (yaw_error / total,
+ pitch_error / total, roll_error / total))
+
+ # Binned accuracy
+ # for idx in xrange(len(yaw_correct)):
+ # print yaw_correct[idx] / total, pitch_correct[idx] / total, roll_correct[idx] / total
diff --git a/code/train.py b/code/train.py
index 5d7fc7d..826793d 100644
--- a/code/train.py
+++ b/code/train.py
@@ -33,6 +33,8 @@
default=0, type=int)
parser.add_argument('--num_epochs', dest='num_epochs', help='Maximum number of training epochs.',
default=5, type=int)
+ parser.add_argument('--num_epochs_ft', dest='num_epochs_ft', help='Maximum number of finetuning epochs.',
+ default=5, type=int)
parser.add_argument('--batch_size', dest='batch_size', help='Batch size.',
default=16, type=int)
parser.add_argument('--lr', dest='lr', help='Base learning rate.',
@@ -41,9 +43,7 @@
default='', type=str)
parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
default='', type=str)
-
args = parser.parse_args()
-
return args
def get_ignored_params(model):
@@ -66,6 +66,7 @@
b.append(model.fc_yaw)
b.append(model.fc_pitch)
b.append(model.fc_roll)
+ b.append(model.fc_finetune)
for i in range(len(b)):
for j in b[i].modules():
for k in j.parameters():
@@ -86,6 +87,7 @@
cudnn.enabled = True
num_epochs = args.num_epochs
+ num_epochs_ft = args.num_epochs_ft
batch_size = args.batch_size
gpu = args.gpu_id
@@ -129,13 +131,10 @@
optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': args.lr},
{'params': get_non_ignored_params(model), 'lr': args.lr * 10}],
lr = args.lr)
- # optimizer = torch.optim.SGD([{'params': get_ignored_params(model), 'lr': args.lr},
- # {'params': get_non_ignored_params(model), 'lr': args.lr}],
- # lr = args.lr,
- # momentum = 0.9, weight_decay=0.01)
print 'Ready to train network.'
+ print 'First phase of training.'
for epoch in range(num_epochs):
for i, (images, labels, name) in enumerate(train_loader):
images = Variable(images.cuda(gpu))
@@ -146,17 +145,17 @@
optimizer.zero_grad()
model.zero_grad()
- yaw, pitch, roll = model(images)
+ pre_yaw, pre_pitch, pre_roll, angles = model(images)
# Cross entropy loss
- loss_yaw = criterion(yaw, label_yaw)
- loss_pitch = criterion(pitch, label_pitch)
- loss_roll = criterion(roll, label_roll)
+ loss_yaw = criterion(pre_yaw, label_yaw)
+ loss_pitch = criterion(pre_pitch, label_pitch)
+ loss_roll = criterion(pre_roll, label_roll)
# MSE loss
- yaw_predicted = F.softmax(yaw)
- pitch_predicted = F.softmax(pitch)
- roll_predicted = F.softmax(roll)
+ yaw_predicted = F.softmax(pre_yaw)
+ pitch_predicted = F.softmax(pre_pitch)
+ roll_predicted = F.softmax(pre_roll)
yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1)
pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1)
@@ -176,21 +175,77 @@
torch.autograd.backward(loss_seq, grad_seq)
optimizer.step()
- # print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
- # %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0]))
-
if (i+1) % 100 == 0:
print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
%(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0]))
# if epoch == 0:
# torch.save(model.state_dict(),
- # 'output/snapshots/resnet50_lbatch_iter_'+ str(i+1) + '.pkl')
+ # 'output/snapshots/hopenet50_epoch_'+ str(i+1) + '.pkl')
# Save models at numbered epochs.
- if epoch % 1 == 0 and epoch < num_epochs - 1:
+ if epoch % 1 == 0 and epoch < num_epochs:
print 'Taking snapshot...'
torch.save(model.state_dict(),
- 'output/snapshots/resnet50_norm_30rot_epoch_'+ str(epoch+1) + '.pkl')
+ 'output/snapshots/hopenet50_epoch_'+ str(epoch+1) + '.pkl')
+
+ print 'Second phase of training (finetuning layer).'
+ for epoch in range(num_epochs_ft):
+ for i, (images, labels, name) in enumerate(train_loader):
+ images = Variable(images.cuda(gpu))
+ label_yaw = Variable(labels[:,0].cuda(gpu))
+ label_pitch = Variable(labels[:,1].cuda(gpu))
+ label_roll = Variable(labels[:,2].cuda(gpu))
+ label_angles = Variable(labels[:,:3].cuda(gpu))
+
+ optimizer.zero_grad()
+ model.zero_grad()
+
+ pre_yaw, pre_pitch, pre_roll, angles = model(images)
+
+ # Cross entropy loss
+ loss_yaw = criterion(pre_yaw, label_yaw)
+ loss_pitch = criterion(pre_pitch, label_pitch)
+ loss_roll = criterion(pre_roll, label_roll)
+
+ # MSE loss
+ yaw_predicted = F.softmax(pre_yaw)
+ pitch_predicted = F.softmax(pre_pitch)
+ roll_predicted = F.softmax(pre_roll)
+
+ yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1)
+ pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1)
+ roll_predicted = torch.sum(roll_predicted.data * idx_tensor, 1)
+
+ loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw.float())
+ loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch.float())
+ loss_reg_roll = reg_criterion(roll_predicted, label_roll.float())
+
+ # Total loss
+ loss_yaw += alpha * loss_reg_yaw
+ loss_pitch += alpha * loss_reg_pitch
+ loss_roll += alpha * loss_reg_roll
+
+ # Finetuning loss
+ loss_angles = reg_criterion(angles[0], label_angles.float())
+
+ loss_seq = [loss_yaw, loss_pitch, loss_roll, loss_angles]
+ grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))]
+ torch.autograd.backward(loss_seq, grad_seq)
+ optimizer.step()
+
+ if (i+1) % 100 == 0:
+ print ('Epoch [%d/%d], Iter [%d/%d] Losses: pre-yaw %.4f, pre-pitch %.4f, pre-roll %.4f, finetuning %.4f'
+ %(epoch+1, num_epochs_ft, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0], loss_angles.data[0]))
+ # if epoch == 0:
+ # torch.save(model.state_dict(),
+ # 'output/snapshots/hopenet50_iter_'+ str(i+1) + '.pkl')
+
+ # Save models at numbered epochs.
+ if epoch % 1 == 0 and epoch < num_epochs_ft - 1:
+ print 'Taking snapshot...'
+ torch.save(model.state_dict(),
+ 'output/snapshots/hopenet50_epoch_'+ str(num_epochs+epoch+1) + '.pkl')
+
# Save the final Trained Model
- torch.save(model.state_dict(), 'output/snapshots/resnet50_norm_30rot_epoch_' + str(epoch+1) + '.pkl')
+ torch.save(model.state_dict(), 'output/snapshots/hopenet50_epoch_' + str(num_epochs+epoch+1) + '.pkl')
--
Gitblit v1.8.0