From dfa3664a0f56445b023020a0ddb5eedc2780169a Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期三, 13 九月 2017 21:38:25 +0800
Subject: [PATCH] Center crop instead of random crop for testing.

---
 code/test_shape.py              |    2 +-
 code/hopenet.py                 |    2 +-
 code/test_old.py                |    2 +-
 code/test_on_video.py           |    2 +-
 code/batch_testing_preangles.py |    2 +-
 code/batch_testing.py           |    2 +-
 code/test.py                    |    2 +-
 code/test_AFW.py                |    2 +-
 code/test_preangles.py          |    2 +-
 9 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/code/batch_testing.py b/code/batch_testing.py
index 74ad58b..237ea54 100644
--- a/code/batch_testing.py
+++ b/code/batch_testing.py
@@ -62,7 +62,7 @@
     print 'Loading data.'
 
     transformations = transforms.Compose([transforms.Scale(224),
-    transforms.RandomCrop(224), transforms.ToTensor(),
+    transforms.CenterCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
     if args.dataset == 'AFLW2000':
diff --git a/code/batch_testing_preangles.py b/code/batch_testing_preangles.py
index bf0d32b..bb2047f 100644
--- a/code/batch_testing_preangles.py
+++ b/code/batch_testing_preangles.py
@@ -61,7 +61,7 @@
     print 'Loading data.'
 
     transformations = transforms.Compose([transforms.Scale(224),
-    transforms.RandomCrop(224), transforms.ToTensor(),
+    transforms.CenterCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
     if args.dataset == 'AFLW2000':
diff --git a/code/hopenet.py b/code/hopenet.py
index 81e645c..d122243 100644
--- a/code/hopenet.py
+++ b/code/hopenet.py
@@ -120,7 +120,7 @@
         angles.append(torch.cat([yaw, pitch, roll], 1))
 
         for idx in xrange(self.iter_ref):
-            angles.append(self.fc_finetune(torch.cat((angles[-1], x), 1)))
+            angles.append(self.fc_finetune(torch.cat((angles[idx], x), 1)))
 
         return pre_yaw, pre_pitch, pre_roll, angles
 
diff --git a/code/test.py b/code/test.py
index f2baf63..41db842 100644
--- a/code/test.py
+++ b/code/test.py
@@ -62,7 +62,7 @@
     print 'Loading data.'
 
     transformations = transforms.Compose([transforms.Scale(224),
-    transforms.RandomCrop(224), transforms.ToTensor(),
+    transforms.CenterCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
     if args.dataset == 'AFLW2000':
diff --git a/code/test_AFW.py b/code/test_AFW.py
index b8bd0bf..8a1c047 100644
--- a/code/test_AFW.py
+++ b/code/test_AFW.py
@@ -60,7 +60,7 @@
     print 'Loading data.'
 
     transformations = transforms.Compose([transforms.Scale(224),
-    transforms.RandomCrop(224), transforms.ToTensor()])
+    transforms.CenterCrop(224), transforms.ToTensor()])
 
     pose_dataset = datasets.AFW(args.data_dir, args.filename_list,
                                 transformations)
diff --git a/code/test_old.py b/code/test_old.py
index b9be11e..e831e22 100644
--- a/code/test_old.py
+++ b/code/test_old.py
@@ -63,7 +63,7 @@
     # transforms.RandomCrop(224), transforms.ToTensor()])
 
     transformations = transforms.Compose([transforms.Scale(224),
-    transforms.RandomCrop(224), transforms.ToTensor(),
+    transforms.CenterCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
     pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
diff --git a/code/test_on_video.py b/code/test_on_video.py
index 8d6c5fd..d384b08 100644
--- a/code/test_on_video.py
+++ b/code/test_on_video.py
@@ -60,7 +60,7 @@
     print 'Loading data.'
 
     transformations = transforms.Compose([transforms.Scale(224),
-    transforms.RandomCrop(224), transforms.ToTensor(),
+    transforms.CenterCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
     model.cuda(gpu)
diff --git a/code/test_preangles.py b/code/test_preangles.py
index f2039f2..7cf8ebb 100644
--- a/code/test_preangles.py
+++ b/code/test_preangles.py
@@ -64,7 +64,7 @@
     # transforms.RandomCrop(224), transforms.ToTensor()])
 
     transformations = transforms.Compose([transforms.Scale(224),
-    transforms.RandomCrop(224), transforms.ToTensor(),
+    transforms.CenterCrop(224), transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
     if args.dataset == 'AFLW2000':
diff --git a/code/test_shape.py b/code/test_shape.py
index aa644d5..a89e5ec 100644
--- a/code/test_shape.py
+++ b/code/test_shape.py
@@ -60,7 +60,7 @@
     print 'Loading data.'
 
     transformations = transforms.Compose([transforms.Scale(224),
-    transforms.RandomCrop(224), transforms.ToTensor()])
+    transforms.CenterCrop(224), transforms.ToTensor()])
 
     pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
                                 transformations)

--
Gitblit v1.8.0