From 4b67b5c8ed5566ec3030d537536282e830d87e40 Mon Sep 17 00:00:00 2001
From: natanielruiz <nruiz9@gatech.edu>
Date: 星期一, 30 十月 2017 07:15:49 +0800
Subject: [PATCH] next

---
 code/test_preangles.py |   45 ++++++++++++++++++++-------------------------
 1 files changed, 20 insertions(+), 25 deletions(-)

diff --git a/code/test_preangles.py b/code/test_preangles.py
index cfee8d1..d4a9f5f 100644
--- a/code/test_preangles.py
+++ b/code/test_preangles.py
@@ -1,4 +1,9 @@
+import sys, os, argparse
+
 import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+
 import torch
 import torch.nn as nn
 from torch.autograd import Variable
@@ -8,15 +13,7 @@
 import torchvision
 import torch.nn.functional as F
 
-import cv2
-import matplotlib.pyplot as plt
-import sys
-import os
-import argparse
-
-import datasets
-import hopenet
-import utils
+import datasets, hopenet, utils
 
 def parse_args():
     """Parse input arguments."""
@@ -46,12 +43,8 @@
     gpu = args.gpu_id
     snapshot_path = args.snapshot
 
-    # ResNet101 with 3 outputs.
-    # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
-    # ResNet50
-    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66, 0)
-    # ResNet18
-    # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
+    # ResNet50 structure
+    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
 
     print 'Loading snapshot.'
     # Load snapshot
@@ -64,18 +57,20 @@
     transforms.CenterCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
-    if args.dataset == 'AFLW2000':
-        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
-                                transformations)
+    if args.dataset == 'Pose_300W_LP':
+        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'Pose_300W_LP_random_ds':
+        pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW2000':
+        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW2000_ds':
-        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list,
-                                transformations)
+        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'BIWI':
         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW':
         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
-    elif args.dataset == 'Pose_300W_LP':
-        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW_aug':
+        pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFW':
         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
     else:
@@ -105,11 +100,12 @@
     for i, (images, labels, cont_labels, name) in enumerate(test_loader):
         images = Variable(images).cuda(gpu)
         total += cont_labels.size(0)
+
         label_yaw = cont_labels[:,0].float()
         label_pitch = cont_labels[:,1].float()
         label_roll = cont_labels[:,2].float()
 
-        yaw, pitch, roll, angles = model(images)
+        yaw, pitch, roll = model(images)
 
         # Binned predictions
         _, yaw_bpred = torch.max(yaw.data, 1)
@@ -130,8 +126,7 @@
         pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch))
         roll_error += torch.sum(torch.abs(roll_predicted - label_roll))
 
-        # Save images with pose cube.
-        # TODO: fix for larger batch size
+        # Save first image in batch with pose cube or axis.
         if args.save_viz:
             name = name[0]
             if args.dataset == 'BIWI':

--
Gitblit v1.8.0