|
import torch
|
|
from torch import nn
|
|
|
|
import vietocr.model.backbone.vgg as vgg
|
|
from vietocr.model.backbone.resnet import Resnet50
|
|
|
|
class CNN(nn.Module):
|
|
def __init__(self, backbone, **kwargs):
|
|
super(CNN, self).__init__()
|
|
|
|
if backbone == 'vgg11_bn':
|
|
self.model = vgg.vgg11_bn(**kwargs)
|
|
elif backbone == 'vgg19_bn':
|
|
self.model = vgg.vgg19_bn(**kwargs)
|
|
elif backbone == 'resnet50':
|
|
self.model = Resnet50(**kwargs)
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
def freeze(self):
|
|
for name, param in self.model.features.named_parameters():
|
|
if name != 'last_conv_1x1':
|
|
param.requires_grad = False
|
|
|
|
def unfreeze(self):
|
|
for param in self.model.features.parameters():
|
|
param.requires_grad = True
|
|
|