inference-api-g1 / api /prediction.py
alexfremont's picture
Refactor API architecture with modular design and database integration
e109700
raw
history blame
2.48 kB
import torch
from fastapi import APIRouter, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
from io import BytesIO
import logging
import httpx
import asyncio
from torchvision import transforms
from schemas.requests import BatchPredictRequest
from models.loader import get_model
from steps.preprocess import process_image
from config.settings import IMAGE_SIZE, NUM_THREADS
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/batch_predict")
async def batch_predict(request: BatchPredictRequest):
"""Endpoint pour prédire à partir de plusieurs images."""
model_name = request.modelName
try:
# Récupérer le modèle
model = get_model(model_name)
semaphore = asyncio.Semaphore(NUM_THREADS) # Limiter à 8 tâches simultanées
async def process_single_image(image_url):
async with semaphore:
try:
async with httpx.AsyncClient() as client:
response = await client.get(image_url)
image = Image.open(BytesIO(response.content))
except Exception:
logger.error(f"Error downloading image from {image_url}")
return {"imageUrl": image_url, "error": "Invalid image URL"}
# Prétraiter l'image
processed_image = process_image(image, size=IMAGE_SIZE)
# Convertir en tenseur
image_tensor = transforms.ToTensor()(processed_image).unsqueeze(0)
# Inférence
with torch.no_grad():
outputs = model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
predicted_probabilities = probabilities.numpy().tolist()
confidence = round(predicted_probabilities[0][1], 2)
return {"imageUrl": image_url, "confidence": confidence}
# Lancer les tâches en parallèle
tasks = [process_single_image(url) for url in request.imageUrls]
results = await asyncio.gather(*tasks)
return JSONResponse(content={"results": results})
except KeyError:
raise HTTPException(status_code=404, detail="Model not found")
except Exception as e:
logger.error(f"Batch prediction error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Batch prediction failed")