Faezeh Sarlakifar commited on
Commit
4745b4a
·
1 Parent(s): d8f5373

Update esm embedder function

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  from transformers import T5Tokenizer, T5EncoderModel
5
  import esm
6
  from inference import load_models, predict_ensemble
 
7
 
8
  # Load trained models
9
  model_protT5, model_cat = load_models()
@@ -13,11 +14,10 @@ tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_low
13
  model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
14
  model_t5 = model_t5.eval()
15
 
16
- # Load ESM model
17
- esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
18
- batch_converter = alphabet.get_batch_converter()
19
- esm_model.eval()
20
-
21
 
22
  def extract_prott5_embedding(sequence):
23
  sequence = sequence.replace(" ", "")
@@ -28,14 +28,21 @@ def extract_prott5_embedding(sequence):
28
  return torch.mean(embedding, dim=1)
29
 
30
 
 
31
  def extract_esm_embedding(sequence):
32
- batch_labels, batch_strs, batch_tokens = batch_converter([("protein1", sequence)])
 
 
 
33
  with torch.no_grad():
34
- results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
35
- token_representations = results["representations"][33]
 
 
36
  return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
37
 
38
 
 
39
  def classify(sequence):
40
  protT5_emb = extract_prott5_embedding(sequence)
41
  esm_emb = extract_esm_embedding(sequence)
 
4
  from transformers import T5Tokenizer, T5EncoderModel
5
  import esm
6
  from inference import load_models, predict_ensemble
7
+ from transformers import AutoTokenizer, AutoModel
8
 
9
  # Load trained models
10
  model_protT5, model_cat = load_models()
 
14
  model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
15
  model_t5 = model_t5.eval()
16
 
17
+ # Load the tokenizer and model
18
+ model_name = "facebook/esm2_t33_650M_UR50D"
19
+ tokenizer_esm = AutoTokenizer.from_pretrained(model_name)
20
+ esm_model = AutoModel.from_pretrained(model_name)
 
21
 
22
  def extract_prott5_embedding(sequence):
23
  sequence = sequence.replace(" ", "")
 
28
  return torch.mean(embedding, dim=1)
29
 
30
 
31
+ # Extract ESM2 embedding
32
  def extract_esm_embedding(sequence):
33
+ # Tokenize the sequence
34
+ inputs = tokenizer_esm(sequence, return_tensors="pt", padding=True, truncation=True)
35
+
36
+ # Forward pass through the model
37
  with torch.no_grad():
38
+ outputs = esm_model(**inputs)
39
+
40
+ # Extract the embeddings from the 33rd layer (ESM2 layer)
41
+ token_representations = outputs.last_hidden_state # This is the default layer
42
  return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
43
 
44
 
45
+
46
  def classify(sequence):
47
  protT5_emb = extract_prott5_embedding(sequence)
48
  esm_emb = extract_esm_embedding(sequence)