From f111cb002b9c6065fdf6bb274ce5857a9e875e8c Mon Sep 17 00:00:00 2001
From: chenshijun <csj_sky@126.com>
Date: 星期三, 05 六月 2019 15:38:49 +0800
Subject: [PATCH] face rectangle

---
 code/hopenet.py |  161 ++---------------------------------------------------
 1 files changed, 8 insertions(+), 153 deletions(-)

diff --git a/code/hopenet.py b/code/hopenet.py
index cd810e3..c9e0b74 100644
--- a/code/hopenet.py
+++ b/code/hopenet.py
@@ -4,44 +4,10 @@
 import math
 import torch.nn.functional as F
 
-# CNN Model (2 conv layer)
-class Simple_CNN(nn.Module):
-    def __init__(self):
-        super(Simple_CNN, self).__init__()
-        self.layer1 = nn.Sequential(
-            nn.Conv2d(3, 64, kernel_size=3, padding=0),
-            nn.BatchNorm2d(64),
-            nn.ReLU(),
-            nn.MaxPool2d(2))
-        self.layer2 = nn.Sequential(
-            nn.Conv2d(64, 128, kernel_size=3, padding=0),
-            nn.BatchNorm2d(128),
-            nn.ReLU(),
-            nn.MaxPool2d(2))
-        self.layer3 = nn.Sequential(
-            nn.Conv2d(128, 256, kernel_size=3, padding=0),
-            nn.BatchNorm2d(256),
-            nn.ReLU(),
-            nn.MaxPool2d(2))
-        self.layer4 = nn.Sequential(
-            nn.Conv2d(256, 512, kernel_size=3, padding=0),
-            nn.BatchNorm2d(512),
-            nn.ReLU(),
-            nn.MaxPool2d(2))
-        self.fc = nn.Linear(17*17*512, 3)
-
-    def forward(self, x):
-        out = self.layer1(x)
-        out = self.layer2(out)
-        out = self.layer3(out)
-        out = self.layer4(out)
-        out = out.view(out.size(0), -1)
-        out = self.fc(out)
-        return out
-
 class Hopenet(nn.Module):
-    # This is just Hopenet with 3 output layers for yaw, pitch and roll.
-    def __init__(self, block, layers, num_bins, iter_ref):
+    # Hopenet with 3 output layers for yaw, pitch and roll
+    # Predicts Euler angles by binning and regression with the expected value
+    def __init__(self, block, layers, num_bins):
         self.inplanes = 64
         super(Hopenet, self).__init__()
         self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
@@ -58,12 +24,8 @@
         self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)
         self.fc_roll = nn.Linear(512 * block.expansion, num_bins)
 
-        self.softmax = nn.Softmax()
+        # Vestigial layer from previous experiments
         self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
-
-        self.idx_tensor = Variable(torch.FloatTensor(range(66))).cuda()
-
-        self.iter_ref = iter_ref
 
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
@@ -107,27 +69,10 @@
         pre_pitch = self.fc_pitch(x)
         pre_roll = self.fc_roll(x)
 
-        yaw = self.softmax(pre_yaw)
-        yaw = Variable(torch.sum(yaw.data * self.idx_tensor.data, 1), requires_grad=True)
-        pitch = self.softmax(pre_pitch)
-        pitch = Variable(torch.sum(pitch.data * self.idx_tensor.data, 1), requires_grad=True)
-        roll = self.softmax(pre_roll)
-        roll = Variable(torch.sum(roll.data * self.idx_tensor.data, 1), requires_grad=True)
-        yaw = yaw.view(yaw.size(0), 1)
-        pitch = pitch.view(pitch.size(0), 1)
-        roll = roll.view(roll.size(0), 1)
-        angles = []
-        preangles = torch.cat([yaw, pitch, roll], 1)
-        angles.append(preangles)
-
-        # angles predicts the residual
-        for idx in xrange(self.iter_ref):
-            angles.append(self.fc_finetune(torch.cat((angles[idx], x), 1)))
-
-        return pre_yaw, pre_pitch, pre_roll, angles
+        return pre_yaw, pre_pitch, pre_roll
 
 class ResNet(nn.Module):
-
+    # ResNet for regression of 3 Euler angles.
     def __init__(self, block, layers, num_classes=1000):
         self.inplanes = 64
         super(ResNet, self).__init__()
@@ -182,11 +127,11 @@
         x = self.avgpool(x)
         x = x.view(x.size(0), -1)
         x = self.fc_angles(x)
-
         return x
 
 class AlexNet(nn.Module):
-
+    # AlexNet laid out as a Hopenet - classify Euler angles in bins and
+    # regress the expected value.
     def __init__(self, num_bins):
         super(AlexNet, self).__init__()
         self.features = nn.Sequential(
@@ -224,93 +169,3 @@
         pitch = self.fc_pitch(x)
         roll = self.fc_roll(x)
         return yaw, pitch, roll
-
-class Hopenet_new(nn.Module):
-    # This is just Hopenet with 3 output layers for yaw, pitch and roll.
-    def __init__(self, block, layers, num_bins):
-        self.inplanes = 64
-        super(Hopenet_new, self).__init__()
-        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
-                               bias=False)
-        self.bn1 = nn.BatchNorm2d(64)
-        self.relu = nn.ReLU(inplace=True)
-        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
-        self.layer1 = self._make_layer(block, 64, layers[0])
-        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
-        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
-        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
-        self.avgpool = nn.AvgPool2d(7)
-        self.fc_yaw = nn.Linear(512 * block.expansion, num_bins)
-        self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)
-        self.fc_roll = nn.Linear(512 * block.expansion, num_bins)
-
-        self.softmax = nn.Softmax()
-        self.fc_finetune_new = nn.Linear(512 * block.expansion + 256 * block.expansion + 3, 3)
-        self.conv1x1 = nn.Conv2d(1024, 64, kernel_size = 1, stride = 1, bias=False)
-        self.maxpool_interm = nn.MaxPool2d(kernel_size=5, stride=3, padding=1)
-
-        self.idx_tensor = Variable(torch.FloatTensor(range(66))).cuda()
-
-        for m in self.modules():
-            if isinstance(m, nn.Conv2d):
-                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
-                m.weight.data.normal_(0, math.sqrt(2. / n))
-            elif isinstance(m, nn.BatchNorm2d):
-                m.weight.data.fill_(1)
-                m.bias.data.zero_()
-
-    def _make_layer(self, block, planes, blocks, stride=1):
-        downsample = None
-        if stride != 1 or self.inplanes != planes * block.expansion:
-            downsample = nn.Sequential(
-                nn.Conv2d(self.inplanes, planes * block.expansion,
-                          kernel_size=1, stride=stride, bias=False),
-                nn.BatchNorm2d(planes * block.expansion),
-            )
-
-        layers = []
-        layers.append(block(self.inplanes, planes, stride, downsample))
-        self.inplanes = planes * block.expansion
-        for i in range(1, blocks):
-            layers.append(block(self.inplanes, planes))
-
-        return nn.Sequential(*layers)
-
-    def forward(self, x):
-        x = self.conv1(x)
-        x = self.bn1(x)
-        x = self.relu(x)
-        x = self.maxpool(x)
-
-        x = self.layer1(x)
-        x = self.layer2(x)
-        x = self.layer3(x)
-        x_interm = self.conv1x1(x)
-        x_interm = self.relu(x_interm)
-        x_interm = self.maxpool_interm(x_interm)
-        x_interm = x_interm.view(x_interm.size(0), -1)
-
-        x = self.layer4(x)
-
-        x = self.avgpool(x)
-        x = x.view(x.size(0), -1)
-        pre_yaw = self.fc_yaw(x)
-        pre_pitch = self.fc_pitch(x)
-        pre_roll = self.fc_roll(x)
-
-        yaw = self.softmax(pre_yaw)
-        yaw = Variable(torch.sum(yaw.data * self.idx_tensor.data, 1), requires_grad=True) * 3 - 99
-        pitch = self.softmax(pre_pitch)
-        pitch = Variable(torch.sum(pitch.data * self.idx_tensor.data, 1), requires_grad=True) * 3 - 99
-        roll = self.softmax(pre_roll)
-        roll = Variable(torch.sum(roll.data * self.idx_tensor.data, 1), requires_grad=True) * 3 - 99
-        yaw = yaw.view(yaw.size(0), 1)
-        pitch = pitch.view(pitch.size(0), 1)
-        roll = roll.view(roll.size(0), 1)
-        preangles = torch.cat([yaw, pitch, roll], 1)
-
-        # angles predicts the residual
-        residuals = self.fc_finetune_new(torch.cat((preangles, x_interm, x), 1))
-        final_angles = preangles + residuals
-
-        return pre_yaw, pre_pitch, pre_roll, preangles, final_angles

--
Gitblit v1.8.0