natanielruiz
2017-09-15 8e5d7dbfe49d194b7d0b616307663e9a88fbcd88
Training AFLW
5个文件已修改
33 ■■■■ 已修改文件
code/batch_testing.py 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/train.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/train_preangles.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
practice/create_filtered_datasets.ipynb 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
practice/remove_KEPLER_test_split.ipynb 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/batch_testing.py
@@ -123,9 +123,14 @@
            label_roll = labels[:,2].float()
            pre_yaw, pre_pitch, pre_roll, angles = model(images)
            yaw = angles[-1][:,0].cpu().data
            pitch = angles[-1][:,1].cpu().data
            roll = angles[-1][:,2].cpu().data
            yaw = angles[0][:,0].cpu().data
            pitch = angles[0][:,1].cpu().data
            roll = angles[0][:,2].cpu().data
            for idx in xrange(1,args.iter_ref+1):
                yaw += angles[idx][:,0].cpu().data
                pitch += angles[idx][:,1].cpu().data
                roll += angles[idx][:,2].cpu().data
            # Mean absolute error
            yaw_error += torch.sum(torch.abs(yaw - label_yaw) * 3)
code/train.py
@@ -251,8 +251,6 @@
            loss_pitch += alpha * loss_reg_pitch
            loss_roll += alpha * loss_reg_roll
            loss_yaw *= 1
            # Finetuning loss
            loss_seq = [loss_yaw, loss_pitch, loss_roll]
            for idx in xrange(1,len(angles)):
code/train_preangles.py
@@ -197,8 +197,6 @@
            loss_pitch += alpha * loss_reg_pitch
            loss_roll += alpha * loss_reg_roll
            loss_yaw *= 1
            loss_seq = [loss_yaw, loss_pitch, loss_roll]
            # loss_seq = [loss_reg_yaw, loss_reg_pitch, loss_reg_roll]
            grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))]
practice/create_filtered_datasets.ipynb
@@ -2,7 +2,7 @@
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
@@ -96,7 +96,7 @@
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
@@ -108,7 +108,7 @@
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
@@ -117,7 +117,7 @@
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "289\n"
      "250\n"
     ]
    }
   ],
@@ -136,7 +136,7 @@
    "    yaw = float(annot[1]) * 180 / np.pi\n",
    "    pitch = float(annot[2]) * 180 / np.pi\n",
    "    roll = float(annot[3]) * 180 / np.pi\n",
    "    if abs(pitch) > 98 or abs(yaw) > 98 or abs(roll) > 98:\n",
    "    if abs(pitch) > 98.99 or abs(yaw) > 98.99 or abs(roll) > 98.99:\n",
    "        counter += 1\n",
    "        continue\n",
    "    out.write(original_line)\n",
practice/remove_KEPLER_test_split.ipynb
@@ -23,7 +23,7 @@
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
@@ -35,7 +35,7 @@
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
@@ -60,7 +60,7 @@
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
@@ -69,7 +69,7 @@
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "943 19537\n"
      "956 19872\n"
     ]
    }
   ],