From dfa3664a0f56445b023020a0ddb5eedc2780169a Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期三, 13 九月 2017 21:38:25 +0800 Subject: [PATCH] Center crop instead of random crop for testing. --- code/test_shape.py | 2 +- code/hopenet.py | 2 +- code/test_old.py | 2 +- code/test_on_video.py | 2 +- code/batch_testing_preangles.py | 2 +- code/batch_testing.py | 2 +- code/test.py | 2 +- code/test_AFW.py | 2 +- code/test_preangles.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/code/batch_testing.py b/code/batch_testing.py index 74ad58b..237ea54 100644 --- a/code/batch_testing.py +++ b/code/batch_testing.py @@ -62,7 +62,7 @@ print 'Loading data.' transformations = transforms.Compose([transforms.Scale(224), - transforms.RandomCrop(224), transforms.ToTensor(), + transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) if args.dataset == 'AFLW2000': diff --git a/code/batch_testing_preangles.py b/code/batch_testing_preangles.py index bf0d32b..bb2047f 100644 --- a/code/batch_testing_preangles.py +++ b/code/batch_testing_preangles.py @@ -61,7 +61,7 @@ print 'Loading data.' transformations = transforms.Compose([transforms.Scale(224), - transforms.RandomCrop(224), transforms.ToTensor(), + transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) if args.dataset == 'AFLW2000': diff --git a/code/hopenet.py b/code/hopenet.py index 81e645c..d122243 100644 --- a/code/hopenet.py +++ b/code/hopenet.py @@ -120,7 +120,7 @@ angles.append(torch.cat([yaw, pitch, roll], 1)) for idx in xrange(self.iter_ref): - angles.append(self.fc_finetune(torch.cat((angles[-1], x), 1))) + angles.append(self.fc_finetune(torch.cat((angles[idx], x), 1))) return pre_yaw, pre_pitch, pre_roll, angles diff --git a/code/test.py b/code/test.py index f2baf63..41db842 100644 --- a/code/test.py +++ b/code/test.py @@ -62,7 +62,7 @@ print 'Loading data.' transformations = transforms.Compose([transforms.Scale(224), - transforms.RandomCrop(224), transforms.ToTensor(), + transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) if args.dataset == 'AFLW2000': diff --git a/code/test_AFW.py b/code/test_AFW.py index b8bd0bf..8a1c047 100644 --- a/code/test_AFW.py +++ b/code/test_AFW.py @@ -60,7 +60,7 @@ print 'Loading data.' transformations = transforms.Compose([transforms.Scale(224), - transforms.RandomCrop(224), transforms.ToTensor()]) + transforms.CenterCrop(224), transforms.ToTensor()]) pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations) diff --git a/code/test_old.py b/code/test_old.py index b9be11e..e831e22 100644 --- a/code/test_old.py +++ b/code/test_old.py @@ -63,7 +63,7 @@ # transforms.RandomCrop(224), transforms.ToTensor()]) transformations = transforms.Compose([transforms.Scale(224), - transforms.RandomCrop(224), transforms.ToTensor(), + transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, diff --git a/code/test_on_video.py b/code/test_on_video.py index 8d6c5fd..d384b08 100644 --- a/code/test_on_video.py +++ b/code/test_on_video.py @@ -60,7 +60,7 @@ print 'Loading data.' transformations = transforms.Compose([transforms.Scale(224), - transforms.RandomCrop(224), transforms.ToTensor(), + transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) model.cuda(gpu) diff --git a/code/test_preangles.py b/code/test_preangles.py index f2039f2..7cf8ebb 100644 --- a/code/test_preangles.py +++ b/code/test_preangles.py @@ -64,7 +64,7 @@ # transforms.RandomCrop(224), transforms.ToTensor()]) transformations = transforms.Compose([transforms.Scale(224), - transforms.RandomCrop(224), transforms.ToTensor(), + transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) if args.dataset == 'AFLW2000': diff --git a/code/test_shape.py b/code/test_shape.py index aa644d5..a89e5ec 100644 --- a/code/test_shape.py +++ b/code/test_shape.py @@ -60,7 +60,7 @@ print 'Loading data.' transformations = transforms.Compose([transforms.Scale(224), - transforms.RandomCrop(224), transforms.ToTensor()]) + transforms.CenterCrop(224), transforms.ToTensor()]) pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations) -- Gitblit v1.8.0