code/test_AFLW.py @@ -97,7 +97,7 @@ label_pitch = labels[:,1].float() label_roll = labels[:,2].float() yaw, pitch, roll = model(images) yaw, pitch, roll, angles = model(images) # Binned predictions _, yaw_bpred = torch.max(yaw.data, 1)