alexfremont commited on
Commit
5df1f2d
·
1 Parent(s): 3635acb

Move API key middleware to main.py and add startup/shutdown lifecycle management

Browse files
Files changed (2) hide show
  1. api/router.py +3 -24
  2. main.py +66 -5
api/router.py CHANGED
@@ -10,31 +10,10 @@ logger = logging.getLogger(__name__)
10
  # Routeur principal
11
  router = APIRouter()
12
 
13
- # --- Middleware (Laissé ici pour l'instant) ---
14
- @router.middleware("http")
15
- async def api_key_middleware(request: Request, call_next):
16
- # Skip if we're in debug mode or during startup
17
- if os.environ.get("DEBUG") == "1":
18
- return await call_next(request)
19
-
20
- # Pour les routes Hugging Face Space, Gradio, etc.
21
- # (Liste simplifiée, ajuster si nécessaire)
22
- public_paths = ['/assets/', '/file=', '/queue/', '/startup-logs', '/config', '/info', '/gradio', '/favicon.ico', '/']
23
- if any(request.url.path.startswith(p) for p in public_paths):
24
- return await call_next(request)
25
-
26
- # Optionnel: Vérification globale de clé API (si non gérée par Depends)
27
- # Désactivé par défaut pour se fier aux Depends sur les endpoints
28
- # try:
29
- # api_key = request.headers.get("X-API-Key")
30
- # await verify_api_key(api_key) # Ne fonctionnera pas directement ici
31
- # except HTTPException as e:
32
- # return Response(content=e.detail, status_code=e.status_code)
33
-
34
- response = await call_next(request)
35
- return response
36
 
37
  # --- Inclusion des sous-routeurs ---
38
  router.include_router(prediction.router) # Routes de prédiction (/batch_predict)
39
  router.include_router(management.router) # Routes de gestion (/manage/load_model, /manage/update_model)
40
- router.include_router(management.public_router) # Routes d'info publiques (/models, /models/{id})
 
10
  # Routeur principal
11
  router = APIRouter()
12
 
13
+ # --- Suppression du Middleware ---
14
+ # Le middleware doit être appliqué à l'instance FastAPI (app) dans main.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # --- Inclusion des sous-routeurs ---
17
  router.include_router(prediction.router) # Routes de prédiction (/batch_predict)
18
  router.include_router(management.router) # Routes de gestion (/manage/load_model, /manage/update_model)
19
+ # router.include_router(management.public_router) # Routes d'info publiques (/models, /models/{id}) - Supprimé par l'utilisateur
main.py CHANGED
@@ -1,13 +1,14 @@
1
  import logging
2
- from fastapi import FastAPI
 
 
3
  import gradio as gr
4
  from gradio.routes import mount_gradio_app
5
 
6
  from api.router import router
7
- from api.dependencies import verify_api_key
8
  from db.models import fetch_models_for_group
9
  from models.loader import load_models, model_pipelines
10
- from config.settings import RESOURCE_GROUP
11
 
12
  # Configuration de base des logs
13
  logging.basicConfig(
@@ -16,15 +17,75 @@ logging.basicConfig(
16
  )
17
  logger = logging.getLogger(__name__)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Créer l'application FastAPI
20
  app = FastAPI(
21
  title="Tamis AI Inference API",
22
  description="API pour l'inférence des modèles de classification d'objets",
23
  version="0.1.0",
 
24
  )
25
 
26
- # Ajouter middleware d'authentification
27
- app.middleware("http")(verify_api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Inclure les routes
30
  app.include_router(router)
 
1
  import logging
2
+ import os
3
+ from contextlib import asynccontextmanager
4
+ from fastapi import FastAPI, Request, Response
5
  import gradio as gr
6
  from gradio.routes import mount_gradio_app
7
 
8
  from api.router import router
 
9
  from db.models import fetch_models_for_group
10
  from models.loader import load_models, model_pipelines
11
+ from config.settings import RESOURCE_GROUP, DATABASE_URL
12
 
13
  # Configuration de base des logs
14
  logging.basicConfig(
 
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
+ # --- Événements Startup/Shutdown (Lifespan Manager) ---
21
+ @asynccontextmanager
22
+ async def lifespan(app: FastAPI):
23
+ # Code exécuté au démarrage
24
+ logger.info("Starting up API...")
25
+ logger.info(f"Attempting to connect to database: {DATABASE_URL[:DATABASE_URL.find('@')] + '@...'}") # Masquer les crédentiels
26
+ try:
27
+ models_to_load = await fetch_models_for_group(RESOURCE_GROUP)
28
+ if models_to_load:
29
+ await load_models(models_to_load)
30
+ logger.info("Initial models loaded successfully.")
31
+ else:
32
+ logger.warning(f"No models found for resource group '{RESOURCE_GROUP}'. API starting without preloaded models.")
33
+ except Exception as e:
34
+ logger.exception(f"Failed to load initial models during startup: {e}")
35
+ # Décider s'il faut empêcher le démarrage de l'API ou continuer sans modèles
36
+ # raise RuntimeError("Could not load initial models, API startup aborted.") from e
37
+
38
+ yield
39
+ # Code exécuté à l'arrêt
40
+ logger.info("Shutting down API...")
41
+ # Ajouter ici le code de nettoyage si nécessaire (ex: fermer connexions persistantes)
42
+
43
  # Créer l'application FastAPI
44
  app = FastAPI(
45
  title="Tamis AI Inference API",
46
  description="API pour l'inférence des modèles de classification d'objets",
47
  version="0.1.0",
48
+ lifespan=lifespan # Correction: Utilisation de la fonction lifespan définie ci-dessus
49
  )
50
 
51
+ # Ajout du Middleware ici
52
+ @app.middleware("http")
53
+ async def api_key_middleware(request: Request, call_next):
54
+ """Middleware pour vérifier la clé API et exempter certaines routes."""
55
+ # Skip if we're in debug mode or during startup
56
+ if os.environ.get("DEBUG") == "1":
57
+ logger.debug("DEBUG mode active, skipping API key check.")
58
+ return await call_next(request)
59
+
60
+ # Liste des chemins publics ou internes à exempter de la vérification de clé
61
+ public_paths = [
62
+ '/docs', '/openapi.json', # Documentation Swagger/OpenAPI
63
+ '/health', # Health check endpoint
64
+ '/', # Racine (Interface Gradio)
65
+ '/assets/', # Assets Gradio
66
+ '/file=', # Fichiers Gradio
67
+ '/queue/', # Queue Gradio
68
+ '/startup-logs', # Logs HF Space
69
+ '/config', # Config Gradio/HF
70
+ '/info', # Info Gradio/HF
71
+ '/gradio', # Potentiel préfixe Gradio
72
+ '/favicon.ico' # Favicon
73
+ ]
74
+
75
+ # Vérifie si le chemin commence par un des préfixes publics
76
+ is_public = any(request.url.path == p or (p.endswith('/') and request.url.path.startswith(p)) for p in public_paths)
77
+
78
+ if is_public:
79
+ logger.debug(f"Public path accessed: {request.url.path}, skipping API key check.")
80
+ response = await call_next(request)
81
+ return response
82
+ else:
83
+ # Pour toutes les autres routes, la vérification se fait via Depends() sur l'endpoint lui-même.
84
+ # Ce middleware ne fait donc plus de vérification active ici,
85
+ # il sert juste à logger et potentiellement à exempter certaines routes si besoin.
86
+ logger.debug(f"Protected path accessed: {request.url.path}. API key verification delegated to endpoint.")
87
+ response = await call_next(request)
88
+ return response
89
 
90
  # Inclure les routes
91
  app.include_router(router)