From 653b3608ebe6272510b4c66f445f6f552fdc9ec9 Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期一, 11 九月 2017 05:53:10 +0800 Subject: [PATCH] Starting serious experiment without regression or iterative finetuning --- code/train.py | 78 +++++++++++++++++++++++---------------- 1 files changed, 46 insertions(+), 32 deletions(-) diff --git a/code/train.py b/code/train.py index 826793d..e339b10 100644 --- a/code/train.py +++ b/code/train.py @@ -43,6 +43,9 @@ default='', type=str) parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.', default='', type=str) + parser.add_argument('--output_string', dest='output_string', help='String appended to output snapshots.', default = '', type=str) + parser.add_argument('--alpha', dest='alpha', help='Regression loss coefficient.', + default=0.001, type=float) args = parser.parse_args() return args @@ -51,26 +54,37 @@ b = [] b.append(model.conv1) b.append(model.bn1) + for i in range(len(b)): + for module_name, module in b[i].named_modules(): + if 'bn' in module_name: + module.eval() + for name, param in module.named_parameters(): + yield param + +def get_non_ignored_params(model): + # Generator function that yields params that will be optimized. + b = [] b.append(model.layer1) b.append(model.layer2) b.append(model.layer3) b.append(model.layer4) for i in range(len(b)): - for j in b[i].modules(): - for k in j.parameters(): - yield k + for module_name, module in b[i].named_modules(): + if 'bn' in module_name: + module.eval() + for name, param in module.named_parameters(): + yield param -def get_non_ignored_params(model): - # Generator function that yields params that will be optimized. +def get_fc_params(model): b = [] b.append(model.fc_yaw) b.append(model.fc_pitch) b.append(model.fc_roll) b.append(model.fc_finetune) for i in range(len(b)): - for j in b[i].modules(): - for k in j.parameters(): - yield k + for module_name, module in b[i].named_modules(): + for name, param in module.named_parameters(): + yield param def load_filtered_state_dict(model, snapshot): # By user apaszke from discuss.pytorch.org @@ -104,11 +118,7 @@ print 'Loading data.' - # transformations = transforms.Compose([transforms.Scale(224), - # transforms.RandomCrop(224), - # transforms.ToTensor()]) - - transformations = transforms.Compose([transforms.Scale(250), + transformations = transforms.Compose([transforms.Scale(240), transforms.RandomCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) @@ -120,17 +130,19 @@ num_workers=2) model.cuda(gpu) + softmax = nn.Softmax() criterion = nn.CrossEntropyLoss().cuda() reg_criterion = nn.MSELoss().cuda() # Regression loss coefficient - alpha = 0.01 + alpha = 0.00 idx_tensor = [idx for idx in xrange(66)] - idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu) + idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu) - optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': args.lr}, - {'params': get_non_ignored_params(model), 'lr': args.lr * 10}], - lr = args.lr) + optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0}, + {'params': get_non_ignored_params(model), 'lr': args.lr}, + {'params': get_fc_params(model), 'lr': args.lr * 2}], + lr = args.lr) print 'Ready to train network.' @@ -153,24 +165,26 @@ loss_roll = criterion(pre_roll, label_roll) # MSE loss - yaw_predicted = F.softmax(pre_yaw) - pitch_predicted = F.softmax(pre_pitch) - roll_predicted = F.softmax(pre_roll) + yaw_predicted = softmax(pre_yaw) + pitch_predicted = softmax(pre_pitch) + roll_predicted = softmax(pre_roll) - yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1) - pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1) - roll_predicted = torch.sum(roll_predicted.data * idx_tensor, 1) + yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) + pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) + roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw.float()) loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch.float()) loss_reg_roll = reg_criterion(roll_predicted, label_roll.float()) + # print yaw_predicted, label_yaw.float(), loss_reg_yaw # Total loss loss_yaw += alpha * loss_reg_yaw loss_pitch += alpha * loss_reg_pitch loss_roll += alpha * loss_reg_roll loss_seq = [loss_yaw, loss_pitch, loss_roll] + # loss_seq = [loss_reg_yaw, loss_reg_pitch, loss_reg_roll] grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))] torch.autograd.backward(loss_seq, grad_seq) optimizer.step() @@ -180,13 +194,13 @@ %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0])) # if epoch == 0: # torch.save(model.state_dict(), - # 'output/snapshots/hopenet50_epoch_'+ str(i+1) + '.pkl') + # 'output/snapshots/' + args.output_string + '_iter_'+ str(i+1) + '.pkl') # Save models at numbered epochs. if epoch % 1 == 0 and epoch < num_epochs: print 'Taking snapshot...' torch.save(model.state_dict(), - 'output/snapshots/hopenet50_epoch_'+ str(epoch+1) + '.pkl') + 'output/snapshots/' + args.output_string + '_epoch_'+ str(epoch+1) + '.pkl') print 'Second phase of training (finetuning layer).' for epoch in range(num_epochs_ft): @@ -208,9 +222,9 @@ loss_roll = criterion(pre_roll, label_roll) # MSE loss - yaw_predicted = F.softmax(pre_yaw) - pitch_predicted = F.softmax(pre_pitch) - roll_predicted = F.softmax(pre_roll) + yaw_predicted = softmax(pre_yaw) + pitch_predicted = softmax(pre_pitch) + roll_predicted = softmax(pre_roll) yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1) pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1) @@ -238,14 +252,14 @@ %(epoch+1, num_epochs_ft, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0], loss_angles.data[0])) # if epoch == 0: # torch.save(model.state_dict(), - # 'output/snapshots/hopenet50_iter_'+ str(i+1) + '.pkl') + # 'output/snapshots/' + args.output_string + '_iter_'+ str(i+1) + '.pkl') # Save models at numbered epochs. if epoch % 1 == 0 and epoch < num_epochs_ft - 1: print 'Taking snapshot...' torch.save(model.state_dict(), - 'output/snapshots/hopenet50_epoch_'+ str(num_epochs+epoch+1) + '.pkl') + 'output/snapshots/' + args.output_string + '_epoch_'+ str(num_epochs+epoch+1) + '.pkl') # Save the final Trained Model - torch.save(model.state_dict(), 'output/snapshots/hopenet50_epoch_' + str(num_epochs+epoch+1) + '.pkl') + torch.save(model.state_dict(), 'output/snapshots/' + args.output_string + '_epoch_' + str(num_epochs+epoch+1) + '.pkl') -- Gitblit v1.8.0