From 92ed4cb2ea68be44b1ff153e00410c2082ee62df Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期二, 15 八月 2017 00:28:28 +0800
Subject: [PATCH] New experiments with hourglass

---
 code/train_shape.py |    5 ++---
 1 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/code/train_shape.py b/code/train_shape.py
index f6baddf..fcebadb 100644
--- a/code/train_shape.py
+++ b/code/train_shape.py
@@ -116,7 +116,7 @@
     transformations = transforms.Compose([transforms.Scale(224),transforms.RandomCrop(224),
                                           transforms.ToTensor()])
 
-    pose_dataset = datasets.Pose_300W_LP_binned(args.data_dir, args.filename_list,
+    pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list,
                                 transformations)
     train_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
                                                batch_size=batch_size,
@@ -128,13 +128,12 @@
     reg_criterion = nn.MSELoss().cuda(gpu)
     # Regression loss coefficient
     alpha = 0.1
-    lsm = nn.Softmax()
 
     idx_tensor = [idx for idx in xrange(66)]
     idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
 
     optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': args.lr},
-                                  {'params': get_non_ignored_params(model), 'lr': args.lr}],
+                                  {'params': get_non_ignored_params(model), 'lr': args.lr * 10}],
                                   lr = args.lr)
 
     print 'Ready to train network.'

--
Gitblit v1.8.0