From 8e5d7dbfe49d194b7d0b616307663e9a88fbcd88 Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期五, 15 九月 2017 06:04:49 +0800 Subject: [PATCH] Training AFLW --- code/train.py | 2 -- practice/remove_KEPLER_test_split.ipynb | 8 ++++---- practice/create_filtered_datasets.ipynb | 10 +++++----- code/batch_testing.py | 11 ++++++++--- code/train_preangles.py | 2 -- 5 files changed, 17 insertions(+), 16 deletions(-) diff --git a/code/batch_testing.py b/code/batch_testing.py index 58d3b30..4a593d6 100644 --- a/code/batch_testing.py +++ b/code/batch_testing.py @@ -123,9 +123,14 @@ label_roll = labels[:,2].float() pre_yaw, pre_pitch, pre_roll, angles = model(images) - yaw = angles[-1][:,0].cpu().data - pitch = angles[-1][:,1].cpu().data - roll = angles[-1][:,2].cpu().data + yaw = angles[0][:,0].cpu().data + pitch = angles[0][:,1].cpu().data + roll = angles[0][:,2].cpu().data + + for idx in xrange(1,args.iter_ref+1): + yaw += angles[idx][:,0].cpu().data + pitch += angles[idx][:,1].cpu().data + roll += angles[idx][:,2].cpu().data # Mean absolute error yaw_error += torch.sum(torch.abs(yaw - label_yaw) * 3) diff --git a/code/train.py b/code/train.py index fd0735a..a41edc0 100644 --- a/code/train.py +++ b/code/train.py @@ -251,8 +251,6 @@ loss_pitch += alpha * loss_reg_pitch loss_roll += alpha * loss_reg_roll - loss_yaw *= 1 - # Finetuning loss loss_seq = [loss_yaw, loss_pitch, loss_roll] for idx in xrange(1,len(angles)): diff --git a/code/train_preangles.py b/code/train_preangles.py index 6328ef2..31144f6 100644 --- a/code/train_preangles.py +++ b/code/train_preangles.py @@ -197,8 +197,6 @@ loss_pitch += alpha * loss_reg_pitch loss_roll += alpha * loss_reg_roll - loss_yaw *= 1 - loss_seq = [loss_yaw, loss_pitch, loss_roll] # loss_seq = [loss_reg_yaw, loss_reg_pitch, loss_reg_roll] grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))] diff --git a/practice/create_filtered_datasets.ipynb b/practice/create_filtered_datasets.ipynb index a6210e9..7801b4a 100644 --- a/practice/create_filtered_datasets.ipynb +++ b/practice/create_filtered_datasets.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": { "collapsed": true }, @@ -96,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 9, "metadata": { "collapsed": true }, @@ -108,7 +108,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 10, "metadata": { "collapsed": false }, @@ -117,7 +117,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "289\n" + "250\n" ] } ], @@ -136,7 +136,7 @@ " yaw = float(annot[1]) * 180 / np.pi\n", " pitch = float(annot[2]) * 180 / np.pi\n", " roll = float(annot[3]) * 180 / np.pi\n", - " if abs(pitch) > 98 or abs(yaw) > 98 or abs(roll) > 98:\n", + " if abs(pitch) > 98.99 or abs(yaw) > 98.99 or abs(roll) > 98.99:\n", " counter += 1\n", " continue\n", " out.write(original_line)\n", diff --git a/practice/remove_KEPLER_test_split.ipynb b/practice/remove_KEPLER_test_split.ipynb index 6898b20..6216b8c 100644 --- a/practice/remove_KEPLER_test_split.ipynb +++ b/practice/remove_KEPLER_test_split.ipynb @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 1, "metadata": { "collapsed": true }, @@ -35,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 2, "metadata": { "collapsed": false }, @@ -60,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 4, "metadata": { "collapsed": false }, @@ -69,7 +69,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "943 19537\n" + "956 19872\n" ] } ], -- Gitblit v1.8.0