chenshijun
2019-06-05 f111cb002b9c6065fdf6bb274ce5857a9e875e8c
code/test_alexnet.py
@@ -36,13 +36,6 @@
    return args
def load_filtered_state_dict(model, snapshot):
    # By user apaszke from discuss.pytorch.org
    model_dict = model.state_dict()
    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
    model_dict.update(snapshot)
    model.load_state_dict(model_dict)
if __name__ == '__main__':
    args = parse_args()
@@ -52,12 +45,12 @@
    model = hopenet.AlexNet(66)
    print 'Loading snapshot.'
    print('Loading snapshot.')
    # Load snapshot
    saved_state_dict = torch.load(snapshot_path)
    load_filtered_state_dict(model, saved_state_dict)
    model.load_state_dict(saved_state_dict)
    print 'Loading data.'
    print('Loading data.')
    transformations = transforms.Compose([transforms.Scale(224),
    transforms.CenterCrop(224), transforms.ToTensor(),
@@ -80,7 +73,7 @@
    elif args.dataset == 'AFW':
        pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
    else:
        print 'Error: not a valid dataset name'
        print('Error: not a valid dataset name')
        sys.exit()
    test_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
                                               batch_size=args.batch_size,
@@ -88,7 +81,7 @@
    model.cuda(gpu)
    print 'Ready to test network.'
    print ('Ready to test network.')
    # Test the Model
    model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).