From f111cb002b9c6065fdf6bb274ce5857a9e875e8c Mon Sep 17 00:00:00 2001
From: chenshijun <csj_sky@126.com>
Date: 星期三, 05 六月 2019 15:38:49 +0800
Subject: [PATCH] face rectangle

---
 code/test_alexnet.py |   45 ++++++++++++++++++++++-----------------------
 1 files changed, 22 insertions(+), 23 deletions(-)

diff --git a/code/test_alexnet.py b/code/test_alexnet.py
index d9cc0a3..45ad25f 100644
--- a/code/test_alexnet.py
+++ b/code/test_alexnet.py
@@ -1,4 +1,9 @@
+import sys, os, argparse
+
 import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+
 import torch
 import torch.nn as nn
 from torch.autograd import Variable
@@ -8,15 +13,7 @@
 import torchvision
 import torch.nn.functional as F
 
-import cv2
-import matplotlib.pyplot as plt
-import sys
-import os
-import argparse
-
-import datasets
-import hopenet
-import utils
+import datasets, hopenet, utils
 
 def parse_args():
     """Parse input arguments."""
@@ -48,33 +45,35 @@
 
     model = hopenet.AlexNet(66)
 
-    print 'Loading snapshot.'
+    print('Loading snapshot.')
     # Load snapshot
     saved_state_dict = torch.load(snapshot_path)
     model.load_state_dict(saved_state_dict)
 
-    print 'Loading data.'
+    print('Loading data.')
 
     transformations = transforms.Compose([transforms.Scale(224),
     transforms.CenterCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
-    if args.dataset == 'AFLW2000':
-        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
-                                transformations)
+    if args.dataset == 'Pose_300W_LP':
+        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'Pose_300W_LP_random_ds':
+        pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW2000':
+        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW2000_ds':
-        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list,
-                                transformations)
+        pose_dataset = datasets.AFLW2000_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'BIWI':
         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW':
         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
-    elif args.dataset == 'Pose_300W_LP':
-        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW_aug':
+        pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFW':
         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
     else:
-        print 'Error: not a valid dataset name'
+        print('Error: not a valid dataset name')
         sys.exit()
     test_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
                                                batch_size=args.batch_size,
@@ -82,7 +81,7 @@
 
     model.cuda(gpu)
 
-    print 'Ready to test network.'
+    print ('Ready to test network.')
 
     # Test the Model
     model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
@@ -125,8 +124,7 @@
         pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch))
         roll_error += torch.sum(torch.abs(roll_predicted - label_roll))
 
-        # Save images with pose cube.
-        # TODO: fix for larger batch size
+        # Save first image in batch with pose cube or axis.
         if args.save_viz:
             name = name[0]
             if args.dataset == 'BIWI':
@@ -136,7 +134,8 @@
             if args.batch_size == 1:
                 error_string = 'y %.2f, p %.2f, r %.2f' % (torch.sum(torch.abs(yaw_predicted - label_yaw)), torch.sum(torch.abs(pitch_predicted - label_pitch)), torch.sum(torch.abs(roll_predicted - label_roll)))
                 cv2.putText(cv2_img, error_string, (30, cv2_img.shape[0]- 30), fontFace=1, fontScale=1, color=(0,0,255), thickness=1)
-            utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0])
+            # utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], size=100)
+            utils.draw_axis(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], tdx = 200, tdy= 200, size=100)
             cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
 
     print('Test error in degrees of the model on the ' + str(total) +

--
Gitblit v1.8.0