From b730bbd6ea565d7689964661c53a6074654b5d3b Mon Sep 17 00:00:00 2001
From: natanielruiz <nataniel777@hotmail.com>
Date: 星期一, 30 十月 2017 05:30:52 +0800
Subject: [PATCH] next

---
 code/hopenet.py |   77 ++++++++++++++++++++------------------
 1 files changed, 41 insertions(+), 36 deletions(-)

diff --git a/code/hopenet.py b/code/hopenet.py
index 63a24cd..129ff63 100644
--- a/code/hopenet.py
+++ b/code/hopenet.py
@@ -4,41 +4,6 @@
 import math
 import torch.nn.functional as F
 
-# CNN Model (2 conv layer)
-class Simple_CNN(nn.Module):
-    def __init__(self):
-        super(Simple_CNN, self).__init__()
-        self.layer1 = nn.Sequential(
-            nn.Conv2d(3, 64, kernel_size=3, padding=0),
-            nn.BatchNorm2d(64),
-            nn.ReLU(),
-            nn.MaxPool2d(2))
-        self.layer2 = nn.Sequential(
-            nn.Conv2d(64, 128, kernel_size=3, padding=0),
-            nn.BatchNorm2d(128),
-            nn.ReLU(),
-            nn.MaxPool2d(2))
-        self.layer3 = nn.Sequential(
-            nn.Conv2d(128, 256, kernel_size=3, padding=0),
-            nn.BatchNorm2d(256),
-            nn.ReLU(),
-            nn.MaxPool2d(2))
-        self.layer4 = nn.Sequential(
-            nn.Conv2d(256, 512, kernel_size=3, padding=0),
-            nn.BatchNorm2d(512),
-            nn.ReLU(),
-            nn.MaxPool2d(2))
-        self.fc = nn.Linear(17*17*512, 3)
-
-    def forward(self, x):
-        out = self.layer1(x)
-        out = self.layer2(out)
-        out = self.layer3(out)
-        out = self.layer4(out)
-        out = out.view(out.size(0), -1)
-        out = self.fc(out)
-        return out
-
 class Hopenet(nn.Module):
     # This is just Hopenet with 3 output layers for yaw, pitch and roll.
     def __init__(self, block, layers, num_bins, iter_ref):
@@ -122,7 +87,7 @@
 
         # angles predicts the residual
         for idx in xrange(self.iter_ref):
-            angles.append(self.fc_finetune(torch.cat((preangles, x), 1)))
+            angles.append(self.fc_finetune(torch.cat((angles[idx], x), 1)))
 
         return pre_yaw, pre_pitch, pre_roll, angles
 
@@ -184,3 +149,43 @@
         x = self.fc_angles(x)
 
         return x
+
+class AlexNet(nn.Module):
+
+    def __init__(self, num_bins):
+        super(AlexNet, self).__init__()
+        self.features = nn.Sequential(
+            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+            nn.Conv2d(64, 192, kernel_size=5, padding=2),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+            nn.Conv2d(192, 384, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(384, 256, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(256, 256, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+        )
+        self.classifier = nn.Sequential(
+            nn.Dropout(),
+            nn.Linear(256 * 6 * 6, 4096),
+            nn.ReLU(inplace=True),
+            nn.Dropout(),
+            nn.Linear(4096, 4096),
+            nn.ReLU(inplace=True),
+        )
+        self.fc_yaw = nn.Linear(4096, num_bins)
+        self.fc_pitch = nn.Linear(4096, num_bins)
+        self.fc_roll = nn.Linear(4096, num_bins)
+
+    def forward(self, x):
+        x = self.features(x)
+        x = x.view(x.size(0), 256 * 6 * 6)
+        x = self.classifier(x)
+        yaw = self.fc_yaw(x)
+        pitch = self.fc_pitch(x)
+        roll = self.fc_roll(x)
+        return yaw, pitch, roll

--
Gitblit v1.8.0