File size: 2,575 Bytes
e109700
0053356
e109700
 
 
 
 
 
 
 
 
 
 
 
bccef3b
e109700
 
 
 
0053356
e109700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import JSONResponse
from PIL import Image
from io import BytesIO
import logging
import httpx
import asyncio
from torchvision import transforms

from schemas.requests import BatchPredictRequest
from models.loader import get_model
from steps.preprocess import process_image
from config.settings import IMAGE_SIZE, NUM_THREADS
from api.dependencies import verify_api_key 

logger = logging.getLogger(__name__)
router = APIRouter()

@router.post("/batch_predict", dependencies=[Depends(verify_api_key)])
async def batch_predict(request: BatchPredictRequest):
    """Endpoint pour prédire à partir de plusieurs images."""
    model_name = request.modelName

    try:
        # Récupérer le modèle
        model = get_model(model_name)
        
        semaphore = asyncio.Semaphore(NUM_THREADS)  # Limiter à 8 tâches simultanées
        
        async def process_single_image(image_url):
            async with semaphore:
                try:
                    async with httpx.AsyncClient() as client:
                        response = await client.get(image_url)
                        image = Image.open(BytesIO(response.content))
                except Exception:
                    logger.error(f"Error downloading image from {image_url}")
                    return {"imageUrl": image_url, "error": "Invalid image URL"}

                # Prétraiter l'image
                processed_image = process_image(image, size=IMAGE_SIZE)

                # Convertir en tenseur
                image_tensor = transforms.ToTensor()(processed_image).unsqueeze(0)

                # Inférence
                with torch.no_grad():
                    outputs = model(image_tensor)
                    probabilities = torch.nn.functional.softmax(outputs, dim=1)
                    predicted_probabilities = probabilities.numpy().tolist()
                    confidence = round(predicted_probabilities[0][1], 2)

                return {"imageUrl": image_url, "confidence": confidence}

        # Lancer les tâches en parallèle
        tasks = [process_single_image(url) for url in request.imageUrls]
        results = await asyncio.gather(*tasks)

        return JSONResponse(content={"results": results})
    
    except KeyError:
        raise HTTPException(status_code=404, detail="Model not found")
    except Exception as e:
        logger.error(f"Batch prediction error: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail="Batch prediction failed")