From 6664c6d52fad58e396861946a3bed7d5afc4d44d Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期五, 07 七月 2017 10:53:52 +0800
Subject: [PATCH] Training for hopenet works.

---
 code/test_resnet_bins.py |   26 +++++++++++++++-----------
 1 files changed, 15 insertions(+), 11 deletions(-)

diff --git a/code/test_resnet_bins.py b/code/test_resnet_bins.py
index 34fc8f5..0a093ee 100644
--- a/code/test_resnet_bins.py
+++ b/code/test_resnet_bins.py
@@ -13,7 +13,7 @@
 import os
 import argparse
 
-from datasets import AFLW2000
+import datasets
 import hopenet
 import utils
 
@@ -55,9 +55,10 @@
 
     print 'Loading data.'
 
-    transformations = transforms.Compose([transforms.Scale(224),transforms.RandomCrop(224), transforms.ToTensor()])
+    transformations = transforms.Compose([transforms.Scale(224),
+    transforms.RandomCrop(224), transforms.ToTensor()])
 
-    pose_dataset = AFLW2000(args.data_dir, args.filename_list,
+    pose_dataset = datasets.AFLW2000_binned(args.data_dir, args.filename_list,
                                 transformations)
     test_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
                                                batch_size=batch_size,
@@ -69,7 +70,9 @@
 
     # Test the Model
     model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
-    error = .0
+    yaw_correct = 0
+    pitch_correct = 0
+    roll_correct = 0
     total = 0
     for i, (images, labels, name) in enumerate(test_loader):
         images = Variable(images).cuda(gpu)
@@ -78,13 +81,14 @@
         _, predicted = torch.max(outputs.data, 1)
         total += labels.size(0)
         # TODO: There are more efficient ways.
+        yaw_correct += (outputs[:][0] == labels[:][0])
+        pitch_correct += (outputs[:][])
         for idx in xrange(len(outputs)):
-            # if abs(outputs[idx].data[1] - labels[idx].data[1]) * 180 / np.pi > 30:
-            print name
-            print abs(outputs[idx].data - labels[idx].data) * 180 / np.pi, 180 * outputs[idx].data / np.pi, labels[idx].data * 180 / np.pi
-            # error += utils.mse_loss(outputs[idx], labels[idx])
-            error += abs(outputs[idx].data - labels[idx].data) * 180 / np.pi
+            yaw_correct += (outputs[idx].data[0] == labels[idx].data[0])
+            pitch_correct += (outputs[idx].data[1] == labels[idx].data[1])
+            roll_correct += (outputs[idx].data[2] == labels[idx].data[2])
 
 
-    print('Test MSE error of the model on the ' + str(total) +
-    ' test images: %.4f' % (error / total))
+    print('Test accuracies of the model on the ' + str(total) +
+    ' test images. Yaw: %.4f %%, Pitch: %.4f %%, Roll: %.4f %%' % (yaw_correct / total,
+    pitch_correct / total, roll_correct / total))

--
Gitblit v1.8.0