From 6f71fb102f509d705d3abaa1f44638a19f57e92e Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期一, 07 八月 2017 05:15:52 +0800
Subject: [PATCH] next

---
 code/train_resnet_bins.py |   12 +++++++++---
 1 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/code/train_resnet_bins.py b/code/train_resnet_bins.py
index f33ffd6..6b07747 100644
--- a/code/train_resnet_bins.py
+++ b/code/train_resnet_bins.py
@@ -109,7 +109,13 @@
 
     model.cuda(gpu)
     criterion = nn.CrossEntropyLoss()
-    optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': args.lr},
+    # optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': args.lr},
+    #                               {'params': get_non_ignored_params(model), 'lr': args.lr * 10}],
+    #                               lr = args.lr)
+    # optimizer = torch.optim.SGD([{'params': get_ignored_params(model), 'lr': args.lr},
+    #                               {'params': get_non_ignored_params(model), 'lr': args.lr}],
+    #                               lr = args.lr, momentum=0.9)
+    optimizer = torch.optim.RMSprop([{'params': get_ignored_params(model), 'lr': args.lr},
                                   {'params': get_non_ignored_params(model), 'lr': args.lr * 10}],
                                   lr = args.lr)
 
@@ -141,7 +147,7 @@
         if epoch % 1 == 0 and epoch < num_epochs - 1:
             print 'Taking snapshot...'
             torch.save(model.state_dict(),
-            'output/snapshots/resnet50_binned_epoch_' + str(epoch+1) + '.pkl')
+            'output/snapshots/resnet50_binned_RMSprop_epoch_' + str(epoch+1) + '.pkl')
 
     # Save the final Trained Model
-    torch.save(model.state_dict(), 'output/snapshots/resnet50_binned_epoch_' + str(epoch+1) + '.pkl')
+    torch.save(model.state_dict(), 'output/snapshots/resnet50_binned_RMSprop_epoch_' + str(epoch+1) + '.pkl')

--
Gitblit v1.8.0