From af51d0ecb51ad4d6c8ed086855bd3c411ebc4aa0 Mon Sep 17 00:00:00 2001 From: natanielruiz <nruiz9@gatech.edu> Date: 星期一, 30 十月 2017 06:29:51 +0800 Subject: [PATCH] Fixed stuff --- code/train_preangles.py | 96 ++++++++++++++---------------------------------- 1 files changed, 28 insertions(+), 68 deletions(-) diff --git a/code/train_preangles.py b/code/train_preangles.py index 6622d3f..1fe626c 100644 --- a/code/train_preangles.py +++ b/code/train_preangles.py @@ -1,4 +1,9 @@ +import sys, os, argparse, time + import numpy as np +import cv2 +import matplotlib.pyplot as plt + import torch import torch.nn as nn from torch.autograd import Variable @@ -8,25 +13,8 @@ import torch.backends.cudnn as cudnn import torch.nn.functional as F -import cv2 -import matplotlib.pyplot as plt -import sys -import os -import argparse - -import datasets -import hopenet +import datasets, hopenet import torch.utils.model_zoo as model_zoo - -import time - -model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', -} def parse_args(): """Parse input arguments.""" @@ -53,10 +41,7 @@ def get_ignored_params(model): # Generator function that yields ignored params. - b = [] - b.append(model.conv1) - b.append(model.bn1) - b.append(model.fc_finetune) + b = [model.conv1, model.bn1, model.fc_finetune] for i in range(len(b)): for module_name, module in b[i].named_modules(): if 'bn' in module_name: @@ -66,11 +51,7 @@ def get_non_ignored_params(model): # Generator function that yields params that will be optimized. - b = [] - b.append(model.layer1) - b.append(model.layer2) - b.append(model.layer3) - b.append(model.layer4) + b = [model.layer1, model.layer2, model.layer3, model.layer4] for i in range(len(b)): for module_name, module in b[i].named_modules(): if 'bn' in module_name: @@ -79,10 +60,8 @@ yield param def get_fc_params(model): - b = [] - b.append(model.fc_yaw) - b.append(model.fc_pitch) - b.append(model.fc_roll) + # Generator function that yields fc layer params. + b = [model.fc_yaw, model.fc_pitch, model.fc_roll] for i in range(len(b)): for module_name, module in b[i].named_modules(): for name, param in module.named_parameters(): @@ -91,11 +70,8 @@ def load_filtered_state_dict(model, snapshot): # By user apaszke from discuss.pytorch.org model_dict = model.state_dict() - # 1. filter out unnecessary keys snapshot = {k: v for k, v in snapshot.items() if k in model_dict} - # 2. overwrite entries in the existing state dict model_dict.update(snapshot) - # 3. load the new state dict model.load_state_dict(model_dict) if __name__ == '__main__': @@ -109,13 +85,9 @@ if not os.path.exists('output/snapshots'): os.makedirs('output/snapshots') - # ResNet101 with 3 outputs - # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66) - # ResNet50 - model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66, 0) - # ResNet18 - # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66) - load_filtered_state_dict(model, model_zoo.load_url(model_urls['resnet50'])) + # ResNet50 structure + model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66) + load_filtered_state_dict(model, model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth')) print 'Loading data.' @@ -125,6 +97,8 @@ if args.dataset == 'Pose_300W_LP': pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations) + elif args.dataset == 'Pose_300W_LP_random_ds': + pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFLW2000': pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations) elif args.dataset == 'BIWI': @@ -138,20 +112,17 @@ else: print 'Error: not a valid dataset name' sys.exit() + train_loader = torch.utils.data.DataLoader(dataset=pose_dataset, batch_size=batch_size, shuffle=True, num_workers=2) model.cuda(gpu) - softmax = nn.Softmax().cuda(gpu) criterion = nn.CrossEntropyLoss().cuda(gpu) reg_criterion = nn.MSELoss().cuda(gpu) # Regression loss coefficient alpha = args.alpha - - idx_tensor = [idx for idx in xrange(66)] - idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu) optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0}, {'params': get_non_ignored_params(model), 'lr': args.lr}, @@ -159,39 +130,32 @@ lr = args.lr) print 'Ready to train network.' - print 'First phase of training.' for epoch in range(num_epochs): - start = time.time() for i, (images, labels, cont_labels, name) in enumerate(train_loader): - print i - print 'start: ', time.time() - start images = Variable(images).cuda(gpu) + + # Binned labels label_yaw = Variable(labels[:,0]).cuda(gpu) label_pitch = Variable(labels[:,1]).cuda(gpu) label_roll = Variable(labels[:,2]).cuda(gpu) - label_angles = Variable(cont_labels[:,:3]).cuda(gpu) + # Continuous labels label_yaw_cont = Variable(cont_labels[:,0]).cuda(gpu) label_pitch_cont = Variable(cont_labels[:,1]).cuda(gpu) label_roll_cont = Variable(cont_labels[:,2]).cuda(gpu) - optimizer.zero_grad() - model.zero_grad() + # Forward pass + yaw, pitch, roll, angles = model(images) - pre_yaw, pre_pitch, pre_roll, angles = model(images) # Cross entropy loss - loss_yaw = criterion(pre_yaw, label_yaw) - loss_pitch = criterion(pre_pitch, label_pitch) - loss_roll = criterion(pre_roll, label_roll) + loss_yaw = criterion(yaw, label_yaw) + loss_pitch = criterion(pitch, label_pitch) + loss_roll = criterion(roll, label_roll) # MSE loss - yaw_predicted = softmax(pre_yaw) - pitch_predicted = softmax(pre_pitch) - roll_predicted = softmax(pre_roll) - - yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 99 - pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 99 - roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) * 3 - 99 + yaw_predicted = angles[:,0] + pitch_predicted = angles[:,1] + roll_predicted = angles[:,2] loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw_cont) loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch_cont) @@ -204,17 +168,13 @@ loss_seq = [loss_yaw, loss_pitch, loss_roll] grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))] + optimizer.zero_grad() torch.autograd.backward(loss_seq, grad_seq) optimizer.step() - - print 'end: ', time.time() - start if (i+1) % 100 == 0: print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f' %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0])) - # if epoch == 0: - # torch.save(model.state_dict(), - # 'output/snapshots/' + args.output_string + '_iter_'+ str(i+1) + '.pkl') # Save models at numbered epochs. if epoch % 1 == 0 and epoch < num_epochs: -- Gitblit v1.8.0