natanielruiz
2017-09-08 0b8e19c1cc8ad03805d4ca68f32df6e4806a36e8
code/hopenet.py
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
import torchvision.datasets as dsets
from torch.autograd import Variable
import math
import torch.nn.functional as F
# CNN Model (2 conv layer)
class Simple_CNN(nn.Module):
@@ -58,6 +58,11 @@
        self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)
        self.fc_roll = nn.Linear(512 * block.expansion, num_bins)
        self.softmax = nn.Softmax()
        self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
        self.idx_tensor = Variable(torch.FloatTensor(range(66))).cuda()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
@@ -83,6 +88,12 @@
        return nn.Sequential(*layers)
    def get_expectation(angle):
        angle_pred = F.softmax(angle)
        angle_pred = torch.sum(angle_pred.data * self.idx_tensor, 1)
        return angle_pred
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
@@ -96,11 +107,26 @@
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        yaw = self.fc_yaw(x)
        pitch = self.fc_pitch(x)
        roll = self.fc_roll(x)
        pre_yaw = self.fc_yaw(x)
        pre_pitch = self.fc_pitch(x)
        pre_roll = self.fc_roll(x)
        return yaw, pitch, roll
        yaw = self.softmax(pre_yaw)
        yaw = Variable(torch.sum(yaw.data * self.idx_tensor.data, 1), requires_grad=True)
        pitch = self.softmax(pre_pitch)
        pitch = Variable(torch.sum(pitch.data * self.idx_tensor.data, 1), requires_grad=True)
        roll = self.softmax(pre_roll)
        roll = Variable(torch.sum(roll.data * self.idx_tensor.data, 1), requires_grad=True)
        yaw = yaw.view(yaw.size(0), 1)
        pitch = pitch.view(pitch.size(0), 1)
        roll = roll.view(roll.size(0), 1)
        angles = []
        angles.append(torch.cat([yaw, pitch, roll], 1))
        for idx in xrange(1):
            angles.append(self.fc_finetune(torch.cat((angles[-1], x), 1)))
        return pre_yaw, pre_pitch, pre_roll, angles
class Hopenet_shape(nn.Module):
    # This is just Hopenet with 3 output layers for yaw, pitch and roll.