Spaces:
Sleeping
Sleeping
File size: 2,575 Bytes
e109700 0053356 e109700 bccef3b e109700 0053356 e109700 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import JSONResponse
from PIL import Image
from io import BytesIO
import logging
import httpx
import asyncio
from torchvision import transforms
from schemas.requests import BatchPredictRequest
from models.loader import get_model
from steps.preprocess import process_image
from config.settings import IMAGE_SIZE, NUM_THREADS
from api.dependencies import verify_api_key
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/batch_predict", dependencies=[Depends(verify_api_key)])
async def batch_predict(request: BatchPredictRequest):
"""Endpoint pour prédire à partir de plusieurs images."""
model_name = request.modelName
try:
# Récupérer le modèle
model = get_model(model_name)
semaphore = asyncio.Semaphore(NUM_THREADS) # Limiter à 8 tâches simultanées
async def process_single_image(image_url):
async with semaphore:
try:
async with httpx.AsyncClient() as client:
response = await client.get(image_url)
image = Image.open(BytesIO(response.content))
except Exception:
logger.error(f"Error downloading image from {image_url}")
return {"imageUrl": image_url, "error": "Invalid image URL"}
# Prétraiter l'image
processed_image = process_image(image, size=IMAGE_SIZE)
# Convertir en tenseur
image_tensor = transforms.ToTensor()(processed_image).unsqueeze(0)
# Inférence
with torch.no_grad():
outputs = model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
predicted_probabilities = probabilities.numpy().tolist()
confidence = round(predicted_probabilities[0][1], 2)
return {"imageUrl": image_url, "confidence": confidence}
# Lancer les tâches en parallèle
tasks = [process_single_image(url) for url in request.imageUrls]
results = await asyncio.gather(*tasks)
return JSONResponse(content={"results": results})
except KeyError:
raise HTTPException(status_code=404, detail="Model not found")
except Exception as e:
logger.error(f"Batch prediction error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Batch prediction failed")
|