Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
bccef3b
1
Parent(s):
832b9bb
Refactor API auth and add management endpoints for model loading/updating
Browse files- api/dependencies.py +40 -0
- api/management.py +80 -0
- api/prediction.py +1 -1
- api/router.py +28 -131
- main.py +2 -1
api/dependencies.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from typing import Optional
|
4 |
+
from fastapi import Header, HTTPException
|
5 |
+
|
6 |
+
from config.settings import API_KEY, MANAGEMENT_API_KEY
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
async def verify_api_key(api_key: Optional[str] = Header(None, alias="X-API-Key")):
|
11 |
+
"""Vérifie si la clé API client fournie correspond à celle configurée."""
|
12 |
+
if not API_KEY:
|
13 |
+
logger.error("API_KEY environment variable is not set. Client endpoints are unsecured!")
|
14 |
+
# Bloquer si aucune clé n'est définie pour éviter l'accès non sécurisé
|
15 |
+
raise HTTPException(status_code=503, detail="Server configuration error: API Key not set.")
|
16 |
+
|
17 |
+
# Gestion de plusieurs clés possibles séparées par des virgules
|
18 |
+
valid_keys = [key.strip() for key in API_KEY.split(',')]
|
19 |
+
|
20 |
+
if api_key is None or api_key not in valid_keys:
|
21 |
+
logger.warning("Invalid or missing client API key attempt.")
|
22 |
+
# Utiliser 403 Forbidden car l'authentification a échoué
|
23 |
+
raise HTTPException(
|
24 |
+
status_code=403,
|
25 |
+
detail="Invalid or missing API Key"
|
26 |
+
)
|
27 |
+
# logger.debug("Client API Key verified successfully.") # Optionnel: Peut être bruyant
|
28 |
+
return True # Clé valide
|
29 |
+
|
30 |
+
async def verify_management_api_key(x_api_key: Optional[str] = Header(None, alias="X-API-Key")):
|
31 |
+
"""Vérifie si la clé API de gestion fournie correspond à celle configurée."""
|
32 |
+
if not MANAGEMENT_API_KEY:
|
33 |
+
logger.warning("MANAGEMENT_API_KEY is not set. Management endpoints are inaccessible!")
|
34 |
+
raise HTTPException(status_code=503, detail="Management API key not configured on server.")
|
35 |
+
|
36 |
+
if not x_api_key or x_api_key != MANAGEMENT_API_KEY:
|
37 |
+
logger.warning("Invalid or missing management API key attempt.")
|
38 |
+
raise HTTPException(status_code=403, detail="Invalid or missing Management API Key")
|
39 |
+
# logger.debug("Management API Key verified successfully.") # Optionnel
|
40 |
+
return True # Clé valide
|
api/management.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Any, Dict, List
|
3 |
+
from fastapi import APIRouter, HTTPException, Depends
|
4 |
+
|
5 |
+
from api.dependencies import verify_management_api_key
|
6 |
+
from db.models import fetch_model_by_id
|
7 |
+
from models.loader import model_pipelines, _load_single_model_pipeline
|
8 |
+
from models.schemas import ModelInfo
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
router = APIRouter(prefix="/manage", tags=["Management"], dependencies=[Depends(verify_management_api_key)])
|
13 |
+
|
14 |
+
@router.post("/load_model/{model_db_id}", summary="Load a specific model into memory")
|
15 |
+
async def load_single_model(model_db_id: Any):
|
16 |
+
"""Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
|
17 |
+
if model_db_id in model_pipelines:
|
18 |
+
raise HTTPException(status_code=400, detail=f"Model ID {model_db_id} is already loaded.")
|
19 |
+
|
20 |
+
try:
|
21 |
+
model_data = await fetch_model_by_id(model_db_id)
|
22 |
+
if not model_data:
|
23 |
+
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")
|
24 |
+
|
25 |
+
logger.info(f"Loading model {model_db_id} ({model_data['model_name']}) on demand...")
|
26 |
+
success = await _load_single_model_pipeline(model_data)
|
27 |
+
if success:
|
28 |
+
logger.info(f"Model {model_db_id} ({model_data['model_name']}) loaded successfully.")
|
29 |
+
return {"message": f"Model {model_db_id} loaded successfully."}
|
30 |
+
else:
|
31 |
+
# L'erreur spécifique devrait être loggée dans _load_single_model_pipeline
|
32 |
+
raise HTTPException(status_code=500, detail=f"Failed to load model {model_db_id}.")
|
33 |
+
|
34 |
+
except HTTPException as http_exc:
|
35 |
+
raise http_exc # Re-lancer les exceptions HTTP déjà formatées
|
36 |
+
except Exception as e:
|
37 |
+
logger.exception(f"Error loading model {model_db_id}: {e}")
|
38 |
+
raise HTTPException(status_code=500, detail=f"Internal server error while loading model {model_db_id}.")
|
39 |
+
|
40 |
+
@router.post("/update_model/{model_db_id}", summary="Update (reload) an already loaded model")
|
41 |
+
async def update_single_model(model_db_id: Any):
|
42 |
+
"""Met à jour un modèle déjà chargé en le rechargeant depuis le Hub.
|
43 |
+
Le modèle doit être actuellement chargé pour pouvoir être mis à jour.
|
44 |
+
"""
|
45 |
+
if model_db_id not in model_pipelines:
|
46 |
+
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first.")
|
47 |
+
|
48 |
+
try:
|
49 |
+
model_data = await fetch_model_by_id(model_db_id)
|
50 |
+
if not model_data:
|
51 |
+
# Devrait être peu probable si le modèle est chargé, mais vérification par sécurité
|
52 |
+
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database despite being loaded.")
|
53 |
+
|
54 |
+
logger.info(f"Updating (reloading) model {model_db_id} ({model_data['model_name']})...")
|
55 |
+
|
56 |
+
# Supprimer l'ancien modèle de la mémoire
|
57 |
+
if model_db_id in model_pipelines:
|
58 |
+
del model_pipelines[model_db_id]
|
59 |
+
logger.info(f"Removed old instance of model {model_db_id} from memory.")
|
60 |
+
|
61 |
+
# Tenter de recharger
|
62 |
+
success = await _load_single_model_pipeline(model_data, force_download=True) # Forcer le re-téléchargement
|
63 |
+
|
64 |
+
if success:
|
65 |
+
logger.info(f"Model {model_db_id} ({model_data['model_name']}) updated successfully.")
|
66 |
+
return {"message": f"Model {model_db_id} updated successfully."}
|
67 |
+
else:
|
68 |
+
# Si le rechargement échoue, le modèle n'est plus en mémoire
|
69 |
+
logger.error(f"Failed to update model {model_db_id}. It has been unloaded.")
|
70 |
+
raise HTTPException(status_code=500, detail=f"Failed to update model {model_db_id}. Model is now unloaded.")
|
71 |
+
|
72 |
+
except HTTPException as http_exc:
|
73 |
+
raise http_exc
|
74 |
+
except Exception as e:
|
75 |
+
logger.exception(f"Error updating model {model_db_id}: {e}")
|
76 |
+
# Le modèle pourrait être dans un état indéterminé (déchargé ou erreur interne)
|
77 |
+
if model_db_id in model_pipelines:
|
78 |
+
del model_pipelines[model_db_id] # Assurer le déchargement en cas d'erreur imprévue
|
79 |
+
raise HTTPException(status_code=500, detail=f"Internal server error while updating model {model_db_id}. Model has been unloaded.")
|
80 |
+
|
api/prediction.py
CHANGED
@@ -12,7 +12,7 @@ from schemas.requests import BatchPredictRequest
|
|
12 |
from models.loader import get_model
|
13 |
from steps.preprocess import process_image
|
14 |
from config.settings import IMAGE_SIZE, NUM_THREADS
|
15 |
-
from api.
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
router = APIRouter()
|
|
|
12 |
from models.loader import get_model
|
13 |
from steps.preprocess import process_image
|
14 |
from config.settings import IMAGE_SIZE, NUM_THREADS
|
15 |
+
from api.dependencies import verify_api_key
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
router = APIRouter()
|
api/router.py
CHANGED
@@ -1,149 +1,46 @@
|
|
1 |
-
from fastapi import APIRouter, Request, HTTPException
|
2 |
import logging
|
3 |
import os
|
4 |
-
|
5 |
-
from
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
7 |
from db.models import fetch_model_by_id
|
8 |
from models.loader import model_pipelines, _load_single_model_pipeline, get_model
|
9 |
-
from models.schemas import PredictionRequest, PredictionResponse
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
-
#
|
14 |
router = APIRouter()
|
15 |
|
16 |
-
# Middleware
|
17 |
-
|
18 |
-
|
19 |
# Skip if we're in debug mode or during startup
|
20 |
if os.environ.get("DEBUG") == "1":
|
21 |
return await call_next(request)
|
22 |
|
23 |
-
# Pour les routes Hugging Face Space
|
24 |
-
#
|
25 |
-
|
26 |
-
|
27 |
-
or request.url.path.startswith('/queue/'):
|
28 |
-
return await call_next(request)
|
29 |
-
|
30 |
-
if request.url.path == "/" and "logs=container" in request.url.query:
|
31 |
return await call_next(request)
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
37 |
response = await call_next(request)
|
38 |
return response
|
39 |
|
40 |
-
#
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
logger.warning("MANAGEMENT_API_KEY is not set. Management endpoints are unsecured!")
|
45 |
-
# Décider si on bloque ou on autorise sans clé définie
|
46 |
-
# Pour la sécurité, il vaut mieux bloquer par défaut
|
47 |
-
raise HTTPException(status_code=500, detail="Management API key not configured on server.")
|
48 |
-
if x_api_key != MANAGEMENT_API_KEY:
|
49 |
-
logger.warning(f"Invalid or missing API key attempt for management endpoint.")
|
50 |
-
raise HTTPException(status_code=401, detail="Invalid or missing API Key")
|
51 |
-
return True # Clé valide
|
52 |
-
|
53 |
-
# Inclure les routes des autres modules
|
54 |
-
router.include_router(prediction.router, tags=["Prediction"])
|
55 |
-
|
56 |
-
# Nouvel Endpoint de Gestion
|
57 |
-
@router.post(
|
58 |
-
"/manage/load_model/{model_db_id}",
|
59 |
-
summary="Load a specific model into memory",
|
60 |
-
dependencies=[Depends(verify_management_api_key)] # Sécurise l'endpoint
|
61 |
-
)
|
62 |
-
async def load_single_model(model_db_id: Any): # L'ID peut être int ou str
|
63 |
-
"""Charge un modèle spécifique en mémoire en utilisant son ID de la base de données."""
|
64 |
-
logger.info(f"Received request to load model with DB ID: {model_db_id}")
|
65 |
-
|
66 |
-
# 1. Vérifier si le modèle est déjà chargé
|
67 |
-
if model_db_id in model_pipelines:
|
68 |
-
logger.info(f"Model ID {model_db_id} is already loaded.")
|
69 |
-
return {"status": "success", "message": f"Model ID {model_db_id} is already loaded."}
|
70 |
-
|
71 |
-
# 2. Récupérer les informations du modèle depuis la DB
|
72 |
-
try:
|
73 |
-
model_data = await fetch_model_by_id(model_db_id)
|
74 |
-
if not model_data:
|
75 |
-
logger.error(f"Model ID {model_db_id} not found in database.")
|
76 |
-
raise HTTPException(status_code=404, detail=f"Model ID {model_db_id} not found in database.")
|
77 |
-
except Exception as e:
|
78 |
-
logger.exception(f"Database error fetching model ID {model_db_id}: {e}")
|
79 |
-
raise HTTPException(status_code=500, detail=f"Database error checking model ID {model_db_id}.")
|
80 |
-
|
81 |
-
# 3. Charger le modèle
|
82 |
-
try:
|
83 |
-
logger.info(f"Attempting to load model ID {model_db_id} ('{model_data.get('name', 'N/A')}') into memory...")
|
84 |
-
pipeline = await _load_single_model_pipeline(model_data)
|
85 |
-
|
86 |
-
# 4. Ajouter au dictionnaire des modèles chargés
|
87 |
-
model_pipelines[model_db_id] = pipeline
|
88 |
-
logger.info(f"Successfully loaded and added model ID {model_db_id} to running pipelines.")
|
89 |
-
return {"status": "success", "message": f"Model ID {model_db_id} loaded successfully."}
|
90 |
-
|
91 |
-
except Exception as e:
|
92 |
-
logger.exception(f"Failed to load model ID {model_db_id}: {e}")
|
93 |
-
# Ne pas laisser un pipeline potentiellement corrompu dans le dictionnaire
|
94 |
-
if model_db_id in model_pipelines:
|
95 |
-
del model_pipelines[model_db_id]
|
96 |
-
raise HTTPException(status_code=500, detail=f"Failed to load model ID {model_db_id}. Check server logs for details.")
|
97 |
-
|
98 |
-
@router.post(
|
99 |
-
"/manage/update_model/{model_db_id}",
|
100 |
-
summary="Reload/Update a specific model already in memory",
|
101 |
-
dependencies=[Depends(verify_management_api_key)] # Sécurise l'endpoint
|
102 |
-
)
|
103 |
-
async def update_single_model(model_db_id: Any):
|
104 |
-
"""Retélécharge et met à jour un modèle spécifique qui est déjà chargé en mémoire."""
|
105 |
-
logger.info(f"Received request to update model with DB ID: {model_db_id}")
|
106 |
-
|
107 |
-
# 1. Vérifier si le modèle est actuellement chargé
|
108 |
-
if model_db_id not in model_pipelines:
|
109 |
-
logger.error(f"Attempted to update model ID {model_db_id}, but it is not loaded.")
|
110 |
-
raise HTTPException(
|
111 |
-
status_code=404,
|
112 |
-
detail=f"Model ID {model_db_id} is not currently loaded. Use load_model first."
|
113 |
-
)
|
114 |
-
|
115 |
-
# 2. Récupérer les informations du modèle depuis la DB (pour s'assurer qu'elles sont à jour si besoin)
|
116 |
-
try:
|
117 |
-
model_data = await fetch_model_by_id(model_db_id)
|
118 |
-
if not model_data:
|
119 |
-
# Ceci indiquerait une incohérence si le modèle est dans model_pipelines mais pas dans la DB
|
120 |
-
logger.error(f"Inconsistency: Model ID {model_db_id} loaded but not found in database during update.")
|
121 |
-
raise HTTPException(status_code=500, detail=f"Inconsistency: Model ID {model_db_id} not found in database.")
|
122 |
-
except Exception as e:
|
123 |
-
logger.exception(f"Database error fetching model ID {model_db_id} during update: {e}")
|
124 |
-
raise HTTPException(status_code=500, detail=f"Database error checking model ID {model_db_id} for update.")
|
125 |
-
|
126 |
-
# 3. Recharger le modèle
|
127 |
-
try:
|
128 |
-
logger.info(f"Attempting to reload model ID {model_db_id} ('{model_data.get('name', 'N/A')}') from source...")
|
129 |
-
# Supprimer l'ancien modèle de la mémoire avant de charger le nouveau pour libérer des ressources GPU/CPU si possible
|
130 |
-
# Attention : ceci pourrait causer une brève indisponibilité du modèle pendant le rechargement.
|
131 |
-
# Une stratégie alternative serait de charger le nouveau d'abord, puis de remplacer.
|
132 |
-
if model_db_id in model_pipelines:
|
133 |
-
del model_pipelines[model_db_id]
|
134 |
-
# Potentiellement forcer le nettoyage de la mémoire GPU ici si nécessaire (torch.cuda.empty_cache() - à utiliser avec prudence)
|
135 |
-
logger.debug(f"Removed old instance of model ID {model_db_id} from memory before update.")
|
136 |
-
|
137 |
-
pipeline = await _load_single_model_pipeline(model_data)
|
138 |
-
|
139 |
-
# 4. Mettre à jour le dictionnaire avec le nouveau pipeline
|
140 |
-
model_pipelines[model_db_id] = pipeline
|
141 |
-
logger.info(f"Successfully updated model ID {model_db_id} in running pipelines.")
|
142 |
-
return {"status": "success", "message": f"Model ID {model_db_id} updated successfully."}
|
143 |
-
|
144 |
-
except Exception as e:
|
145 |
-
logger.exception(f"Failed to reload model ID {model_db_id}: {e}")
|
146 |
-
# Si le rechargement échoue, l'ancien modèle a déjà été supprimé.
|
147 |
-
# Il faut soit tenter de recharger l'ancien, soit le laisser déchargé.
|
148 |
-
# Pour l'instant, on le laisse déchargé et on signale l'erreur.
|
149 |
-
raise HTTPException(status_code=500, detail=f"Failed to reload model ID {model_db_id}. Model is now unloaded. Check server logs.")
|
|
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
+
from typing import Any, Optional, Dict, List
|
4 |
+
from fastapi import APIRouter, HTTPException, Depends, Header, Request, Response
|
5 |
+
|
6 |
+
# Import des dépendances
|
7 |
+
from api.dependencies import verify_api_key, verify_management_api_key
|
8 |
+
# Import des sous-routeurs
|
9 |
+
from api import prediction, management
|
10 |
+
from config.settings import MANAGEMENT_API_KEY
|
11 |
from db.models import fetch_model_by_id
|
12 |
from models.loader import model_pipelines, _load_single_model_pipeline, get_model
|
|
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
16 |
+
# Routeur principal
|
17 |
router = APIRouter()
|
18 |
|
19 |
+
# --- Middleware (Laissé ici pour l'instant) ---
|
20 |
+
@router.middleware("http")
|
21 |
+
async def api_key_middleware(request: Request, call_next):
|
22 |
# Skip if we're in debug mode or during startup
|
23 |
if os.environ.get("DEBUG") == "1":
|
24 |
return await call_next(request)
|
25 |
|
26 |
+
# Pour les routes Hugging Face Space, Gradio, etc.
|
27 |
+
# (Liste simplifiée, ajuster si nécessaire)
|
28 |
+
public_paths = ['/assets/', '/file=', '/queue/', '/startup-logs', '/config', '/info', '/gradio', '/favicon.ico', '/']
|
29 |
+
if any(request.url.path.startswith(p) for p in public_paths):
|
|
|
|
|
|
|
|
|
30 |
return await call_next(request)
|
31 |
|
32 |
+
# Optionnel: Vérification globale de clé API (si non gérée par Depends)
|
33 |
+
# Désactivé par défaut pour se fier aux Depends sur les endpoints
|
34 |
+
# try:
|
35 |
+
# api_key = request.headers.get("X-API-Key")
|
36 |
+
# await verify_api_key(api_key) # Ne fonctionnera pas directement ici
|
37 |
+
# except HTTPException as e:
|
38 |
+
# return Response(content=e.detail, status_code=e.status_code)
|
39 |
+
|
40 |
response = await call_next(request)
|
41 |
return response
|
42 |
|
43 |
+
# --- Inclusion des sous-routeurs ---
|
44 |
+
router.include_router(prediction.router) # Routes de prédiction (/batch_predict)
|
45 |
+
router.include_router(management.router) # Routes de gestion (/manage/load_model, /manage/update_model)
|
46 |
+
router.include_router(management.public_router) # Routes d'info publiques (/models, /models/{id})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
@@ -3,7 +3,8 @@ from fastapi import FastAPI
|
|
3 |
import gradio as gr
|
4 |
from gradio.routes import mount_gradio_app
|
5 |
|
6 |
-
from api.router import router
|
|
|
7 |
from db.models import fetch_models_for_group
|
8 |
from models.loader import load_models, model_pipelines
|
9 |
from config.settings import RESOURCE_GROUP
|
|
|
3 |
import gradio as gr
|
4 |
from gradio.routes import mount_gradio_app
|
5 |
|
6 |
+
from api.router import router
|
7 |
+
from api.dependencies import verify_api_key
|
8 |
from db.models import fetch_models_for_group
|
9 |
from models.loader import load_models, model_pipelines
|
10 |
from config.settings import RESOURCE_GROUP
|