From 8e5d7dbfe49d194b7d0b616307663e9a88fbcd88 Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期五, 15 九月 2017 06:04:49 +0800
Subject: [PATCH] Training AFLW

---
 code/train.py                           |    2 --
 practice/remove_KEPLER_test_split.ipynb |    8 ++++----
 practice/create_filtered_datasets.ipynb |   10 +++++-----
 code/batch_testing.py                   |   11 ++++++++---
 code/train_preangles.py                 |    2 --
 5 files changed, 17 insertions(+), 16 deletions(-)

diff --git a/code/batch_testing.py b/code/batch_testing.py
index 58d3b30..4a593d6 100644
--- a/code/batch_testing.py
+++ b/code/batch_testing.py
@@ -123,9 +123,14 @@
             label_roll = labels[:,2].float()
 
             pre_yaw, pre_pitch, pre_roll, angles = model(images)
-            yaw = angles[-1][:,0].cpu().data
-            pitch = angles[-1][:,1].cpu().data
-            roll = angles[-1][:,2].cpu().data
+            yaw = angles[0][:,0].cpu().data
+            pitch = angles[0][:,1].cpu().data
+            roll = angles[0][:,2].cpu().data
+
+            for idx in xrange(1,args.iter_ref+1):
+                yaw += angles[idx][:,0].cpu().data
+                pitch += angles[idx][:,1].cpu().data
+                roll += angles[idx][:,2].cpu().data
 
             # Mean absolute error
             yaw_error += torch.sum(torch.abs(yaw - label_yaw) * 3)
diff --git a/code/train.py b/code/train.py
index fd0735a..a41edc0 100644
--- a/code/train.py
+++ b/code/train.py
@@ -251,8 +251,6 @@
             loss_pitch += alpha * loss_reg_pitch
             loss_roll += alpha * loss_reg_roll
 
-            loss_yaw *= 1
-
             # Finetuning loss
             loss_seq = [loss_yaw, loss_pitch, loss_roll]
             for idx in xrange(1,len(angles)):
diff --git a/code/train_preangles.py b/code/train_preangles.py
index 6328ef2..31144f6 100644
--- a/code/train_preangles.py
+++ b/code/train_preangles.py
@@ -197,8 +197,6 @@
             loss_pitch += alpha * loss_reg_pitch
             loss_roll += alpha * loss_reg_roll
 
-            loss_yaw *= 1
-
             loss_seq = [loss_yaw, loss_pitch, loss_roll]
             # loss_seq = [loss_reg_yaw, loss_reg_pitch, loss_reg_roll]
             grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))]
diff --git a/practice/create_filtered_datasets.ipynb b/practice/create_filtered_datasets.ipynb
index a6210e9..7801b4a 100644
--- a/practice/create_filtered_datasets.ipynb
+++ b/practice/create_filtered_datasets.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 4,
    "metadata": {
     "collapsed": true
    },
@@ -96,7 +96,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 9,
    "metadata": {
     "collapsed": true
    },
@@ -108,7 +108,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 10,
    "metadata": {
     "collapsed": false
    },
@@ -117,7 +117,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "289\n"
+      "250\n"
      ]
     }
    ],
@@ -136,7 +136,7 @@
     "    yaw = float(annot[1]) * 180 / np.pi\n",
     "    pitch = float(annot[2]) * 180 / np.pi\n",
     "    roll = float(annot[3]) * 180 / np.pi\n",
-    "    if abs(pitch) > 98 or abs(yaw) > 98 or abs(roll) > 98:\n",
+    "    if abs(pitch) > 98.99 or abs(yaw) > 98.99 or abs(roll) > 98.99:\n",
     "        counter += 1\n",
     "        continue\n",
     "    out.write(original_line)\n",
diff --git a/practice/remove_KEPLER_test_split.ipynb b/practice/remove_KEPLER_test_split.ipynb
index 6898b20..6216b8c 100644
--- a/practice/remove_KEPLER_test_split.ipynb
+++ b/practice/remove_KEPLER_test_split.ipynb
@@ -23,7 +23,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 34,
+   "execution_count": 1,
    "metadata": {
     "collapsed": true
    },
@@ -35,7 +35,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 35,
+   "execution_count": 2,
    "metadata": {
     "collapsed": false
    },
@@ -60,7 +60,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 36,
+   "execution_count": 4,
    "metadata": {
     "collapsed": false
    },
@@ -69,7 +69,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "943 19537\n"
+      "956 19872\n"
      ]
     }
    ],

--
Gitblit v1.8.0