File size: 1,168 Bytes
5120311 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
from .exec_backends.trt_loader import TrtModel, encode as encode_trt
from transformers import AutoTokenizer
import math
tokenizer_en = AutoTokenizer.from_pretrained("tensorRT/models/paraphrase-mpnet-base-v2")
model_en = TrtModel("tensorRT/models/paraphrase-mpnet-base-v2.engine")
tokenizer_multilingual = AutoTokenizer.from_pretrained("tensorRT/models/paraphrase-multilingual-MiniLM-L12-v2")
model_multilingual= TrtModel("tensorRT/models/paraphrase-multilingual-MiniLM-L12-v2.engine")
def encode(sentences, lang, batch_size = 8):
if batch_size >=8:
batch_size = 8
all_embs = []
NUM_BATCH = math.ceil(len(sentences) / batch_size)
for j in range(NUM_BATCH):
lst_sen = sentences[j*batch_size: j*batch_size + batch_size]
if lang == 'en':
# print(lst_sen)
embs = encode_trt(lst_sen, tokenizer=tokenizer_en, trt_model= model_en, use_token_type_ids=False)
else:
# print(lst_sen)
embs = encode_trt(lst_sen, tokenizer=tokenizer_multilingual, trt_model= model_multilingual, use_token_type_ids=False)
all_embs.extend(embs)
return all_embs
|