Spaces:
Sleeping
Sleeping
import os | |
import time | |
from fastapi import APIRouter | |
from scripts.data_model import ClassificationInput, ClassificationOutput | |
from utils.pipeline import load_model | |
router = APIRouter() | |
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
MODEL_PATH = os.path.join(BASE_DIR, "ml-models", "tinybert-disaster-tweet/") | |
def disaster_classification(input: ClassificationInput)-> ClassificationOutput: | |
""" | |
Classify the disaster of a given text using a pre-trained model. | |
Args: | |
input (ClassificationInput): The input data containing the user_id and text. | |
Returns: | |
ClassificationOutput: The output data containing the user_id, text, model_name, sentiment, score, and prediction_time. | |
""" | |
try: | |
pipe = load_model(MODEL_PATH) | |
start = time.time() | |
output = pipe(input.text) | |
end = time.time() | |
prediction_time = int((end-start)*1000) | |
labels = [x['label'] for x in output] | |
scores = [x['score'] for x in output] | |
return ClassificationOutput( | |
user_id=input.user_id, | |
text=input.text, | |
model_name="tinybert-disaster-tweet", | |
sentiment=labels, | |
score=scores, | |
prediction_time=prediction_time | |
) | |
except Exception as e: | |
return {"error": f"Failed to process text classification: {str(e)}"}, 500 |