Skip to content
Snippets Groups Projects
Commit 5fd98cb8 authored by Mathieu Léonardon's avatar Mathieu Léonardon
Browse files

Change resnet definition.

parent 6e2faa51
No related branches found
No related tags found
No related merge requests found
......@@ -10,14 +10,13 @@ import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from models.resnet import ResNet18
# from quick_test import ResNet
import argparse
import wandb
import torchinfo
mixup = v2.MixUp(num_classes=10)
def mixup_collate_fn(batch):
return mixup(*default_collate(batch))
......@@ -52,7 +51,7 @@ def main(args):
test_data = datasets.CIFAR10(data_path, train=False, download=True, transform=preproc)
model = ResNet18()
# declare ResNet-18 model
epochs = 150
torchinfo.summary(model, input_size=(32, 3, 32, 32))
......@@ -85,7 +84,6 @@ def main(args):
scheduler.step()
# test ResNet-18 model
model.eval()
correct, total = 0, 0
test_loss = 0.0
......@@ -127,6 +125,5 @@ if __name__ == "__main__":
parser.add_argument("-wd", "--weight_decay", type=int, default=5e-4, help="Weight decay")
parser.add_argument("--mixup", action="store_true", help="Use MixUp data augmentation")
args = parser.parse_args()
main(args)
\ No newline at end of file
'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
class ResNetBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, groups=4, stride=stride, padding=1, bias=False)
super(ResNetBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, groups=4, stride=1, padding=1, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
......@@ -35,63 +27,27 @@ class BasicBlock(nn.Module):
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=100,fmaps_repeat=16):
def __init__(self, blocks, num_classes=100,fmaps_repeat=16):
super(ResNet, self).__init__()
self.in_planes = fmaps_repeat
self.fmaps_repeat = fmaps_repeat
self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_planes)
self.layer1 = self._make_layer(block, self.fmaps_repeat, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 2*self.fmaps_repeat, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 4*self.fmaps_repeat, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 8*self.fmaps_repeat, num_blocks[3], stride=2)
self.linear = nn.Linear((8*self.fmaps_repeat)*block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
blocks_array = []
previous_fmaps = blocks[0][1]
for (num_blocks, fmaps, stride) in blocks:
for i in range(num_blocks):
blocks_array.append(ResNetBlock(previous_fmaps, fmaps, stride if i == 0 else 1))
previous_fmaps = fmaps
self.blocks = nn.ModuleList(blocks_array)
self.linear = nn.Linear(blocks[-1][1], num_classes)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
for block in self.blocks:
out = block(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
......@@ -99,16 +55,16 @@ class ResNet(nn.Module):
def ResNet18():
return ResNet(BasicBlock, [2,2,2,2], num_classes=10,fmaps_repeat=16)
return ResNet([(2,64,1), (2,128,2), (2,256,2), (2,512,2)], num_classes=10,fmaps_repeat=64)
def ResNet34():
return ResNet(BasicBlock, [3,4,6,3])
# def ResNet34():
# return ResNet(BasicBlock, [3,4,6,3])
def ResNet50():
return ResNet(Bottleneck, [3,4,6,3])
# def ResNet50():
# return ResNet(Bottleneck, [3,4,6,3])
def ResNet101():
return ResNet(Bottleneck, [3,4,23,3])
# def ResNet101():
# return ResNet(Bottleneck, [3,4,23,3])
def ResNet152():
return ResNet(Bottleneck, [3,8,36,3])
\ No newline at end of file
# def ResNet152():
# return ResNet(Bottleneck, [3,8,36,3])
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment