| | |
| | | |
| | | return args |
| | | |
| | | def load_filtered_state_dict(model, snapshot): |
| | | # By user apaszke from discuss.pytorch.org |
| | | model_dict = model.state_dict() |
| | | snapshot = {k: v for k, v in snapshot.items() if k in model_dict} |
| | | model_dict.update(snapshot) |
| | | model.load_state_dict(model_dict) |
| | | |
| | | if __name__ == '__main__': |
| | | args = parse_args() |
| | | |
| | |
| | | # Load snapshot |
| | | saved_state_dict = torch.load(snapshot_path) |
| | | model.load_state_dict(saved_state_dict) |
| | | # load_filtered_state_dict(model, saved_state_dict) |
| | | |
| | | print 'Loading data.' |
| | | |
| | |
| | | pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations) |
| | | elif args.dataset == 'AFLW2000': |
| | | pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations) |
| | | elif args.dataset == 'AFLW2000_ds': |
| | | pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, transformations) |
| | | elif args.dataset == 'BIWI': |
| | | pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations) |
| | | elif args.dataset == 'AFLW': |
| | |
| | | _, roll_bpred = torch.max(roll.data, 1) |
| | | |
| | | # Continuous predictions |
| | | yaw_predicted = angles[:,0].data.cpu() |
| | | pitch_predicted = angles[:,1].data.cpu() |
| | | roll_predicted = angles[:,2].data.cpu() |
| | | yaw_predicted = utils.softmax_temperature(yaw.data, 1) |
| | | pitch_predicted = utils.softmax_temperature(pitch.data, 1) |
| | | roll_predicted = utils.softmax_temperature(roll.data, 1) |
| | | |
| | | yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 3 - 99 |
| | | pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 3 - 99 |
| | | roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu() * 3 - 99 |
| | | |
| | | # Mean absolute error |
| | | yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw)) |