natanielruiz
2017-09-08 0b8e19c1cc8ad03805d4ca68f32df6e4806a36e8
Finetune layer working
2个文件已添加
4个文件已修改
483 ■■■■ 已修改文件
code/datasets.py 16 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/hopenet.py 36 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/test.py 38 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/test_old.py 149 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/test_preangles.py 149 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/train.py 95 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
code/datasets.py
@@ -60,14 +60,14 @@
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
        # Rotate?
        rnd = np.random.random_sample()
        if rnd < 0.5:
            if roll >= 0:
                img = img.rotate(30)
                roll -= 30
            else:
                img = img.rotate(-30)
                roll += 30
        # rnd = np.random.random_sample()
        # if rnd < 0.5:
        #     if roll >= 0:
        #         img = img.rotate(30)
        #         roll -= 30
        #     else:
        #         img = img.rotate(-30)
        #         roll += 30
        # Bin values
        bins = np.array(range(-99, 102, 3))
code/hopenet.py
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
import torchvision.datasets as dsets
from torch.autograd import Variable
import math
import torch.nn.functional as F
# CNN Model (2 conv layer)
class Simple_CNN(nn.Module):
@@ -58,6 +58,11 @@
        self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)
        self.fc_roll = nn.Linear(512 * block.expansion, num_bins)
        self.softmax = nn.Softmax()
        self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
        self.idx_tensor = Variable(torch.FloatTensor(range(66))).cuda()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
@@ -83,6 +88,12 @@
        return nn.Sequential(*layers)
    def get_expectation(angle):
        angle_pred = F.softmax(angle)
        angle_pred = torch.sum(angle_pred.data * self.idx_tensor, 1)
        return angle_pred
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
@@ -96,11 +107,26 @@
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        yaw = self.fc_yaw(x)
        pitch = self.fc_pitch(x)
        roll = self.fc_roll(x)
        pre_yaw = self.fc_yaw(x)
        pre_pitch = self.fc_pitch(x)
        pre_roll = self.fc_roll(x)
        return yaw, pitch, roll
        yaw = self.softmax(pre_yaw)
        yaw = Variable(torch.sum(yaw.data * self.idx_tensor.data, 1), requires_grad=True)
        pitch = self.softmax(pre_pitch)
        pitch = Variable(torch.sum(pitch.data * self.idx_tensor.data, 1), requires_grad=True)
        roll = self.softmax(pre_roll)
        roll = Variable(torch.sum(roll.data * self.idx_tensor.data, 1), requires_grad=True)
        yaw = yaw.view(yaw.size(0), 1)
        pitch = pitch.view(pitch.size(0), 1)
        roll = roll.view(roll.size(0), 1)
        angles = []
        angles.append(torch.cat([yaw, pitch, roll], 1))
        for idx in xrange(1):
            angles.append(self.fc_finetune(torch.cat((angles[-1], x), 1)))
        return pre_yaw, pre_pitch, pre_roll, angles
class Hopenet_shape(nn.Module):
    # This is just Hopenet with 3 output layers for yaw, pitch and roll.
code/test.py
@@ -100,44 +100,22 @@
        label_pitch = labels[:,1].float()
        label_roll = labels[:,2].float()
        yaw, pitch, roll = model(images)
        # Binned predictions
        _, yaw_bpred = torch.max(yaw.data, 1)
        _, pitch_bpred = torch.max(pitch.data, 1)
        _, roll_bpred = torch.max(roll.data, 1)
        # Continuous predictions
        yaw_predicted = utils.softmax_temperature(yaw.data, 1)
        pitch_predicted = utils.softmax_temperature(pitch.data, 1)
        roll_predicted = utils.softmax_temperature(roll.data, 1)
        yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu()
        pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu()
        roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu()
        pre_yaw, pre_pitch, pre_roll, angles = model(images)
        yaw = angles[:,0].cpu().data
        pitch = angles[:,1].cpu().data
        roll = angles[:,2].cpu().data
        # Mean absolute error
        yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw) * 3)
        pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch) * 3)
        roll_error += torch.sum(torch.abs(roll_predicted - label_roll) * 3)
        # Binned Accuracy
        # for er in xrange(n_margins):
        #     yaw_bpred[er] += (label_yaw[0] in range(yaw_bpred[0,0] - er, yaw_bpred[0,0] + er + 1))
        #     pitch_bpred[er] += (label_pitch[0] in range(pitch_bpred[0,0] - er, pitch_bpred[0,0] + er + 1))
        #     roll_bpred[er] += (label_roll[0] in range(roll_bpred[0,0] - er, roll_bpred[0,0] + er + 1))
        # print label_yaw[0], yaw_bpred[0,0]
        yaw_error += torch.sum(torch.abs(yaw - label_yaw) * 3)
        pitch_error += torch.sum(torch.abs(pitch - label_pitch) * 3)
        roll_error += torch.sum(torch.abs(roll - label_roll) * 3)
        # Save images with pose cube.
        # TODO: fix for larger batch size
        if args.save_viz:
            name = name[0]
            cv2_img = cv2.imread(os.path.join(args.data_dir, name + '.jpg'))
            #print os.path.join('output/images', name + '.jpg')
            #print label_yaw[0] * 3 - 99, label_pitch[0] * 3 - 99, label_roll[0] * 3 - 99
            #print yaw_predicted * 3 - 99, pitch_predicted * 3 - 99, roll_predicted * 3 - 99
            utils.plot_pose_cube(cv2_img, yaw_predicted[0] * 3 - 99, pitch_predicted[0] * 3 - 99, roll_predicted[0] * 3 - 99)
            utils.plot_pose_cube(cv2_img, yaw[0] * 3 - 99, pitch[0] * 3 - 99, roll[0] * 3 - 99)
            cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
    print('Test error in degrees of the model on the ' + str(total) +
code/test_old.py
New file
@@ -0,0 +1,149 @@
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.backends.cudnn as cudnn
import torchvision
import torch.nn.functional as F
import cv2
import matplotlib.pyplot as plt
import sys
import os
import argparse
import datasets
import hopenet
import utils
def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
            default=0, type=int)
    parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.',
          default='', type=str)
    parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
          default='', type=str)
    parser.add_argument('--snapshot', dest='snapshot', help='Name of model snapshot.',
          default='', type=str)
    parser.add_argument('--batch_size', dest='batch_size', help='Batch size.',
          default=1, type=int)
    parser.add_argument('--save_viz', dest='save_viz', help='Save images with pose cube.',
          default=False, type=bool)
    args = parser.parse_args()
    return args
if __name__ == '__main__':
    args = parse_args()
    cudnn.enabled = True
    gpu = args.gpu_id
    snapshot_path = os.path.join('output/snapshots', args.snapshot + '.pkl')
    # ResNet101 with 3 outputs.
    # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
    # ResNet50
    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
    # ResNet18
    # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
    print 'Loading snapshot.'
    # Load snapshot
    saved_state_dict = torch.load(snapshot_path)
    model.load_state_dict(saved_state_dict)
    print 'Loading data.'
    # transformations = transforms.Compose([transforms.Scale(224),
    # transforms.RandomCrop(224), transforms.ToTensor()])
    transformations = transforms.Compose([transforms.Scale(224),
    transforms.RandomCrop(224), transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
                                transformations)
    test_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=2)
    model.cuda(gpu)
    print 'Ready to test network.'
    # Test the Model
    model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
    total = 0
    n_margins = 20
    yaw_correct = np.zeros(n_margins)
    pitch_correct = np.zeros(n_margins)
    roll_correct = np.zeros(n_margins)
    idx_tensor = [idx for idx in xrange(66)]
    idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
    yaw_error = .0
    pitch_error = .0
    roll_error = .0
    l1loss = torch.nn.L1Loss(size_average=False)
    for i, (images, labels, name) in enumerate(test_loader):
        images = Variable(images).cuda(gpu)
        total += labels.size(0)
        label_yaw = labels[:,0].float()
        label_pitch = labels[:,1].float()
        label_roll = labels[:,2].float()
        yaw, pitch, roll = model(images)
        # Binned predictions
        _, yaw_bpred = torch.max(yaw.data, 1)
        _, pitch_bpred = torch.max(pitch.data, 1)
        _, roll_bpred = torch.max(roll.data, 1)
        # Continuous predictions
        yaw_predicted = utils.softmax_temperature(yaw.data, 1)
        pitch_predicted = utils.softmax_temperature(pitch.data, 1)
        roll_predicted = utils.softmax_temperature(roll.data, 1)
        yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu()
        pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu()
        roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu()
        # Mean absolute error
        yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw) * 3)
        pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch) * 3)
        roll_error += torch.sum(torch.abs(roll_predicted - label_roll) * 3)
        # Binned Accuracy
        # for er in xrange(n_margins):
        #     yaw_bpred[er] += (label_yaw[0] in range(yaw_bpred[0,0] - er, yaw_bpred[0,0] + er + 1))
        #     pitch_bpred[er] += (label_pitch[0] in range(pitch_bpred[0,0] - er, pitch_bpred[0,0] + er + 1))
        #     roll_bpred[er] += (label_roll[0] in range(roll_bpred[0,0] - er, roll_bpred[0,0] + er + 1))
        # print label_yaw[0], yaw_bpred[0,0]
        # Save images with pose cube.
        # TODO: fix for larger batch size
        if args.save_viz:
            name = name[0]
            cv2_img = cv2.imread(os.path.join(args.data_dir, name + '.jpg'))
            #print os.path.join('output/images', name + '.jpg')
            #print label_yaw[0] * 3 - 99, label_pitch[0] * 3 - 99, label_roll[0] * 3 - 99
            #print yaw_predicted * 3 - 99, pitch_predicted * 3 - 99, roll_predicted * 3 - 99
            utils.plot_pose_cube(cv2_img, yaw_predicted[0] * 3 - 99, pitch_predicted[0] * 3 - 99, roll_predicted[0] * 3 - 99)
            cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
    print('Test error in degrees of the model on the ' + str(total) +
    ' test images. Yaw: %.4f, Pitch: %.4f, Roll: %.4f' % (yaw_error / total,
    pitch_error / total, roll_error / total))
    # Binned accuracy
    # for idx in xrange(len(yaw_correct)):
    #     print yaw_correct[idx] / total, pitch_correct[idx] / total, roll_correct[idx] / total
code/test_preangles.py
New file
@@ -0,0 +1,149 @@
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.backends.cudnn as cudnn
import torchvision
import torch.nn.functional as F
import cv2
import matplotlib.pyplot as plt
import sys
import os
import argparse
import datasets
import hopenet
import utils
def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
            default=0, type=int)
    parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.',
          default='', type=str)
    parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
          default='', type=str)
    parser.add_argument('--snapshot', dest='snapshot', help='Name of model snapshot.',
          default='', type=str)
    parser.add_argument('--batch_size', dest='batch_size', help='Batch size.',
          default=1, type=int)
    parser.add_argument('--save_viz', dest='save_viz', help='Save images with pose cube.',
          default=False, type=bool)
    args = parser.parse_args()
    return args
if __name__ == '__main__':
    args = parse_args()
    cudnn.enabled = True
    gpu = args.gpu_id
    snapshot_path = os.path.join('output/snapshots', args.snapshot + '.pkl')
    # ResNet101 with 3 outputs.
    # model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], 66)
    # ResNet50
    model = hopenet.Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
    # ResNet18
    # model = hopenet.Hopenet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], 66)
    print 'Loading snapshot.'
    # Load snapshot
    saved_state_dict = torch.load(snapshot_path)
    model.load_state_dict(saved_state_dict)
    print 'Loading data.'
    # transformations = transforms.Compose([transforms.Scale(224),
    # transforms.RandomCrop(224), transforms.ToTensor()])
    transformations = transforms.Compose([transforms.Scale(224),
    transforms.RandomCrop(224), transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list,
                                transformations)
    test_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=2)
    model.cuda(gpu)
    print 'Ready to test network.'
    # Test the Model
    model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
    total = 0
    n_margins = 20
    yaw_correct = np.zeros(n_margins)
    pitch_correct = np.zeros(n_margins)
    roll_correct = np.zeros(n_margins)
    idx_tensor = [idx for idx in xrange(66)]
    idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
    yaw_error = .0
    pitch_error = .0
    roll_error = .0
    l1loss = torch.nn.L1Loss(size_average=False)
    for i, (images, labels, name) in enumerate(test_loader):
        images = Variable(images).cuda(gpu)
        total += labels.size(0)
        label_yaw = labels[:,0].float()
        label_pitch = labels[:,1].float()
        label_roll = labels[:,2].float()
        yaw, pitch, roll, angles = model(images)
        # Binned predictions
        _, yaw_bpred = torch.max(yaw.data, 1)
        _, pitch_bpred = torch.max(pitch.data, 1)
        _, roll_bpred = torch.max(roll.data, 1)
        # Continuous predictions
        yaw_predicted = utils.softmax_temperature(yaw.data, 1)
        pitch_predicted = utils.softmax_temperature(pitch.data, 1)
        roll_predicted = utils.softmax_temperature(roll.data, 1)
        yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu()
        pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu()
        roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu()
        # Mean absolute error
        yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw) * 3)
        pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch) * 3)
        roll_error += torch.sum(torch.abs(roll_predicted - label_roll) * 3)
        # Binned Accuracy
        # for er in xrange(n_margins):
        #     yaw_bpred[er] += (label_yaw[0] in range(yaw_bpred[0,0] - er, yaw_bpred[0,0] + er + 1))
        #     pitch_bpred[er] += (label_pitch[0] in range(pitch_bpred[0,0] - er, pitch_bpred[0,0] + er + 1))
        #     roll_bpred[er] += (label_roll[0] in range(roll_bpred[0,0] - er, roll_bpred[0,0] + er + 1))
        # print label_yaw[0], yaw_bpred[0,0]
        # Save images with pose cube.
        # TODO: fix for larger batch size
        if args.save_viz:
            name = name[0]
            cv2_img = cv2.imread(os.path.join(args.data_dir, name + '.jpg'))
            #print os.path.join('output/images', name + '.jpg')
            #print label_yaw[0] * 3 - 99, label_pitch[0] * 3 - 99, label_roll[0] * 3 - 99
            #print yaw_predicted * 3 - 99, pitch_predicted * 3 - 99, roll_predicted * 3 - 99
            utils.plot_pose_cube(cv2_img, yaw_predicted[0] * 3 - 99, pitch_predicted[0] * 3 - 99, roll_predicted[0] * 3 - 99)
            cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
    print('Test error in degrees of the model on the ' + str(total) +
    ' test images. Yaw: %.4f, Pitch: %.4f, Roll: %.4f' % (yaw_error / total,
    pitch_error / total, roll_error / total))
    # Binned accuracy
    # for idx in xrange(len(yaw_correct)):
    #     print yaw_correct[idx] / total, pitch_correct[idx] / total, roll_correct[idx] / total
code/train.py
@@ -33,6 +33,8 @@
            default=0, type=int)
    parser.add_argument('--num_epochs', dest='num_epochs', help='Maximum number of training epochs.',
          default=5, type=int)
    parser.add_argument('--num_epochs_ft', dest='num_epochs_ft', help='Maximum number of finetuning epochs.',
          default=5, type=int)
    parser.add_argument('--batch_size', dest='batch_size', help='Batch size.',
          default=16, type=int)
    parser.add_argument('--lr', dest='lr', help='Base learning rate.',
@@ -41,9 +43,7 @@
          default='', type=str)
    parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
          default='', type=str)
    args = parser.parse_args()
    return args
def get_ignored_params(model):
@@ -66,6 +66,7 @@
    b.append(model.fc_yaw)
    b.append(model.fc_pitch)
    b.append(model.fc_roll)
    b.append(model.fc_finetune)
    for i in range(len(b)):
        for j in b[i].modules():
            for k in j.parameters():
@@ -86,6 +87,7 @@
    cudnn.enabled = True
    num_epochs = args.num_epochs
    num_epochs_ft = args.num_epochs_ft
    batch_size = args.batch_size
    gpu = args.gpu_id
@@ -129,13 +131,10 @@
    optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': args.lr},
                                  {'params': get_non_ignored_params(model), 'lr': args.lr * 10}],
                                  lr = args.lr)
    # optimizer = torch.optim.SGD([{'params': get_ignored_params(model), 'lr': args.lr},
    #                               {'params': get_non_ignored_params(model), 'lr': args.lr}],
    #                               lr = args.lr,
    #                               momentum = 0.9, weight_decay=0.01)
    print 'Ready to train network.'
    print 'First phase of training.'
    for epoch in range(num_epochs):
        for i, (images, labels, name) in enumerate(train_loader):
            images = Variable(images.cuda(gpu))
@@ -146,17 +145,17 @@
            optimizer.zero_grad()
            model.zero_grad()
            yaw, pitch, roll = model(images)
            pre_yaw, pre_pitch, pre_roll, angles = model(images)
            # Cross entropy loss
            loss_yaw = criterion(yaw, label_yaw)
            loss_pitch = criterion(pitch, label_pitch)
            loss_roll = criterion(roll, label_roll)
            loss_yaw = criterion(pre_yaw, label_yaw)
            loss_pitch = criterion(pre_pitch, label_pitch)
            loss_roll = criterion(pre_roll, label_roll)
            # MSE loss
            yaw_predicted = F.softmax(yaw)
            pitch_predicted = F.softmax(pitch)
            roll_predicted = F.softmax(roll)
            yaw_predicted = F.softmax(pre_yaw)
            pitch_predicted = F.softmax(pre_pitch)
            roll_predicted = F.softmax(pre_roll)
            yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1)
            pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1)
@@ -176,21 +175,77 @@
            torch.autograd.backward(loss_seq, grad_seq)
            optimizer.step()
            # print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
            #        %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0]))
            if (i+1) % 100 == 0:
                print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
                       %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0]))
                # if epoch == 0:
                #     torch.save(model.state_dict(),
                #     'output/snapshots/resnet50_lbatch_iter_'+ str(i+1) + '.pkl')
                #     'output/snapshots/hopenet50_epoch_'+ str(i+1) + '.pkl')
        # Save models at numbered epochs.
        if epoch % 1 == 0 and epoch < num_epochs - 1:
        if epoch % 1 == 0 and epoch < num_epochs:
            print 'Taking snapshot...'
            torch.save(model.state_dict(),
            'output/snapshots/resnet50_norm_30rot_epoch_'+ str(epoch+1) + '.pkl')
            'output/snapshots/hopenet50_epoch_'+ str(epoch+1) + '.pkl')
    print 'Second phase of training (finetuning layer).'
    for epoch in range(num_epochs_ft):
        for i, (images, labels, name) in enumerate(train_loader):
            images = Variable(images.cuda(gpu))
            label_yaw = Variable(labels[:,0].cuda(gpu))
            label_pitch = Variable(labels[:,1].cuda(gpu))
            label_roll = Variable(labels[:,2].cuda(gpu))
            label_angles = Variable(labels[:,:3].cuda(gpu))
            optimizer.zero_grad()
            model.zero_grad()
            pre_yaw, pre_pitch, pre_roll, angles = model(images)
            # Cross entropy loss
            loss_yaw = criterion(pre_yaw, label_yaw)
            loss_pitch = criterion(pre_pitch, label_pitch)
            loss_roll = criterion(pre_roll, label_roll)
            # MSE loss
            yaw_predicted = F.softmax(pre_yaw)
            pitch_predicted = F.softmax(pre_pitch)
            roll_predicted = F.softmax(pre_roll)
            yaw_predicted = torch.sum(yaw_predicted.data * idx_tensor, 1)
            pitch_predicted = torch.sum(pitch_predicted.data * idx_tensor, 1)
            roll_predicted = torch.sum(roll_predicted.data * idx_tensor, 1)
            loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw.float())
            loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch.float())
            loss_reg_roll = reg_criterion(roll_predicted, label_roll.float())
            # Total loss
            loss_yaw += alpha * loss_reg_yaw
            loss_pitch += alpha * loss_reg_pitch
            loss_roll += alpha * loss_reg_roll
            # Finetuning loss
            loss_angles = reg_criterion(angles[0], label_angles.float())
            loss_seq = [loss_yaw, loss_pitch, loss_roll, loss_angles]
            grad_seq = [torch.Tensor(1).cuda(gpu) for _ in range(len(loss_seq))]
            torch.autograd.backward(loss_seq, grad_seq)
            optimizer.step()
            if (i+1) % 100 == 0:
                print ('Epoch [%d/%d], Iter [%d/%d] Losses: pre-yaw %.4f, pre-pitch %.4f, pre-roll %.4f, finetuning %.4f'
                       %(epoch+1, num_epochs_ft, i+1, len(pose_dataset)//batch_size, loss_yaw.data[0], loss_pitch.data[0], loss_roll.data[0], loss_angles.data[0]))
                # if epoch == 0:
                #     torch.save(model.state_dict(),
                #     'output/snapshots/hopenet50_iter_'+ str(i+1) + '.pkl')
        # Save models at numbered epochs.
        if epoch % 1 == 0 and epoch < num_epochs_ft - 1:
            print 'Taking snapshot...'
            torch.save(model.state_dict(),
            'output/snapshots/hopenet50_epoch_'+ str(num_epochs+epoch+1) + '.pkl')
    # Save the final Trained Model
    torch.save(model.state_dict(), 'output/snapshots/resnet50_norm_30rot_epoch_' + str(epoch+1) + '.pkl')
    torch.save(model.state_dict(), 'output/snapshots/hopenet50_epoch_' + str(num_epochs+epoch+1) + '.pkl')