lhallee commited on
Commit
32d5094
·
verified ·
1 Parent(s): 688ced4

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +7 -3
modeling_esm_plusplus.py CHANGED
@@ -629,6 +629,7 @@ class EmbeddingMixin:
629
  tokenizer: PreTrainedTokenizerBase,
630
  batch_size: int = 2,
631
  max_len: int = 512,
 
632
  full_embeddings: bool = False,
633
  embed_dtype: torch.dtype = torch.float32,
634
  pooling_types: List[str] = ['mean'],
@@ -680,8 +681,9 @@ class EmbeddingMixin:
680
  )
681
  >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
682
  """
683
- sequences = list(set([seq[:max_len] for seq in sequences]))
684
  sequences = sorted(sequences, key=len, reverse=True)
 
685
  collate_fn = build_collator(tokenizer)
686
  device = self.device
687
  pooler = Pooler(pooling_types) if not full_embeddings else None
@@ -712,7 +714,7 @@ class EmbeddingMixin:
712
  embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
713
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
714
  if full_embeddings:
715
- emb = emb[mask.bool()]
716
  c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
717
  (seq, emb.cpu().numpy().tobytes()))
718
 
@@ -742,7 +744,9 @@ class EmbeddingMixin:
742
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
743
  residue_embeddings = self._embed(input_ids, attention_mask)
744
  embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
745
- for seq, emb in zip(seqs, embeddings):
 
 
746
  embeddings_dict[seq] = emb
747
 
748
  if save:
 
629
  tokenizer: PreTrainedTokenizerBase,
630
  batch_size: int = 2,
631
  max_len: int = 512,
632
+ truncate: bool = True,
633
  full_embeddings: bool = False,
634
  embed_dtype: torch.dtype = torch.float32,
635
  pooling_types: List[str] = ['mean'],
 
681
  )
682
  >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
683
  """
684
+ sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
685
  sequences = sorted(sequences, key=len, reverse=True)
686
+ hidden_size = self.config.hidden_size
687
  collate_fn = build_collator(tokenizer)
688
  device = self.device
689
  pooler = Pooler(pooling_types) if not full_embeddings else None
 
714
  embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
715
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
716
  if full_embeddings:
717
+ emb = emb[mask.bool()].reshape(-1, hidden_size)
718
  c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
719
  (seq, emb.cpu().numpy().tobytes()))
720
 
 
744
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
745
  residue_embeddings = self._embed(input_ids, attention_mask)
746
  embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
747
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
748
+ if full_embeddings:
749
+ emb = emb[mask.bool()].reshape(-1, hidden_size)
750
  embeddings_dict[seq] = emb
751
 
752
  if save: