|
--- |
|
language: |
|
- he |
|
pipeline_tag: text-generation |
|
--- |
|
|
|
### Description |
|
Experiments with encoder-decoder model, where encoder is [alephbert-base](https://huggingface.co/onlplab/alephbert-base) and [decoder is pruned mT5-base model](https://huggingface.co/imvladikon/het5-base) |
|
Could be useful for generation negative and hard-negative samples for pair-text classification. |
|
(To paraphrase is better to use classical approaches rather than this one) |
|
|
|
|
|
### Usage |
|
|
|
```bash |
|
git clone https://huggingface.co/imvladikon/alephbert-encoder-t5-decoder |
|
``` |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModel |
|
from transformers.modeling_outputs import BaseModelOutput |
|
from datasets import load_dataset |
|
|
|
enc_checkpoint = "./alephbert-encoder-t5-decoder/encoder" |
|
enc_tokenizer = AutoTokenizer.from_pretrained(enc_checkpoint) |
|
encoder = AutoModel.from_pretrained(enc_checkpoint).cuda() |
|
|
|
dec_checkpoint = "./alephbert-encoder-t5-decoder/decoder" |
|
dec_tokenizer = AutoTokenizer.from_pretrained(dec_checkpoint) |
|
decoder = AutoModelForSeq2SeqLM.from_pretrained(dec_checkpoint).cuda() |
|
|
|
|
|
def encode(texts): |
|
encoded_input = enc_tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors='pt') |
|
with torch.no_grad(): |
|
model_output = encoder(**encoded_input.to(encoder.device)) |
|
embeddings = model_output.pooler_output |
|
embeddings = torch.nn.functional.normalize(embeddings) |
|
return embeddings |
|
|
|
|
|
def decode(embeddings, max_length=256, repetition_penalty=3.0, **kwargs): |
|
out = decoder.generate( |
|
encoder_outputs=BaseModelOutput(last_hidden_state=embeddings.unsqueeze(1)), |
|
max_length=max_length, |
|
repetition_penalty=repetition_penalty, |
|
) |
|
return [dec_tokenizer.decode(tokens, skip_special_tokens=True) for tokens in out] |
|
|
|
|
|
encoder.eval() |
|
|
|
text = """ |
|
诪讞专 讬讜住讬祝 诇讛讬讜转 诪注讜谞谉 讞诇拽讬转 讜讘诪讛诇讱 讛讬讜诐 讬转讞讝拽讜 讛专讜讞讜转 讘讚专讜诐 讛讗专抓 讜讬讬转讻谉 讗讜讘讱 讘讗讝讜专. |
|
""".strip() |
|
batch = [text] |
|
embeddings = encode(batch) |
|
decoder.eval() |
|
out = decoder.generate(encoder_outputs=BaseModelOutput(last_hidden_state=embeddings.unsqueeze(1)), max_length=512, repetition_penalty=3.0) |
|
|
|
for t, o in zip(batch, out): |
|
print(t) |
|
print(dec_tokenizer.decode(o, skip_special_tokens=True)) |
|
print('-----------') |
|
``` |
|
|