natanielruiz
2017-08-08 868222967bf310e6c5bc1d6b3af0e9e49d2992c2
Before experiments
2个文件已添加
2个文件已修改
375 ■■■■■ 已修改文件
code/datasets.py 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/test_resnet_bins_grayscale.py 144 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/train_resnet_bins_grayscale.py 159 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/utils.py 58 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/datasets.py
@@ -7,6 +7,10 @@
import utils
def stack_grayscale_tensor(tensor):
    tensor = torch.cat([tensor, tensor, tensor], 0)
    return tensor
class Pose_300W_LP(Dataset):
    def __init__(self, data_dir, filename_path, transform, img_ext='.jpg', annot_ext='.mat'):
        self.data_dir = data_dir
@@ -66,7 +70,7 @@
        return self.length
class Pose_300W_LP_binned(Dataset):
    def __init__(self, data_dir, filename_path, transform, img_ext='.jpg', annot_ext='.mat'):
    def __init__(self, data_dir, filename_path, transform, img_ext='.jpg', annot_ext='.mat', image_mode='RGB'):
        self.data_dir = data_dir
        self.transform = transform
        self.img_ext = img_ext
@@ -76,11 +80,12 @@
        self.X_train = filename_list
        self.y_train = filename_list
        self.image_mode = image_mode
        self.length = len(filename_list)
    def __getitem__(self, index):
        img = Image.open(os.path.join(self.data_dir, self.X_train[index] + self.img_ext))
        img = img.convert('RGB')
        img = img.convert(self.image_mode)
        mat_path = os.path.join(self.data_dir, self.y_train[index] + self.annot_ext)
        # Crop the face
@@ -117,7 +122,7 @@
        return self.length
class AFLW2000_binned(Dataset):
    def __init__(self, data_dir, filename_path, transform, img_ext='.jpg', annot_ext='.mat'):
    def __init__(self, data_dir, filename_path, transform, img_ext='.jpg', annot_ext='.mat', image_mode='RGB'):
        self.data_dir = data_dir
        self.transform = transform
        self.img_ext = img_ext
@@ -127,11 +132,12 @@
        self.X_train = filename_list
        self.y_train = filename_list
        self.image_mode = image_mode
        self.length = len(filename_list)
    def __getitem__(self, index):
        img = Image.open(os.path.join(self.data_dir, self.X_train[index] + self.img_ext))
        img = img.convert('RGB')
        img = img.convert(self.image_mode)
        mat_path = os.path.join(self.data_dir, self.y_train[index] + self.annot_ext)
        # Crop the face
code/test_resnet_bins_grayscale.py
New file
@@ -0,0 +1,144 @@
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.backends.cudnn as cudnn
import torchvision
import torch.nn.functional as F
import cv2
import matplotlib.pyplot as plt
import sys
import os
import argparse
import datasets
import hopenet
import utils
def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
            default=0, type=int)
    parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.',
          default='', type=str)
    parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
          default='', type=str)
    parser.add_argument('--snapshot', dest='snapshot', help='Name of model snapshot.',
          default='', type=str)
    parser.add_argument('--batch_size', dest='batch_size', help='Batch size.',
          default=1, type=int)
    parser.add_argument('--save_viz', dest='save_viz', help='Save images with pose cube.',
          default=False, type=bool)
    args = parser.parse_args()
    return args
if __name__ == '__main__':
    args = parse_args()
    cudnn.enabled = True
    batch_size = 1
    gpu = args.gpu_id
    snapshot_path = os.path.join('output/snapshots', args.snapshot + '.pkl')
    # ResNet101 with 3 outputs.
    # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
    # ResNet50
    # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
    # ResNet18
    model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
    print 'Loading snapshot.'
    # Load snapshot
    saved_state_dict = torch.load(snapshot_path)
    model.load_state_dict(saved_state_dict)
    print 'Loading data.'
    transformations = transforms.Compose([transforms.Scale(224),
    transforms.RandomCrop(224), transforms.ToTensor(),
    transforms.Lambda(lambda x: datasets.stack_grayscale_tensor(x))])
    pose_dataset = datasets.AFLW2000_binned(args.data_dir, args.filename_list,
                                transformations, image_mode = 'L')
    test_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
                                               batch_size=batch_size,
                                               num_workers=2)
    model.cuda(gpu)
    print 'Ready to test network.'
    # Test the Model
    model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
    total = 0
    n_margins = 20
    yaw_correct = np.zeros(n_margins)
    pitch_correct = np.zeros(n_margins)
    roll_correct = np.zeros(n_margins)
    idx_tensor = [idx for idx in xrange(66)]
    idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
    yaw_error = .0
    pitch_error = .0
    roll_error = .0
    for i, (images, labels, name) in enumerate(test_loader):
        images = Variable(images).cuda(gpu)
        total += labels.size(0)
        label_yaw = labels[:,0]
        label_pitch = labels[:,1]
        label_roll = labels[:,2]
        yaw, pitch, roll = model(images)
        # Binned predictions
        _, yaw_bpred = torch.max(yaw.data, 1)
        _, pitch_bpred = torch.max(pitch.data, 1)
        _, roll_bpred = torch.max(roll.data, 1)
        yaw_predicted = F.softmax(yaw)
        pitch_predicted = F.softmax(pitch)
        roll_predicted = F.softmax(roll)
        # Continuous predictions
        yaw_predicted = torch.sum(yaw_predicted.data[0] * idx_tensor)
        pitch_predicted = torch.sum(pitch_predicted.data[0] * idx_tensor)
        roll_predicted = torch.sum(roll_predicted.data[0] * idx_tensor)
        # Mean absolute error
        yaw_error += abs(yaw_predicted - label_yaw[0]) * 3
        pitch_error += abs(pitch_predicted - label_pitch[0]) * 3
        roll_error += abs(roll_predicted - label_roll[0]) * 3
        # Binned Accuracy
        # for er in xrange(n_margins):
        #     yaw_bpred[er] += (label_yaw[0] in range(yaw_bpred[0,0] - er, yaw_bpred[0,0] + er + 1))
        #     pitch_bpred[er] += (label_pitch[0] in range(pitch_bpred[0,0] - er, pitch_bpred[0,0] + er + 1))
        #     roll_bpred[er] += (label_roll[0] in range(roll_bpred[0,0] - er, roll_bpred[0,0] + er + 1))
        # print label_yaw[0], yaw_bpred[0,0]
        # Save images with pose cube.
        if args.save_viz:
            name = name[0]
            cv2_img = cv2.imread(os.path.join(args.data_dir, name + '.jpg'))
            #print os.path.join('output/images', name + '.jpg')
            #print label_yaw[0] * 3 - 99, label_pitch[0] * 3 - 99, label_roll[0] * 3 - 99
            #print yaw_predicted * 3 - 99, pitch_predicted * 3 - 99, roll_predicted * 3 - 99
            utils.plot_pose_cube(cv2_img, yaw_predicted * 3 - 99, pitch_predicted * 3 - 99, roll_predicted * 3 - 99)
            cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
    print('Test error in degrees of the model on the ' + str(total) +
    ' test images. Yaw: %.4f, Pitch: %.4f, Roll: %.4f' % (yaw_error / total,
    pitch_error / total, roll_error / total))
    # Binned accuracy
    # for idx in xrange(len(yaw_correct)):
    #     print yaw_correct[idx] / total, pitch_correct[idx] / total, roll_correct[idx] / total
code/train_resnet_bins_grayscale.py
New file
@@ -0,0 +1,159 @@
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
import torch.backends.cudnn as cudnn
import cv2
import matplotlib.pyplot as plt
import sys
import os
import argparse
import datasets
import hopenet
import torch.utils.model_zoo as model_zoo
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
            default=0, type=int)
    parser.add_argument('--num_epochs', dest='num_epochs', help='Maximum number of training epochs.',
          default=5, type=int)
    parser.add_argument('--batch_size', dest='batch_size', help='Batch size.',
          default=16, type=int)
    parser.add_argument('--lr', dest='lr', help='Base learning rate.',
          default=0.001, type=float)
    parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.',
          default='', type=str)
    parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
          default='', type=str)
    args = parser.parse_args()
    return args
def get_ignored_params(model):
    # Generator function that yields ignored params.
    b = []
    b.append(model.conv1)
    b.append(model.bn1)
    b.append(model.layer1)
    b.append(model.layer2)
    b.append(model.layer3)
    b.append(model.layer4)
    for i in range(len(b)):
        for j in b[i].modules():
            for k in j.parameters():
                yield k
def get_non_ignored_params(model):
    # Generator function that yields params that will be optimized.
    b = []
    b.append(model.fc_yaw)
    b.append(model.fc_pitch)
    b.append(model.fc_roll)
    for i in range(len(b)):
        for j in b[i].modules():
            for k in j.parameters():
                    yield k
def load_filtered_state_dict(model, snapshot):
    # By user apaszke from discuss.pytorch.org
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(snapshot)
    # 3. load the new state dict
    model.load_state_dict(model_dict)
if __name__ == '__main__':
    args = parse_args()
    cudnn.enabled = True
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    gpu = args.gpu_id
    if not os.path.exists('output/snapshots'):
        os.makedirs('output/snapshots')
    # ResNet101 with 3 outputs
    # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
    # ResNet50
    # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
    # ResNet18
    model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
    load_filtered_state_dict(model, model_zoo.load_url(model_urls['resnet18']))
    print 'Loading data.'
    transformations = transforms.Compose([transforms.Scale(224),transforms.RandomCrop(224),
                                          transforms.ToTensor(), transforms.Lambda(lambda x: datasets.stack_grayscale_tensor(x))])
    pose_dataset = datasets.Pose_300W_LP_binned(args.data_dir, args.filename_list,
                                transformations, image_mode='L')
    train_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=2)
    model.cuda(gpu)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': args.lr},
                                  {'params': get_non_ignored_params(model), 'lr': args.lr * 10}],
                                  lr = args.lr)
    # optimizer = torch.optim.SGD([{'params': get_ignored_params(model), 'lr': args.lr},
    #                               {'params': get_non_ignored_params(model), 'lr': args.lr}],
    #                               lr = args.lr, momentum=0.9)
    # optimizer = torch.optim.RMSprop([{'params': get_ignored_params(model), 'lr': args.lr},
    #                               {'params': get_non_ignored_params(model), 'lr': args.lr}],
    #                               lr = args.lr)
    print 'Ready to train network.'
    for epoch in range(num_epochs):
        for i, (images, labels, name) in enumerate(train_loader):
            images = Variable(images).cuda(gpu)
            label_yaw = Variable(labels[:,0]).cuda(gpu)
            label_pitch = Variable(labels[:,1]).cuda(gpu)
            label_roll = Variable(labels[:,2]).cuda(gpu)
            optimizer.zero_grad()
            yaw, pitch, roll = model(images)
            loss_yaw = criterion(yaw, label_yaw)
            loss_pitch = criterion(pitch, label_pitch)
            loss_roll = criterion(roll, label_roll)
            loss_seq = [loss_yaw, loss_pitch, loss_roll]
            grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))]
            torch.autograd.backward(loss_seq, grad_seq)
            optimizer.step()
            if (i+1) % 100 == 0:
                print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
                       %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0]))
            # if (i+1) % 10000 and epoch == 0:
            #     torch.save(model.state_dict(), 'output/snapshots/resnet18_cr_gray_iter_' + str(i+1) + '.pkl')
        # Save models at numbered epochs.
        if epoch % 1 == 0 and epoch < num_epochs - 1:
            print 'Taking snapshot...'
            torch.save(model.state_dict(),
            'output/snapshots/resnet18_cr_gray_epoch_'+ str(epoch+1) + '.pkl')
    # Save the final Trained Model
    torch.save(model.state_dict(), 'output/snapshots/resnet18_cr_gray_epoch_' + str(epoch+1) + '.pkl')
code/utils.py
@@ -7,6 +7,35 @@
import math
from math import cos, sin
def get_pose_params_from_mat(mat_path):
    # This functions gets the pose parameters from the .mat
    # Annotations that come with the 300W_LP dataset.
    mat = sio.loadmat(mat_path)
    # [pitch yaw roll tdx tdy tdz scale_factor]
    pre_pose_params = mat['Pose_Para'][0]
    # Get [pitch, yaw, roll, tdx, tdy]
    pose_params = pre_pose_params[:5]
    return pose_params
def get_ypr_from_mat(mat_path):
    # Get yaw, pitch, roll from .mat annotation.
    # They are in radians
    mat = sio.loadmat(mat_path)
    # [pitch yaw roll tdx tdy tdz scale_factor]
    pre_pose_params = mat['Pose_Para'][0]
    # Get [pitch, yaw, roll]
    pose_params = pre_pose_params[:3]
    return pose_params
def get_pt2d_from_mat(mat_path):
    # Get 2D landmarks
    mat = sio.loadmat(mat_path)
    pt2d = mat['pt2d']
    return pt2d
def mse_loss(input, target):
    return torch.sum(torch.abs(input.data - target.data) ** 2)
def plot_pose_cube(img, yaw, pitch, roll, tdx=None, tdy=None, size=150.):
    # Input is a cv2 image
    # pose_params: (pitch, yaw, roll, tdx, tdy)
@@ -49,32 +78,3 @@
    cv2.line(img, (int(x3), int(y3)), (int(x3+x2-face_x),int(y3+y2-face_y)),(0,255,0),2)
    return img
def get_pose_params_from_mat(mat_path):
    # This functions gets the pose parameters from the .mat
    # Annotations that come with the 300W_LP dataset.
    mat = sio.loadmat(mat_path)
    # [pitch yaw roll tdx tdy tdz scale_factor]
    pre_pose_params = mat['Pose_Para'][0]
    # Get [pitch, yaw, roll, tdx, tdy]
    pose_params = pre_pose_params[:5]
    return pose_params
def get_ypr_from_mat(mat_path):
    # Get yaw, pitch, roll from .mat annotation.
    # They are in radians
    mat = sio.loadmat(mat_path)
    # [pitch yaw roll tdx tdy tdz scale_factor]
    pre_pose_params = mat['Pose_Para'][0]
    # Get [pitch, yaw, roll]
    pose_params = pre_pose_params[:3]
    return pose_params
def get_pt2d_from_mat(mat_path):
    # Get 2D landmarks
    mat = sio.loadmat(mat_path)
    pt2d = mat['pt2d']
    return pt2d
def mse_loss(input, target):
    return torch.sum(torch.abs(input.data - target.data) ** 2)