alexfremont commited on
Commit
bccef3b
·
1 Parent(s): 832b9bb

Refactor API auth and add management endpoints for model loading/updating

Browse files
Files changed (5) hide show
  1. api/dependencies.py +40 -0
  2. api/management.py +80 -0
  3. api/prediction.py +1 -1
  4. api/router.py +28 -131
  5. main.py +2 -1
api/dependencies.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import Optional
4
+ from fastapi import Header, HTTPException
5
+
6
+ from config.settings import API_KEY, MANAGEMENT_API_KEY
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ async def verify_api_key(api_key: Optional[str] = Header(None, alias="X-API-Key")):
11
+ """Vérifie si la clé API client fournie correspond à celle configurée."""
12
+ if not API_KEY:
13
+ logger.error("API_KEY environment variable is not set. Client endpoints are unsecured!")
14
+ # Bloquer si aucune clé n'est définie pour éviter l'accès non sécurisé
15
+ raise HTTPException(status_code=503, detail="Server configuration error: API Key not set.")
16
+
17
+ # Gestion de plusieurs clés possibles séparées par des virgules
18
+ valid_keys = [key.strip() for key in API_KEY.split(',')]
19
+
20
+ if api_key is None or api_key not in valid_keys:
21
+ logger.warning("Invalid or missing client API key attempt.")
22
+ # Utiliser 403 Forbidden car l'authentification a échoué
23
+ raise HTTPException(
24
+ status_code=403,
25
+ detail="Invalid or missing API Key"
26
+ )
27
+ # logger.debug("Client API Key verified successfully.") # Optionnel: Peut être bruyant
28
+ return True # Clé valide
29
+
30
+ async def verify_management_api_key(x_api_key: Optional[str] = Header(None, alias="X-API-Key")):
31
+ """Vérifie si la clé API de gestion fournie correspond à celle configurée."""
32
+ if not MANAGEMENT_API_KEY:
33
+ logger.warning("MANAGEMENT_API_KEY is not set. Management endpoints are inaccessible!")
34
+ raise HTTPException(status_code=503, detail="Management API key not configured on server.")
35
+
36
+ if not x_api_key or x_api_key != MANAGEMENT_API_KEY:
37
+ logger.warning("Invalid or missing management API key attempt.")
38
+ raise HTTPException(status_code=403, detail="Invalid or missing Management API Key")
39
+ # logger.debug("Management API Key verified successfully.") # Optionnel
40
+ return True # Clé valide
api/management.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List
3
+ from fastapi import APIRouter, HTTPException, Depends
4
+
5
+ from api.dependencies import verify_management_api_key
6
+ from db.models import fetch_model_by_id
7
+ from models.loader import model_pipelines, _load_single_model_pipeline
8
+ from models.schemas import ModelInfo
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ router = APIRouter(prefix="/manage", tags=["Management"], dependencies=[Depends(verify_management_api_key)])
13
+
14
+ @router.post("/load_model/{model_db_id}", summary="Load a specific model into memory")
15
+ async def load_single_model(model_db_id: Any):
16
+ """Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
17
+ if model_db_id in model_pipelines:
18
+ raise HTTPException(status_code=400, detail=f"Model ID {model_db_id} is already loaded.")
19
+
20
+ try:
21
+ model_data = await fetch_model_by_id(model_db_id)
22
+ if not model_data:
23
+ raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")
24
+
25
+ logger.info(f"Loading model {model_db_id} ({model_data['model_name']}) on demand...")
26
+ success = await _load_single_model_pipeline(model_data)
27
+ if success:
28
+ logger.info(f"Model {model_db_id} ({model_data['model_name']}) loaded successfully.")
29
+ return {"message": f"Model {model_db_id} loaded successfully."}
30
+ else:
31
+ # L'erreur spécifique devrait être loggée dans _load_single_model_pipeline
32
+ raise HTTPException(status_code=500, detail=f"Failed to load model {model_db_id}.")
33
+
34
+ except HTTPException as http_exc:
35
+ raise http_exc # Re-lancer les exceptions HTTP déjà formatées
36
+ except Exception as e:
37
+ logger.exception(f"Error loading model {model_db_id}: {e}")
38
+ raise HTTPException(status_code=500, detail=f"Internal server error while loading model {model_db_id}.")
39
+
40
+ @router.post("/update_model/{model_db_id}", summary="Update (reload) an already loaded model")
41
+ async def update_single_model(model_db_id: Any):
42
+ """Met à jour un modèle déjà chargé en le rechargeant depuis le Hub.
43
+ Le modèle doit être actuellement chargé pour pouvoir être mis à jour.
44
+ """
45
+ if model_db_id not in model_pipelines:
46
+ raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first.")
47
+
48
+ try:
49
+ model_data = await fetch_model_by_id(model_db_id)
50
+ if not model_data:
51
+ # Devrait être peu probable si le modèle est chargé, mais vérification par sécurité
52
+ raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database despite being loaded.")
53
+
54
+ logger.info(f"Updating (reloading) model {model_db_id} ({model_data['model_name']})...")
55
+
56
+ # Supprimer l'ancien modèle de la mémoire
57
+ if model_db_id in model_pipelines:
58
+ del model_pipelines[model_db_id]
59
+ logger.info(f"Removed old instance of model {model_db_id} from memory.")
60
+
61
+ # Tenter de recharger
62
+ success = await _load_single_model_pipeline(model_data, force_download=True) # Forcer le re-téléchargement
63
+
64
+ if success:
65
+ logger.info(f"Model {model_db_id} ({model_data['model_name']}) updated successfully.")
66
+ return {"message": f"Model {model_db_id} updated successfully."}
67
+ else:
68
+ # Si le rechargement échoue, le modèle n'est plus en mémoire
69
+ logger.error(f"Failed to update model {model_db_id}. It has been unloaded.")
70
+ raise HTTPException(status_code=500, detail=f"Failed to update model {model_db_id}. Model is now unloaded.")
71
+
72
+ except HTTPException as http_exc:
73
+ raise http_exc
74
+ except Exception as e:
75
+ logger.exception(f"Error updating model {model_db_id}: {e}")
76
+ # Le modèle pourrait être dans un état indéterminé (déchargé ou erreur interne)
77
+ if model_db_id in model_pipelines:
78
+ del model_pipelines[model_db_id] # Assurer le déchargement en cas d'erreur imprévue
79
+ raise HTTPException(status_code=500, detail=f"Internal server error while updating model {model_db_id}. Model has been unloaded.")
80
+
api/prediction.py CHANGED
@@ -12,7 +12,7 @@ from schemas.requests import BatchPredictRequest
12
  from models.loader import get_model
13
  from steps.preprocess import process_image
14
  from config.settings import IMAGE_SIZE, NUM_THREADS
15
- from api.router import verify_api_key
16
 
17
  logger = logging.getLogger(__name__)
18
  router = APIRouter()
 
12
  from models.loader import get_model
13
  from steps.preprocess import process_image
14
  from config.settings import IMAGE_SIZE, NUM_THREADS
15
+ from api.dependencies import verify_api_key
16
 
17
  logger = logging.getLogger(__name__)
18
  router = APIRouter()
api/router.py CHANGED
@@ -1,149 +1,46 @@
1
- from fastapi import APIRouter, Request, HTTPException
2
  import logging
3
  import os
4
-
5
- from api import prediction
6
- from config.settings import API_KEY, MANAGEMENT_API_KEY
 
 
 
 
 
7
  from db.models import fetch_model_by_id
8
  from models.loader import model_pipelines, _load_single_model_pipeline, get_model
9
- from models.schemas import PredictionRequest, PredictionResponse
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
- # Router principal
14
  router = APIRouter()
15
 
16
- # Middleware d'authentification
17
- async def verify_api_key(request: Request, call_next):
18
- """Middleware pour vérifier la clé API dans les en-têtes."""
19
  # Skip if we're in debug mode or during startup
20
  if os.environ.get("DEBUG") == "1":
21
  return await call_next(request)
22
 
23
- # Pour les routes Hugging Face Space (logs, etc.)
24
- # Aussi pour les assets statiques de Gradio
25
- if request.url.path.startswith('/assets/') \
26
- or request.url.path == '/theme.css' \
27
- or request.url.path.startswith('/queue/'):
28
- return await call_next(request)
29
-
30
- if request.url.path == "/" and "logs=container" in request.url.query:
31
  return await call_next(request)
32
 
33
- api_key = request.headers.get("x-api-key")
34
- if api_key is None or api_key not in API_KEY.split(','):
35
- logger.warning(f"Unauthorized API access attempt from {request.client.host}")
36
- raise HTTPException(status_code=403, detail="Unauthorized")
 
 
 
 
37
  response = await call_next(request)
38
  return response
39
 
40
- # Dépendance pour la Sécurité de l'API de Gestion
41
- async def verify_management_api_key(x_api_key: str = Header(...)):
42
- """Vérifie si la clé API fournie correspond à celle configurée."""
43
- if not MANAGEMENT_API_KEY:
44
- logger.warning("MANAGEMENT_API_KEY is not set. Management endpoints are unsecured!")
45
- # Décider si on bloque ou on autorise sans clé définie
46
- # Pour la sécurité, il vaut mieux bloquer par défaut
47
- raise HTTPException(status_code=500, detail="Management API key not configured on server.")
48
- if x_api_key != MANAGEMENT_API_KEY:
49
- logger.warning(f"Invalid or missing API key attempt for management endpoint.")
50
- raise HTTPException(status_code=401, detail="Invalid or missing API Key")
51
- return True # Clé valide
52
-
53
- # Inclure les routes des autres modules
54
- router.include_router(prediction.router, tags=["Prediction"])
55
-
56
- # Nouvel Endpoint de Gestion
57
- @router.post(
58
- "/manage/load_model/{model_db_id}",
59
- summary="Load a specific model into memory",
60
- dependencies=[Depends(verify_management_api_key)] # Sécurise l'endpoint
61
- )
62
- async def load_single_model(model_db_id: Any): # L'ID peut être int ou str
63
- """Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
64
- logger.info(f"Received request to load model with DB ID: {model_db_id}")
65
-
66
- # 1. Vérifier si le modèle est déjà chargé
67
- if model_db_id in model_pipelines:
68
- logger.info(f"Model ID {model_db_id} is already loaded.")
69
- return {"status": "success", "message": f"Model ID {model_db_id} is already loaded."}
70
-
71
- # 2. Récupérer les informations du modèle depuis la DB
72
- try:
73
- model_data = await fetch_model_by_id(model_db_id)
74
- if not model_data:
75
- logger.error(f"Model ID {model_db_id} not found in database.")
76
- raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")
77
- except Exception as e:
78
- logger.exception(f"Database error fetching model ID {model_db_id}: {e}")
79
- raise HTTPException(status_code=500, detail=f"Database error checking model ID {model_db_id}.")
80
-
81
- # 3. Charger le modèle
82
- try:
83
- logger.info(f"Attempting to load model ID {model_db_id} ('{model_data.get('name', 'N/A')}') into memory...")
84
- pipeline = await _load_single_model_pipeline(model_data)
85
-
86
- # 4. Ajouter au dictionnaire des modèles chargés
87
- model_pipelines[model_db_id] = pipeline
88
- logger.info(f"Successfully loaded and added model ID {model_db_id} to running pipelines.")
89
- return {"status": "success", "message": f"Model ID {model_db_id} loaded successfully."}
90
-
91
- except Exception as e:
92
- logger.exception(f"Failed to load model ID {model_db_id}: {e}")
93
- # Ne pas laisser un pipeline potentiellement corrompu dans le dictionnaire
94
- if model_db_id in model_pipelines:
95
- del model_pipelines[model_db_id]
96
- raise HTTPException(status_code=500, detail=f"Failed to load model ID {model_db_id}. Check server logs for details.")
97
-
98
- @router.post(
99
- "/manage/update_model/{model_db_id}",
100
- summary="Reload/Update a specific model already in memory",
101
- dependencies=[Depends(verify_management_api_key)] # Sécurise l'endpoint
102
- )
103
- async def update_single_model(model_db_id: Any):
104
- """Retélécharge et met à jour un modèle spécifique qui est déjà chargé en mémoire."""
105
- logger.info(f"Received request to update model with DB ID: {model_db_id}")
106
-
107
- # 1. Vérifier si le modèle est actuellement chargé
108
- if model_db_id not in model_pipelines:
109
- logger.error(f"Attempted to update model ID {model_db_id}, but it is not loaded.")
110
- raise HTTPException(
111
- status_code=404,
112
- detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first."
113
- )
114
-
115
- # 2. Récupérer les informations du modèle depuis la DB (pour s'assurer qu'elles sont à jour si besoin)
116
- try:
117
- model_data = await fetch_model_by_id(model_db_id)
118
- if not model_data:
119
- # Ceci indiquerait une incohérence si le modèle est dans model_pipelines mais pas dans la DB
120
- logger.error(f"Inconsistency: Model ID {model_db_id} loaded but not found in database during update.")
121
- raise HTTPException(status_code=500, detail=f"Inconsistency: Model ID {model_db_id} not found in database.")
122
- except Exception as e:
123
- logger.exception(f"Database error fetching model ID {model_db_id} during update: {e}")
124
- raise HTTPException(status_code=500, detail=f"Database error checking model ID {model_db_id} for update.")
125
-
126
- # 3. Recharger le modèle
127
- try:
128
- logger.info(f"Attempting to reload model ID {model_db_id} ('{model_data.get('name', 'N/A')}') from source...")
129
- # Supprimer l'ancien modèle de la mémoire avant de charger le nouveau pour libérer des ressources GPU/CPU si possible
130
- # Attention : ceci pourrait causer une brève indisponibilité du modèle pendant le rechargement.
131
- # Une stratégie alternative serait de charger le nouveau d'abord, puis de remplacer.
132
- if model_db_id in model_pipelines:
133
- del model_pipelines[model_db_id]
134
- # Potentiellement forcer le nettoyage de la mémoire GPU ici si nécessaire (torch.cuda.empty_cache() - à utiliser avec prudence)
135
- logger.debug(f"Removed old instance of model ID {model_db_id} from memory before update.")
136
-
137
- pipeline = await _load_single_model_pipeline(model_data)
138
-
139
- # 4. Mettre à jour le dictionnaire avec le nouveau pipeline
140
- model_pipelines[model_db_id] = pipeline
141
- logger.info(f"Successfully updated model ID {model_db_id} in running pipelines.")
142
- return {"status": "success", "message": f"Model ID {model_db_id} updated successfully."}
143
-
144
- except Exception as e:
145
- logger.exception(f"Failed to reload model ID {model_db_id}: {e}")
146
- # Si le rechargement échoue, l'ancien modèle a déjà été supprimé.
147
- # Il faut soit tenter de recharger l'ancien, soit le laisser déchargé.
148
- # Pour l'instant, on le laisse déchargé et on signale l'erreur.
149
- raise HTTPException(status_code=500, detail=f"Failed to reload model ID {model_db_id}. Model is now unloaded. Check server logs.")
 
 
1
  import logging
2
  import os
3
+ from typing import Any, Optional, Dict, List
4
+ from fastapi import APIRouter, HTTPException, Depends, Header, Request, Response
5
+
6
+ # Import des dépendances
7
+ from api.dependencies import verify_api_key, verify_management_api_key
8
+ # Import des sous-routeurs
9
+ from api import prediction, management
10
+ from config.settings import MANAGEMENT_API_KEY
11
  from db.models import fetch_model_by_id
12
  from models.loader import model_pipelines, _load_single_model_pipeline, get_model
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
+ # Routeur principal
17
  router = APIRouter()
18
 
19
+ # --- Middleware (Laissé ici pour l'instant) ---
20
+ @router.middleware("http")
21
+ async def api_key_middleware(request: Request, call_next):
22
  # Skip if we're in debug mode or during startup
23
  if os.environ.get("DEBUG") == "1":
24
  return await call_next(request)
25
 
26
+ # Pour les routes Hugging Face Space, Gradio, etc.
27
+ # (Liste simplifiée, ajuster si nécessaire)
28
+ public_paths = ['/assets/', '/file=', '/queue/', '/startup-logs', '/config', '/info', '/gradio', '/favicon.ico', '/']
29
+ if any(request.url.path.startswith(p) for p in public_paths):
 
 
 
 
30
  return await call_next(request)
31
 
32
+ # Optionnel: Vérification globale de clé API (si non gérée par Depends)
33
+ # Désactivé par défaut pour se fier aux Depends sur les endpoints
34
+ # try:
35
+ # api_key = request.headers.get("X-API-Key")
36
+ # await verify_api_key(api_key) # Ne fonctionnera pas directement ici
37
+ # except HTTPException as e:
38
+ # return Response(content=e.detail, status_code=e.status_code)
39
+
40
  response = await call_next(request)
41
  return response
42
 
43
+ # --- Inclusion des sous-routeurs ---
44
+ router.include_router(prediction.router) # Routes de prédiction (/batch_predict)
45
+ router.include_router(management.router) # Routes de gestion (/manage/load_model, /manage/update_model)
46
+ router.include_router(management.public_router) # Routes d'info publiques (/models, /models/{id})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -3,7 +3,8 @@ from fastapi import FastAPI
3
  import gradio as gr
4
  from gradio.routes import mount_gradio_app
5
 
6
- from api.router import router, verify_api_key
 
7
  from db.models import fetch_models_for_group
8
  from models.loader import load_models, model_pipelines
9
  from config.settings import RESOURCE_GROUP
 
3
  import gradio as gr
4
  from gradio.routes import mount_gradio_app
5
 
6
+ from api.router import router
7
+ from api.dependencies import verify_api_key
8
  from db.models import fetch_models_for_group
9
  from models.loader import load_models, model_pipelines
10
  from config.settings import RESOURCE_GROUP