From 3c7129a5434b086a3253a28479d8fa43b32823a9 Mon Sep 17 00:00:00 2001
From: natanielruiz <nruiz9@gatech.edu>
Date: 星期一, 30 十月 2017 06:19:01 +0800
Subject: [PATCH] next

---
 code/test_preangles.py |   56 +++++++++++++++++++++-----------------------------------
 1 files changed, 21 insertions(+), 35 deletions(-)

diff --git a/code/test_preangles.py b/code/test_preangles.py
index cfee8d1..be9bfda 100644
--- a/code/test_preangles.py
+++ b/code/test_preangles.py
@@ -1,4 +1,9 @@
+import sys, os, argparse
+
 import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+
 import torch
 import torch.nn as nn
 from torch.autograd import Variable
@@ -8,15 +13,7 @@
 import torchvision
 import torch.nn.functional as F
 
-import cv2
-import matplotlib.pyplot as plt
-import sys
-import os
-import argparse
-
-import datasets
-import hopenet
-import utils
+import datasets, hopenet, utils
 
 def parse_args():
     """Parse input arguments."""
@@ -46,12 +43,8 @@
     gpu = args.gpu_id
     snapshot_path = args.snapshot
 
-    # ResNet101 with 3 outputs.
-    # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
-    # ResNet50
-    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66, 0)
-    # ResNet18
-    # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
+    # ResNet50 structure
+    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
 
     print 'Loading snapshot.'
     # Load snapshot
@@ -64,18 +57,18 @@
     transforms.CenterCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
-    if args.dataset == 'AFLW2000':
-        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
-                                transformations)
-    elif args.dataset == 'AFLW2000_ds':
-        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list,
-                                transformations)
+    if args.dataset == 'Pose_300W_LP':
+        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'Pose_300W_LP_random_ds':
+        pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW2000':
+        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'BIWI':
         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW':
         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
-    elif args.dataset == 'Pose_300W_LP':
-        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW_aug':
+        pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFW':
         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
     else:
@@ -93,9 +86,6 @@
     model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
     total = 0
 
-    idx_tensor = [idx for idx in xrange(66)]
-    idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
-
     yaw_error = .0
     pitch_error = .0
     roll_error = .0
@@ -105,6 +95,7 @@
     for i, (images, labels, cont_labels, name) in enumerate(test_loader):
         images = Variable(images).cuda(gpu)
         total += cont_labels.size(0)
+
         label_yaw = cont_labels[:,0].float()
         label_pitch = cont_labels[:,1].float()
         label_roll = cont_labels[:,2].float()
@@ -117,21 +108,16 @@
         _, roll_bpred = torch.max(roll.data, 1)
 
         # Continuous predictions
-        yaw_predicted = utils.softmax_temperature(yaw.data, 1)
-        pitch_predicted = utils.softmax_temperature(pitch.data, 1)
-        roll_predicted = utils.softmax_temperature(roll.data, 1)
-
-        yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 3 - 99
-        pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 3 - 99
-        roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu() * 3 - 99
+        yaw_predicted = angles[:,0].data.cpu()
+        pitch_predicted = angles[:,1].data.cpu()
+        roll_predicted = angles[:,2].data.cpu()
 
         # Mean absolute error
         yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw))
         pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch))
         roll_error += torch.sum(torch.abs(roll_predicted - label_roll))
 
-        # Save images with pose cube.
-        # TODO: fix for larger batch size
+        # Save first image in batch with pose cube or axis.
         if args.save_viz:
             name = name[0]
             if args.dataset == 'BIWI':

--
Gitblit v1.8.0