From dd62d6fa4a85f18a29de009a972f5599b19ec946 Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期四, 14 九月 2017 00:51:53 +0800 Subject: [PATCH] Fixing hopenet --- code/train_AFLW.py | 28 +++++++++++++++++++--------- 1 files changed, 19 insertions(+), 9 deletions(-) diff --git a/code/train_AFLW.py b/code/train_AFLW.py index 13bfc29..f355f63 100644 --- a/code/train_AFLW.py +++ b/code/train_AFLW.py @@ -41,6 +41,10 @@ default='', type=str) parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.', default='', type=str) + parser.add_argument('--finetune', dest='finetune', help='Boolean: finetune or from Imagenet pretrain.', + default=False, type=bool) + parser.add_argument('--snapshot', dest='snapshot', help='Path to finetune snapshot.', + default='', type=str) args = parser.parse_args() @@ -98,14 +102,18 @@ model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66) # ResNet18 # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66) - load_filtered_state_dict(model, model_zoo.load_url(model_urls['resnet50'])) + + if args.finetune: + model.load_state_dict(torch.load(args.snapshot)) + else: + load_filtered_state_dict(model, model_zoo.load_url(model_urls['resnet50'])) print 'Loading data.' transformations = transforms.Compose([transforms.Scale(224),transforms.RandomCrop(224), transforms.ToTensor()]) - pose_dataset = datasets.Pose_300W_LP_binned(args.data_dir, args.filename_list, + pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations) train_loader = torch.utils.data.DataLoader(dataset=pose_dataset, batch_size=batch_size, @@ -113,10 +121,10 @@ num_workers=2) model.cuda(gpu) - criterion = nn.CrossEntropyLoss().cuda() - reg_criterion = nn.MSELoss().cuda() + criterion = nn.CrossEntropyLoss().cuda(gpu) + reg_criterion = nn.MSELoss().cuda(gpu) # Regression loss coefficient - alpha = 0.01 + alpha = 0.1 idx_tensor = [idx for idx in xrange(66)] idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu) @@ -124,6 +132,9 @@ optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': args.lr}, {'params': get_non_ignored_params(model), 'lr': args.lr * 10}], lr = args.lr) + # optimizer = torch.optim.SGD([{'params': get_ignored_params(model), 'lr': args.lr}, + # {'params': get_non_ignored_params(model), 'lr': args.lr * 10}], + # lr = args.lr, momentum = 0.9) print 'Ready to train network.' @@ -135,7 +146,6 @@ label_roll = Variable(labels[:,2].cuda(gpu)) optimizer.zero_grad() - model.zero_grad() yaw, pitch, roll = model(images) @@ -175,13 +185,13 @@ %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0])) if epoch == 0: torch.save(model.state_dict(), - 'output/snapshots/resnet50_AFW_iter_'+ str(i+1) + '.pkl') + 'output/snapshots/resnet50_AFLW_finetuned_iter_'+ str(i+1) + '.pkl') # Save models at numbered epochs. if epoch % 1 == 0 and epoch < num_epochs - 1: print 'Taking snapshot...' torch.save(model.state_dict(), - 'output/snapshots/resnet50_AFW_epoch_'+ str(epoch+1) + '.pkl') + 'output/snapshots/resnet50_AFLW_finetuned_epoch_'+ str(epoch+1) + '.pkl') # Save the final Trained Model - torch.save(model.state_dict(), 'output/snapshots/resnet50_AFLW_epoch' + str(epoch+1) + '.pkl') + torch.save(model.state_dict(), 'output/snapshots/resnet50_AFLW_finetuned_epoch_' + str(epoch+1) + '.pkl') -- Gitblit v1.8.0