AllerTrans / inference.py
Faezeh Sarlakifar
Update model code for Hugging Face ZeroGPU compatibility
1e22daf
import torch
import torch.nn as nn
import torch.nn.functional as F
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size1, hidden_size2, hidden_size3, num_classes):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size1)
self.dropout = nn.Dropout(0.1)
self.fc2 = nn.Linear(hidden_size1, hidden_size2)
self.dropout = nn.Dropout(0.1)
self.fc3 = nn.Linear(hidden_size2, hidden_size3)
self.dropout = nn.Dropout(0.1)
self.fc4 = nn.Linear(hidden_size3, num_classes)
def forward(self, x):
out = F.relu(self.fc1(x))
out = F.relu(self.fc2(out))
out = F.relu(self.fc3(out))
out = self.fc4(out)
return out
def load_models():
model_protT5 = NeuralNet(1024, 200, 100, 50, 2)
model_protT5.load_state_dict(torch.load("checkpoints/model17-protT5.pt", map_location=torch.device("cpu")))
model_protT5.eval().to("cuda")
model_cat = NeuralNet(2304, 200, 100, 100, 2)
model_cat.load_state_dict(torch.load("checkpoints/model-esm-protT5-5.pt", map_location=torch.device("cpu")))
model_cat.eval().to("cuda")
return model_protT5, model_cat
def predict_ensemble(X_protT5, X_concat, model_protT5, model_cat, weight1=0.60, weight2=0.30):
device = next(model_protT5.parameters()).device
X_protT5 = X_protT5.to(device)
X_concat = X_concat.to(device)
with torch.no_grad():
outputs1 = model_cat(X_concat)
outputs2 = model_protT5(X_protT5)
ensemble_outputs = weight1 * outputs1 + weight2 * outputs2
_, predicted = torch.max(ensemble_outputs.data, 1)
return predicted