From 4b67b5c8ed5566ec3030d537536282e830d87e40 Mon Sep 17 00:00:00 2001
From: natanielruiz <nruiz9@gatech.edu>
Date: 星期一, 30 十月 2017 07:15:49 +0800
Subject: [PATCH] next

---
 code/test_preangles.py |   10 +++++++---
 1 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/code/test_preangles.py b/code/test_preangles.py
index 3d70bb0..d4a9f5f 100644
--- a/code/test_preangles.py
+++ b/code/test_preangles.py
@@ -63,6 +63,8 @@
         pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW2000':
         pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW2000_ds':
+        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'BIWI':
         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW':
@@ -86,6 +88,9 @@
     model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
     total = 0
 
+    idx_tensor = [idx for idx in xrange(66)]
+    idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
+
     yaw_error = .0
     pitch_error = .0
     roll_error = .0
@@ -100,7 +105,7 @@
         label_pitch = cont_labels[:,1].float()
         label_roll = cont_labels[:,2].float()
 
-        yaw, pitch, roll, angles = model(images)
+        yaw, pitch, roll = model(images)
 
         # Binned predictions
         _, yaw_bpred = torch.max(yaw.data, 1)
@@ -121,8 +126,7 @@
         pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch))
         roll_error += torch.sum(torch.abs(roll_predicted - label_roll))
 
-        # Save images with pose cube.
-        # TODO: fix for larger batch size
+        # Save first image in batch with pose cube or axis.
         if args.save_viz:
             name = name[0]
             if args.dataset == 'BIWI':

--
Gitblit v1.8.0