|
from transformers import pipeline |
|
import torch |
|
from typing import Dict, List, Any |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
|
|
self.pipeline = pipeline( |
|
"text-classification", |
|
model=path, |
|
tokenizer=path, |
|
device=device |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Args: |
|
data (:obj:`Dict[str, Any]`): |
|
The payload dictionary includes: |
|
- inputs (:obj:`str` or `List[str]`): Text(s) to classify. |
|
- parameters (:obj:`Dict[str, Any]`, optional): Additional keyword arguments for the pipeline. |
|
Return: |
|
A :obj:`List[Dict[str, Any]]`: A list of dictionaries containing the classification results for each input. |
|
""" |
|
|
|
inputs = data.pop("inputs", []) |
|
parameters = data.pop("parameters", None) |
|
|
|
|
|
|
|
if parameters: |
|
|
|
final_params = {**parameters, "top_k": None} |
|
prediction = self.pipeline(inputs, **final_params) |
|
else: |
|
|
|
prediction = self.pipeline(inputs, top_k=None) |
|
|
|
|
|
return prediction |
|
|