From dd62d6fa4a85f18a29de009a972f5599b19ec946 Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期四, 14 九月 2017 00:51:53 +0800
Subject: [PATCH] Fixing hopenet

---
 code/train.py |   96 +++++++++++++++++++++++++++---------------------
 1 files changed, 54 insertions(+), 42 deletions(-)

diff --git a/code/train.py b/code/train.py
index 826793d..6e1ae5b 100644
--- a/code/train.py
+++ b/code/train.py
@@ -43,6 +43,11 @@
           default='', type=str)
     parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
           default='', type=str)
+    parser.add_argument('--output_string', dest='output_string', help='String appended to output snapshots.', default = '', type=str)
+    parser.add_argument('--alpha', dest='alpha', help='Regression loss coefficient.',
+          default=0.001, type=float)
+    parser.add_argument('--iter_ref', dest='iter_ref', help='Number of iterative refinement passes.',
+          default=1, type=int)
     args = parser.parse_args()
     return args
 
@@ -51,26 +56,37 @@
     b = []
     b.append(model.conv1)
     b.append(model.bn1)
+    for i in range(len(b)):
+        for module_name, module in b[i].named_modules():
+            if 'bn' in module_name:
+                module.eval()
+            for name, param in module.named_parameters():
+                yield param
+
+def get_non_ignored_params(model):
+    # Generator function that yields params that will be optimized.
+    b = []
     b.append(model.layer1)
     b.append(model.layer2)
     b.append(model.layer3)
     b.append(model.layer4)
     for i in range(len(b)):
-        for j in b[i].modules():
-            for k in j.parameters():
-                yield k
+        for module_name, module in b[i].named_modules():
+            if 'bn' in module_name:
+                module.eval()
+            for name, param in module.named_parameters():
+                yield param
 
-def get_non_ignored_params(model):
-    # Generator function that yields params that will be optimized.
+def get_fc_params(model):
     b = []
     b.append(model.fc_yaw)
     b.append(model.fc_pitch)
     b.append(model.fc_roll)
     b.append(model.fc_finetune)
     for i in range(len(b)):
-        for j in b[i].modules():
-            for k in j.parameters():
-                    yield k
+        for module_name, module in b[i].named_modules():
+            for name, param in module.named_parameters():
+                yield param
 
 def load_filtered_state_dict(model, snapshot):
     # By user apaszke from discuss.pytorch.org
@@ -97,18 +113,14 @@
     # ResNet101 with 3 outputs
     # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
     # ResNet50
-    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
+    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66, args.iter_ref)
     # ResNet18
     # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
     load_filtered_state_dict(model, model_zoo.load_url(model_urls['resnet50']))
 
     print 'Loading data.'
 
-    # transformations = transforms.Compose([transforms.Scale(224),
-    #                                       transforms.RandomCrop(224),
-    #                                       transforms.ToTensor()])
-
-    transformations = transforms.Compose([transforms.Scale(250),
+    transformations = transforms.Compose([transforms.Scale(240),
     transforms.RandomCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
@@ -120,17 +132,19 @@
                                                num_workers=2)
 
     model.cuda(gpu)
+    softmax = nn.Softmax()
     criterion = nn.CrossEntropyLoss().cuda()
     reg_criterion = nn.MSELoss().cuda()
     # Regression loss coefficient
-    alpha = 0.01
+    alpha = args.alpha
 
     idx_tensor = [idx for idx in xrange(66)]
-    idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
+    idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
 
-    optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': args.lr},
-                                  {'params': get_non_ignored_params(model), 'lr': args.lr * 10}],
-                                  lr = args.lr)
+    optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0},
+                                  {'params': get_non_ignored_params(model), 'lr': args.lr},
+                                  {'params': get_fc_params(model), 'lr': args.lr * 2}],
+                                   lr = args.lr)
 
     print 'Ready to train network.'
 
@@ -153,13 +167,13 @@
             loss_roll = criterion(pre_roll, label_roll)
 
             # MSE loss
-            yaw_predicted = F.softmax(pre_yaw)
-            pitch_predicted = F.softmax(pre_pitch)
-            roll_predicted = F.softmax(pre_roll)
+            yaw_predicted = softmax(pre_yaw)
+            pitch_predicted = softmax(pre_pitch)
+            roll_predicted = softmax(pre_roll)
 
-            yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1)
-            pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1)
-            roll_predicted = torch.sum(roll_predicted.data * idx_tensor, 1)
+            yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1)
+            pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1)
+            roll_predicted = torch.sum(roll_predicted * idx_tensor, 1)
 
             loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw.float())
             loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch.float())
@@ -180,13 +194,13 @@
                        %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0]))
                 # if epoch == 0:
                 #     torch.save(model.state_dict(),
-                #     'output/snapshots/hopenet50_epoch_'+ str(i+1) + '.pkl')
+                #     'output/snapshots/' + args.output_string + '_iter_'+ str(i+1) + '.pkl')
 
         # Save models at numbered epochs.
         if epoch % 1 == 0 and epoch < num_epochs:
             print 'Taking snapshot...'
             torch.save(model.state_dict(),
-            'output/snapshots/hopenet50_epoch_'+ str(epoch+1) + '.pkl')
+            'output/snapshots/' + args.output_string + '_epoch_'+ str(epoch+1) + '.pkl')
 
     print 'Second phase of training (finetuning layer).'
     for epoch in range(num_epochs_ft):
@@ -208,13 +222,13 @@
             loss_roll = criterion(pre_roll, label_roll)
 
             # MSE loss
-            yaw_predicted = F.softmax(pre_yaw)
-            pitch_predicted = F.softmax(pre_pitch)
-            roll_predicted = F.softmax(pre_roll)
+            yaw_predicted = softmax(pre_yaw)
+            pitch_predicted = softmax(pre_pitch)
+            roll_predicted = softmax(pre_roll)
 
-            yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1)
-            pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1)
-            roll_predicted = torch.sum(roll_predicted.data * idx_tensor, 1)
+            yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1)
+            pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1)
+            roll_predicted = torch.sum(roll_predicted * idx_tensor, 1)
 
             loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw.float())
             loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch.float())
@@ -226,9 +240,11 @@
             loss_roll += alpha * loss_reg_roll
 
             # Finetuning loss
-            loss_angles = reg_criterion(angles[0], label_angles.float())
+            loss_seq = [loss_yaw, loss_pitch, loss_roll]
+            for idx in xrange(args.iter_ref+1):
+                loss_angles = reg_criterion(angles[idx], label_angles.float())
+                loss_seq.append(loss_angles)
 
-            loss_seq = [loss_yaw, loss_pitch, loss_roll, loss_angles]
             grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))]
             torch.autograd.backward(loss_seq, grad_seq)
             optimizer.step()
@@ -238,14 +254,10 @@
                        %(epoch+1, num_epochs_ft, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0], loss_angles.data[0]))
                 # if epoch == 0:
                 #     torch.save(model.state_dict(),
-                #     'output/snapshots/hopenet50_iter_'+ str(i+1) + '.pkl')
+                #     'output/snapshots/' + args.output_string + '_iter_'+ str(i+1) + '.pkl')
 
         # Save models at numbered epochs.
-        if epoch % 1 == 0 and epoch < num_epochs_ft - 1:
+        if epoch % 1 == 0 and epoch < num_epochs_ft:
             print 'Taking snapshot...'
             torch.save(model.state_dict(),
-            'output/snapshots/hopenet50_epoch_'+ str(num_epochs+epoch+1) + '.pkl')
-
-
-    # Save the final Trained Model
-    torch.save(model.state_dict(), 'output/snapshots/hopenet50_epoch_' + str(num_epochs+epoch+1) + '.pkl')
+            'output/snapshots/' + args.output_string + '_epoch_'+ str(num_epochs+epoch+1) + '.pkl')

--
Gitblit v1.8.0