Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 2,481 Bytes
e109700 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import torch
from fastapi import APIRouter, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
from io import BytesIO
import logging
import httpx
import asyncio
from torchvision import transforms
from schemas.requests import BatchPredictRequest
from models.loader import get_model
from steps.preprocess import process_image
from config.settings import IMAGE_SIZE, NUM_THREADS
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/batch_predict")
async def batch_predict(request: BatchPredictRequest):
"""Endpoint pour prédire à partir de plusieurs images."""
model_name = request.modelName
try:
# Récupérer le modèle
model = get_model(model_name)
semaphore = asyncio.Semaphore(NUM_THREADS) # Limiter à 8 tâches simultanées
async def process_single_image(image_url):
async with semaphore:
try:
async with httpx.AsyncClient() as client:
response = await client.get(image_url)
image = Image.open(BytesIO(response.content))
except Exception:
logger.error(f"Error downloading image from {image_url}")
return {"imageUrl": image_url, "error": "Invalid image URL"}
# Prétraiter l'image
processed_image = process_image(image, size=IMAGE_SIZE)
# Convertir en tenseur
image_tensor = transforms.ToTensor()(processed_image).unsqueeze(0)
# Inférence
with torch.no_grad():
outputs = model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
predicted_probabilities = probabilities.numpy().tolist()
confidence = round(predicted_probabilities[0][1], 2)
return {"imageUrl": image_url, "confidence": confidence}
# Lancer les tâches en parallèle
tasks = [process_single_image(url) for url in request.imageUrls]
results = await asyncio.gather(*tasks)
return JSONResponse(content={"results": results})
except KeyError:
raise HTTPException(status_code=404, detail="Model not found")
except Exception as e:
logger.error(f"Batch prediction error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Batch prediction failed")
|