import torch from fastapi import APIRouter, HTTPException, Depends from fastapi.responses import JSONResponse from PIL import Image from io import BytesIO import logging import httpx import asyncio from torchvision import transforms from schemas.requests import BatchPredictRequest from models.loader import get_model from steps.preprocess import process_image from config.settings import IMAGE_SIZE, NUM_THREADS from api.dependencies import verify_api_key logger = logging.getLogger(__name__) router = APIRouter() @router.post("/batch_predict", dependencies=[Depends(verify_api_key)]) async def batch_predict(request: BatchPredictRequest): """Endpoint pour prédire à partir de plusieurs images.""" model_name = request.modelName try: # Récupérer le modèle model = get_model(model_name) semaphore = asyncio.Semaphore(NUM_THREADS) # Limiter à 8 tâches simultanées async def process_single_image(image_url): async with semaphore: try: async with httpx.AsyncClient() as client: response = await client.get(image_url) image = Image.open(BytesIO(response.content)) except Exception: logger.error(f"Error downloading image from {image_url}") return {"imageUrl": image_url, "error": "Invalid image URL"} # Prétraiter l'image processed_image = process_image(image, size=IMAGE_SIZE) # Convertir en tenseur image_tensor = transforms.ToTensor()(processed_image).unsqueeze(0) # Inférence with torch.no_grad(): outputs = model(image_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) predicted_probabilities = probabilities.numpy().tolist() confidence = round(predicted_probabilities[0][1], 2) return {"imageUrl": image_url, "confidence": confidence} # Lancer les tâches en parallèle tasks = [process_single_image(url) for url in request.imageUrls] results = await asyncio.gather(*tasks) return JSONResponse(content={"results": results}) except KeyError: raise HTTPException(status_code=404, detail="Model not found") except Exception as e: logger.error(f"Batch prediction error: {e}", exc_info=True) raise HTTPException(status_code=500, detail="Batch prediction failed")