File size: 374 Bytes
e72169b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from functools import lru_cache
from transformers import pipeline, Pipeline


@lru_cache
def init_model(task: str, model: str = None) -> Pipeline:
    pipe = pipeline(
        task=task,
        model=model
    )
    return pipe


def custom_predict(text: str, pipe: Pipeline):
    result = pipe(text)
    # result: [{'label': 'POSITIVE', 'score': 0.998}]
    return result