From f111cb002b9c6065fdf6bb274ce5857a9e875e8c Mon Sep 17 00:00:00 2001 From: chenshijun <csj_sky@126.com> Date: 星期三, 05 六月 2019 15:38:49 +0800 Subject: [PATCH] face rectangle --- code/test_alexnet.py | 17 +++++------------ 1 files changed, 5 insertions(+), 12 deletions(-) diff --git a/code/test_alexnet.py b/code/test_alexnet.py index 529d566..45ad25f 100644 --- a/code/test_alexnet.py +++ b/code/test_alexnet.py @@ -36,13 +36,6 @@ return args -def load_filtered_state_dict(model, snapshot): - # By user apaszke from discuss.pytorch.org - model_dict = model.state_dict() - snapshot = {k: v for k, v in snapshot.items() if k in model_dict} - model_dict.update(snapshot) - model.load_state_dict(model_dict) - if __name__ == '__main__': args = parse_args() @@ -52,12 +45,12 @@ model = hopenet.AlexNet(66) - print 'Loading snapshot.' + print('Loading snapshot.') # Load snapshot saved_state_dict = torch.load(snapshot_path) - load_filtered_state_dict(model, saved_state_dict) + model.load_state_dict(saved_state_dict) - print 'Loading data.' + print('Loading data.') transformations = transforms.Compose([transforms.Scale(224), transforms.CenterCrop(224), transforms.ToTensor(), @@ -80,7 +73,7 @@ elif args.dataset == 'AFW': pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations) else: - print 'Error: not a valid dataset name' + print('Error: not a valid dataset name') sys.exit() test_loader = torch.utils.data.DataLoader(dataset=pose_dataset, batch_size=args.batch_size, @@ -88,7 +81,7 @@ model.cuda(gpu) - print 'Ready to test network.' + print ('Ready to test network.') # Test the Model model.eval() # Change model to 'eval' mode (BN uses moving mean/var). -- Gitblit v1.8.0