From 868222967bf310e6c5bc1d6b3af0e9e49d2992c2 Mon Sep 17 00:00:00 2001 From: natanielruiz <nataniel777@hotmail.com> Date: 星期二, 08 八月 2017 10:30:30 +0800 Subject: [PATCH] Before experiments --- code/datasets.py | 14 ++++++++++---- 1 files changed, 10 insertions(+), 4 deletions(-) diff --git a/code/datasets.py b/code/datasets.py index 06cd433..4d1f71f 100644 --- a/code/datasets.py +++ b/code/datasets.py @@ -7,6 +7,10 @@ import utils +def stack_grayscale_tensor(tensor): + tensor = torch.cat([tensor, tensor, tensor], 0) + return tensor + class Pose_300W_LP(Dataset): def __init__(self, data_dir, filename_path, transform, img_ext='.jpg', annot_ext='.mat'): self.data_dir = data_dir @@ -66,7 +70,7 @@ return self.length class Pose_300W_LP_binned(Dataset): - def __init__(self, data_dir, filename_path, transform, img_ext='.jpg', annot_ext='.mat'): + def __init__(self, data_dir, filename_path, transform, img_ext='.jpg', annot_ext='.mat', image_mode='RGB'): self.data_dir = data_dir self.transform = transform self.img_ext = img_ext @@ -76,11 +80,12 @@ self.X_train = filename_list self.y_train = filename_list + self.image_mode = image_mode self.length = len(filename_list) def __getitem__(self, index): img = Image.open(os.path.join(self.data_dir, self.X_train[index] + self.img_ext)) - img = img.convert('RGB') + img = img.convert(self.image_mode) mat_path = os.path.join(self.data_dir, self.y_train[index] + self.annot_ext) # Crop the face @@ -117,7 +122,7 @@ return self.length class AFLW2000_binned(Dataset): - def __init__(self, data_dir, filename_path, transform, img_ext='.jpg', annot_ext='.mat'): + def __init__(self, data_dir, filename_path, transform, img_ext='.jpg', annot_ext='.mat', image_mode='RGB'): self.data_dir = data_dir self.transform = transform self.img_ext = img_ext @@ -127,11 +132,12 @@ self.X_train = filename_list self.y_train = filename_list + self.image_mode = image_mode self.length = len(filename_list) def __getitem__(self, index): img = Image.open(os.path.join(self.data_dir, self.X_train[index] + self.img_ext)) - img = img.convert('RGB') + img = img.convert(self.image_mode) mat_path = os.path.join(self.data_dir, self.y_train[index] + self.annot_ext) # Crop the face -- Gitblit v1.8.0