From 6664c6d52fad58e396861946a3bed7d5afc4d44d Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期五, 07 七月 2017 10:53:52 +0800 Subject: [PATCH] Training for hopenet works. --- code/train_resnet_bins.py | 94 +++++++++++++++++++++++++++++++++++----------- 1 files changed, 71 insertions(+), 23 deletions(-) diff --git a/code/train_resnet_bins.py b/code/train_resnet_bins.py index f2ec5f2..1bbf5be 100644 --- a/code/train_resnet_bins.py +++ b/code/train_resnet_bins.py @@ -13,8 +13,17 @@ import os import argparse -from datasets import Pose_300W_LP +import datasets import hopenet +import torch.utils.model_zoo as model_zoo + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} def parse_args(): """Parse input arguments.""" @@ -36,6 +45,41 @@ return args +def get_ignored_params(model): + # Generator function that yields ignored params. + b = [] + b.append(model.conv1) + b.append(model.bn1) + b.append(model.layer1) + b.append(model.layer2) + b.append(model.layer3) + b.append(model.layer4) + for i in range(len(b)): + for j in b[i].modules(): + for k in j.parameters(): + yield k + +def get_non_ignored_params(model): + # Generator function that yields params that will be optimized. + b = [] + b.append(model.fc_yaw) + b.append(model.fc_pitch) + b.append(model.fc_roll) + for i in range(len(b)): + for j in b[i].modules(): + for k in j.parameters(): + yield k + +def load_filtered_state_dict(model, snapshot): + # By user apaszke from discuss.pytorch.org + model_dict = model.state_dict() + # 1. filter out unnecessary keys + snapshot = {k: v for k, v in snapshot.items() if k in model_dict} + # 2. overwrite entries in the existing state dict + model_dict.update(snapshot) + # 3. load the new state dict + model.load_state_dict(model_dict) + if __name__ == '__main__': args = parse_args() @@ -47,21 +91,16 @@ if not os.path.exists('output/snapshots'): os.makedirs('output/snapshots') - model = torchvision.models.resnet18(pretrained=True) - for param in model.parameters(): - param.requires_grad = False - # Parameters of newly constructed modules have requires_grad=True by default - num_ftrs = model.fc.in_features - model.fc_pitch = nn.Linear(num_ftrs, 3) - model.fc_yaw = nn.Linear(num_ftrs, 3) - model.fc_roll = nn.Linear(num_ftrs, ) - + # ResNet18 with 3 outputs. + model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66) + load_filtered_state_dict(model, model_zoo.load_url(model_urls['resnet18'])) + print 'Loading data.' - transformations = transforms.Compose([transforms.Scale(230),transforms.RandomCrop(224), + transformations = transforms.Compose([transforms.Scale(224),transforms.RandomCrop(224), transforms.ToTensor()]) - pose_dataset = Pose_300W_LP(args.data_dir, args.filename_list, + pose_dataset = datasets.Pose_300W_LP_binned(args.data_dir, args.filename_list, transformations) train_loader = torch.utils.data.DataLoader(dataset=pose_dataset, batch_size=batch_size, @@ -69,31 +108,40 @@ num_workers=2) model.cuda(gpu) - criterion = nn.MSELoss(size_average = True) - optimizer = torch.optim.Adam(model.fc.parameters(), lr = args.lr) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': .0}, + {'params': get_non_ignored_params(model), 'lr': args.lr}], + lr = args.lr) print 'Ready to train network.' for epoch in range(num_epochs): - for i, (images, labels) in enumerate(train_loader): + for i, (images, labels, name) in enumerate(train_loader): images = Variable(images).cuda(gpu) - labels = Variable(labels).cuda(gpu) + label_yaw = Variable(labels[:,0]).cuda(gpu) + label_pitch = Variable(labels[:,1]).cuda(gpu) + label_roll = Variable(labels[:,2]).cuda(gpu) optimizer.zero_grad() - outputs = model(images) - loss = criterion(outputs, labels) - loss.backward() + yaw, pitch, roll = model(images) + loss_yaw = criterion(yaw, label_yaw) + loss_pitch = criterion(pitch, label_pitch) + loss_roll = criterion(roll, label_roll) + + loss_seq = [loss_yaw, loss_pitch, loss_roll] + grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))] + torch.autograd.backward(loss_seq, grad_seq) optimizer.step() if (i+1) % 100 == 0: - print ('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f' - %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss.data[0])) + print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f' + %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0])) # Save models at even numbered epochs. if epoch % 1 == 0 and epoch < num_epochs - 1: print 'Taking snapshot...' torch.save(model.state_dict(), - 'output/snapshots/resnet18_epoch_' + str(epoch+1) + '.pkl') + 'output/snapshots/resnet18_binned_epoch_' + str(epoch+1) + '.pkl') # Save the final Trained Model - torch.save(model.state_dict(), 'output/snapshots/resnet18_epoch_' + str(epoch+1) + '.pkl') + torch.save(model.state_dict(), 'output/snapshots/resnet18_binned_epoch_' + str(epoch+1) + '.pkl') -- Gitblit v1.8.0