File size: 1,168 Bytes
5120311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from .exec_backends.trt_loader import TrtModel, encode as encode_trt
from transformers import AutoTokenizer
import math

tokenizer_en = AutoTokenizer.from_pretrained("tensorRT/models/paraphrase-mpnet-base-v2")
model_en = TrtModel("tensorRT/models/paraphrase-mpnet-base-v2.engine")

tokenizer_multilingual = AutoTokenizer.from_pretrained("tensorRT/models/paraphrase-multilingual-MiniLM-L12-v2")
model_multilingual= TrtModel("tensorRT/models/paraphrase-multilingual-MiniLM-L12-v2.engine")

def encode(sentences, lang, batch_size = 8):
    if batch_size >=8:
        batch_size = 8
    all_embs = []

    NUM_BATCH = math.ceil(len(sentences) / batch_size)
    for j in range(NUM_BATCH):
        lst_sen = sentences[j*batch_size: j*batch_size + batch_size]

        if lang == 'en':
            # print(lst_sen)
            embs = encode_trt(lst_sen, tokenizer=tokenizer_en, trt_model= model_en, use_token_type_ids=False)
        else:
            # print(lst_sen)
            embs = encode_trt(lst_sen, tokenizer=tokenizer_multilingual, trt_model= model_multilingual, use_token_type_ids=False)
        all_embs.extend(embs)
    return all_embs