File size: 5,522 Bytes
deb3471
 
 
bdc22f6
05e2837
deb3471
d352fe2
 
 
d6a1cfd
d352fe2
 
 
 
 
 
 
 
 
deb3471
 
 
05e2837
 
 
 
 
 
 
 
 
d352fe2
 
 
 
aa37823
 
 
 
deb3471
d352fe2
aa37823
deb3471
 
 
d352fe2
deb3471
 
 
 
d352fe2
933f6ff
 
 
 
 
 
05e2837
 
 
 
 
 
 
deb3471
 
 
d352fe2
deb3471
08e0ab7
 
e0174a0
deb3471
d352fe2
deb3471
 
 
 
 
 
 
 
 
 
 
08e0ab7
 
deb3471
 
 
d352fe2
deb3471
 
08e0ab7
 
 
deb3471
 
08e0ab7
deb3471
 
08e0ab7
deb3471
d352fe2
deb3471
d6a1cfd
 
 
 
 
 
 
 
 
 
 
 
 
d352fe2
deb3471
d6a1cfd
 
 
 
 
 
 
 
 
 
 
 
deb3471
d352fe2
deb3471
 
 
aa37823
 
 
05e2837
aa37823
bdc22f6
 
05e2837
 
 
 
 
 
 
bdc22f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from optimum.neuron import utils
import logging
import sys
import os
import httpx
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Get the absolute path to the static directory
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static")
logger.info(f"Static directory path: {static_dir}")

# Get the absolute path to the templates directory
templates_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
logger.info(f"Templates directory path: {templates_dir}")

# Mount static files and templates
app.mount("/static", StaticFiles(directory=static_dir), name="static")
templates = Jinja2Templates(directory=templates_dir)

@app.get("/health")
async def health_check():
    logger.info("Health check endpoint called")
    return {"status": "healthy"}

@app.get("/")
async def home(request: Request):
    logger.info("Home page requested")
    # Check if we're running in Spaces
    is_spaces = os.getenv("SPACE_ID") is not None
    # Use HTTPS only for Spaces, otherwise use the request's protocol
    base_url = str(request.base_url)
    if is_spaces:
        base_url = base_url.replace("http://", "https://")
    return templates.TemplateResponse(
        "index.html", 
        {
            "request": request,
            "base_url": base_url
        }
    )

@app.get("/api/models")
async def get_model_list():
    logger.info("Fetching model list")
    try:
        # Add debug logging
        logger.info(f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}")
        
        model_list = utils.get_hub_cached_models(mode="inference")
        logger.info(f"Found {len(model_list)} models")
        
        models = []
        seen_models = set()
        
        for model_tuple in model_list:
            architecture, org, model_id = model_tuple
            full_model_id = f"{org}/{model_id}"
            
            if full_model_id not in seen_models:
                models.append({
                    "id": full_model_id,
                    "name": full_model_id,
                    "type": architecture
                })
                seen_models.add(full_model_id)
                
        logger.info(f"Returning {len(models)} unique models")
        return JSONResponse(content=models)
    except Exception as e:
        # Enhanced error logging
        logger.error(f"Error fetching models: {str(e)}")
        logger.error("Full error details:", exc_info=True)
        return JSONResponse(
            status_code=500,
            content={"error": str(e), "type": str(type(e).__name__)}
        )

@app.get("/api/models/{model_id:path}")
async def get_model_info_endpoint(model_id: str):
    logger.info(f"Fetching configurations for model: {model_id}")
    try:
        # Define the base URL for the HuggingFace API
        base_url = "https://huggingface.co/api/integrations/aws/v1/lookup"
        api_url = f"{base_url}/{model_id}"
        
        # Make async HTTP request with timeout
        timeout = httpx.Timeout(15.0, connect=5.0)  # 10s for entire request, 5s for connection
        async with httpx.AsyncClient(timeout=timeout) as client:
            response = await client.get(api_url)
            response.raise_for_status()
            
            data = response.json()
            configs = data.get("cached_configs", [])
            
        logger.info(f"Found {len(configs)} configurations for model {model_id}")
        return JSONResponse(content={"configurations": configs})
    except httpx.TimeoutException as e:
        logger.error(f"Timeout while fetching configurations for model {model_id}: {str(e)}", exc_info=True)
        return JSONResponse(
            status_code=504,  # Gateway Timeout
            content={"error": "Request timed out while fetching model configurations"}
        )
    except httpx.HTTPError as e:
        logger.error(f"HTTP error fetching configurations for model {model_id}: {str(e)}", exc_info=True)
        return JSONResponse(
            status_code=500,
            content={"error": f"Failed to fetch model configurations: {str(e)}"}
        )
    except Exception as e:
        logger.error(f"Error fetching configurations for model {model_id}: {str(e)}", exc_info=True)
        return JSONResponse(
            status_code=500,
            content={"error": str(e)}
        )

@app.get("/static/{path:path}")
async def static_files(path: str, request: Request):
    logger.info(f"Static file requested: {path}")
    file_path = os.path.join(static_dir, path)
    if os.path.exists(file_path):
        response = FileResponse(file_path)
        # Ensure proper content type
        if path.endswith('.css'):
            response.headers["content-type"] = "text/css"
        elif path.endswith('.js'):
            response.headers["content-type"] = "application/javascript"
        return response
    return JSONResponse(status_code=404, content={"error": "File not found"})