From 8f2b081586161a55388934456a9b0193d02dd267 Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期三, 27 九月 2017 10:15:27 +0800 Subject: [PATCH] Random downsample experiment --- code/datasets.py | 10 ++++------ code/hopenet.py | 2 +- code/train_preangles.py | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/code/datasets.py b/code/datasets.py index f5941ae..aee5246 100644 --- a/code/datasets.py +++ b/code/datasets.py @@ -128,12 +128,10 @@ yaw = pose[1] * 180 / np.pi roll = pose[2] * 180 / np.pi - rnd = np.random.random_sample() - if rnd < 0.5: - ds = 10 - original_size = img.size - img = img.resize((img.size[0] / ds, img.size[1] / ds), resample=Image.NEAREST) - img = img.resize((original_size[0], original_size[1]), resample=Image.NEAREST) + ds = np.random.randint(1,11) + original_size = img.size + img = img.resize((img.size[0] / ds, img.size[1] / ds), resample=Image.NEAREST) + img = img.resize((original_size[0], original_size[1]), resample=Image.NEAREST) # Flip? rnd = np.random.random_sample() diff --git a/code/hopenet.py b/code/hopenet.py index de2f4ec..80160d9 100644 --- a/code/hopenet.py +++ b/code/hopenet.py @@ -339,7 +339,7 @@ preangles = torch.cat([yaw, pitch, roll], 1) angles.append(preangles) - return pre_yaw, pre_pitch, pre_roll, angles, sr_output + return pre_yaw, pre_pitch, pre_roll, angles, sr_y class Hopenet_new(nn.Module): # This is just Hopenet with 3 output layers for yaw, pitch and roll. diff --git a/code/train_preangles.py b/code/train_preangles.py index 4752aef..ffceee2 100644 --- a/code/train_preangles.py +++ b/code/train_preangles.py @@ -125,7 +125,7 @@ if args.dataset == 'Pose_300W_LP': pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations) - if args.dataset == 'Pose_300W_LP_random_ds': + elif args.dataset == 'Pose_300W_LP_random_ds': pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations) elif args.dataset == 'AFLW2000': pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations) -- Gitblit v1.8.0