Skip to content
Snippets Groups Projects
Commit b6536178 authored by Mathieu Léonardon's avatar Mathieu Léonardon
Browse files

Change resnet definition.

parent 5fd98cb8
No related branches found
No related tags found
No related merge requests found
......@@ -9,8 +9,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from models.resnet import ResNet18
# from quick_test import ResNet
from models.resnet import ResNet
import argparse
import wandb
import torchinfo
......@@ -35,6 +34,9 @@ def main(args):
momentum = wandb.config['momentum']
weight_decay = wandb.config['weight_decay']
mixup = wandb.config['mixup']
depth = wandb.config['depth']
width = wandb.config['width']
groups = wandb.config['groups']
preproc = v2.Compose([
v2.PILToTensor(),
......@@ -50,11 +52,22 @@ def main(args):
train_data = datasets.CIFAR10(data_path, train=True, download=True, transform=preproc)
test_data = datasets.CIFAR10(data_path, train=False, download=True, transform=preproc)
model = ResNet18()
if depth == 18:
model = ResNet([(width, 1, [groups, groups]), (width, 1, [groups, groups]), (width*2, 2, [groups, groups]), (width*2, 1, [groups, groups]), (width*4, 2, [groups, groups]), (width*4, 1, [groups, groups]), (width*8, 2, [groups, groups]), (width*8, 1, [groups, groups])])
elif depth == 14:
model = ResNet([(width, 1, [groups, groups]), (width, 1, [groups, groups]), (width*2, 2, [groups, groups]), (width*2, 1, [groups, groups]), (width*4, 2, [groups, groups]), (width*4, 1, [groups, groups])])
elif depth == 8:
model = ResNet([(width, 1, [groups, groups]), (width*2, 2, [groups, groups]), (width*4, 2, [groups, groups])])
else:
raise ValueError('Invalid depth')
epochs = 150
torchinfo.summary(model, input_size=(32, 3, 32, 32))
summary = torchinfo.summary(model, input_size=(32, 3, 32, 32))
run.config['total_params'] = summary.total_params
run.config['mult_add'] = summary.total_mult_adds
collate_fn = conf_collate_fn(mixup)
......@@ -110,8 +123,7 @@ def main(args):
print(f'Epoch: {epoch}, Test Accuracy: {correct / total}')
model.train()
# save ResNet-18 model
torch.save(model.state_dict(), 'resnet18.pth')
torch.save(model.state_dict(), run.id + '.pt')
if __name__ == "__main__":
parser = argparse.ArgumentParser(
......@@ -124,6 +136,9 @@ if __name__ == "__main__":
parser.add_argument("-m", "--momentum", type=int, default=0.9, help="Momentum")
parser.add_argument("-wd", "--weight_decay", type=int, default=5e-4, help="Weight decay")
parser.add_argument("--mixup", action="store_true", help="Use MixUp data augmentation")
parser.add_argument("--depth", type=int, default=18, help="ResNet depth")
parser.add_argument("--width", type=int, default=64, help="ResNet width")
parser.add_argument("--groups", type=int, default=1, help="ResNet groups")
args = parser.parse_args()
main(args)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResNetBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
def __init__(self, ifm, ofm, stride=1, groups=[1,1]):
super(ResNetBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv1 = nn.Conv2d(ifm, ofm, kernel_size=3, stride=stride, padding=1, groups = groups[0], bias=False)
self.bn1 = nn.BatchNorm2d(ofm)
self.conv2 = nn.Conv2d(ofm, ofm, kernel_size=3, stride=1, padding=1, groups = groups[0], bias=False)
self.bn2 = nn.BatchNorm2d(ofm)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
if stride != 1 or ifm != ofm:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
nn.Conv2d(ifm, ofm, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(ofm)
)
def forward(self, x):
......@@ -28,43 +27,25 @@ class ResNetBlock(nn.Module):
return out
class ResNet(nn.Module):
def __init__(self, blocks, num_classes=100,fmaps_repeat=16):
def __init__(self, blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = fmaps_repeat
self.fmaps_repeat = fmaps_repeat
self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_planes)
self.ifm = blocks[0][0]
self.conv1 = nn.Conv2d(3, self.ifm, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.ifm)
blocks_array = []
previous_fmaps = blocks[0][1]
for (num_blocks, fmaps, stride) in blocks:
for i in range(num_blocks):
blocks_array.append(ResNetBlock(previous_fmaps, fmaps, stride if i == 0 else 1))
previous_fmaps = fmaps
previous_fmaps = blocks[0][0]
for (fmaps, stride, groups) in blocks:
blocks_array.append(ResNetBlock(previous_fmaps, fmaps, stride, groups))
previous_fmaps = fmaps
self.blocks = nn.ModuleList(blocks_array)
self.linear = nn.Linear(blocks[-1][1], num_classes)
self.linear = nn.Linear(blocks[-1][0], num_classes)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
for block in self.blocks:
out = block(out)
out = F.avg_pool2d(out, 4)
out = F.avg_pool2d(out, out.shape[2])
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def ResNet18():
return ResNet([(2,64,1), (2,128,2), (2,256,2), (2,512,2)], num_classes=10,fmaps_repeat=64)
# def ResNet34():
# return ResNet(BasicBlock, [3,4,6,3])
# def ResNet50():
# return ResNet(Bottleneck, [3,4,6,3])
# def ResNet101():
# return ResNet(Bottleneck, [3,4,23,3])
# def ResNet152():
# return ResNet(Bottleneck, [3,8,36,3])
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment