From 2c764e41e2fde6244b87da58d12c40d09a14fcb4 Mon Sep 17 00:00:00 2001
From: natanielruiz <nruiz9@gatech.edu>
Date: 星期一, 30 十月 2017 06:49:01 +0800
Subject: [PATCH] Next

---
 code/test_alexnet.py |   21 +++++++++------------
 1 files changed, 9 insertions(+), 12 deletions(-)

diff --git a/code/test_alexnet.py b/code/test_alexnet.py
index 7a3989a..529d566 100644
--- a/code/test_alexnet.py
+++ b/code/test_alexnet.py
@@ -1,4 +1,9 @@
+import sys, os, argparse
+
 import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+
 import torch
 import torch.nn as nn
 from torch.autograd import Variable
@@ -8,15 +13,7 @@
 import torchvision
 import torch.nn.functional as F
 
-import cv2
-import matplotlib.pyplot as plt
-import sys
-import os
-import argparse
-
-import datasets
-import hopenet
-import utils
+import datasets, hopenet, utils
 
 def parse_args():
     """Parse input arguments."""
@@ -134,8 +131,7 @@
         pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch))
         roll_error += torch.sum(torch.abs(roll_predicted - label_roll))
 
-        # Save images with pose cube.
-        # TODO: fix for larger batch size
+        # Save first image in batch with pose cube or axis.
         if args.save_viz:
             name = name[0]
             if args.dataset == 'BIWI':
@@ -145,7 +141,8 @@
             if args.batch_size == 1:
                 error_string = 'y %.2f, p %.2f, r %.2f' % (torch.sum(torch.abs(yaw_predicted - label_yaw)), torch.sum(torch.abs(pitch_predicted - label_pitch)), torch.sum(torch.abs(roll_predicted - label_roll)))
                 cv2.putText(cv2_img, error_string, (30, cv2_img.shape[0]- 30), fontFace=1, fontScale=1, color=(0,0,255), thickness=1)
-            utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0])
+            # utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], size=100)
+            utils.draw_axis(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], tdx = 200, tdy= 200, size=100)
             cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
 
     print('Test error in degrees of the model on the ' + str(total) +

--
Gitblit v1.8.0