Spaces:
Running
on
Zero
Running
on
Zero
Faezeh Sarlakifar
commited on
Commit
·
326d9e6
1
Parent(s):
b99c772
Initial upload of AllerTrans app
Browse files- app.py +52 -0
- inference.py +46 -0
- requirements.txt +9 -0
app.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
from transformers import T5Tokenizer, T5EncoderModel
|
5 |
+
import esm
|
6 |
+
from inference import load_models, predict_ensemble
|
7 |
+
|
8 |
+
# Load trained models
|
9 |
+
model_protT5, model_cat = load_models()
|
10 |
+
|
11 |
+
# Load ProtT5 model
|
12 |
+
tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
|
13 |
+
model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
|
14 |
+
model_t5 = model_t5.eval()
|
15 |
+
|
16 |
+
# Load ESM model
|
17 |
+
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
|
18 |
+
batch_converter = alphabet.get_batch_converter()
|
19 |
+
esm_model.eval()
|
20 |
+
|
21 |
+
|
22 |
+
def extract_prott5_embedding(sequence):
|
23 |
+
sequence = sequence.replace(" ", "")
|
24 |
+
seq = " ".join(list(sequence))
|
25 |
+
ids = tokenizer_t5(seq, return_tensors="pt", padding=True)
|
26 |
+
with torch.no_grad():
|
27 |
+
embedding = model_t5(**ids).last_hidden_state
|
28 |
+
return torch.mean(embedding, dim=1)
|
29 |
+
|
30 |
+
|
31 |
+
def extract_esm_embedding(sequence):
|
32 |
+
batch_labels, batch_strs, batch_tokens = batch_converter([("protein1", sequence)])
|
33 |
+
with torch.no_grad():
|
34 |
+
results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
|
35 |
+
token_representations = results["representations"][33]
|
36 |
+
return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
|
37 |
+
|
38 |
+
|
39 |
+
def classify(sequence):
|
40 |
+
protT5_emb = extract_prott5_embedding(sequence)
|
41 |
+
esm_emb = extract_esm_embedding(sequence)
|
42 |
+
concat = torch.cat((esm_emb, protT5_emb), dim=1)
|
43 |
+
pred = predict_ensemble(protT5_emb, concat, model_protT5, model_cat)
|
44 |
+
return "Allergen" if pred.item() == 1 else "Non-Allergen"
|
45 |
+
|
46 |
+
|
47 |
+
demo = gr.Interface(fn=classify,
|
48 |
+
inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence..."),
|
49 |
+
outputs=gr.Label(label="Prediction"))
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
demo.launch()
|
inference.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class NeuralNet(nn.Module):
|
7 |
+
def __init__(self, input_size, hidden_size1, hidden_size2, hidden_size3, num_classes):
|
8 |
+
super(NeuralNet, self).__init__()
|
9 |
+
self.fc1 = nn.Linear(input_size, hidden_size1)
|
10 |
+
self.dropout = nn.Dropout(0.1)
|
11 |
+
self.fc2 = nn.Linear(hidden_size1, hidden_size2)
|
12 |
+
self.dropout = nn.Dropout(0.1)
|
13 |
+
self.fc3 = nn.Linear(hidden_size2, hidden_size3)
|
14 |
+
self.dropout = nn.Dropout(0.1)
|
15 |
+
self.fc4 = nn.Linear(hidden_size3, num_classes)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
out = F.relu(self.fc1(x))
|
19 |
+
out = F.relu(self.fc2(out))
|
20 |
+
out = F.relu(self.fc3(out))
|
21 |
+
out = self.fc4(out)
|
22 |
+
return out
|
23 |
+
|
24 |
+
|
25 |
+
def load_models():
|
26 |
+
model_protT5 = NeuralNet(1024, 200, 100, 50, 2)
|
27 |
+
model_protT5.load_state_dict(torch.load("checkpoints/model17-protT5.pt", map_location=torch.device("cpu")))
|
28 |
+
model_protT5.eval()
|
29 |
+
|
30 |
+
model_cat = NeuralNet(2304, 200, 100, 100, 2)
|
31 |
+
model_cat.load_state_dict(torch.load("checkpoints/model-esm-protT5-5.pt", map_location=torch.device("cpu")))
|
32 |
+
model_cat.eval()
|
33 |
+
|
34 |
+
return model_protT5, model_cat
|
35 |
+
|
36 |
+
|
37 |
+
def predict_ensemble(X_protT5, X_concat, model_protT5, model_cat, weight1=0.60, weight2=0.30):
|
38 |
+
with torch.no_grad():
|
39 |
+
outputs1 = model_cat(X_concat)
|
40 |
+
outputs2 = model_protT5(X_protT5)
|
41 |
+
ensemble_outputs = weight1 * outputs1 + weight2 * outputs2
|
42 |
+
_, predicted = torch.max(ensemble_outputs.data, 1)
|
43 |
+
return predicted
|
44 |
+
|
45 |
+
|
46 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
gradio
|
3 |
+
transformers
|
4 |
+
esm
|
5 |
+
fair-esm # if esm isn't installed via pip
|
6 |
+
sentencepiece
|
7 |
+
h5py
|
8 |
+
git+https://github.com/facebookresearch/esm.git
|
9 |
+
git+https://github.com/agemagician/ProtTrans.git
|