From 8f2b081586161a55388934456a9b0193d02dd267 Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期三, 27 九月 2017 10:15:27 +0800
Subject: [PATCH] Random downsample experiment

---
 code/datasets.py        |   10 ++++------
 code/hopenet.py         |    2 +-
 code/train_preangles.py |    2 +-
 3 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/code/datasets.py b/code/datasets.py
index f5941ae..aee5246 100644
--- a/code/datasets.py
+++ b/code/datasets.py
@@ -128,12 +128,10 @@
         yaw = pose[1] * 180 / np.pi
         roll = pose[2] * 180 / np.pi
 
-        rnd = np.random.random_sample()
-        if rnd < 0.5:
-            ds = 10
-            original_size = img.size
-            img = img.resize((img.size[0] / ds, img.size[1] / ds), resample=Image.NEAREST)
-            img = img.resize((original_size[0], original_size[1]), resample=Image.NEAREST)
+        ds = np.random.randint(1,11)
+        original_size = img.size
+        img = img.resize((img.size[0] / ds, img.size[1] / ds), resample=Image.NEAREST)
+        img = img.resize((original_size[0], original_size[1]), resample=Image.NEAREST)
 
         # Flip?
         rnd = np.random.random_sample()
diff --git a/code/hopenet.py b/code/hopenet.py
index de2f4ec..80160d9 100644
--- a/code/hopenet.py
+++ b/code/hopenet.py
@@ -339,7 +339,7 @@
         preangles = torch.cat([yaw, pitch, roll], 1)
         angles.append(preangles)
 
-        return pre_yaw, pre_pitch, pre_roll, angles, sr_output
+        return pre_yaw, pre_pitch, pre_roll, angles, sr_y
 
 class Hopenet_new(nn.Module):
     # This is just Hopenet with 3 output layers for yaw, pitch and roll.
diff --git a/code/train_preangles.py b/code/train_preangles.py
index 4752aef..ffceee2 100644
--- a/code/train_preangles.py
+++ b/code/train_preangles.py
@@ -125,7 +125,7 @@
 
     if args.dataset == 'Pose_300W_LP':
         pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
-    if args.dataset == 'Pose_300W_LP_random_ds':
+    elif args.dataset == 'Pose_300W_LP_random_ds':
         pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
     elif args.dataset == 'AFLW2000':
         pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)

--
Gitblit v1.8.0