From 6f71fb102f509d705d3abaa1f44638a19f57e92e Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期一, 07 八月 2017 05:15:52 +0800 Subject: [PATCH] next --- code/train_resnet_bins.py | 12 +++++++++--- 1 files changed, 9 insertions(+), 3 deletions(-) diff --git a/code/train_resnet_bins.py b/code/train_resnet_bins.py index f33ffd6..6b07747 100644 --- a/code/train_resnet_bins.py +++ b/code/train_resnet_bins.py @@ -109,7 +109,13 @@ model.cuda(gpu) criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': args.lr}, + # optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': args.lr}, + # {'params': get_non_ignored_params(model), 'lr': args.lr * 10}], + # lr = args.lr) + # optimizer = torch.optim.SGD([{'params': get_ignored_params(model), 'lr': args.lr}, + # {'params': get_non_ignored_params(model), 'lr': args.lr}], + # lr = args.lr, momentum=0.9) + optimizer = torch.optim.RMSprop([{'params': get_ignored_params(model), 'lr': args.lr}, {'params': get_non_ignored_params(model), 'lr': args.lr * 10}], lr = args.lr) @@ -141,7 +147,7 @@ if epoch % 1 == 0 and epoch < num_epochs - 1: print 'Taking snapshot...' torch.save(model.state_dict(), - 'output/snapshots/resnet50_binned_epoch_' + str(epoch+1) + '.pkl') + 'output/snapshots/resnet50_binned_RMSprop_epoch_' + str(epoch+1) + '.pkl') # Save the final Trained Model - torch.save(model.state_dict(), 'output/snapshots/resnet50_binned_epoch_' + str(epoch+1) + '.pkl') + torch.save(model.state_dict(), 'output/snapshots/resnet50_binned_RMSprop_epoch_' + str(epoch+1) + '.pkl') -- Gitblit v1.8.0