Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +7 -3
modeling_esm_plusplus.py
CHANGED
@@ -629,6 +629,7 @@ class EmbeddingMixin:
|
|
629 |
tokenizer: PreTrainedTokenizerBase,
|
630 |
batch_size: int = 2,
|
631 |
max_len: int = 512,
|
|
|
632 |
full_embeddings: bool = False,
|
633 |
embed_dtype: torch.dtype = torch.float32,
|
634 |
pooling_types: List[str] = ['mean'],
|
@@ -680,8 +681,9 @@ class EmbeddingMixin:
|
|
680 |
)
|
681 |
>>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
|
682 |
"""
|
683 |
-
sequences = list(set([seq[:max_len] for seq in sequences]))
|
684 |
sequences = sorted(sequences, key=len, reverse=True)
|
|
|
685 |
collate_fn = build_collator(tokenizer)
|
686 |
device = self.device
|
687 |
pooler = Pooler(pooling_types) if not full_embeddings else None
|
@@ -712,7 +714,7 @@ class EmbeddingMixin:
|
|
712 |
embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
|
713 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
714 |
if full_embeddings:
|
715 |
-
emb = emb[mask.bool()]
|
716 |
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
717 |
(seq, emb.cpu().numpy().tobytes()))
|
718 |
|
@@ -742,7 +744,9 @@ class EmbeddingMixin:
|
|
742 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
743 |
residue_embeddings = self._embed(input_ids, attention_mask)
|
744 |
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
|
745 |
-
for seq, emb in zip(seqs, embeddings):
|
|
|
|
|
746 |
embeddings_dict[seq] = emb
|
747 |
|
748 |
if save:
|
|
|
629 |
tokenizer: PreTrainedTokenizerBase,
|
630 |
batch_size: int = 2,
|
631 |
max_len: int = 512,
|
632 |
+
truncate: bool = True,
|
633 |
full_embeddings: bool = False,
|
634 |
embed_dtype: torch.dtype = torch.float32,
|
635 |
pooling_types: List[str] = ['mean'],
|
|
|
681 |
)
|
682 |
>>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
|
683 |
"""
|
684 |
+
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
|
685 |
sequences = sorted(sequences, key=len, reverse=True)
|
686 |
+
hidden_size = self.config.hidden_size
|
687 |
collate_fn = build_collator(tokenizer)
|
688 |
device = self.device
|
689 |
pooler = Pooler(pooling_types) if not full_embeddings else None
|
|
|
714 |
embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
|
715 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
716 |
if full_embeddings:
|
717 |
+
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
718 |
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
719 |
(seq, emb.cpu().numpy().tobytes()))
|
720 |
|
|
|
744 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
745 |
residue_embeddings = self._embed(input_ids, attention_mask)
|
746 |
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
|
747 |
+
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
748 |
+
if full_embeddings:
|
749 |
+
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
750 |
embeddings_dict[seq] = emb
|
751 |
|
752 |
if save:
|