Spaces:
Sleeping
Sleeping
import os | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
from transformers import pipeline | |
from torchvision import transforms | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
from steps.preprocess import process_image | |
from huggingface_hub import hf_hub_download | |
from architecture.resnet import ResNet | |
import torch | |
import logging | |
app = FastAPI() | |
image_size = 256 | |
hf_token = os.environ.get("api_read") | |
VALID_API_KEYS = os.environ.get("api_key") | |
async def verify_api_key(request, call_next): | |
logging.info(f"Received request: {request.method} {request.url}") | |
api_key = request.headers.get("x-api-key") | |
if api_key is None or api_key not in VALID_API_KEYS: | |
logging.warning("Unauthorized access attempt.") | |
raise HTTPException(status_code=403, detail="Unauthorized") | |
response = await call_next(request) | |
return response | |
print(hf_token) | |
models_locations = [ | |
{ | |
"repo_id": "TamisAI/category-lamp", | |
"subfolder": "maison-jansen/palmtree-152-0005-32-256", | |
"filename": "palmtree-jansen.pth", | |
}, | |
{ | |
"repo_id": "TamisAI/category-lamp", | |
"subfolder": "maison-charles/corail-152-0001-32-256", | |
"filename": "maison-charles-corail.pth", | |
}, | |
] | |
device = torch.device("cpu") | |
# Modèle de données pour les requêtes | |
class PredictRequest(BaseModel): | |
imageUrl: str | |
modelName: str | |
# Dictionnaire pour stocker les pipelines de modèles | |
model_pipelines = {} | |
# Create a single instance of the ResNet model | |
base_model = ResNet("resnet152", num_output_neurons=2).to(device) | |
async def load_models(): | |
# Charger les modèles au démarrage | |
print(f"Loading models...{len(models_locations)}") | |
for model_location in models_locations: | |
try: | |
print(f"Loading model: {model_location['filename']}") | |
model_weight = hf_hub_download( | |
repo_id=model_location["repo_id"], | |
subfolder=model_location["subfolder"], | |
filename=model_location["filename"], | |
token=hf_token, | |
) | |
model = base_model.__class__("resnet152", num_output_neurons=2).to(device) | |
model.load_state_dict( | |
torch.load(model_weight, weights_only=True, map_location=device) | |
) | |
model.eval() | |
model_pipelines[model_location["filename"]] = model | |
except Exception as e: | |
print(f"Error loading model {model_location['filename']}: {e}") | |
print(f"Models loaded. {len(model_pipelines)}") | |
async def predict(request: PredictRequest): | |
image_url = request.imageUrl | |
model_name = request.modelName | |
# Télécharger l'image depuis l'URL | |
try: | |
response = requests.get(image_url) | |
image = Image.open(BytesIO(response.content)) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail="Invalid image URL") | |
# Vérifier si le modèle est chargé | |
if model_name not in model_pipelines: | |
raise HTTPException(status_code=404, detail="Model not found") | |
# Preprocess the image | |
processed_image = process_image(image, size=image_size) | |
# Convert to tensor | |
image_tensor = transforms.ToTensor()(processed_image).unsqueeze(0) | |
model = model_pipelines[model_name] | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(image_tensor) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
predicted_probabilities = probabilities.numpy().tolist() | |
confidence = round(predicted_probabilities[0][1], 2) | |
# Return the probabilities as JSON | |
return JSONResponse(content={"confidence": confidence}) | |