| | |
| | | import os |
| | | import numpy as np |
| | | import cv2 |
| | | import pandas as pd |
| | | |
| | | import torch |
| | | from torch.utils.data.dataset import Dataset |
| | |
| | | lines = f.read().splitlines() |
| | | return lines |
| | | |
| | | class Synhead(Dataset): |
| | | def __init__(self, data_dir, csv_path, transform, test=False): |
| | | column_names = ['path', 'bbox_x_min', 'bbox_y_min', 'bbox_x_max', 'bbox_y_max', 'yaw', 'pitch', 'roll'] |
| | | tmp_df = pd.read_csv(csv_path, sep=',', names=column_names, index_col=False, encoding="utf-8-sig") |
| | | self.data_dir = data_dir |
| | | self.transform = transform |
| | | self.X_train = tmp_df['path'] |
| | | self.y_train = tmp_df[['bbox_x_min', 'bbox_y_min', 'bbox_x_max', 'bbox_y_max', 'yaw', 'pitch', 'roll']] |
| | | self.length = len(tmp_df) |
| | | self.test = test |
| | | |
| | | def __getitem__(self, index): |
| | | path = os.path.join(self.data_dir, self.X_train.iloc[index]).strip('.jpg') + '.png' |
| | | img = Image.open(path) |
| | | img = img.convert('RGB') |
| | | |
| | | x_min, y_min, x_max, y_max, yaw, pitch, roll = self.y_train.iloc[index] |
| | | x_min = float(x_min); x_max = float(x_max) |
| | | y_min = float(y_min); y_max = float(y_max) |
| | | yaw = -float(yaw); pitch = float(pitch); roll = float(roll) |
| | | |
| | | # k = 0.2 to 0.40 |
| | | k = np.random.random_sample() * 0.2 + 0.2 |
| | | x_min -= 0.6 * k * abs(x_max - x_min) |
| | | y_min -= 2 * k * abs(y_max - y_min) |
| | | x_max += 0.6 * k * abs(x_max - x_min) |
| | | y_max += 0.6 * k * abs(y_max - y_min) |
| | | |
| | | width, height = img.size |
| | | # Crop the face |
| | | img = img.crop((int(x_min), int(y_min), int(x_max), int(y_max))) |
| | | |
| | | # Flip? |
| | | rnd = np.random.random_sample() |
| | | if rnd < 0.5: |
| | | yaw = -yaw |
| | | roll = -roll |
| | | img = img.transpose(Image.FLIP_LEFT_RIGHT) |
| | | |
| | | # Blur? |
| | | rnd = np.random.random_sample() |
| | | if rnd < 0.05: |
| | | img = img.filter(ImageFilter.BLUR) |
| | | |
| | | # Bin values |
| | | bins = np.array(range(-99, 102, 3)) |
| | | binned_pose = np.digitize([yaw, pitch, roll], bins) - 1 |
| | | |
| | | labels = torch.LongTensor(binned_pose) |
| | | cont_labels = torch.FloatTensor([yaw, pitch, roll]) |
| | | |
| | | if self.transform is not None: |
| | | img = self.transform(img) |
| | | |
| | | return img, labels, cont_labels, self.X_train[index] |
| | | |
| | | def __len__(self): |
| | | return self.length |
| | | |
| | | class Pose_300W_LP(Dataset): |
| | | # Head pose from 300W-LP dataset |
| | | def __init__(self, data_dir, filename_path, transform, img_ext='.jpg', annot_ext='.mat', image_mode='RGB'): |
| | |
| | | default=16, type=int) |
| | | parser.add_argument('--lr', dest='lr', help='Base learning rate.', |
| | | default=0.001, type=float) |
| | | parser.add_argument('--dataset', dest='dataset', help='Dataset type.', default='Pose_300W_LP', type=str) |
| | | parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.', |
| | | default='', type=str) |
| | | parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.', |
| | |
| | | parser.add_argument('--output_string', dest='output_string', help='String appended to output snapshots.', default = '', type=str) |
| | | parser.add_argument('--alpha', dest='alpha', help='Regression loss coefficient.', |
| | | default=0.001, type=float) |
| | | parser.add_argument('--dataset', dest='dataset', help='Dataset type.', default='Pose_300W_LP', type=str) |
| | | parser.add_argument('--snapshot', dest='snapshot', help='Path of model snapshot.', |
| | | default='', type=str) |
| | | |
| | | args = parser.parse_args() |
| | | return args |
| | |
| | | |
| | | # ResNet50 structure |
| | | model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66) |
| | | load_filtered_state_dict(model, model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth')) |
| | | |
| | | if args.snapshot == '': |
| | | load_filtered_state_dict(model, model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth')) |
| | | else: |
| | | saved_state_dict = torch.load(args.snapshot) |
| | | model.load_state_dict(saved_state_dict) |
| | | |
| | | print 'Loading data.' |
| | | |
| | |
| | | pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations) |
| | | elif args.dataset == 'Pose_300W_LP_random_ds': |
| | | pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations) |
| | | elif args.dataset == 'Synhead': |
| | | pose_dataset = datasets.Synhead(args.data_dir, args.filename_list, transformations) |
| | | elif args.dataset == 'AFLW2000': |
| | | pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations) |
| | | elif args.dataset == 'BIWI': |
| | |
| | | label_roll_cont = Variable(cont_labels[:,2]).cuda(gpu) |
| | | |
| | | # Forward pass |
| | | yaw, pitch, roll, angles = model(images) |
| | | yaw, pitch, roll = model(images) |
| | | |
| | | # Cross entropy loss |
| | | loss_yaw = criterion(yaw, label_yaw) |