From af51d0ecb51ad4d6c8ed086855bd3c411ebc4aa0 Mon Sep 17 00:00:00 2001 From: natanielruiz <nruiz9@gatech.edu> Date: 星期一, 30 十月 2017 06:29:51 +0800 Subject: [PATCH] Fixed stuff --- code/test_alexnet.py | 25 ++++++++---- code/test_resnet50_regression.py | 46 +++++++++++++---------- code/train_alexnet.py | 22 +++++++---- code/test_preangles.py | 11 +++++ 4 files changed, 67 insertions(+), 37 deletions(-) diff --git a/code/test_alexnet.py b/code/test_alexnet.py index d9cc0a3..7a3989a 100644 --- a/code/test_alexnet.py +++ b/code/test_alexnet.py @@ -39,6 +39,13 @@ return args +def load_filtered_state_dict(model, snapshot): + # By user apaszke from discuss.pytorch.org + model_dict = model.state_dict() + snapshot = {k: v for k, v in snapshot.items() if k in model_dict} + model_dict.update(snapshot) + model.load_state_dict(model_dict) + if __name__ == '__main__': args = parse_args() @@ -51,7 +58,7 @@ print 'Loading snapshot.' # Load snapshot saved_state_dict = torch.load(snapshot_path) - model.load_state_dict(saved_state_dict) + load_filtered_state_dict(model, saved_state_dict) print 'Loading data.' @@ -59,18 +66,20 @@ transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) - if args.dataset == 'AFLW2000': - pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, - transformations) + if args.dataset == 'Pose_300W_LP': + pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations) + elif args.dataset == 'Pose_300W_LP_random_ds': + pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations) + elif args.dataset == 'AFLW2000': + pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFLW2000_ds': - pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, - transformations) + pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, transformations) elif args.dataset == 'BIWI': pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFLW': pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations) - elif args.dataset == 'Pose_300W_LP': - pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations) + elif args.dataset == 'AFLW_aug': + pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFW': pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations) else: diff --git a/code/test_preangles.py b/code/test_preangles.py index be9bfda..05f621a 100644 --- a/code/test_preangles.py +++ b/code/test_preangles.py @@ -36,6 +36,13 @@ return args +def load_filtered_state_dict(model, snapshot): + # By user apaszke from discuss.pytorch.org + model_dict = model.state_dict() + snapshot = {k: v for k, v in snapshot.items() if k in model_dict} + model_dict.update(snapshot) + model.load_state_dict(model_dict) + if __name__ == '__main__': args = parse_args() @@ -49,7 +56,7 @@ print 'Loading snapshot.' # Load snapshot saved_state_dict = torch.load(snapshot_path) - model.load_state_dict(saved_state_dict) + load_filtered_state_dict(model, saved_state_dict) print 'Loading data.' @@ -63,6 +70,8 @@ pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFLW2000': pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations) + elif args.dataset == 'AFLW2000_ds': + pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, transformations) elif args.dataset == 'BIWI': pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFLW': diff --git a/code/test_resnet50_regression.py b/code/test_resnet50_regression.py index 85207f8..6945269 100644 --- a/code/test_resnet50_regression.py +++ b/code/test_resnet50_regression.py @@ -1,4 +1,9 @@ +import sys, os, argparse + import numpy as np +import cv2 +import matplotlib.pyplot as plt + import torch import torch.nn as nn from torch.autograd import Variable @@ -8,15 +13,7 @@ import torchvision import torch.nn.functional as F -import cv2 -import matplotlib.pyplot as plt -import sys -import os -import argparse - -import datasets -import hopenet -import utils +import datasets, hopenet, utils def parse_args(): """Parse input arguments.""" @@ -39,6 +36,13 @@ return args +def load_filtered_state_dict(model, snapshot): + # By user apaszke from discuss.pytorch.org + model_dict = model.state_dict() + snapshot = {k: v for k, v in snapshot.items() if k in model_dict} + model_dict.update(snapshot) + model.load_state_dict(model_dict) + if __name__ == '__main__': args = parse_args() @@ -51,7 +55,7 @@ print 'Loading snapshot.' # Load snapshot saved_state_dict = torch.load(snapshot_path) - model.load_state_dict(saved_state_dict) + load_filtered_state_dict(model, saved_state_dict) print 'Loading data.' @@ -59,18 +63,20 @@ transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) - if args.dataset == 'AFLW2000': - pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, - transformations) + if args.dataset == 'Pose_300W_LP': + pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations) + elif args.dataset == 'Pose_300W_LP_random_ds': + pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations) + elif args.dataset == 'AFLW2000': + pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFLW2000_ds': - pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, - transformations) + pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, transformations) elif args.dataset == 'BIWI': pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFLW': pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations) - elif args.dataset == 'Pose_300W_LP': - pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations) + elif args.dataset == 'AFLW_aug': + pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFW': pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations) else: @@ -111,8 +117,7 @@ pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch)) roll_error += torch.sum(torch.abs(roll_predicted - label_roll)) - # Save images with pose cube. - # TODO: fix for larger batch size + # Save first image in batch with pose cube or axis. if args.save_viz: name = name[0] if args.dataset == 'BIWI': @@ -122,7 +127,8 @@ if args.batch_size == 1: error_string = 'y %.2f, p %.2f, r %.2f' % (torch.sum(torch.abs(yaw_predicted - label_yaw)), torch.sum(torch.abs(pitch_predicted - label_pitch)), torch.sum(torch.abs(roll_predicted - label_roll))) cv2.putText(cv2_img, error_string, (30, cv2_img.shape[0]- 30), fontFace=1, fontScale=1, color=(0,0,255), thickness=1) - utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0]) + # utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], size=100) + utils.draw_axis(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], tdx = 200, tdy= 200, size=100) cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img) print('Test error in degrees of the model on the ' + str(total) + diff --git a/code/train_alexnet.py b/code/train_alexnet.py index 9254ee7..51cf43b 100644 --- a/code/train_alexnet.py +++ b/code/train_alexnet.py @@ -129,6 +129,9 @@ # Regression loss coefficient alpha = args.alpha + idx_tensor = [idx for idx in xrange(66)] + idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu) + optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0}, {'params': get_non_ignored_params(model), 'lr': args.lr}, {'params': get_fc_params(model), 'lr': args.lr * 5}], @@ -150,17 +153,21 @@ label_roll_cont = Variable(cont_labels[:,2]).cuda(gpu) # Forward pass - yaw, pitch, roll, angles = model(images) + pre_yaw, pre_pitch, pre_roll = model(images) # Cross entropy loss - loss_yaw = criterion(yaw, label_yaw) - loss_pitch = criterion(pitch, label_pitch) - loss_roll = criterion(roll, label_roll) + loss_yaw = criterion(pre_yaw, label_yaw) + loss_pitch = criterion(pre_pitch, label_pitch) + loss_roll = criterion(pre_roll, label_roll) # MSE loss - yaw_predicted = angles[:,0] - pitch_predicted = angles[:,1] - roll_predicted = angles[:,2] + yaw_predicted = softmax(pre_yaw) + pitch_predicted = softmax(pre_pitch) + roll_predicted = softmax(pre_roll) + + yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 99 + pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 99 + roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) * 3 - 99 loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw_cont) loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch_cont) @@ -173,7 +180,6 @@ loss_seq = [loss_yaw, loss_pitch, loss_roll] grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))] - optimizer.zero_grad() torch.autograd.backward(loss_seq, grad_seq) optimizer.step() -- Gitblit v1.8.0