| | |
| | | label_roll = labels[:,2].float() |
| | | |
| | | pre_yaw, pre_pitch, pre_roll, angles = model(images) |
| | | yaw = angles[-1][:,0].cpu().data |
| | | pitch = angles[-1][:,1].cpu().data |
| | | roll = angles[-1][:,2].cpu().data |
| | | yaw = angles[0][:,0].cpu().data |
| | | pitch = angles[0][:,1].cpu().data |
| | | roll = angles[0][:,2].cpu().data |
| | | |
| | | for idx in xrange(1,args.iter_ref+1): |
| | | yaw += angles[idx][:,0].cpu().data |
| | | pitch += angles[idx][:,1].cpu().data |
| | | roll += angles[idx][:,2].cpu().data |
| | | |
| | | # Mean absolute error |
| | | yaw_error += torch.sum(torch.abs(yaw - label_yaw) * 3) |
| | |
| | | loss_pitch += alpha * loss_reg_pitch |
| | | loss_roll += alpha * loss_reg_roll |
| | | |
| | | loss_yaw *= 1 |
| | | |
| | | # Finetuning loss |
| | | loss_seq = [loss_yaw, loss_pitch, loss_roll] |
| | | for idx in xrange(1,len(angles)): |
| | |
| | | loss_pitch += alpha * loss_reg_pitch |
| | | loss_roll += alpha * loss_reg_roll |
| | | |
| | | loss_yaw *= 1 |
| | | |
| | | loss_seq = [loss_yaw, loss_pitch, loss_roll] |
| | | # loss_seq = [loss_reg_yaw, loss_reg_pitch, loss_reg_roll] |
| | | grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))] |
| | |
| | | "cells": [ |
| | | { |
| | | "cell_type": "code", |
| | | "execution_count": 6, |
| | | "execution_count": 4, |
| | | "metadata": { |
| | | "collapsed": true |
| | | }, |
| | |
| | | }, |
| | | { |
| | | "cell_type": "code", |
| | | "execution_count": 19, |
| | | "execution_count": 9, |
| | | "metadata": { |
| | | "collapsed": true |
| | | }, |
| | |
| | | }, |
| | | { |
| | | "cell_type": "code", |
| | | "execution_count": 20, |
| | | "execution_count": 10, |
| | | "metadata": { |
| | | "collapsed": false |
| | | }, |
| | |
| | | "name": "stdout", |
| | | "output_type": "stream", |
| | | "text": [ |
| | | "289\n" |
| | | "250\n" |
| | | ] |
| | | } |
| | | ], |
| | |
| | | " yaw = float(annot[1]) * 180 / np.pi\n", |
| | | " pitch = float(annot[2]) * 180 / np.pi\n", |
| | | " roll = float(annot[3]) * 180 / np.pi\n", |
| | | " if abs(pitch) > 98 or abs(yaw) > 98 or abs(roll) > 98:\n", |
| | | " if abs(pitch) > 98.99 or abs(yaw) > 98.99 or abs(roll) > 98.99:\n", |
| | | " counter += 1\n", |
| | | " continue\n", |
| | | " out.write(original_line)\n", |
| | |
| | | }, |
| | | { |
| | | "cell_type": "code", |
| | | "execution_count": 34, |
| | | "execution_count": 1, |
| | | "metadata": { |
| | | "collapsed": true |
| | | }, |
| | |
| | | }, |
| | | { |
| | | "cell_type": "code", |
| | | "execution_count": 35, |
| | | "execution_count": 2, |
| | | "metadata": { |
| | | "collapsed": false |
| | | }, |
| | |
| | | }, |
| | | { |
| | | "cell_type": "code", |
| | | "execution_count": 36, |
| | | "execution_count": 4, |
| | | "metadata": { |
| | | "collapsed": false |
| | | }, |
| | |
| | | "name": "stdout", |
| | | "output_type": "stream", |
| | | "text": [ |
| | | "943 19537\n" |
| | | "956 19872\n" |
| | | ] |
| | | } |
| | | ], |