From beb9f36419d0df03c3248757f54af032a633e05c Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期六, 12 八月 2017 11:49:10 +0800 Subject: [PATCH] AFLW training ready. --- code/train_AFLW.py | 17 ++++++++++------- 1 files changed, 10 insertions(+), 7 deletions(-) diff --git a/code/train_AFLW.py b/code/train_AFLW.py index 13bfc29..65ea83c 100644 --- a/code/train_AFLW.py +++ b/code/train_AFLW.py @@ -41,6 +41,10 @@ default='', type=str) parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.', default='', type=str) + parser.add_argument('--finetune', dest='finetune', help='Boolean: finetune or from Imagenet pretrain.', + default=False, type=bool) + parser.add_argument('--snapshot', dest='snapshot', help='Path to finetune snapshot.', + default='', type=str) args = parser.parse_args() @@ -105,7 +109,7 @@ transformations = transforms.Compose([transforms.Scale(224),transforms.RandomCrop(224), transforms.ToTensor()]) - pose_dataset = datasets.Pose_300W_LP_binned(args.data_dir, args.filename_list, + pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations) train_loader = torch.utils.data.DataLoader(dataset=pose_dataset, batch_size=batch_size, @@ -113,10 +117,10 @@ num_workers=2) model.cuda(gpu) - criterion = nn.CrossEntropyLoss().cuda() - reg_criterion = nn.MSELoss().cuda() + criterion = nn.CrossEntropyLoss().cuda(gpu) + reg_criterion = nn.MSELoss().cuda(gpu) # Regression loss coefficient - alpha = 0.01 + alpha = 0.1 idx_tensor = [idx for idx in xrange(66)] idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu) @@ -135,7 +139,6 @@ label_roll = Variable(labels[:,2].cuda(gpu)) optimizer.zero_grad() - model.zero_grad() yaw, pitch, roll = model(images) @@ -175,13 +178,13 @@ %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0])) if epoch == 0: torch.save(model.state_dict(), - 'output/snapshots/resnet50_AFW_iter_'+ str(i+1) + '.pkl') + 'output/snapshots/resnet50_AFLW_iter_'+ str(i+1) + '.pkl') # Save models at numbered epochs. if epoch % 1 == 0 and epoch < num_epochs - 1: print 'Taking snapshot...' torch.save(model.state_dict(), - 'output/snapshots/resnet50_AFW_epoch_'+ str(epoch+1) + '.pkl') + 'output/snapshots/resnet50_AFLW_epoch_'+ str(epoch+1) + '.pkl') # Save the final Trained Model torch.save(model.state_dict(), 'output/snapshots/resnet50_AFLW_epoch' + str(epoch+1) + '.pkl') -- Gitblit v1.8.0