From 63d8126a674b8c3f0adf6ebc978832f548f757ca Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期六, 23 九月 2017 02:43:56 +0800
Subject: [PATCH] next

---
 code/hopenet.py       |   40 ++++++++
 code/train_alexnet.py |  214 ++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 254 insertions(+), 0 deletions(-)

diff --git a/code/hopenet.py b/code/hopenet.py
index 63a24cd..7b5f764 100644
--- a/code/hopenet.py
+++ b/code/hopenet.py
@@ -184,3 +184,43 @@
         x = self.fc_angles(x)
 
         return x
+
+class AlexNet(nn.Module):
+
+    def __init__(self, num_bins):
+        super(AlexNet, self).__init__()
+        self.features = nn.Sequential(
+            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+            nn.Conv2d(64, 192, kernel_size=5, padding=2),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+            nn.Conv2d(192, 384, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(384, 256, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(256, 256, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+        )
+        self.classifier = nn.Sequential(
+            nn.Dropout(),
+            nn.Linear(256 * 6 * 6, 4096),
+            nn.ReLU(inplace=True),
+            nn.Dropout(),
+            nn.Linear(4096, 4096),
+            nn.ReLU(inplace=True),
+        )
+        self.fc_yaw = nn.Linear(4096, num_bins)
+        self.fc_pitch = nn.Linear(4096, num_bins)
+        self.fc_roll = nn.Linear(4096, num_bins)
+
+    def forward(self, x):
+        x = self.features(x)
+        x = x.view(x.size(0), 256 * 6 * 6)
+        x = self.classifier(x)
+        yaw = self.fc_yaw(x)
+        pitch = self.fc_pitch(x)
+        roll = self.fc_roll(x)
+        return yaw, pitch, roll
diff --git a/code/train_alexnet.py b/code/train_alexnet.py
new file mode 100644
index 0000000..5f60211
--- /dev/null
+++ b/code/train_alexnet.py
@@ -0,0 +1,214 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.utils.data import DataLoader
+from torchvision import transforms
+import torchvision
+import torch.backends.cudnn as cudnn
+import torch.nn.functional as F
+
+import cv2
+import matplotlib.pyplot as plt
+import sys
+import os
+import argparse
+
+import datasets
+import hopenet
+import torch.utils.model_zoo as model_zoo
+
+import time
+
+model_urls = {
+    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
+}
+
+def parse_args():
+    """Parse input arguments."""
+    parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
+    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
+            default=0, type=int)
+    parser.add_argument('--num_epochs', dest='num_epochs', help='Maximum number of training epochs.',
+          default=5, type=int)
+    parser.add_argument('--batch_size', dest='batch_size', help='Batch size.',
+          default=16, type=int)
+    parser.add_argument('--lr', dest='lr', help='Base learning rate.',
+          default=0.001, type=float)
+    parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.',
+          default='', type=str)
+    parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
+          default='', type=str)
+    parser.add_argument('--output_string', dest='output_string', help='String appended to output snapshots.', default = '', type=str)
+    parser.add_argument('--alpha', dest='alpha', help='Regression loss coefficient.',
+          default=0.001, type=float)
+    parser.add_argument('--dataset', dest='dataset', help='Dataset type.', default='Pose_300W_LP', type=str)
+
+    args = parser.parse_args()
+    return args
+
+def get_ignored_params(model):
+    # Generator function that yields ignored params.
+    b = []
+    b.append(model.features[0])
+    b.append(model.features[1])
+    b.append(model.features[2])
+    for i in range(len(b)):
+        for module_name, module in b[i].named_modules():
+            if 'bn' in module_name:
+                module.eval()
+            for name, param in module.named_parameters():
+                yield param
+
+def get_non_ignored_params(model):
+    # Generator function that yields params that will be optimized.
+    b = []
+    for idx in xrange(3, len(model.features)):
+        b.append(model.features[idx])
+    for layer in model.classifier:
+        b.append(layer)
+    for i in range(len(b)):
+        for module_name, module in b[i].named_modules():
+            if 'bn' in module_name:
+                module.eval()
+            for name, param in module.named_parameters():
+                yield param
+
+def get_fc_params(model):
+    b = []
+    b.append(model.fc_yaw)
+    b.append(model.fc_pitch)
+    b.append(model.fc_roll)
+    for i in range(len(b)):
+        for module_name, module in b[i].named_modules():
+            for name, param in module.named_parameters():
+                yield param
+
+def load_filtered_state_dict(model, snapshot):
+    # By user apaszke from discuss.pytorch.org
+    model_dict = model.state_dict()
+    # 1. filter out unnecessary keys
+    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
+    # 2. overwrite entries in the existing state dict
+    model_dict.update(snapshot)
+    # 3. load the new state dict
+    model.load_state_dict(model_dict)
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    cudnn.enabled = True
+    num_epochs = args.num_epochs
+    batch_size = args.batch_size
+    gpu = args.gpu_id
+
+    if not os.path.exists('output/snapshots'):
+        os.makedirs('output/snapshots')
+
+    model = hopenet.AlexNet(66)
+    load_filtered_state_dict(model, model_zoo.load_url(model_urls['alexnet']))
+
+    print 'Loading data.'
+
+    transformations = transforms.Compose([transforms.Scale(240),
+    transforms.RandomCrop(224), transforms.ToTensor(),
+    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
+
+    if args.dataset == 'Pose_300W_LP':
+        pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW2000':
+        pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'BIWI':
+        pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW':
+        pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFLW_aug':
+        pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
+    elif args.dataset == 'AFW':
+        pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
+    else:
+        print 'Error: not a valid dataset name'
+        sys.exit()
+    train_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
+                                               batch_size=batch_size,
+                                               shuffle=True,
+                                               num_workers=2)
+
+    model.cuda(gpu)
+    softmax = nn.Softmax().cuda(gpu)
+    criterion = nn.CrossEntropyLoss().cuda(gpu)
+    reg_criterion = nn.MSELoss().cuda(gpu)
+    # Regression loss coefficient
+    alpha = args.alpha
+
+    idx_tensor = [idx for idx in xrange(66)]
+    idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
+
+    optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0},
+                                  {'params': get_non_ignored_params(model), 'lr': args.lr},
+                                  {'params': get_fc_params(model), 'lr': args.lr * 5}],
+                                   lr = args.lr)
+
+    print 'Ready to train network.'
+    print 'First phase of training.'
+    for epoch in range(num_epochs):
+        # start = time.time()
+        for i, (images, labels, cont_labels, name) in enumerate(train_loader):
+            # print i
+            # print 'start: ', time.time() - start
+            images = Variable(images).cuda(gpu)
+            label_yaw = Variable(labels[:,0]).cuda(gpu)
+            label_pitch = Variable(labels[:,1]).cuda(gpu)
+            label_roll = Variable(labels[:,2]).cuda(gpu)
+
+            label_angles = Variable(cont_labels[:,:3]).cuda(gpu)
+            label_yaw_cont = Variable(cont_labels[:,0]).cuda(gpu)
+            label_pitch_cont = Variable(cont_labels[:,1]).cuda(gpu)
+            label_roll_cont = Variable(cont_labels[:,2]).cuda(gpu)
+
+            optimizer.zero_grad()
+            model.zero_grad()
+
+            pre_yaw, pre_pitch, pre_roll = model(images)
+            # Cross entropy loss
+            loss_yaw = criterion(pre_yaw, label_yaw)
+            loss_pitch = criterion(pre_pitch, label_pitch)
+            loss_roll = criterion(pre_roll, label_roll)
+
+            # MSE loss
+            yaw_predicted = softmax(pre_yaw)
+            pitch_predicted = softmax(pre_pitch)
+            roll_predicted = softmax(pre_roll)
+
+            yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 99
+            pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 99
+            roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) * 3 - 99
+
+            loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw_cont)
+            loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch_cont)
+            loss_reg_roll = reg_criterion(roll_predicted, label_roll_cont)
+
+            # Total loss
+            loss_yaw += alpha * loss_reg_yaw
+            loss_pitch += alpha * loss_reg_pitch
+            loss_roll += alpha * loss_reg_roll
+
+            loss_seq = [loss_yaw, loss_pitch, loss_roll]
+            grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))]
+            torch.autograd.backward(loss_seq, grad_seq)
+            optimizer.step()
+
+            # print 'end: ', time.time() - start
+
+            if (i+1) % 100 == 0:
+                print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
+                       %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0]))
+                # if epoch == 0:
+                #     torch.save(model.state_dict(),
+                #     'output/snapshots/' + args.output_string + '_iter_'+ str(i+1) + '.pkl')
+
+        # Save models at numbered epochs.
+        if epoch % 1 == 0 and epoch < num_epochs:
+            print 'Taking snapshot...'
+            torch.save(model.state_dict(),
+            'output/snapshots/' + args.output_string + '_epoch_'+ str(epoch+1) + '.pkl')

--
Gitblit v1.8.0