Faezeh Sarlakifar commited on
Commit
326d9e6
·
1 Parent(s): b99c772

Initial upload of AllerTrans app

Browse files
Files changed (3) hide show
  1. app.py +52 -0
  2. inference.py +46 -0
  3. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+ from transformers import T5Tokenizer, T5EncoderModel
5
+ import esm
6
+ from inference import load_models, predict_ensemble
7
+
8
+ # Load trained models
9
+ model_protT5, model_cat = load_models()
10
+
11
+ # Load ProtT5 model
12
+ tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
13
+ model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
14
+ model_t5 = model_t5.eval()
15
+
16
+ # Load ESM model
17
+ esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
18
+ batch_converter = alphabet.get_batch_converter()
19
+ esm_model.eval()
20
+
21
+
22
+ def extract_prott5_embedding(sequence):
23
+ sequence = sequence.replace(" ", "")
24
+ seq = " ".join(list(sequence))
25
+ ids = tokenizer_t5(seq, return_tensors="pt", padding=True)
26
+ with torch.no_grad():
27
+ embedding = model_t5(**ids).last_hidden_state
28
+ return torch.mean(embedding, dim=1)
29
+
30
+
31
+ def extract_esm_embedding(sequence):
32
+ batch_labels, batch_strs, batch_tokens = batch_converter([("protein1", sequence)])
33
+ with torch.no_grad():
34
+ results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
35
+ token_representations = results["representations"][33]
36
+ return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
37
+
38
+
39
+ def classify(sequence):
40
+ protT5_emb = extract_prott5_embedding(sequence)
41
+ esm_emb = extract_esm_embedding(sequence)
42
+ concat = torch.cat((esm_emb, protT5_emb), dim=1)
43
+ pred = predict_ensemble(protT5_emb, concat, model_protT5, model_cat)
44
+ return "Allergen" if pred.item() == 1 else "Non-Allergen"
45
+
46
+
47
+ demo = gr.Interface(fn=classify,
48
+ inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence..."),
49
+ outputs=gr.Label(label="Prediction"))
50
+
51
+ if __name__ == "__main__":
52
+ demo.launch()
inference.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class NeuralNet(nn.Module):
7
+ def __init__(self, input_size, hidden_size1, hidden_size2, hidden_size3, num_classes):
8
+ super(NeuralNet, self).__init__()
9
+ self.fc1 = nn.Linear(input_size, hidden_size1)
10
+ self.dropout = nn.Dropout(0.1)
11
+ self.fc2 = nn.Linear(hidden_size1, hidden_size2)
12
+ self.dropout = nn.Dropout(0.1)
13
+ self.fc3 = nn.Linear(hidden_size2, hidden_size3)
14
+ self.dropout = nn.Dropout(0.1)
15
+ self.fc4 = nn.Linear(hidden_size3, num_classes)
16
+
17
+ def forward(self, x):
18
+ out = F.relu(self.fc1(x))
19
+ out = F.relu(self.fc2(out))
20
+ out = F.relu(self.fc3(out))
21
+ out = self.fc4(out)
22
+ return out
23
+
24
+
25
+ def load_models():
26
+ model_protT5 = NeuralNet(1024, 200, 100, 50, 2)
27
+ model_protT5.load_state_dict(torch.load("checkpoints/model17-protT5.pt", map_location=torch.device("cpu")))
28
+ model_protT5.eval()
29
+
30
+ model_cat = NeuralNet(2304, 200, 100, 100, 2)
31
+ model_cat.load_state_dict(torch.load("checkpoints/model-esm-protT5-5.pt", map_location=torch.device("cpu")))
32
+ model_cat.eval()
33
+
34
+ return model_protT5, model_cat
35
+
36
+
37
+ def predict_ensemble(X_protT5, X_concat, model_protT5, model_cat, weight1=0.60, weight2=0.30):
38
+ with torch.no_grad():
39
+ outputs1 = model_cat(X_concat)
40
+ outputs2 = model_protT5(X_protT5)
41
+ ensemble_outputs = weight1 * outputs1 + weight2 * outputs2
42
+ _, predicted = torch.max(ensemble_outputs.data, 1)
43
+ return predicted
44
+
45
+
46
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ transformers
4
+ esm
5
+ fair-esm # if esm isn't installed via pip
6
+ sentencepiece
7
+ h5py
8
+ git+https://github.com/facebookresearch/esm.git
9
+ git+https://github.com/agemagician/ProtTrans.git