File size: 8,283 Bytes
1ef2263
5df1f2d
 
68a1cf9
77f06e1
b26a6dc
f85309c
 
b26a6dc
5df1f2d
38a3c61
e109700
 
 
 
 
 
f73ccd7
5df1f2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e109700
 
 
 
 
5df1f2d
e109700
f73ccd7
5df1f2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38a3c61
e109700
 
38a3c61
b26a6dc
 
 
 
 
 
 
 
 
 
 
 
 
 
df3bf97
b26a6dc
dfb1c84
 
 
b26a6dc
df3bf97
 
 
 
 
 
 
1445bc9
 
 
 
 
df3bf97
 
 
b26a6dc
df3bf97
b26a6dc
df3bf97
b26a6dc
df3bf97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b26a6dc
 
 
 
 
 
 
 
38a3c61
e109700
1445bc9
b26a6dc
e109700
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import logging
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
import gradio as gr
from gradio.routes import mount_gradio_app
from api.router import router    
from db.models import fetch_models_for_group
from models.loader import load_models, model_pipelines
from config.settings import RESOURCE_GROUP, DATABASE_URL

# Configuration de base des logs
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

# --- Événements Startup/Shutdown (Lifespan Manager) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Code exécuté au démarrage
    logger.info("Starting up API...")
    logger.info(f"Attempting to connect to database: {DATABASE_URL[:DATABASE_URL.find('@')] + '@...'}") # Masquer les crédentiels
    try:
        models_to_load = await fetch_models_for_group(RESOURCE_GROUP)
        if models_to_load:
            await load_models(models_to_load)
            logger.info("Initial models loaded successfully.")
        else:
            logger.warning(f"No models found for resource group '{RESOURCE_GROUP}'. API starting without preloaded models.")
    except Exception as e:
        logger.exception(f"Failed to load initial models during startup: {e}")
        # Décider s'il faut empêcher le démarrage de l'API ou continuer sans modèles
        # raise RuntimeError("Could not load initial models, API startup aborted.") from e
    
    yield
    # Code exécuté à l'arrêt
    logger.info("Shutting down API...")
    # Ajouter ici le code de nettoyage si nécessaire (ex: fermer connexions persistantes)

# Créer l'application FastAPI
app = FastAPI(
    title="Tamis AI Inference API",
    description="API pour l'inférence des modèles de classification d'objets",
    version="0.1.0",
    lifespan=lifespan # Correction: Utilisation de la fonction lifespan définie ci-dessus
)

# Ajout du Middleware ici
@app.middleware("http")
async def api_key_middleware(request: Request, call_next):
    """Middleware pour vérifier la clé API et exempter certaines routes."""
    # Skip if we're in debug mode or during startup
    if os.environ.get("DEBUG") == "1":
        logger.debug("DEBUG mode active, skipping API key check.")
        return await call_next(request)

    # Liste des chemins publics ou internes à exempter de la vérification de clé
    public_paths = [
        '/docs', '/openapi.json', # Documentation Swagger/OpenAPI
        '/health',                # Health check endpoint
        '/',                      # Racine (Interface Gradio)
        '/assets/',               # Assets Gradio
        '/file=',                 # Fichiers Gradio
        '/queue/',                # Queue Gradio
        '/startup-logs',          # Logs HF Space
        '/config',                # Config Gradio/HF
        '/info',                  # Info Gradio/HF
        '/gradio',                # Potentiel préfixe Gradio
        '/favicon.ico'            # Favicon
    ]
    
    # Vérifie si le chemin commence par un des préfixes publics
    is_public = any(request.url.path == p or (p.endswith('/') and request.url.path.startswith(p)) for p in public_paths)

    if is_public:
        logger.debug(f"Public path accessed: {request.url.path}, skipping API key check.")
        response = await call_next(request)
        return response
    else:
        # Pour toutes les autres routes, la vérification se fait via Depends() sur l'endpoint lui-même.
        # Ce middleware ne fait donc plus de vérification active ici,
        # il sert juste à logger et potentiellement à exempter certaines routes si besoin.
        logger.debug(f"Protected path accessed: {request.url.path}. API key verification delegated to endpoint.")
        response = await call_next(request)
        return response

# Inclure les routes
app.include_router(router)

async def init_models():
    """Charger les modèles au démarrage pour Gradio et FastAPI."""
    logger.info("Initializing models for Gradio and FastAPI...")
    try:
        models_data = await fetch_models_for_group(RESOURCE_GROUP)
        await load_models(models_data)
        logger.info("Models loaded successfully.")
    except Exception as e:
        logger.error(f"Failed to initialize models: {e}", exc_info=True)
        # Decide if the app should fail to start or continue without models
        # raise RuntimeError("Model initialization failed.")
        # For now, let's log and continue, Gradio will show an empty list
        pass

# Définir les fonctions pour Gradio qui récupèrent les modèles chargés
def get_loaded_models_list():
    """Retourne la liste des métadonnées des modèles actuellement chargés."""
    # Extraire les métadonnées de chaque entrée dans model_pipelines
    return [item['metadata'] for item in model_pipelines.values()]

def format_model_info(metadata_list):
    """Formate les informations des modèles pour un affichage plus convivial."""
    if not metadata_list:
        return "Aucun modèle chargé actuellement."
    
    # Créer une version formatée des informations des modèles
    formatted_info = ""
    for model in metadata_list:
        # Utiliser le nom du fichier comme titre principal
        formatted_info += f"### Modèle : {model.get('hf_filename', 'N/A')}\n"
        # Mettre l'ID et la dernière mise à jour sur une même ligne
        formatted_info += f"**ID:** {model.get('model_id', 'N/A')} | **Dernière mise à jour:** {model.get('updated_at', 'N/A')}\n\n"
    
    return formatted_info

# Créer l'interface Gradio
gradio_app = gr.Blocks(title="Tamis AI - Modèles Chargés", theme=gr.themes.Soft())
with gradio_app:
    gr.Markdown("# 🤖 Tamis AI - Interface d'administration")
    gr.Markdown("## Modèles actuellement chargés dans l'API")
    
    with gr.Row():
        with gr.Column(scale=2):
            # Visualisation des modèles avec des cartes
            with gr.Tab("Vue détaillée"):
                markdown_output = gr.Markdown(value="Chargement des modèles...", elem_id="model_details")
            
            # Affichage en tableau
            with gr.Tab("Vue tableau"):
                model_table = gr.Dataframe(
                    headers=["ID", "Fichier", "Dernière mise à jour"],
                    datatype=["str", "str", "str"],
                    elem_id="model_table"
                )
                
            # Vue JSON (pour référence)
            with gr.Tab("Vue JSON (debug)"):
                json_output = gr.JSON(label="Données brutes")
        
        with gr.Column(scale=1):
            refresh_btn = gr.Button("Rafraîchir", variant="primary")
            status = gr.Textbox(label="Statut", value="Prêt", interactive=False)
    
    # Fonction pour mettre à jour tous les composants d'affichage
    def update_all_displays():
        models = get_loaded_models_list()
        formatted_text = format_model_info(models)
        
        # Préparer les données pour le tableau
        table_data = []
        for model in models:
            table_data.append([
                model.get('model_id', 'N/A'),
                model.get('hf_filename', 'N/A'),
                model.get('updated_at', 'N/A')
            ])
        
        return formatted_text, table_data, models, "Modèles mis à jour"
    
    # Connecter les événements
    refresh_btn.click(
        fn=update_all_displays,
        outputs=[markdown_output, model_table, json_output, status]
    )
    
    # Initialiser les affichages au chargement
    gradio_app.load(fn=update_all_displays, outputs=[markdown_output, model_table, json_output, status])

# Monter l'application Gradio à la racine dans FastAPI
app = mount_gradio_app(
    app, gradio_app, path="/"
)

# Event startup to load models (ensure it runs *after* Gradio is mounted if needed)
# We call init_models inside startup
@app.on_event("startup")
async def startup():
    """Initialiser l'API : charger les modèles depuis la base de données."""  
    await init_models() # Call the consolidated init function

@app.get("/health")
async def health_check():
    """Point d'entrée pour vérifier l'état de l'API."""
    return {"status": "healthy"}