from fastapi import APIRouter, HTTPException from pydantic import BaseModel from typing import Dict, Union, List from models.text_classification import TextClassificationModel router = APIRouter() model = TextClassificationModel() class TextInput(BaseModel): text: str class BatchTextInput(BaseModel): texts: List[str] class PredictionResponse(BaseModel): label: str confidence: float class BatchPredictionResponse(BaseModel): predictions: List[PredictionResponse] @router.post("/predict", response_model=PredictionResponse) async def predict(input_data: TextInput) -> Dict[str, Union[str, float]]: """Make a prediction for a single text.""" try: result = await model.predict(input_data.text) return result except Exception as e: raise HTTPException( status_code=500, detail=f"Prediction failed: {str(e)}" ) @router.post("/predict_batch", response_model=BatchPredictionResponse) async def predict_batch(input_data: BatchTextInput) -> Dict[str, List[Dict[str, Union[str, float]]]]: """Make predictions for multiple texts.""" try: predictions = [] for text in input_data.texts: result = await model.predict(text) predictions.append(result) return {"predictions": predictions} except Exception as e: raise HTTPException( status_code=500, detail=f"Batch prediction failed: {str(e)}" ) @router.get("/info") async def get_model_info(): """Get information about the text classification model.""" return model.get_info()