Spaces:
Sleeping
Sleeping
import torch | |
from fastapi import APIRouter, HTTPException, Depends | |
from fastapi.responses import JSONResponse | |
from PIL import Image | |
from io import BytesIO | |
import logging | |
import httpx | |
import asyncio | |
from torchvision import transforms | |
from schemas.requests import BatchPredictRequest | |
from models.loader import get_model | |
from steps.preprocess import process_image | |
from config.settings import IMAGE_SIZE, NUM_THREADS | |
from api.dependencies import verify_api_key | |
logger = logging.getLogger(__name__) | |
router = APIRouter() | |
async def batch_predict(request: BatchPredictRequest): | |
"""Endpoint pour prédire à partir de plusieurs images.""" | |
model_name = request.modelName | |
try: | |
# Récupérer le modèle | |
model = get_model(model_name) | |
semaphore = asyncio.Semaphore(NUM_THREADS) # Limiter à 8 tâches simultanées | |
async def process_single_image(image_url): | |
async with semaphore: | |
try: | |
async with httpx.AsyncClient() as client: | |
response = await client.get(image_url) | |
image = Image.open(BytesIO(response.content)) | |
except Exception: | |
logger.error(f"Error downloading image from {image_url}") | |
return {"imageUrl": image_url, "error": "Invalid image URL"} | |
# Prétraiter l'image | |
processed_image = process_image(image, size=IMAGE_SIZE) | |
# Convertir en tenseur | |
image_tensor = transforms.ToTensor()(processed_image).unsqueeze(0) | |
# Inférence | |
with torch.no_grad(): | |
outputs = model(image_tensor) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
predicted_probabilities = probabilities.numpy().tolist() | |
confidence = round(predicted_probabilities[0][1], 2) | |
return {"imageUrl": image_url, "confidence": confidence} | |
# Lancer les tâches en parallèle | |
tasks = [process_single_image(url) for url in request.imageUrls] | |
results = await asyncio.gather(*tasks) | |
return JSONResponse(content={"results": results}) | |
except KeyError: | |
raise HTTPException(status_code=404, detail="Model not found") | |
except Exception as e: | |
logger.error(f"Batch prediction error: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail="Batch prediction failed") | |