Faezeh Sarlakifar commited on
Commit
27bd3e3
·
1 Parent(s): d41b207

Make webfiles compatible with ZeroGPU processor

Browse files
Files changed (2) hide show
  1. app.py +9 -11
  2. requirements.txt +2 -2
app.py CHANGED
@@ -13,12 +13,12 @@ model_protT5, model_cat = load_models()
13
  # Load ProtT5 model
14
  tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
15
  model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
16
- model_t5 = model_t5.eval().to("cuda")
17
 
18
  # Load the tokenizer and model
19
  model_name = "facebook/esm2_t33_650M_UR50D"
20
  tokenizer_esm = AutoTokenizer.from_pretrained(model_name)
21
- esm_model = AutoModel.from_pretrained(model_name).to("cuda")
22
 
23
  def extract_prott5_embedding(sequence):
24
  sequence = sequence.replace(" ", "")
@@ -42,14 +42,12 @@ def extract_esm_embedding(sequence):
42
  token_representations = outputs.last_hidden_state # This is the default layer
43
  return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
44
 
45
-
46
- # def classify(sequence):
47
- # protT5_emb = extract_prott5_embedding(sequence)
48
- # esm_emb = extract_esm_embedding(sequence)
49
- # concat = torch.cat((esm_emb, protT5_emb), dim=1)
50
- # pred = predict_ensemble(protT5_emb, concat, model_protT5, model_cat)
51
- # return "Potential Allergen" if pred.item() == 1 else "Non-Allergen"
52
-
53
 
54
  @spaces.GPU(duration=120)
55
  def classify(sequence):
@@ -65,4 +63,4 @@ demo = gr.Interface(fn=classify,
65
  outputs=gr.Label(label="Prediction"))
66
 
67
  if __name__ == "__main__":
68
- demo.launch()
 
13
  # Load ProtT5 model
14
  tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
15
  model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
16
+ model_t5 = model_t5.eval()
17
 
18
  # Load the tokenizer and model
19
  model_name = "facebook/esm2_t33_650M_UR50D"
20
  tokenizer_esm = AutoTokenizer.from_pretrained(model_name)
21
+ esm_model = AutoModel.from_pretrained(model_name)
22
 
23
  def extract_prott5_embedding(sequence):
24
  sequence = sequence.replace(" ", "")
 
42
  token_representations = outputs.last_hidden_state # This is the default layer
43
  return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
44
 
45
+ def estimate_duration(sequence):
46
+ # Estimate duration based on sequence length
47
+ base_time = 30 # Base time in seconds
48
+ time_per_residue = 0.5 # Estimated time per residue
49
+ estimated_time = base_time + len(sequence) * time_per_residue
50
+ return min(int(estimated_time), 300) # Cap at 300 seconds
 
 
51
 
52
  @spaces.GPU(duration=120)
53
  def classify(sequence):
 
63
  outputs=gr.Label(label="Prediction"))
64
 
65
  if __name__ == "__main__":
66
+ demo.launch()
requirements.txt CHANGED
@@ -5,5 +5,5 @@ esm
5
  fair-esm # if esm isn't installed via pip
6
  sentencepiece
7
  h5py
8
- git+https://github.com/facebookresearch/esm.git
9
- git+https://github.com/agemagician/ProtTrans.git
 
5
  fair-esm # if esm isn't installed via pip
6
  sentencepiece
7
  h5py
8
+ spaces
9
+ git+https://github.com/facebookresearch/esm.git