Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
5df1f2d
1
Parent(s):
3635acb
Move API key middleware to main.py and add startup/shutdown lifecycle management
Browse files- api/router.py +3 -24
- main.py +66 -5
api/router.py
CHANGED
@@ -10,31 +10,10 @@ logger = logging.getLogger(__name__)
|
|
10 |
# Routeur principal
|
11 |
router = APIRouter()
|
12 |
|
13 |
-
# ---
|
14 |
-
|
15 |
-
async def api_key_middleware(request: Request, call_next):
|
16 |
-
# Skip if we're in debug mode or during startup
|
17 |
-
if os.environ.get("DEBUG") == "1":
|
18 |
-
return await call_next(request)
|
19 |
-
|
20 |
-
# Pour les routes Hugging Face Space, Gradio, etc.
|
21 |
-
# (Liste simplifiée, ajuster si nécessaire)
|
22 |
-
public_paths = ['/assets/', '/file=', '/queue/', '/startup-logs', '/config', '/info', '/gradio', '/favicon.ico', '/']
|
23 |
-
if any(request.url.path.startswith(p) for p in public_paths):
|
24 |
-
return await call_next(request)
|
25 |
-
|
26 |
-
# Optionnel: Vérification globale de clé API (si non gérée par Depends)
|
27 |
-
# Désactivé par défaut pour se fier aux Depends sur les endpoints
|
28 |
-
# try:
|
29 |
-
# api_key = request.headers.get("X-API-Key")
|
30 |
-
# await verify_api_key(api_key) # Ne fonctionnera pas directement ici
|
31 |
-
# except HTTPException as e:
|
32 |
-
# return Response(content=e.detail, status_code=e.status_code)
|
33 |
-
|
34 |
-
response = await call_next(request)
|
35 |
-
return response
|
36 |
|
37 |
# --- Inclusion des sous-routeurs ---
|
38 |
router.include_router(prediction.router) # Routes de prédiction (/batch_predict)
|
39 |
router.include_router(management.router) # Routes de gestion (/manage/load_model, /manage/update_model)
|
40 |
-
router.include_router(management.public_router) # Routes d'info publiques (/models, /models/{id})
|
|
|
10 |
# Routeur principal
|
11 |
router = APIRouter()
|
12 |
|
13 |
+
# --- Suppression du Middleware ---
|
14 |
+
# Le middleware doit être appliqué à l'instance FastAPI (app) dans main.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# --- Inclusion des sous-routeurs ---
|
17 |
router.include_router(prediction.router) # Routes de prédiction (/batch_predict)
|
18 |
router.include_router(management.router) # Routes de gestion (/manage/load_model, /manage/update_model)
|
19 |
+
# router.include_router(management.public_router) # Routes d'info publiques (/models, /models/{id}) - Supprimé par l'utilisateur
|
main.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
import logging
|
2 |
-
|
|
|
|
|
3 |
import gradio as gr
|
4 |
from gradio.routes import mount_gradio_app
|
5 |
|
6 |
from api.router import router
|
7 |
-
from api.dependencies import verify_api_key
|
8 |
from db.models import fetch_models_for_group
|
9 |
from models.loader import load_models, model_pipelines
|
10 |
-
from config.settings import RESOURCE_GROUP
|
11 |
|
12 |
# Configuration de base des logs
|
13 |
logging.basicConfig(
|
@@ -16,15 +17,75 @@ logging.basicConfig(
|
|
16 |
)
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
# Créer l'application FastAPI
|
20 |
app = FastAPI(
|
21 |
title="Tamis AI Inference API",
|
22 |
description="API pour l'inférence des modèles de classification d'objets",
|
23 |
version="0.1.0",
|
|
|
24 |
)
|
25 |
|
26 |
-
#
|
27 |
-
app.middleware("http")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
# Inclure les routes
|
30 |
app.include_router(router)
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
+
from contextlib import asynccontextmanager
|
4 |
+
from fastapi import FastAPI, Request, Response
|
5 |
import gradio as gr
|
6 |
from gradio.routes import mount_gradio_app
|
7 |
|
8 |
from api.router import router
|
|
|
9 |
from db.models import fetch_models_for_group
|
10 |
from models.loader import load_models, model_pipelines
|
11 |
+
from config.settings import RESOURCE_GROUP, DATABASE_URL
|
12 |
|
13 |
# Configuration de base des logs
|
14 |
logging.basicConfig(
|
|
|
17 |
)
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
20 |
+
# --- Événements Startup/Shutdown (Lifespan Manager) ---
|
21 |
+
@asynccontextmanager
|
22 |
+
async def lifespan(app: FastAPI):
|
23 |
+
# Code exécuté au démarrage
|
24 |
+
logger.info("Starting up API...")
|
25 |
+
logger.info(f"Attempting to connect to database: {DATABASE_URL[:DATABASE_URL.find('@')] + '@...'}") # Masquer les crédentiels
|
26 |
+
try:
|
27 |
+
models_to_load = await fetch_models_for_group(RESOURCE_GROUP)
|
28 |
+
if models_to_load:
|
29 |
+
await load_models(models_to_load)
|
30 |
+
logger.info("Initial models loaded successfully.")
|
31 |
+
else:
|
32 |
+
logger.warning(f"No models found for resource group '{RESOURCE_GROUP}'. API starting without preloaded models.")
|
33 |
+
except Exception as e:
|
34 |
+
logger.exception(f"Failed to load initial models during startup: {e}")
|
35 |
+
# Décider s'il faut empêcher le démarrage de l'API ou continuer sans modèles
|
36 |
+
# raise RuntimeError("Could not load initial models, API startup aborted.") from e
|
37 |
+
|
38 |
+
yield
|
39 |
+
# Code exécuté à l'arrêt
|
40 |
+
logger.info("Shutting down API...")
|
41 |
+
# Ajouter ici le code de nettoyage si nécessaire (ex: fermer connexions persistantes)
|
42 |
+
|
43 |
# Créer l'application FastAPI
|
44 |
app = FastAPI(
|
45 |
title="Tamis AI Inference API",
|
46 |
description="API pour l'inférence des modèles de classification d'objets",
|
47 |
version="0.1.0",
|
48 |
+
lifespan=lifespan # Correction: Utilisation de la fonction lifespan définie ci-dessus
|
49 |
)
|
50 |
|
51 |
+
# Ajout du Middleware ici
|
52 |
+
@app.middleware("http")
|
53 |
+
async def api_key_middleware(request: Request, call_next):
|
54 |
+
"""Middleware pour vérifier la clé API et exempter certaines routes."""
|
55 |
+
# Skip if we're in debug mode or during startup
|
56 |
+
if os.environ.get("DEBUG") == "1":
|
57 |
+
logger.debug("DEBUG mode active, skipping API key check.")
|
58 |
+
return await call_next(request)
|
59 |
+
|
60 |
+
# Liste des chemins publics ou internes à exempter de la vérification de clé
|
61 |
+
public_paths = [
|
62 |
+
'/docs', '/openapi.json', # Documentation Swagger/OpenAPI
|
63 |
+
'/health', # Health check endpoint
|
64 |
+
'/', # Racine (Interface Gradio)
|
65 |
+
'/assets/', # Assets Gradio
|
66 |
+
'/file=', # Fichiers Gradio
|
67 |
+
'/queue/', # Queue Gradio
|
68 |
+
'/startup-logs', # Logs HF Space
|
69 |
+
'/config', # Config Gradio/HF
|
70 |
+
'/info', # Info Gradio/HF
|
71 |
+
'/gradio', # Potentiel préfixe Gradio
|
72 |
+
'/favicon.ico' # Favicon
|
73 |
+
]
|
74 |
+
|
75 |
+
# Vérifie si le chemin commence par un des préfixes publics
|
76 |
+
is_public = any(request.url.path == p or (p.endswith('/') and request.url.path.startswith(p)) for p in public_paths)
|
77 |
+
|
78 |
+
if is_public:
|
79 |
+
logger.debug(f"Public path accessed: {request.url.path}, skipping API key check.")
|
80 |
+
response = await call_next(request)
|
81 |
+
return response
|
82 |
+
else:
|
83 |
+
# Pour toutes les autres routes, la vérification se fait via Depends() sur l'endpoint lui-même.
|
84 |
+
# Ce middleware ne fait donc plus de vérification active ici,
|
85 |
+
# il sert juste à logger et potentiellement à exempter certaines routes si besoin.
|
86 |
+
logger.debug(f"Protected path accessed: {request.url.path}. API key verification delegated to endpoint.")
|
87 |
+
response = await call_next(request)
|
88 |
+
return response
|
89 |
|
90 |
# Inclure les routes
|
91 |
app.include_router(router)
|