Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
import timm | |
class Model(nn.Module): | |
def __init__(self, model_name, pretrained=True): | |
super(Model, self).__init__() | |
# Load the pretrained ConvNeXt model (you can choose the specific variant you want) | |
self.model = timm.create_model(model_name, pretrained=pretrained) | |
self.model.head.fc = nn.Linear(self.model.head.fc.in_features, 1) # change the last linear for classification | |
def forward(self, x): | |
return self.model(x) |