mbashas's picture
Update handler.py
84ceec9 verified
from transformers import pipeline
import torch
from typing import Dict, List, Any
class EndpointHandler():
def __init__(self, path=""):
# Determine device: use GPU if available, otherwise CPU
device = 0 if torch.cuda.is_available() else -1
# Load the text classification pipeline
# model=path and tokenizer=path tells it to load from the repository files
self.pipeline = pipeline(
"text-classification",
model=path,
tokenizer=path,
device=device
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:obj:`Dict[str, Any]`):
The payload dictionary includes:
- inputs (:obj:`str` or `List[str]`): Text(s) to classify.
- parameters (:obj:`Dict[str, Any]`, optional): Additional keyword arguments for the pipeline.
Return:
A :obj:`List[Dict[str, Any]]`: A list of dictionaries containing the classification results for each input.
"""
# Get inputs
inputs = data.pop("inputs", [])
parameters = data.pop("parameters", None)
# --- MODIFIED SECTION ---
# Run inference, requesting all results using top_k=None
if parameters:
# Ensure top_k=None is added or overrides any existing top_k in parameters
final_params = {**parameters, "top_k": None}
prediction = self.pipeline(inputs, **final_params)
else:
# Add top_k=None here
prediction = self.pipeline(inputs, top_k=None)
# --- END MODIFIED SECTION ---
return prediction