From 653b3608ebe6272510b4c66f445f6f552fdc9ec9 Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期一, 11 九月 2017 05:53:10 +0800
Subject: [PATCH] Starting serious experiment without regression or iterative finetuning
---
code/train.py | 78 ++++++----
code/test_AFLW.py | 3
code/hopenet.py | 6
code/train_preangles.py | 265 +++++++++++++++++++++++++++++++++++++
code/test.py | 13 -
code/test_preangles.py | 2
6 files changed, 319 insertions(+), 48 deletions(-)
diff --git a/code/hopenet.py b/code/hopenet.py
index 274044f..5bac804 100644
--- a/code/hopenet.py
+++ b/code/hopenet.py
@@ -88,12 +88,6 @@
return nn.Sequential(*layers)
- def get_expectation(angle):
- angle_pred = F.softmax(angle)
-
- angle_pred = torch.sum(angle_pred.data * self.idx_tensor, 1)
- return angle_pred
-
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
diff --git a/code/test.py b/code/test.py
index 8e8fe50..b01d07e 100644
--- a/code/test.py
+++ b/code/test.py
@@ -27,7 +27,7 @@
default='', type=str)
parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
default='', type=str)
- parser.add_argument('--snapshot', dest='snapshot', help='Name of model snapshot.',
+ parser.add_argument('--snapshot', dest='snapshot', help='Path of model snapshot.',
default='', type=str)
parser.add_argument('--batch_size', dest='batch_size', help='Batch size.',
default=1, type=int)
@@ -43,7 +43,7 @@
cudnn.enabled = True
gpu = args.gpu_id
- snapshot_path = os.path.join('output/snapshots', args.snapshot + '.pkl')
+ snapshot_path = args.snapshot
# ResNet101 with 3 outputs.
# model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
@@ -58,9 +58,6 @@
model.load_state_dict(saved_state_dict)
print 'Loading data.'
-
- # transformations = transforms.Compose([transforms.Scale(224),
- # transforms.RandomCrop(224), transforms.ToTensor()])
transformations = transforms.Compose([transforms.Scale(224),
transforms.RandomCrop(224), transforms.ToTensor(),
@@ -101,9 +98,9 @@
label_roll = labels[:,2].float()
pre_yaw, pre_pitch, pre_roll, angles = model(images)
- yaw = angles[:,0].cpu().data
- pitch = angles[:,1].cpu().data
- roll = angles[:,2].cpu().data
+ yaw = angles[0][:,0].cpu().data
+ pitch = angles[0][:,1].cpu().data
+ roll = angles[0][:,2].cpu().data
# Mean absolute error
yaw_error += torch.sum(torch.abs(yaw - label_yaw) * 3)
diff --git a/code/test_AFLW.py b/code/test_AFLW.py
index 1e1dff3..f61ab98 100644
--- a/code/test_AFLW.py
+++ b/code/test_AFLW.py
@@ -60,7 +60,8 @@
print 'Loading data.'
transformations = transforms.Compose([transforms.Scale(224),
- transforms.RandomCrop(224), transforms.ToTensor()])
+ transforms.RandomCrop(224), transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
pose_dataset = datasets.AFLW(args.data_dir, args.filename_list,
transformations)
diff --git a/code/test_preangles.py b/code/test_preangles.py
index 67e4744..4aedfd8 100644
--- a/code/test_preangles.py
+++ b/code/test_preangles.py
@@ -43,7 +43,7 @@
cudnn.enabled = True
gpu = args.gpu_id
- snapshot_path = os.path.join('output/snapshots', args.snapshot + '.pkl')
+ snapshot_path = args.snapshot
# ResNet101 with 3 outputs.
# model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
diff --git a/code/train.py b/code/train.py
index 826793d..e339b10 100644
--- a/code/train.py
+++ b/code/train.py
@@ -43,6 +43,9 @@
default='', type=str)
parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
default='', type=str)
+ parser.add_argument('--output_string', dest='output_string', help='String appended to output snapshots.', default = '', type=str)
+ parser.add_argument('--alpha', dest='alpha', help='Regression loss coefficient.',
+ default=0.001, type=float)
args = parser.parse_args()
return args
@@ -51,26 +54,37 @@
b = []
b.append(model.conv1)
b.append(model.bn1)
+ for i in range(len(b)):
+ for module_name, module in b[i].named_modules():
+ if 'bn' in module_name:
+ module.eval()
+ for name, param in module.named_parameters():
+ yield param
+
+def get_non_ignored_params(model):
+ # Generator function that yields params that will be optimized.
+ b = []
b.append(model.layer1)
b.append(model.layer2)
b.append(model.layer3)
b.append(model.layer4)
for i in range(len(b)):
- for j in b[i].modules():
- for k in j.parameters():
- yield k
+ for module_name, module in b[i].named_modules():
+ if 'bn' in module_name:
+ module.eval()
+ for name, param in module.named_parameters():
+ yield param
-def get_non_ignored_params(model):
- # Generator function that yields params that will be optimized.
+def get_fc_params(model):
b = []
b.append(model.fc_yaw)
b.append(model.fc_pitch)
b.append(model.fc_roll)
b.append(model.fc_finetune)
for i in range(len(b)):
- for j in b[i].modules():
- for k in j.parameters():
- yield k
+ for module_name, module in b[i].named_modules():
+ for name, param in module.named_parameters():
+ yield param
def load_filtered_state_dict(model, snapshot):
# By user apaszke from discuss.pytorch.org
@@ -104,11 +118,7 @@
print 'Loading data.'
- # transformations = transforms.Compose([transforms.Scale(224),
- # transforms.RandomCrop(224),
- # transforms.ToTensor()])
-
- transformations = transforms.Compose([transforms.Scale(250),
+ transformations = transforms.Compose([transforms.Scale(240),
transforms.RandomCrop(224), transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
@@ -120,17 +130,19 @@
num_workers=2)
model.cuda(gpu)
+ softmax = nn.Softmax()
criterion = nn.CrossEntropyLoss().cuda()
reg_criterion = nn.MSELoss().cuda()
# Regression loss coefficient
- alpha = 0.01
+ alpha = 0.00
idx_tensor = [idx for idx in xrange(66)]
- idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
+ idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
- optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': args.lr},
- {'params': get_non_ignored_params(model), 'lr': args.lr * 10}],
- lr = args.lr)
+ optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0},
+ {'params': get_non_ignored_params(model), 'lr': args.lr},
+ {'params': get_fc_params(model), 'lr': args.lr * 2}],
+ lr = args.lr)
print 'Ready to train network.'
@@ -153,24 +165,26 @@
loss_roll = criterion(pre_roll, label_roll)
# MSE loss
- yaw_predicted = F.softmax(pre_yaw)
- pitch_predicted = F.softmax(pre_pitch)
- roll_predicted = F.softmax(pre_roll)
+ yaw_predicted = softmax(pre_yaw)
+ pitch_predicted = softmax(pre_pitch)
+ roll_predicted = softmax(pre_roll)
- yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1)
- pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1)
- roll_predicted = torch.sum(roll_predicted.data * idx_tensor, 1)
+ yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1)
+ pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1)
+ roll_predicted = torch.sum(roll_predicted * idx_tensor, 1)
loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw.float())
loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch.float())
loss_reg_roll = reg_criterion(roll_predicted, label_roll.float())
+ # print yaw_predicted, label_yaw.float(), loss_reg_yaw
# Total loss
loss_yaw += alpha * loss_reg_yaw
loss_pitch += alpha * loss_reg_pitch
loss_roll += alpha * loss_reg_roll
loss_seq = [loss_yaw, loss_pitch, loss_roll]
+ # loss_seq = [loss_reg_yaw, loss_reg_pitch, loss_reg_roll]
grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))]
torch.autograd.backward(loss_seq, grad_seq)
optimizer.step()
@@ -180,13 +194,13 @@
%(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0]))
# if epoch == 0:
# torch.save(model.state_dict(),
- # 'output/snapshots/hopenet50_epoch_'+ str(i+1) + '.pkl')
+ # 'output/snapshots/' + args.output_string + '_iter_'+ str(i+1) + '.pkl')
# Save models at numbered epochs.
if epoch % 1 == 0 and epoch < num_epochs:
print 'Taking snapshot...'
torch.save(model.state_dict(),
- 'output/snapshots/hopenet50_epoch_'+ str(epoch+1) + '.pkl')
+ 'output/snapshots/' + args.output_string + '_epoch_'+ str(epoch+1) + '.pkl')
print 'Second phase of training (finetuning layer).'
for epoch in range(num_epochs_ft):
@@ -208,9 +222,9 @@
loss_roll = criterion(pre_roll, label_roll)
# MSE loss
- yaw_predicted = F.softmax(pre_yaw)
- pitch_predicted = F.softmax(pre_pitch)
- roll_predicted = F.softmax(pre_roll)
+ yaw_predicted = softmax(pre_yaw)
+ pitch_predicted = softmax(pre_pitch)
+ roll_predicted = softmax(pre_roll)
yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1)
pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1)
@@ -238,14 +252,14 @@
%(epoch+1, num_epochs_ft, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0], loss_angles.data[0]))
# if epoch == 0:
# torch.save(model.state_dict(),
- # 'output/snapshots/hopenet50_iter_'+ str(i+1) + '.pkl')
+ # 'output/snapshots/' + args.output_string + '_iter_'+ str(i+1) + '.pkl')
# Save models at numbered epochs.
if epoch % 1 == 0 and epoch < num_epochs_ft - 1:
print 'Taking snapshot...'
torch.save(model.state_dict(),
- 'output/snapshots/hopenet50_epoch_'+ str(num_epochs+epoch+1) + '.pkl')
+ 'output/snapshots/' + args.output_string + '_epoch_'+ str(num_epochs+epoch+1) + '.pkl')
# Save the final Trained Model
- torch.save(model.state_dict(), 'output/snapshots/hopenet50_epoch_' + str(num_epochs+epoch+1) + '.pkl')
+ torch.save(model.state_dict(), 'output/snapshots/' + args.output_string + '_epoch_' + str(num_epochs+epoch+1) + '.pkl')
diff --git a/code/train_preangles.py b/code/train_preangles.py
new file mode 100644
index 0000000..65a2017
--- /dev/null
+++ b/code/train_preangles.py
@@ -0,0 +1,265 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.utils.data import DataLoader
+from torchvision import transforms
+import torchvision
+import torch.backends.cudnn as cudnn
+import torch.nn.functional as F
+
+import cv2
+import matplotlib.pyplot as plt
+import sys
+import os
+import argparse
+
+import datasets
+import hopenet
+import torch.utils.model_zoo as model_zoo
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+def parse_args():
+ """Parse input arguments."""
+ parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
+ parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
+ default=0, type=int)
+ parser.add_argument('--num_epochs', dest='num_epochs', help='Maximum number of training epochs.',
+ default=5, type=int)
+ parser.add_argument('--num_epochs_ft', dest='num_epochs_ft', help='Maximum number of finetuning epochs.',
+ default=5, type=int)
+ parser.add_argument('--batch_size', dest='batch_size', help='Batch size.',
+ default=16, type=int)
+ parser.add_argument('--lr', dest='lr', help='Base learning rate.',
+ default=0.001, type=float)
+ parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.',
+ default='', type=str)
+ parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
+ default='', type=str)
+ parser.add_argument('--output_string', dest='output_string', help='String appended to output snapshots.', default = '', type=str)
+ parser.add_argument('--alpha', dest='alpha', help='Regression loss coefficient.',
+ default=0.001, type=float)
+ args = parser.parse_args()
+ return args
+
+def get_ignored_params(model):
+ # Generator function that yields ignored params.
+ b = []
+ b.append(model.conv1)
+ b.append(model.bn1)
+ b.append(model.fc_finetune)
+ for i in range(len(b)):
+ for module_name, module in b[i].named_modules():
+ if 'bn' in module_name:
+ module.eval()
+ for name, param in module.named_parameters():
+ yield param
+
+def get_non_ignored_params(model):
+ # Generator function that yields params that will be optimized.
+ b = []
+ b.append(model.layer1)
+ b.append(model.layer2)
+ b.append(model.layer3)
+ b.append(model.layer4)
+ for i in range(len(b)):
+ for module_name, module in b[i].named_modules():
+ if 'bn' in module_name:
+ module.eval()
+ for name, param in module.named_parameters():
+ yield param
+
+def get_fc_params(model):
+ b = []
+ b.append(model.fc_yaw)
+ b.append(model.fc_pitch)
+ b.append(model.fc_roll)
+ for i in range(len(b)):
+ for module_name, module in b[i].named_modules():
+ for name, param in module.named_parameters():
+ yield param
+
+def load_filtered_state_dict(model, snapshot):
+ # By user apaszke from discuss.pytorch.org
+ model_dict = model.state_dict()
+ # 1. filter out unnecessary keys
+ snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
+ # 2. overwrite entries in the existing state dict
+ model_dict.update(snapshot)
+ # 3. load the new state dict
+ model.load_state_dict(model_dict)
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ cudnn.enabled = True
+ num_epochs = args.num_epochs
+ num_epochs_ft = args.num_epochs_ft
+ batch_size = args.batch_size
+ gpu = args.gpu_id
+
+ if not os.path.exists('output/snapshots'):
+ os.makedirs('output/snapshots')
+
+ # ResNet101 with 3 outputs
+ # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
+ # ResNet50
+ model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
+ # ResNet18
+ # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
+ load_filtered_state_dict(model, model_zoo.load_url(model_urls['resnet50']))
+
+ print 'Loading data.'
+
+ transformations = transforms.Compose([transforms.Scale(240),
+ transforms.RandomCrop(224), transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
+
+ pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list,
+ transformations)
+ train_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=2)
+
+ model.cuda(gpu)
+ softmax = nn.Softmax()
+ criterion = nn.CrossEntropyLoss().cuda()
+ reg_criterion = nn.MSELoss().cuda()
+ # Regression loss coefficient
+ alpha = 0.00
+
+ idx_tensor = [idx for idx in xrange(66)]
+ idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
+
+ optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0},
+ {'params': get_non_ignored_params(model), 'lr': args.lr},
+ {'params': get_fc_params(model), 'lr': args.lr * 2}],
+ lr = args.lr)
+
+ print 'Ready to train network.'
+
+ print 'First phase of training.'
+ for epoch in range(num_epochs):
+ for i, (images, labels, name) in enumerate(train_loader):
+ images = Variable(images.cuda(gpu))
+ label_yaw = Variable(labels[:,0].cuda(gpu))
+ label_pitch = Variable(labels[:,1].cuda(gpu))
+ label_roll = Variable(labels[:,2].cuda(gpu))
+
+ optimizer.zero_grad()
+ model.zero_grad()
+
+ pre_yaw, pre_pitch, pre_roll, angles = model(images)
+
+ # Cross entropy loss
+ loss_yaw = criterion(pre_yaw, label_yaw)
+ loss_pitch = criterion(pre_pitch, label_pitch)
+ loss_roll = criterion(pre_roll, label_roll)
+
+ # MSE loss
+ yaw_predicted = softmax(pre_yaw)
+ pitch_predicted = softmax(pre_pitch)
+ roll_predicted = softmax(pre_roll)
+
+ yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1)
+ pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1)
+ roll_predicted = torch.sum(roll_predicted * idx_tensor, 1)
+
+ loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw.float())
+ loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch.float())
+ loss_reg_roll = reg_criterion(roll_predicted, label_roll.float())
+
+ # print yaw_predicted, label_yaw.float(), loss_reg_yaw
+ # Total loss
+ loss_yaw += alpha * loss_reg_yaw
+ loss_pitch += alpha * loss_reg_pitch
+ loss_roll += alpha * loss_reg_roll
+
+ loss_seq = [loss_yaw, loss_pitch, loss_roll]
+ # loss_seq = [loss_reg_yaw, loss_reg_pitch, loss_reg_roll]
+ grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))]
+ torch.autograd.backward(loss_seq, grad_seq)
+ optimizer.step()
+
+ if (i+1) % 100 == 0:
+ print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
+ %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0]))
+ # if epoch == 0:
+ # torch.save(model.state_dict(),
+ # 'output/snapshots/' + args.output_string + '_iter_'+ str(i+1) + '.pkl')
+
+ # Save models at numbered epochs.
+ if epoch % 1 == 0 and epoch < num_epochs:
+ print 'Taking snapshot...'
+ torch.save(model.state_dict(),
+ 'output/snapshots/' + args.output_string + '_epoch_'+ str(epoch+1) + '.pkl')
+
+ print 'Second phase of training (finetuning layer).'
+ for epoch in range(num_epochs_ft):
+ for i, (images, labels, name) in enumerate(train_loader):
+ images = Variable(images.cuda(gpu))
+ label_yaw = Variable(labels[:,0].cuda(gpu))
+ label_pitch = Variable(labels[:,1].cuda(gpu))
+ label_roll = Variable(labels[:,2].cuda(gpu))
+ label_angles = Variable(labels[:,:3].cuda(gpu))
+
+ optimizer.zero_grad()
+ model.zero_grad()
+
+ pre_yaw, pre_pitch, pre_roll, angles = model(images)
+
+ # Cross entropy loss
+ loss_yaw = criterion(pre_yaw, label_yaw)
+ loss_pitch = criterion(pre_pitch, label_pitch)
+ loss_roll = criterion(pre_roll, label_roll)
+
+ # MSE loss
+ yaw_predicted = softmax(pre_yaw)
+ pitch_predicted = softmax(pre_pitch)
+ roll_predicted = softmax(pre_roll)
+
+ yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1)
+ pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1)
+ roll_predicted = torch.sum(roll_predicted.data * idx_tensor, 1)
+
+ loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw.float())
+ loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch.float())
+ loss_reg_roll = reg_criterion(roll_predicted, label_roll.float())
+
+ # Total loss
+ loss_yaw += alpha * loss_reg_yaw
+ loss_pitch += alpha * loss_reg_pitch
+ loss_roll += alpha * loss_reg_roll
+
+ # Finetuning loss
+ loss_angles = reg_criterion(angles[0], label_angles.float())
+
+ loss_seq = [loss_yaw, loss_pitch, loss_roll, loss_angles]
+ grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))]
+ torch.autograd.backward(loss_seq, grad_seq)
+ optimizer.step()
+
+ if (i+1) % 100 == 0:
+ print ('Epoch [%d/%d], Iter [%d/%d] Losses: pre-yaw %.4f, pre-pitch %.4f, pre-roll %.4f, finetuning %.4f'
+ %(epoch+1, num_epochs_ft, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0], loss_angles.data[0]))
+ # if epoch == 0:
+ # torch.save(model.state_dict(),
+ # 'output/snapshots/' + args.output_string + '_iter_'+ str(i+1) + '.pkl')
+
+ # Save models at numbered epochs.
+ if epoch % 1 == 0 and epoch < num_epochs_ft - 1:
+ print 'Taking snapshot...'
+ torch.save(model.state_dict(),
+ 'output/snapshots/' + args.output_string + '_epoch_'+ str(num_epochs+epoch+1) + '.pkl')
+
+
+ # Save the final Trained Model
+ torch.save(model.state_dict(), 'output/snapshots/' + args.output_string + '_epoch_' + str(num_epochs+epoch+1) + '.pkl')
--
Gitblit v1.8.0