import torch import torch.nn as nn import torchvision.datasets as dsets from torch.autograd import Variable # CNN Model (2 conv layer) class Simple_CNN(nn.Module): def __init__(self): super(Simple_CNN, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=0), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2)) self.layer2 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, padding=0), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2)) self.layer3 = nn.Sequential( nn.Conv2d(128, 256, kernel_size=3, padding=0), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2)) self.layer4 = nn.Sequential( nn.Conv2d(256, 512, kernel_size=3, padding=0), nn.BatchNorm2d(512), nn.ReLU(), nn.MaxPool2d(2)) self.fc = nn.Linear(17*17*512, 3) def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = out.view(out.size(0), -1) out = self.fc(out) return out