Faezeh Sarlakifar commited on
Commit
1e22daf
·
1 Parent(s): 42f8b52

Update model code for Hugging Face ZeroGPU compatibility

Browse files
Files changed (3) hide show
  1. app.py +12 -2
  2. inference.py +6 -2
  3. requirements.txt +2 -0
app.py CHANGED
@@ -5,6 +5,7 @@ from transformers import T5Tokenizer, T5EncoderModel
5
  import esm
6
  from inference import load_models, predict_ensemble
7
  from transformers import AutoTokenizer, AutoModel
 
8
 
9
  # Load trained models
10
  model_protT5, model_cat = load_models()
@@ -12,12 +13,12 @@ model_protT5, model_cat = load_models()
12
  # Load ProtT5 model
13
  tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
14
  model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
15
- model_t5 = model_t5.eval()
16
 
17
  # Load the tokenizer and model
18
  model_name = "facebook/esm2_t33_650M_UR50D"
19
  tokenizer_esm = AutoTokenizer.from_pretrained(model_name)
20
- esm_model = AutoModel.from_pretrained(model_name)
21
 
22
  def extract_prott5_embedding(sequence):
23
  sequence = sequence.replace(" ", "")
@@ -42,6 +43,15 @@ def extract_esm_embedding(sequence):
42
  return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
43
 
44
 
 
 
 
 
 
 
 
 
 
45
  def classify(sequence):
46
  protT5_emb = extract_prott5_embedding(sequence)
47
  esm_emb = extract_esm_embedding(sequence)
 
5
  import esm
6
  from inference import load_models, predict_ensemble
7
  from transformers import AutoTokenizer, AutoModel
8
+ import spaces
9
 
10
  # Load trained models
11
  model_protT5, model_cat = load_models()
 
13
  # Load ProtT5 model
14
  tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
15
  model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
16
+ model_t5 = model_t5.eval().to("cuda")
17
 
18
  # Load the tokenizer and model
19
  model_name = "facebook/esm2_t33_650M_UR50D"
20
  tokenizer_esm = AutoTokenizer.from_pretrained(model_name)
21
+ esm_model = AutoModel.from_pretrained(model_name).to("cuda")
22
 
23
  def extract_prott5_embedding(sequence):
24
  sequence = sequence.replace(" ", "")
 
43
  return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
44
 
45
 
46
+ # def classify(sequence):
47
+ # protT5_emb = extract_prott5_embedding(sequence)
48
+ # esm_emb = extract_esm_embedding(sequence)
49
+ # concat = torch.cat((esm_emb, protT5_emb), dim=1)
50
+ # pred = predict_ensemble(protT5_emb, concat, model_protT5, model_cat)
51
+ # return "Potential Allergen" if pred.item() == 1 else "Non-Allergen"
52
+
53
+
54
+ @spaces.GPU(duration=120)
55
  def classify(sequence):
56
  protT5_emb = extract_prott5_embedding(sequence)
57
  esm_emb = extract_esm_embedding(sequence)
inference.py CHANGED
@@ -25,16 +25,20 @@ class NeuralNet(nn.Module):
25
  def load_models():
26
  model_protT5 = NeuralNet(1024, 200, 100, 50, 2)
27
  model_protT5.load_state_dict(torch.load("checkpoints/model17-protT5.pt", map_location=torch.device("cpu")))
28
- model_protT5.eval()
29
 
30
  model_cat = NeuralNet(2304, 200, 100, 100, 2)
31
  model_cat.load_state_dict(torch.load("checkpoints/model-esm-protT5-5.pt", map_location=torch.device("cpu")))
32
- model_cat.eval()
33
 
34
  return model_protT5, model_cat
35
 
36
 
37
  def predict_ensemble(X_protT5, X_concat, model_protT5, model_cat, weight1=0.60, weight2=0.30):
 
 
 
 
38
  with torch.no_grad():
39
  outputs1 = model_cat(X_concat)
40
  outputs2 = model_protT5(X_protT5)
 
25
  def load_models():
26
  model_protT5 = NeuralNet(1024, 200, 100, 50, 2)
27
  model_protT5.load_state_dict(torch.load("checkpoints/model17-protT5.pt", map_location=torch.device("cpu")))
28
+ model_protT5.eval().to("cuda")
29
 
30
  model_cat = NeuralNet(2304, 200, 100, 100, 2)
31
  model_cat.load_state_dict(torch.load("checkpoints/model-esm-protT5-5.pt", map_location=torch.device("cpu")))
32
+ model_cat.eval().to("cuda")
33
 
34
  return model_protT5, model_cat
35
 
36
 
37
  def predict_ensemble(X_protT5, X_concat, model_protT5, model_cat, weight1=0.60, weight2=0.30):
38
+ device = next(model_protT5.parameters()).device
39
+ X_protT5 = X_protT5.to(device)
40
+ X_concat = X_concat.to(device)
41
+
42
  with torch.no_grad():
43
  outputs1 = model_cat(X_concat)
44
  outputs2 = model_protT5(X_protT5)
requirements.txt CHANGED
@@ -5,4 +5,6 @@ esm
5
  fair-esm # if esm isn't installed via pip
6
  sentencepiece
7
  h5py
 
8
  git+https://github.com/facebookresearch/esm.git
 
 
5
  fair-esm # if esm isn't installed via pip
6
  sentencepiece
7
  h5py
8
+ spaces
9
  git+https://github.com/facebookresearch/esm.git
10
+ git+https://github.com/agemagician/ProtTrans.git