From dfa3664a0f56445b023020a0ddb5eedc2780169a Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期三, 13 九月 2017 21:38:25 +0800
Subject: [PATCH] Center crop instead of random crop for testing.
---
code/test_shape.py | 2 +-
code/hopenet.py | 2 +-
code/test_old.py | 2 +-
code/test_on_video.py | 2 +-
code/batch_testing_preangles.py | 2 +-
code/batch_testing.py | 2 +-
code/test.py | 2 +-
code/test_AFW.py | 2 +-
code/test_preangles.py | 2 +-
9 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/code/batch_testing.py b/code/batch_testing.py
index 74ad58b..237ea54 100644
--- a/code/batch_testing.py
+++ b/code/batch_testing.py
@@ -62,7 +62,7 @@
print 'Loading data.'
transformations = transforms.Compose([transforms.Scale(224),
- transforms.RandomCrop(224), transforms.ToTensor(),
+ transforms.CenterCrop(224), transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
if args.dataset == 'AFLW2000':
diff --git a/code/batch_testing_preangles.py b/code/batch_testing_preangles.py
index bf0d32b..bb2047f 100644
--- a/code/batch_testing_preangles.py
+++ b/code/batch_testing_preangles.py
@@ -61,7 +61,7 @@
print 'Loading data.'
transformations = transforms.Compose([transforms.Scale(224),
- transforms.RandomCrop(224), transforms.ToTensor(),
+ transforms.CenterCrop(224), transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
if args.dataset == 'AFLW2000':
diff --git a/code/hopenet.py b/code/hopenet.py
index 81e645c..d122243 100644
--- a/code/hopenet.py
+++ b/code/hopenet.py
@@ -120,7 +120,7 @@
angles.append(torch.cat([yaw, pitch, roll], 1))
for idx in xrange(self.iter_ref):
- angles.append(self.fc_finetune(torch.cat((angles[-1], x), 1)))
+ angles.append(self.fc_finetune(torch.cat((angles[idx], x), 1)))
return pre_yaw, pre_pitch, pre_roll, angles
diff --git a/code/test.py b/code/test.py
index f2baf63..41db842 100644
--- a/code/test.py
+++ b/code/test.py
@@ -62,7 +62,7 @@
print 'Loading data.'
transformations = transforms.Compose([transforms.Scale(224),
- transforms.RandomCrop(224), transforms.ToTensor(),
+ transforms.CenterCrop(224), transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
if args.dataset == 'AFLW2000':
diff --git a/code/test_AFW.py b/code/test_AFW.py
index b8bd0bf..8a1c047 100644
--- a/code/test_AFW.py
+++ b/code/test_AFW.py
@@ -60,7 +60,7 @@
print 'Loading data.'
transformations = transforms.Compose([transforms.Scale(224),
- transforms.RandomCrop(224), transforms.ToTensor()])
+ transforms.CenterCrop(224), transforms.ToTensor()])
pose_dataset = datasets.AFW(args.data_dir, args.filename_list,
transformations)
diff --git a/code/test_old.py b/code/test_old.py
index b9be11e..e831e22 100644
--- a/code/test_old.py
+++ b/code/test_old.py
@@ -63,7 +63,7 @@
# transforms.RandomCrop(224), transforms.ToTensor()])
transformations = transforms.Compose([transforms.Scale(224),
- transforms.RandomCrop(224), transforms.ToTensor(),
+ transforms.CenterCrop(224), transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
diff --git a/code/test_on_video.py b/code/test_on_video.py
index 8d6c5fd..d384b08 100644
--- a/code/test_on_video.py
+++ b/code/test_on_video.py
@@ -60,7 +60,7 @@
print 'Loading data.'
transformations = transforms.Compose([transforms.Scale(224),
- transforms.RandomCrop(224), transforms.ToTensor(),
+ transforms.CenterCrop(224), transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
model.cuda(gpu)
diff --git a/code/test_preangles.py b/code/test_preangles.py
index f2039f2..7cf8ebb 100644
--- a/code/test_preangles.py
+++ b/code/test_preangles.py
@@ -64,7 +64,7 @@
# transforms.RandomCrop(224), transforms.ToTensor()])
transformations = transforms.Compose([transforms.Scale(224),
- transforms.RandomCrop(224), transforms.ToTensor(),
+ transforms.CenterCrop(224), transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
if args.dataset == 'AFLW2000':
diff --git a/code/test_shape.py b/code/test_shape.py
index aa644d5..a89e5ec 100644
--- a/code/test_shape.py
+++ b/code/test_shape.py
@@ -60,7 +60,7 @@
print 'Loading data.'
transformations = transforms.Compose([transforms.Scale(224),
- transforms.RandomCrop(224), transforms.ToTensor()])
+ transforms.CenterCrop(224), transforms.ToTensor()])
pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
transformations)
--
Gitblit v1.8.0