| | |
| | | |
| | | return args |
| | | |
| | | def load_filtered_state_dict(model, snapshot): |
| | | # By user apaszke from discuss.pytorch.org |
| | | model_dict = model.state_dict() |
| | | snapshot = {k: v for k, v in snapshot.items() if k in model_dict} |
| | | model_dict.update(snapshot) |
| | | model.load_state_dict(model_dict) |
| | | |
| | | if __name__ == '__main__': |
| | | args = parse_args() |
| | | |
| | |
| | | # Load snapshot |
| | | saved_state_dict = torch.load(snapshot_path) |
| | | model.load_state_dict(saved_state_dict) |
| | | # load_filtered_state_dict(model, saved_state_dict) |
| | | |
| | | print 'Loading data.' |
| | | |
| | |
| | | |
| | | l1loss = torch.nn.L1Loss(size_average=False) |
| | | |
| | | |
| | | |
| | | for i, (images, labels, cont_labels, name) in enumerate(test_loader): |
| | | images = Variable(images).cuda(gpu) |
| | | total += cont_labels.size(0) |
| | |
| | | label_pitch = cont_labels[:,1].float() |
| | | label_roll = cont_labels[:,2].float() |
| | | |
| | | yaw, pitch, roll, angles = model(images) |
| | | yaw, pitch, roll = model(images) |
| | | |
| | | # Binned predictions |
| | | _, yaw_bpred = torch.max(yaw.data, 1) |