ParthSadaria commited on
Commit
1f0a039
·
verified ·
1 Parent(s): 20b82c3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +757 -960
main.py CHANGED
@@ -1,1056 +1,886 @@
1
  import os
2
  import re
3
- import json
4
- import datetime
5
- import time
6
- import asyncio
7
- import logging
8
- from pathlib import Path
9
- from functools import lru_cache
10
- from typing import Optional, Dict, List, Any, Generator, Set
11
- from concurrent.futures import ThreadPoolExecutor
12
-
13
- # Third-party libraries (ensure these are in requirements.txt)
14
  from dotenv import load_dotenv
15
- from fastapi import FastAPI, HTTPException, Request, Depends, Security, Response
16
  from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
17
  from fastapi.security import APIKeyHeader
18
  from pydantic import BaseModel
19
  import httpx
20
- import uvloop # Use uvloop for performance
 
 
 
 
 
 
 
 
 
 
 
21
  from fastapi.middleware.gzip import GZipMiddleware
22
  from starlette.middleware.cors import CORSMiddleware
23
- import cloudscraper # For bypassing Cloudflare, potentially unreliable
24
- import requests # For synchronous requests like in /dynamo
25
-
26
- # HF Space Note: Ensure usage_tracker.py is in your repository
27
- try:
28
- from usage_tracker import UsageTracker
29
- usage_tracker = UsageTracker()
30
- except ImportError:
31
- print("Warning: usage_tracker.py not found. Usage tracking will be disabled.")
32
- # Create a dummy tracker if the file is missing
33
- class DummyUsageTracker:
34
- def record_request(self, *args, **kwargs): pass
35
- def get_usage_summary(self, *args, **kwargs): return {}
36
- def save_data(self, *args, **kwargs): pass
37
- usage_tracker = DummyUsageTracker()
38
-
39
-
40
- # --- Configuration & Setup ---
41
-
42
- # HF Space Note: uvloop can improve performance in I/O bound tasks common in web apps.
43
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
44
 
45
- # HF Space Note: Adjust max_workers based on your HF Space resources (CPU).
46
- # Higher tiers allow more workers. Start lower (e.g., 4) for free tier.
47
- executor = ThreadPoolExecutor(max_workers=8)
48
 
49
- # HF Space Note: load_dotenv() is useful for local dev but HF Spaces use Secrets.
50
- # os.getenv will automatically pick up secrets set in the HF Space settings.
51
  load_dotenv()
52
 
53
- # Logging setup
54
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
55
- logger = logging.getLogger(__name__)
56
-
57
- # API key security
58
  api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
59
 
60
- # --- FastAPI App Initialization ---
 
 
61
 
62
- app = FastAPI(
63
- title="LokiAI API",
64
- description="API Proxy for various AI models with usage tracking and streaming.",
65
- version="1.0.0"
66
- )
67
 
68
- # Middleware
69
- app.add_middleware(GZipMiddleware, minimum_size=1000) # Compress large responses
70
  app.add_middleware(
71
- CORSMiddleware, # Allow cross-origin requests (useful for web playgrounds)
72
- allow_origins=["*"], # Or restrict to specific origins
73
  allow_credentials=True,
74
  allow_methods=["*"],
75
  allow_headers=["*"],
76
  )
77
 
78
- # --- Environment Variables & Model Config ---
79
-
80
- @lru_cache(maxsize=1) # Cache environment variables
81
- def get_env_vars() -> Dict[str, Any]:
82
- """Loads and returns essential environment variables."""
83
- # HF Space Note: Set these as Secrets in your Hugging Face Space settings.
84
  return {
85
- 'api_keys': set(filter(None, os.getenv('API_KEYS', '').split(','))), # Use set for faster lookup
86
  'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'),
87
  'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
88
- 'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'), # Search endpoint
89
- 'secret_api_endpoint_4': os.getenv('SECRET_API_ENDPOINT_4', "https://text.pollinations.ai/openai"), # Pollinations
90
- 'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'), # Claude 3 endpoint
91
- 'mistral_api': os.getenv('MISTRAL_API', "https://api.mistral.ai"),
92
  'mistral_key': os.getenv('MISTRAL_KEY'),
93
- 'new_img_endpoint': os.getenv('NEW_IMG'), # Image generation endpoint
94
- 'hf_space_url': os.getenv('HF_SPACE_URL', 'https://your-space-name.hf.space') # HF Space Note: Set this! Used for Referer/Origin checks.
95
  }
96
 
97
- # Model sets for fast lookups
98
- # HF Space Note: Consider moving these large sets to a separate config file (e.g., config.py or models_config.json)
99
- # for better organization if they grow larger.
100
- mistral_models: Set[str] = {
101
- "mistral-large-latest", "pixtral-large-latest", "mistral-moderation-latest",
102
- "ministral-3b-latest", "ministral-8b-latest", "open-mistral-nemo",
103
- "mistral-small-latest", "mistral-saba-latest", "codestral-latest"
 
 
 
 
104
  }
105
 
106
- pollinations_models: Set[str] = {
107
- "openai", "openai-large", "openai-xlarge", "openai-reasoning", "qwen-coder",
108
- "llama", "mistral", "searchgpt", "deepseek", "claude-hybridspace",
109
- "deepseek-r1", "deepseek-reasoner", "llamalight", "gemini", "gemini-thinking",
110
- "hormoz", "phi", "phi-mini", "openai-audio", "llama-scaleway"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  }
112
 
113
- alternate_models: Set[str] = {
114
- "gpt-4o", "deepseek-v3", "llama-3.1-8b-instruct", "llama-3.1-sonar-small-128k-online",
115
- "deepseek-r1-uncensored", "tinyswallow1.5b", "andy-3.5", "o3-mini-low",
116
- "hermes-3-llama-3.2-3b", "creitin-r1", "fluffy.1-chat", "plutotext-1-text",
117
- "command-a", "claude-3-7-sonnet-20250219", "plutogpt-3.5-turbo"
 
 
 
 
 
 
 
 
 
 
 
118
  }
119
 
120
- claude_3_models: Set[str] = {
121
- "claude-3-7-sonnet", "claude-3-7-sonnet-thinking", "claude 3.5 haiku",
122
- "claude 3.5 sonnet", "claude 3.5 haiku", "o3-mini-medium", "o3-mini-high",
123
- "grok-3", "grok-3-thinking", "grok 2"
 
 
 
 
 
 
 
124
  }
125
 
126
- supported_image_models: Set[str] = {
127
- "Flux Pro Ultra", "grok-2-aurora", "Flux Pro", "Flux Pro Ultra Raw", "Flux Dev",
128
- "Flux Schnell", "stable-diffusion-3-large-turbo", "Flux Realism",
129
- "stable-diffusion-ultra", "dall-e-3", "sdxl-lightning-4step"
 
 
 
 
 
 
 
 
 
130
  }
131
 
132
- # --- Pydantic Models ---
133
-
134
- class Message(BaseModel):
135
- role: str
136
- content: Any # Allow content to be string or potentially list for multimodal models
137
 
 
138
  class Payload(BaseModel):
139
  model: str
140
- messages: List[Message]
141
  stream: bool = False
142
- # Add other potential OpenAI compatible parameters with defaults
143
- max_tokens: Optional[int] = None
144
- temperature: Optional[float] = None
145
- top_p: Optional[float] = None
146
- # ... add others as needed
147
 
 
 
148
  class ImageGenerationPayload(BaseModel):
149
  model: str
150
  prompt: str
151
- size: Optional[str] = "1024x1024" # Default size, make optional if API allows
152
- n: Optional[int] = 1 # Number of images, OpenAI uses 'n'
153
- # HF Space Note: Ensure these parameter names match the target NEW_IMG endpoint API
154
- # Renaming from 'number' to 'n' and 'size' type hint correction.
155
 
156
- # --- Global State & Clients ---
157
 
158
- server_status: bool = True # For maintenance mode
159
- available_model_ids: List[str] = [] # Loaded at startup
160
 
161
- # HF Space Note: Reusable HTTP client with connection pooling is crucial for performance.
162
- # Adjust limits based on expected load and HF Space resources.
 
 
 
163
  @lru_cache(maxsize=1)
164
- def get_async_client() -> httpx.AsyncClient:
165
- """Returns a cached instance of httpx.AsyncClient."""
166
- # HF Space Note: Timeouts are important to prevent hanging requests.
167
- # Keepalive connections reduce handshake overhead.
168
- timeout = httpx.Timeout(30.0, connect=10.0) # 30s total, 10s connect
169
- limits = httpx.Limits(max_keepalive_connections=20, max_connections=100)
170
- return httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True)
171
-
172
- # HF Space Note: cloudscraper pool. Be mindful of potential rate limits or blocks.
173
- # Consider alternatives if this becomes unreliable.
174
- scraper_pool: List[cloudscraper.CloudScraper] = []
175
- MAX_SCRAPERS = 10 # Reduced pool size for potentially lower resource usage
176
-
177
- def get_scraper() -> cloudscraper.CloudScraper:
178
- """Gets a cloudscraper instance from the pool."""
179
  if not scraper_pool:
180
- logger.info(f"Initializing {MAX_SCRAPERS} cloudscraper instances...")
181
  for _ in range(MAX_SCRAPERS):
182
- # HF Space Note: Scraper creation can be slow, doing it upfront is good.
183
  scraper_pool.append(cloudscraper.create_scraper())
184
- logger.info("Cloudscraper pool initialized.")
185
- # Simple round-robin selection
186
- return scraper_pool[int(time.monotonic() * 1000) % MAX_SCRAPERS]
187
 
188
- # --- Security & Authentication ---
189
 
 
190
  async def verify_api_key(
191
  request: Request,
192
- api_key: Optional[str] = Security(api_key_header)
193
  ) -> bool:
194
- """Verifies the provided API key against environment variables."""
195
- env_vars = get_env_vars()
196
- valid_api_keys = env_vars.get('api_keys', set())
197
- hf_space_url = env_vars.get('hf_space_url', '')
198
-
199
- # Allow bypass if the referer is from the known HF Space playground URLs
200
- # HF Space Note: Make HF_SPACE_URL a secret for flexibility.
201
  referer = request.headers.get("referer", "")
202
- if hf_space_url and referer.startswith((f"{hf_space_url}/playground", f"{hf_space_url}/image-playground")):
203
- logger.debug(f"API Key check bypassed for referer: {referer}")
204
  return True
205
-
206
  if not api_key:
207
- logger.warning("API Key missing.")
208
- raise HTTPException(status_code=403, detail="Not authenticated: No API key provided")
209
-
210
- # Clean 'Bearer ' prefix if present
 
 
211
  if api_key.startswith('Bearer '):
212
- api_key = api_key[7:]
213
-
214
- if not valid_api_keys:
215
- logger.error("API keys are not configured on the server (API_KEYS secret missing?).")
216
- raise HTTPException(status_code=500, detail="Server configuration error: API keys not set")
217
-
218
- if api_key not in valid_api_keys:
219
- logger.warning(f"Invalid API key received: {api_key[:4]}...") # Log prefix only
220
- raise HTTPException(status_code=403, detail="Not authenticated: Invalid API key")
221
-
222
- logger.debug("API Key verified successfully.")
 
 
 
 
 
 
223
  return True
224
 
225
- # --- Model & File Loading ---
226
-
227
  @lru_cache(maxsize=1)
228
- def load_models_data() -> List[Dict]:
229
- """Loads model data from models.json."""
230
- # HF Space Note: Ensure models.json is in the root of your HF Space repo.
231
- models_file = Path(__file__).parent / 'models.json'
232
- if not models_file.is_file():
233
- logger.error("models.json not found!")
234
- return []
235
  try:
236
- with open(models_file, 'r') as f:
 
237
  return json.load(f)
238
  except (FileNotFoundError, json.JSONDecodeError) as e:
239
- logger.error(f"Error loading models.json: {e}")
240
  return []
241
 
242
- async def get_models() -> List[Dict]:
243
- """Async wrapper to get models data."""
244
  models_data = load_models_data()
245
  if not models_data:
246
  raise HTTPException(status_code=500, detail="Error loading available models")
247
  return models_data
248
 
249
- # --- Static File Serving ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- # HF Space Note: Cache frequently accessed static files in memory.
 
 
 
 
 
 
252
  @lru_cache(maxsize=10)
253
- def read_static_file(file_path: str) -> Optional[str]:
254
- """Reads a static file, caching the result."""
255
- full_path = Path(__file__).parent / file_path
256
- if not full_path.is_file():
257
- logger.warning(f"Static file not found: {file_path}")
258
- return None
259
  try:
260
- with open(full_path, "r", encoding="utf-8") as file:
261
  return file.read()
262
- except Exception as e:
263
- logger.error(f"Error reading static file {file_path}: {e}")
264
  return None
265
 
266
- async def serve_static_html(file_path: str) -> HTMLResponse:
267
- """Serves a static HTML file."""
268
- content = read_static_file(file_path)
269
- if content is None:
270
- return HTMLResponse(content=f"<h1>Error: {file_path} not found</h1>", status_code=404)
271
- return HTMLResponse(content=content)
272
-
273
- # --- API Endpoints ---
274
-
275
- # Basic Routes & Static Files
276
- @app.get("/favicon.ico", include_in_schema=False)
277
  async def favicon():
278
  favicon_path = Path(__file__).parent / "favicon.ico"
279
- if favicon_path.is_file():
280
- return FileResponse(favicon_path, media_type="image/vnd.microsoft.icon")
281
- raise HTTPException(status_code=404, detail="favicon.ico not found")
282
-
283
- @app.get("/banner.jpg", include_in_schema=False)
284
- async def banner():
285
- banner_path = Path(__file__).parent / "banner.jpg"
286
- if banner_path.is_file():
287
- return FileResponse(banner_path, media_type="image/jpeg") # Assuming JPEG
288
- raise HTTPException(status_code=404, detail="banner.jpg not found")
289
-
290
- @app.get("/ping", tags=["Utility"])
291
- async def ping():
292
- """Simple health check endpoint."""
293
- return {"message": "pong"}
294
 
295
- @app.get("/", response_class=HTMLResponse, tags=["Frontend"])
296
- async def root():
297
- """Serves the main index HTML page."""
298
- return await serve_static_html("index.html")
299
-
300
- @app.get("/script.js", response_class=Response, tags=["Frontend"], include_in_schema=False)
301
- async def script_js():
302
- content = read_static_file("script.js")
303
- if content is None:
304
- return Response(content="/* script.js not found */", status_code=404, media_type="application/javascript")
305
- return Response(content=content, media_type="application/javascript")
306
-
307
- @app.get("/style.css", response_class=Response, tags=["Frontend"], include_in_schema=False)
308
- async def style_css():
309
- content = read_static_file("style.css")
310
- if content is None:
311
- return Response(content="/* style.css not found */", status_code=404, media_type="text/css")
312
- return Response(content=content, media_type="text/css")
313
-
314
- @app.get("/playground", response_class=HTMLResponse, tags=["Frontend"])
315
- async def playground():
316
- """Serves the chat playground HTML page."""
317
- return await serve_static_html("playground.html")
318
 
319
- @app.get("/image-playground", response_class=HTMLResponse, tags=["Frontend"])
320
- async def image_playground():
321
- """Serves the image playground HTML page."""
322
- return await serve_static_html("image-playground.html")
323
 
324
- # Dynamic Page Example
325
- @app.get("/dynamo", response_class=HTMLResponse, tags=["Examples"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  async def dynamic_ai_page(request: Request):
327
- """Generates a dynamic HTML page using an AI model (example)."""
328
- # HF Space Note: This uses a hardcoded URL to *itself* if running in the space.
329
- # Ensure the HF_SPACE_URL secret is set correctly.
330
- env_vars = get_env_vars()
331
- hf_space_url = env_vars.get('hf_space_url', '')
332
- if not hf_space_url:
333
- raise HTTPException(status_code=500, detail="HF_SPACE_URL environment variable not set.")
334
-
335
- user_agent = request.headers.get('user-agent', 'Unknown')
336
- client_ip = request.client.host if request.client else "Unknown"
337
- location = f"IP: {client_ip}" # Basic IP, location requires GeoIP lookup (extra dependency)
338
-
339
  prompt = f"""
340
- Generate a cool, dynamic HTML page for a user with the following details:
341
- - App Name: "LokiAI"
342
  - User-Agent: {user_agent}
343
- - Location Info: {location}
344
- - Style: Cyberpunk aesthetic, minimalist layout, maybe some retro touches.
345
- - Content: Include a heading, a short motivational or witty message, and perhaps a subtle animation. Use inline CSS for styling within a <style> tag.
346
- - Output: Provide ONLY the raw HTML code, starting with <!DOCTYPE html>. Do not wrap it in backticks or add explanations.
 
347
  """
348
-
349
  payload = {
350
- "model": "mistral-small-latest", # Or another capable model
351
- "messages": [{"role": "user", "content": prompt}],
352
- "max_tokens": 1000,
353
- "temperature": 0.7
354
  }
 
355
  headers = {
356
- # HF Space Note: Use the space's own URL and a valid API key if required by your setup.
357
- # Here, we assume the playground key bypass works or use a dedicated internal key.
358
- "Authorization": f"Bearer {list(env_vars['api_keys'])[0] if env_vars['api_keys'] else 'dummy-key'}" # Use first key or dummy
359
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
- try:
362
- # HF Space Note: Use the async client for internal requests too.
363
- client = get_async_client()
364
- api_url = f"{hf_space_url}/chat/completions" # Call own endpoint
365
- response = await client.post(api_url, json=payload, headers=headers)
366
- response.raise_for_status() # Raise exception for bad status codes
367
- data = response.json()
368
-
369
- html_content = data.get('choices', [{}])[0].get('message', {}).get('content', '')
370
-
371
- # Basic cleanup (remove potential markdown backticks if model adds them)
372
- html_content = re.sub(r"^```html\s*", "", html_content, flags=re.IGNORECASE)
373
- html_content = re.sub(r"\s*```$", "", html_content)
374
-
375
- if not html_content.strip().lower().startswith("<!doctype html"):
376
- logger.warning("Dynamo page generation might be incomplete or malformed.")
377
- # Optionally return a fallback static page here
378
-
379
- return HTMLResponse(content=html_content)
380
 
381
- except httpx.HTTPStatusError as e:
382
- logger.error(f"Error calling self API for /dynamo: {e.response.status_code} - {e.response.text}")
383
- raise HTTPException(status_code=502, detail=f"Failed to generate dynamic content: Upstream API error {e.response.status_code}")
384
- except Exception as e:
385
- logger.error(f"Unexpected error in /dynamo: {e}", exc_info=True)
386
- raise HTTPException(status_code=500, detail="Failed to generate dynamic content due to an internal error.")
387
 
388
 
389
- # Vetra Example (Fetching from GitHub)
390
- # HF Space Note: Ensure outbound requests to raw.githubusercontent.com are allowed.
391
  GITHUB_BASE = "https://raw.githubusercontent.com/Parthsadaria/Vetra/main"
392
- VETRA_FILES = {"html": "index.html", "css": "style.css", "js": "script.js"}
393
 
394
- async def get_github_file(filename: str) -> Optional[str]:
395
- """Fetches a file from the Vetra GitHub repo."""
 
 
 
 
 
396
  url = f"{GITHUB_BASE}/{filename}"
397
- try:
398
- client = get_async_client()
399
  res = await client.get(url)
400
- res.raise_for_status()
401
- return res.text
402
- except httpx.RequestError as e:
403
- logger.error(f"Error fetching GitHub file {url}: {e}")
404
- return None
405
- except httpx.HTTPStatusError as e:
406
- logger.error(f"GitHub file {url} returned status {e.response.status_code}")
407
- return None
408
 
409
- @app.get("/vetra", response_class=HTMLResponse, tags=["Examples"])
410
  async def serve_vetra():
411
- """Serves the Vetra application by fetching components from GitHub."""
412
- logger.info("Fetching Vetra files from GitHub...")
413
- # Fetch files concurrently
414
- html_task = asyncio.create_task(get_github_file(VETRA_FILES["html"]))
415
- css_task = asyncio.create_task(get_github_file(VETRA_FILES["css"]))
416
- js_task = asyncio.create_task(get_github_file(VETRA_FILES["js"]))
417
-
418
- html, css, js = await asyncio.gather(html_task, css_task, js_task)
419
 
420
  if not html:
421
- logger.error("Failed to fetch Vetra index.html")
422
- return HTMLResponse(content="<h1>Error: Could not load Vetra application (HTML missing)</h1>", status_code=502)
423
-
424
- # Inject CSS and JS into HTML
425
- css_content = f"<style>{css or '/* CSS failed to load */'}</style>"
426
- js_content = f"<script>{js or '// JS failed to load'}</script>"
427
 
428
- # Inject carefully before closing tags
429
- final_html = html.replace("</head>", f"{css_content}\n</head>", 1)
430
- final_html = final_html.replace("</body>", f"{js_content}\n</body>", 1)
 
 
 
 
431
 
432
- logger.info("Successfully served Vetra application.")
433
  return HTMLResponse(content=final_html)
434
 
435
 
436
- # Model Info Endpoint
437
- @app.get("/api/v1/models", tags=["Models"])
438
- @app.get("/models", tags=["Models"])
439
- async def return_models():
440
- """Returns the list of available models loaded from models.json."""
441
- # HF Space Note: This endpoint now relies on models.json being present.
442
- # It no longer dynamically adds models defined only in the script's sets.
443
- # Ensure models.json is comprehensive or adjust startup logic if needed.
444
- return await get_models()
445
 
446
- # Search Endpoint (using cloudscraper)
447
- # HF Space Note: This uses cloudscraper which might be blocked or require updates.
448
- # Consider replacing with a more stable search API if possible.
449
- async def generate_search_async(query: str, systemprompt: Optional[str] = None) -> asyncio.Queue:
450
- """Performs search using the configured backend and streams results."""
451
- queue = asyncio.Queue()
452
- env_vars = get_env_vars()
453
- search_endpoint = env_vars.get('secret_api_endpoint_3')
454
-
455
- async def _fetch_search_data():
456
- if not search_endpoint:
457
- await queue.put({"error": "Search API endpoint (SECRET_API_ENDPOINT_3) not configured"})
458
- await queue.put(None) # Signal end
459
- return
460
 
461
- try:
462
- scraper = get_scraper() # Get a scraper instance from the pool
463
- loop = asyncio.get_running_loop()
464
 
465
- system_message = systemprompt or "You are a helpful search assistant."
466
- messages = [
467
- {"role": "system", "content": system_message},
468
- {"role": "user", "content": query},
469
- ]
470
- payload = {
471
- "model": "searchgpt", # Assuming the endpoint expects this model name
472
- "messages": messages,
473
- "stream": True # Explicitly request streaming from backend
474
- }
475
- headers = {"User-Agent": "Mozilla/5.0"} # Standard user agent
476
-
477
- # HF Space Note: Run synchronous scraper call in executor thread
478
- response = await loop.run_in_executor(
479
- executor,
480
- scraper.post,
481
- search_endpoint,
482
- json=payload,
483
- headers=headers,
484
- stream=True # Request streaming from requests library perspective
485
- )
486
-
487
- response.raise_for_status()
488
-
489
- # Process SSE stream
490
- # HF Space Note: Iterating lines on the response directly can be blocking if not handled carefully.
491
- # Using iter_lines with decode_unicode=True is generally safe.
492
- for line in response.iter_lines(decode_unicode=True):
493
- if line.startswith("data: "):
494
- try:
495
- data_str = line[6:]
496
- if data_str.strip() == "[DONE]": # Check for OpenAI style completion
497
- break
498
- json_data = json.loads(data_str)
499
- # Assuming OpenAI compatible streaming format
500
- delta = json_data.get("choices", [{}])[0].get("delta", {})
501
- content = delta.get("content")
502
- if content:
503
- # Reconstruct OpenAI-like SSE chunk
504
- chunk = {
505
- "id": json_data.get("id"),
506
- "object": "chat.completion.chunk",
507
- "created": json_data.get("created", int(time.time())),
508
- "model": "searchgpt",
509
- "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}]
510
- }
511
- await queue.put({"data": f"data: {json.dumps(chunk)}\n\n", "text": content})
512
- # Check for finish reason
513
- finish_reason = json_data.get("choices", [{}])[0].get("finish_reason")
514
- if finish_reason:
515
- chunk = {
516
- "id": json_data.get("id"),
517
- "object": "chat.completion.chunk",
518
- "created": json_data.get("created", int(time.time())),
519
- "model": "searchgpt",
520
- "choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}]
521
- }
522
- await queue.put({"data": f"data: {json.dumps(chunk)}\n\n", "text": ""})
523
- break # Stop processing after finish reason
524
-
525
- except json.JSONDecodeError:
526
- logger.warning(f"Failed to decode JSON from search stream: {line}")
527
- continue
528
- except Exception as e:
529
- logger.error(f"Error processing search stream chunk: {e}", exc_info=True)
530
- await queue.put({"error": f"Error processing stream: {e}"})
531
- break # Stop on processing error
532
-
533
- except requests.exceptions.RequestException as e:
534
- logger.error(f"Search request failed: {e}")
535
- await queue.put({"error": f"Search request failed: {e}"})
536
- except Exception as e:
537
- logger.error(f"Unexpected error during search: {e}", exc_info=True)
538
- await queue.put({"error": f"An unexpected error occurred during search: {e}"})
539
- finally:
540
- await queue.put(None) # Signal completion
541
-
542
-
543
- asyncio.create_task(_fetch_search_data())
544
- return queue
545
 
546
- @app.get("/searchgpt", tags=["Search"])
547
- async def search_gpt(q: str, stream: bool = True, systemprompt: Optional[str] = None):
548
- """
549
- Performs a search using the backend search model and streams results.
550
- Pass `stream=false` to get the full response at once.
551
- """
552
  if not q:
553
  raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
554
 
555
- # HF Space Note: Ensure usage_tracker is thread-safe if used across async/sync boundaries.
556
- # The dummy tracker used when the module isn't found is safe.
557
  usage_tracker.record_request(endpoint="/searchgpt")
558
 
559
- queue = await generate_search_async(q, systemprompt=systemprompt)
560
 
561
  if stream:
562
  async def stream_generator():
563
- full_response_text = "" # Keep track for non-streaming case if needed
564
  while True:
565
  item = await queue.get()
566
- if item is None: # End of stream signal
567
  break
 
568
  if "error" in item:
569
- # HF Space Note: Log errors server-side, return generic error to client for security.
570
- logger.error(f"Search stream error: {item['error']}")
571
- # Send an error event in the stream
572
- error_event = {"error": {"message": "Search failed.", "code": 500}}
573
- yield f"data: {json.dumps(error_event)}\n\n"
574
  break
 
575
  if "data" in item:
576
  yield item["data"]
577
- full_response_text += item.get("text", "")
578
- # Optionally yield a [DONE] message if backend doesn't guarantee it
579
- # yield "data: [DONE]\n\n"
580
 
581
  return StreamingResponse(
582
  stream_generator(),
583
- media_type="text/event-stream",
584
- headers={
585
- "Content-Type": "text/event-stream",
586
- "Cache-Control": "no-cache",
587
- "Connection": "keep-alive",
588
- "X-Accel-Buffering": "no" # Crucial for Nginx/proxies in HF Spaces
589
- }
590
  )
591
  else:
592
- # Collect full response for non-streaming request
593
- full_response_text = ""
594
  while True:
595
  item = await queue.get()
596
  if item is None:
597
  break
 
598
  if "error" in item:
599
- logger.error(f"Search non-stream error: {item['error']}")
600
- raise HTTPException(status_code=502, detail=f"Search failed: {item['error']}")
601
- full_response_text += item.get("text", "")
602
-
603
- # Mimic OpenAI non-streaming response structure
604
- return JSONResponse(content={
605
- "id": f"search-{int(time.time())}",
606
- "object": "chat.completion",
607
- "created": int(time.time()),
608
- "model": "searchgpt",
609
- "choices": [{
610
- "index": 0,
611
- "message": {
612
- "role": "assistant",
613
- "content": full_response_text,
614
- },
615
- "finish_reason": "stop",
616
- }],
617
- "usage": { # Note: Token usage is unknown here
618
- "prompt_tokens": None,
619
- "completion_tokens": None,
620
- "total_tokens": None,
621
- }
622
- })
623
 
 
624
 
625
- # Main Chat Completions Proxy
626
- @app.post("/api/v1/chat/completions", tags=["Chat Completions"])
627
- @app.post("/chat/completions", tags=["Chat Completions"])
628
- async def get_completion(
629
- payload: Payload,
630
- request: Request,
631
- authenticated: bool = Depends(verify_api_key) # Apply authentication
632
- ):
633
- """
634
- Proxies chat completion requests to the appropriate backend API based on the model.
635
- Supports streaming (SSE).
636
- """
637
- if not server_status:
638
- raise HTTPException(status_code=503, detail="Server is under maintenance.")
639
 
640
- model_to_use = payload.model or "gpt-4o-mini" # Default model
641
 
642
- # HF Space Note: Check against models loaded at startup.
643
- if available_model_ids and model_to_use not in available_model_ids:
644
- logger.warning(f"Requested model '{model_to_use}' not in available list.")
645
- # Check if it's a known category even if not explicitly in models.json
646
- known_categories = mistral_models | pollinations_models | alternate_models | claude_3_models
647
- if model_to_use not in known_categories:
648
- raise HTTPException(
649
- status_code=400,
650
- detail=f"Model '{model_to_use}' is not available or recognized. Check /models."
651
- )
652
- else:
653
- logger.info(f"Allowing known category model '{model_to_use}' despite not being in models.json.")
654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
656
- # Log request asynchronously
 
 
 
 
 
 
 
657
  asyncio.create_task(log_request(request, model_to_use))
658
  usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
659
 
660
- # Prepare payload for the target API
661
- payload_dict = payload.dict(exclude_none=True) # Exclude None values
662
- payload_dict["model"] = model_to_use # Ensure model is set
663
 
664
- env_vars = get_env_vars()
665
- hf_space_url = env_vars.get('hf_space_url', '') # Needed for Referer/Origin
666
 
667
- # Determine target endpoint and headers
668
- endpoint = None
669
- custom_headers = {}
670
 
 
671
  if model_to_use in mistral_models:
672
- endpoint = env_vars.get('mistral_api')
673
- api_key = env_vars.get('mistral_key')
674
- if not endpoint or not api_key:
675
- raise HTTPException(status_code=500, detail="Mistral API endpoint or key not configured.")
676
- custom_headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "application/json"}
677
- # Mistral specific adjustments if needed
678
- # payload_dict.pop('system', None) # Example: if Mistral doesn't use 'system' role
679
-
680
  elif model_to_use in pollinations_models:
681
- endpoint = env_vars.get('secret_api_endpoint_4')
682
- if not endpoint:
683
- raise HTTPException(status_code=500, detail="Pollinations API endpoint (SECRET_API_ENDPOINT_4) not configured.")
684
- # Pollinations might need specific headers? Add them here.
685
- custom_headers = {"Content-Type": "application/json"}
686
-
687
  elif model_to_use in alternate_models:
688
- endpoint = env_vars.get('secret_api_endpoint_2')
689
- if not endpoint:
690
- raise HTTPException(status_code=500, detail="Alternate API endpoint (SECRET_API_ENDPOINT_2) not configured.")
691
- custom_headers = {"Content-Type": "application/json"}
692
-
693
- elif model_to_use in claude_3_models:
694
- endpoint = env_vars.get('secret_api_endpoint_5')
695
- if not endpoint:
696
- raise HTTPException(status_code=500, detail="Claude 3 API endpoint (SECRET_API_ENDPOINT_5) not configured.")
697
- custom_headers = {"Content-Type": "application/json"}
698
- # Claude specific headers (like anthropic-version) might be needed
699
- # custom_headers["anthropic-version"] = "2023-06-01"
700
-
701
- else: # Default endpoint
702
- endpoint = env_vars.get('secret_api_endpoint')
703
- if not endpoint:
704
- raise HTTPException(status_code=500, detail="Default API endpoint (SECRET_API_ENDPOINT) not configured.")
705
- # Default endpoint might need Origin/Referer
706
- if hf_space_url:
707
- custom_headers = {
708
- "Origin": hf_space_url,
709
- "Referer": hf_space_url,
710
- "Content-Type": "application/json"
711
- }
712
- else:
713
- custom_headers = {"Content-Type": "application/json"}
714
-
715
-
716
- target_url = f"{endpoint.rstrip('/')}/v1/chat/completions" # Assume OpenAI compatible path
717
- logger.info(f"Proxying request for model '{model_to_use}' to endpoint: {endpoint}")
718
 
719
- client = get_async_client()
720
 
721
- async def stream_generator():
722
- """Generator for streaming the response."""
723
- nonlocal target_url # Allow modification if needed
724
  try:
725
- async with client.stream("POST", target_url, json=payload_dict, headers=custom_headers) as response:
726
- # Check for initial errors before streaming
727
- if response.status_code >= 400:
728
- error_body = await response.aread()
729
- logger.error(f"Upstream API error: {response.status_code} - {error_body.decode()}")
730
- # Try to parse error detail from upstream
731
- detail = f"Upstream API error: {response.status_code}"
732
- try:
733
- error_json = json.loads(error_body)
734
- detail = error_json.get('error', {}).get('message', detail)
735
- except json.JSONDecodeError:
736
- pass
737
- # Send error as SSE event
738
- error_event = {"error": {"message": detail, "code": response.status_code}}
739
- yield f"data: {json.dumps(error_event)}\n\n"
740
- return # Stop generation
741
-
742
- # Stream the response line by line
743
- async for line in response.aiter_lines():
744
- if line:
745
- # Pass through the data directly
746
- yield line + "\n"
747
- # Ensure stream is properly closed, yield [DONE] if backend doesn't
748
- # Some backends might not send [DONE], uncomment if needed
749
- # yield "data: [DONE]\n\n"
750
-
751
  except httpx.TimeoutException:
752
- logger.error(f"Request to {target_url} timed out.")
753
- error_event = {"error": {"message": "Request timed out", "code": 504}}
754
- yield f"data: {json.dumps(error_event)}\n\n"
755
  except httpx.RequestError as e:
756
- logger.error(f"Failed to connect to upstream API {target_url}: {e}")
757
- error_event = {"error": {"message": f"Upstream connection error: {e}", "code": 502}}
758
- yield f"data: {json.dumps(error_event)}\n\n"
759
  except Exception as e:
760
- logger.error(f"An unexpected error occurred during streaming proxy: {e}", exc_info=True)
761
- error_event = {"error": {"message": f"Internal server error: {e}", "code": 500}}
762
- yield f"data: {json.dumps(error_event)}\n\n"
763
 
764
- if payload.stream:
 
765
  return StreamingResponse(
766
- stream_generator(),
767
  media_type="text/event-stream",
768
  headers={
769
  "Content-Type": "text/event-stream",
770
  "Cache-Control": "no-cache",
771
  "Connection": "keep-alive",
772
- "X-Accel-Buffering": "no" # Essential for HF Spaces proxying SSE
773
  }
774
  )
775
  else:
776
- # Handle non-streaming request by collecting the streamed chunks
777
- full_response_content = ""
778
- final_json_response = None
779
- async for line in stream_generator():
780
- if line.startswith("data: "):
781
- data_str = line[6:].strip()
782
- if data_str == "[DONE]":
783
- break
784
- try:
785
- chunk = json.loads(data_str)
786
- # Check for error chunk
787
- if "error" in chunk:
788
- logger.error(f"Received error during non-stream collection: {chunk['error']}")
789
- raise HTTPException(status_code=chunk['error'].get('code', 502), detail=chunk['error'].get('message', 'Upstream API error'))
790
-
791
- # Accumulate content from delta
792
- delta = chunk.get("choices", [{}])[0].get("delta", {})
793
- content = delta.get("content")
794
- if content:
795
- full_response_content += content
796
-
797
- # Store the last chunk structure to reconstruct the final response
798
- # We assume the last chunk contains necessary info like id, model, etc.
799
- # but we overwrite the choices/message part.
800
- final_json_response = chunk # Keep the structure
801
- # Check for finish reason
802
- finish_reason = chunk.get("choices", [{}])[0].get("finish_reason")
803
- if finish_reason:
804
- break # Stop collecting
805
-
806
- except json.JSONDecodeError:
807
- logger.warning(f"Could not decode JSON chunk in non-stream mode: {data_str}")
808
- except Exception as e:
809
- logger.error(f"Error processing chunk in non-stream mode: {e}")
810
- raise HTTPException(status_code=500, detail="Error processing response stream.")
811
-
812
- if final_json_response is None:
813
- # Handle cases where no valid data chunks were received
814
- logger.error("No valid response chunks received for non-streaming request.")
815
- raise HTTPException(status_code=502, detail="Failed to get valid response from upstream API.")
816
-
817
-
818
- # Reconstruct OpenAI-like non-streaming response
819
- final_response_obj = {
820
- "id": final_json_response.get("id", f"chatcmpl-{int(time.time())}"),
821
- "object": "chat.completion",
822
- "created": final_json_response.get("created", int(time.time())),
823
- "model": model_to_use, # Use the requested model
824
- "choices": [{
825
- "index": 0,
826
- "message": {
827
- "role": "assistant",
828
- "content": full_response_content,
829
- },
830
- "finish_reason": final_json_response.get("choices", [{}])[0].get("finish_reason", "stop"), # Get finish reason from last chunk
831
- }],
832
- "usage": { # Token usage might be in the last chunk for some APIs, otherwise unknown
833
- "prompt_tokens": None,
834
- "completion_tokens": None,
835
- "total_tokens": None,
836
- }
837
- }
838
- # Attempt to extract usage if present in the (potentially non-standard) final chunk
839
- usage_data = final_json_response.get("usage")
840
- if isinstance(usage_data, dict):
841
- final_response_obj["usage"].update(usage_data)
842
 
 
843
 
844
- return JSONResponse(content=final_response_obj)
845
 
846
 
847
- # Image Generation Endpoint
848
- @app.post("/images/generations", tags=["Image Generation"])
849
- async def create_image(
850
- payload: ImageGenerationPayload,
851
- authenticated: bool = Depends(verify_api_key)
852
- ):
853
  """
854
- Generates images based on a text prompt using the configured backend.
855
  """
 
856
  if not server_status:
857
- raise HTTPException(status_code=503, detail="Server is under maintenance.")
 
 
 
858
 
 
859
  if payload.model not in supported_image_models:
860
  raise HTTPException(
861
  status_code=400,
862
- detail=f"Model '{payload.model}' is not supported for image generation. Supported: {', '.join(supported_image_models)}"
863
  )
864
 
 
865
  usage_tracker.record_request(model=payload.model, endpoint="/images/generations")
866
 
867
- env_vars = get_env_vars()
868
- target_api_url = env_vars.get('new_img_endpoint')
869
- if not target_api_url:
870
- raise HTTPException(status_code=500, detail="Image generation endpoint (NEW_IMG) not configured.")
871
-
872
- # Prepare payload for the target API (adjust keys if needed)
873
- # HF Space Note: Ensure the keys match the actual API expected by NEW_IMG endpoint.
874
- # Assuming it's OpenAI compatible here.
875
  api_payload = {
876
  "model": payload.model,
877
  "prompt": payload.prompt,
878
- "n": payload.n,
879
- "size": payload.size
880
  }
881
- # Remove None values the target API might not like
882
- api_payload = {k: v for k, v in api_payload.items() if v is not None}
883
 
884
-
885
- logger.info(f"Requesting image generation for model '{payload.model}' from {target_api_url}")
886
- client = get_async_client()
887
 
888
  try:
889
- # HF Space Note: Image generation can take time, use a longer timeout if needed.
890
- # Consider making this truly async if the backend supports webhooks or polling.
891
- response = await client.post(target_api_url, json=api_payload, timeout=120.0) # 2 min timeout
892
- response.raise_for_status() # Raise HTTP errors
 
 
 
893
 
894
- # Return the exact response from the backend
895
  return JSONResponse(content=response.json())
896
 
897
  except httpx.TimeoutException:
898
- logger.error(f"Image generation request to {target_api_url} timed out.")
899
  raise HTTPException(status_code=504, detail="Image generation request timed out.")
900
- except httpx.HTTPStatusError as e:
901
- logger.error(f"Image generation API error: {e.response.status_code} - {e.response.text}")
902
- detail = f"Image generation failed: Upstream API error {e.response.status_code}"
903
- try:
904
- err_json = e.response.json()
905
- detail = err_json.get('error', {}).get('message', detail)
906
- except json.JSONDecodeError:
907
- pass
908
- raise HTTPException(status_code=e.response.status_code, detail=detail)
909
  except httpx.RequestError as e:
910
- logger.error(f"Error connecting to image generation service {target_api_url}: {e}")
911
  raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}")
912
  except Exception as e:
913
- logger.error(f"Unexpected error during image generation: {e}", exc_info=True)
914
  raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
915
 
916
 
917
- # --- Utility & Admin Endpoints ---
918
 
919
- async def log_request(request: Request, model: Optional[str] = None):
920
- """Logs basic request information asynchronously."""
921
- # HF Space Note: Avoid logging sensitive info like full IP or headers unless necessary.
922
- # Hashing IP provides some privacy.
923
- client_host = request.client.host if request.client else "unknown"
924
- ip_hash = hash(client_host) % 10000
925
- timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
926
- log_message = f"Timestamp: {timestamp}, IP Hash: {ip_hash}, Method: {request.method}, Path: {request.url.path}"
927
- if model:
928
- log_message += f", Model: {model}"
929
- logger.info(log_message)
930
 
 
 
 
 
931
 
932
- @app.get("/usage", tags=["Admin"])
933
  async def get_usage(days: int = 7):
934
- """Retrieves aggregated usage statistics."""
935
- # HF Space Note: Ensure usage_tracker methods are efficient, especially get_usage_summary.
936
- # Caching might be needed if it becomes slow.
937
- if days <= 0:
938
- raise HTTPException(status_code=400, detail="Number of days must be positive.")
939
- try:
940
- # Run potentially CPU-bound summary generation in executor
941
- loop = asyncio.get_running_loop()
942
- summary = await loop.run_in_executor(executor, usage_tracker.get_usage_summary, days)
943
- return summary
944
- except Exception as e:
945
- logger.error(f"Error retrieving usage statistics: {e}", exc_info=True)
946
- raise HTTPException(status_code=500, detail="Failed to retrieve usage statistics.")
947
-
948
- # HF Space Note: Generating HTML dynamically can be resource-intensive.
949
- # Consider caching the generated HTML or serving a static page updated periodically.
950
- def generate_usage_html(usage_data: Dict) -> str:
951
- """Generates an HTML report from usage data."""
952
- # (Keep the HTML generation logic as provided in the original file)
953
- # ... (rest of the HTML generation code from the original file) ...
954
- # Ensure this function handles potentially missing keys gracefully
955
- models_usage = usage_data.get('models', {})
956
- endpoints_usage = usage_data.get('api_endpoints', {})
957
- daily_usage = usage_data.get('recent_daily_usage', {})
958
- total_requests = usage_data.get('total_requests', 0)
959
 
 
 
 
960
  model_usage_rows = "\n".join([
961
  f"""
962
  <tr>
963
  <td>{model}</td>
964
- <td>{model_data.get('total_requests', 'N/A')}</td>
965
- <td>{model_data.get('first_used', 'N/A')}</td>
966
- <td>{model_data.get('last_used', 'N/A')}</td>
967
  </tr>
968
- """ for model, model_data in models_usage.items()
969
- ]) if models_usage else "<tr><td colspan='4'>No model usage data</td></tr>"
970
 
 
971
  api_usage_rows = "\n".join([
972
  f"""
973
  <tr>
974
  <td>{endpoint}</td>
975
- <td>{endpoint_data.get('total_requests', 'N/A')}</td>
976
- <td>{endpoint_data.get('first_used', 'N/A')}</td>
977
- <td>{endpoint_data.get('last_used', 'N/A')}</td>
978
  </tr>
979
- """ for endpoint, endpoint_data in endpoints_usage.items()
980
- ]) if endpoints_usage else "<tr><td colspan='4'>No API endpoint usage data</td></tr>"
981
 
 
982
  daily_usage_rows = "\n".join([
983
- f"""
984
- <tr>
985
- <td>{date}</td>
986
- <td>{entity}</td>
987
- <td>{requests}</td>
988
- </tr>
989
- """
990
- for date, date_data in daily_usage.items()
991
- for entity, requests in date_data.items()
992
- ]) if daily_usage else "<tr><td colspan='3'>No daily usage data</td></tr>"
993
 
994
-
995
- # HF Space Note: Using f-string for large HTML is okay, but consider template engines (Jinja2)
996
- # for more complex pages. Ensure CSS/JS are either inline or served via separate endpoints.
997
  html_content = f"""
998
  <!DOCTYPE html>
999
  <html lang="en">
1000
  <head>
1001
  <meta charset="UTF-8">
1002
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
1003
  <title>Lokiai AI - Usage Statistics</title>
1004
  <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet">
1005
  <style>
1006
- /* (Keep the CSS styles as provided in the original file) */
1007
  :root {{
1008
- --bg-dark: #0f1011; --bg-darker: #070708; --text-primary: #e6e6e6;
1009
- --text-secondary: #8c8c8c; --border-color: #2c2c2c; --accent-color: #3a6ee0;
 
 
 
 
1010
  --accent-hover: #4a7ef0;
1011
  }}
1012
- body {{ font-family: 'Inter', sans-serif; background-color: var(--bg-dark); color: var(--text-primary); max-width: 1200px; margin: 0 auto; padding: 40px 20px; line-height: 1.6; }}
1013
- .logo {{ display: flex; align-items: center; justify-content: center; margin-bottom: 30px; }}
1014
- .logo h1 {{ font-weight: 600; font-size: 2.5em; color: var(--text-primary); margin-left: 15px; }}
1015
- .logo img {{ width: 60px; height: 60px; border-radius: 10px; }}
1016
- .container {{ background-color: var(--bg-darker); border-radius: 12px; padding: 30px; box-shadow: 0 15px 40px rgba(0,0,0,0.3); border: 1px solid var(--border-color); }}
1017
- h2, h3 {{ color: var(--text-primary); border-bottom: 2px solid var(--border-color); padding-bottom: 10px; font-weight: 500; }}
1018
- .total-requests {{ background-color: var(--accent-color); color: white; text-align: center; padding: 15px; border-radius: 8px; margin-bottom: 30px; font-weight: 600; letter-spacing: -0.5px; }}
1019
- table {{ width: 100%; border-collapse: separate; border-spacing: 0; margin-bottom: 30px; background-color: var(--bg-dark); border-radius: 8px; overflow: hidden; }}
1020
- th, td {{ border: 1px solid var(--border-color); padding: 12px; text-align: left; transition: background-color 0.3s ease; }}
1021
- th {{ background-color: #1e1e1e; color: var(--text-primary); font-weight: 600; text-transform: uppercase; font-size: 0.9em; }}
1022
- tr:nth-child(even) {{ background-color: rgba(255,255,255,0.05); }}
1023
- tr:hover {{ background-color: rgba(62,100,255,0.1); }}
1024
- @media (max-width: 768px) {{ .container {{ padding: 15px; }} table {{ font-size: 0.9em; }} }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1025
  </style>
1026
  </head>
1027
  <body>
1028
  <div class="container">
1029
  <div class="logo">
1030
  <img src="data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjAwIiBoZWlnaHQ9IjIwMCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48cGF0aCBkPSJNMTAwIDM1TDUwIDkwaDEwMHoiIGZpbGw9IiMzYTZlZTAiLz48Y2lyY2xlIGN4PSIxMDAiIGN5PSIxNDAiIHI9IjMwIiBmaWxsPSIjM2E2ZWUwIi8+PC9zdmc+" alt="Lokai AI Logo">
1031
- <h1>Lokiai AI Usage</h1>
1032
  </div>
1033
 
1034
  <div class="total-requests">
1035
- Total API Requests Recorded: {total_requests}
1036
  </div>
1037
 
1038
  <h2>Model Usage</h2>
1039
  <table>
1040
- <thead><tr><th>Model</th><th>Total Requests</th><th>First Used</th><th>Last Used</th></tr></thead>
1041
- <tbody>{model_usage_rows}</tbody>
 
 
 
 
 
1042
  </table>
1043
 
1044
  <h2>API Endpoint Usage</h2>
1045
  <table>
1046
- <thead><tr><th>Endpoint</th><th>Total Requests</th><th>First Used</th><th>Last Used</th></tr></thead>
1047
- <tbody>{api_usage_rows}</tbody>
 
 
 
 
 
1048
  </table>
1049
 
1050
- <h2>Daily Usage (Last {usage_data.get('days_analyzed', 7)} Days)</h2>
1051
  <table>
1052
- <thead><tr><th>Date</th><th>Entity (Model/Endpoint)</th><th>Requests</th></tr></thead>
1053
- <tbody>{daily_usage_rows}</tbody>
 
 
 
 
1054
  </table>
1055
  </div>
1056
  </body>
@@ -1058,199 +888,166 @@ def generate_usage_html(usage_data: Dict) -> str:
1058
  """
1059
  return html_content
1060
 
1061
- # HF Space Note: Caching the generated HTML page can save resources.
1062
- # Invalidate cache periodically or when usage data changes significantly.
1063
- usage_html_cache = {"content": None, "timestamp": 0}
1064
- CACHE_DURATION = 300 # Cache usage page for 5 minutes
 
1065
 
1066
- @app.get("/usage/page", response_class=HTMLResponse, tags=["Admin"])
1067
  async def usage_page():
1068
- """Serves an HTML page showing usage statistics."""
1069
- now = time.monotonic()
1070
- if usage_html_cache["content"] and (now - usage_html_cache["timestamp"] < CACHE_DURATION):
1071
- logger.info("Serving cached usage page.")
1072
- return HTMLResponse(content=usage_html_cache["content"])
1073
-
1074
- logger.info("Generating fresh usage page.")
1075
- try:
1076
- # Run potentially slow parts in executor
1077
- loop = asyncio.get_running_loop()
1078
- usage_data = await loop.run_in_executor(executor, usage_tracker.get_usage_summary, 7) # Get data for 7 days
1079
- html_content = await loop.run_in_executor(executor, generate_usage_html, usage_data)
1080
-
1081
- # Update cache
1082
- usage_html_cache["content"] = html_content
1083
- usage_html_cache["timestamp"] = now
1084
-
1085
- return HTMLResponse(content=html_content)
1086
- except Exception as e:
1087
- logger.error(f"Failed to generate usage page: {e}", exc_info=True)
1088
- # Serve stale cache if available, otherwise error
1089
- if usage_html_cache["content"]:
1090
- logger.warning("Serving stale usage page due to generation error.")
1091
- return HTMLResponse(content=usage_html_cache["content"])
1092
- else:
1093
- raise HTTPException(status_code=500, detail="Failed to generate usage statistics page.")
1094
-
1095
 
1096
- # Meme Endpoint
1097
- @app.get("/meme", tags=["Fun"])
1098
  async def get_meme():
1099
- """Fetches a random meme and streams the image."""
1100
- # HF Space Note: Ensure meme-api.com is accessible from the HF Space network.
1101
- client = get_async_client()
1102
- meme_api_url = "https://meme-api.com/gimme"
1103
  try:
1104
- logger.info("Fetching meme info...")
1105
- response = await client.get(meme_api_url)
1106
- response.raise_for_status()
1107
  response_data = response.json()
1108
 
1109
  meme_url = response_data.get("url")
1110
- if not meme_url or not isinstance(meme_url, str):
1111
- logger.error(f"Invalid meme URL received from API: {meme_url}")
1112
- raise HTTPException(status_code=502, detail="Failed to get valid meme URL from API.")
1113
-
1114
- logger.info(f"Fetching meme image: {meme_url}")
1115
- # Use streaming request for the image itself
1116
- async with client.stream("GET", meme_url) as image_response:
1117
- image_response.raise_for_status() # Check if image URL is valid
1118
-
1119
- # Get content type, default to image/png
1120
- media_type = image_response.headers.get("content-type", "image/png")
1121
- if not media_type.startswith("image/"):
1122
- logger.warning(f"Unexpected content type '{media_type}' for meme URL: {meme_url}")
1123
- # You might want to reject non-image types
1124
- # raise HTTPException(status_code=502, detail="Meme URL did not return an image.")
1125
-
1126
-
1127
- # Stream the image content directly
1128
- return StreamingResponse(
1129
- image_response.aiter_bytes(),
1130
- media_type=media_type,
1131
- headers={'Cache-Control': 'no-cache'} # Don't cache the meme itself heavily
1132
- )
1133
-
1134
- except httpx.HTTPStatusError as e:
1135
- logger.error(f"HTTP error fetching meme ({e.request.url}): {e.response.status_code}")
1136
- raise HTTPException(status_code=502, detail=f"Failed to fetch meme (HTTP {e.response.status_code})")
1137
- except httpx.RequestError as e:
1138
- logger.error(f"Network error fetching meme ({e.request.url}): {e}")
1139
- raise HTTPException(status_code=502, detail="Failed to fetch meme (Network Error)")
1140
- except Exception as e:
1141
- logger.error(f"Unexpected error fetching meme: {e}", exc_info=True)
1142
- raise HTTPException(status_code=500, detail="Failed to retrieve meme due to an internal error.")
1143
 
 
1144
 
1145
- # Health Check Endpoint
1146
- @app.get("/health", tags=["Utility"])
1147
- async def health_check():
1148
- """Provides a health check status, including missing critical configurations."""
1149
- env_vars = get_env_vars()
1150
- missing_critical_vars = []
 
1151
 
1152
- # Define critical vars needed for core functionality
1153
- critical_vars = [
1154
- 'api_keys', 'secret_api_endpoint', 'secret_api_endpoint_2',
1155
- 'secret_api_endpoint_3', 'secret_api_endpoint_4', 'secret_api_endpoint_5',
1156
- 'new_img_endpoint', 'hf_space_url'
1157
- ]
1158
- # Conditionally critical vars
1159
- if any(model in mistral_models for model in available_model_ids):
1160
- critical_vars.extend(['mistral_api', 'mistral_key'])
1161
-
1162
- for var_name in critical_vars:
1163
- value = env_vars.get(var_name)
1164
- # Check for None or empty strings/lists/sets
1165
- if value is None or (isinstance(value, (str, list, set)) and not value):
1166
- missing_critical_vars.append(var_name)
1167
-
1168
- is_healthy = not missing_critical_vars and server_status
1169
- status_code = 200 if is_healthy else 503 # Service Unavailable if unhealthy
1170
 
1171
- health_status = {
1172
- "status": "healthy" if is_healthy else "unhealthy",
1173
- "server_mode": "online" if server_status else "maintenance",
1174
- "missing_critical_env_vars": missing_critical_vars,
1175
- "details": "All critical configurations seem okay. Ready to roll! 🚀" if is_healthy else "Service issues detected. Check missing env vars or server status. 🛠️"
1176
- }
1177
- return JSONResponse(content=health_status, status_code=status_code)
1178
 
 
 
 
 
 
 
 
1179
 
1180
- # --- Startup and Shutdown Events ---
 
 
 
 
 
 
 
 
 
1181
 
1182
  @app.on_event("startup")
1183
  async def startup_event():
1184
- """Tasks to run when the application starts."""
1185
  global available_model_ids
1186
- logger.info("Application startup sequence initiated...")
1187
-
1188
- # Load models from JSON
1189
- models_from_file = load_models_data()
1190
- model_ids_from_file = {model['id'] for model in models_from_file if 'id' in model}
1191
-
1192
- # Combine models from file and predefined sets
1193
- predefined_model_sets = mistral_models | pollinations_models | alternate_models | claude_3_models
1194
- all_model_ids = model_ids_from_file.union(predefined_model_sets)
1195
- available_model_ids = sorted(list(all_model_ids)) # Keep as sorted list
1196
-
1197
- logger.info(f"Loaded {len(model_ids_from_file)} models from models.json.")
1198
- logger.info(f"Total {len(available_model_ids)} unique models available.")
1199
-
1200
- # Initialize scraper pool (can take time)
1201
- # HF Space Note: Run potentially blocking I/O in executor during startup
1202
- loop = asyncio.get_running_loop()
1203
- await loop.run_in_executor(executor, get_scraper) # This initializes the pool
1204
-
1205
- # Validate critical environment variables and log warnings
1206
  env_vars = get_env_vars()
1207
- logger.info("Checking critical environment variables (Secrets)...")
1208
- await health_check() # Run health check logic to log warnings
1209
-
1210
- # Pre-connect async client? Optional, httpx handles connections on demand.
1211
- # client = get_async_client()
1212
- # await client.get("https://www.google.com") # Example warm-up call
1213
-
1214
- logger.info("Startup complete. Server is ready to accept requests.")
1215
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1216
 
1217
  @app.on_event("shutdown")
1218
  async def shutdown_event():
1219
- """Tasks to run when the application shuts down."""
1220
- logger.info("Application shutdown sequence initiated...")
1221
-
1222
- # Close the httpx client gracefully
1223
  client = get_async_client()
1224
  await client.aclose()
1225
- logger.info("HTTP client closed.")
1226
-
1227
- # Shutdown the thread pool executor
1228
- executor.shutdown(wait=True)
1229
- logger.info("Thread pool executor shut down.")
1230
 
1231
- # Clear scraper pool (optional, resources will be reclaimed anyway)
1232
  scraper_pool.clear()
1233
- logger.info("Scraper pool cleared.")
1234
 
1235
  # Persist usage data
1236
- # HF Space Note: Ensure file system is writable if saving locally.
1237
- # Consider using HF Datasets or external DB for persistent storage.
1238
- try:
1239
- logger.info("Saving usage data...")
1240
- usage_tracker.save_data()
1241
- logger.info("Usage data saved.")
1242
- except Exception as e:
1243
- logger.error(f"Failed to save usage data during shutdown: {e}")
1244
 
1245
- logger.info("Shutdown complete.")
1246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1247
 
1248
- # --- Main Execution Block ---
1249
- # HF Space Note: This block is mainly for local testing.
1250
- # HF Spaces usually run the app using `uvicorn main:app --host 0.0.0.0 --port 7860` (or similar)
1251
- # defined in the README metadata or a Procfile.
1252
  if __name__ == "__main__":
1253
  import uvicorn
1254
- logger.info("Starting server locally with uvicorn...")
1255
- # HF Space Note: Port 7860 is the default for HF Spaces. Host 0.0.0.0 is required.
1256
- uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
 
1
  import os
2
  import re
 
 
 
 
 
 
 
 
 
 
 
3
  from dotenv import load_dotenv
4
+ from fastapi import FastAPI, HTTPException, Request, Depends, Security
5
  from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
6
  from fastapi.security import APIKeyHeader
7
  from pydantic import BaseModel
8
  import httpx
9
+ from functools import lru_cache
10
+ from pathlib import Path
11
+ import json
12
+ import datetime
13
+ import time
14
+ import threading
15
+ from typing import Optional, Dict, List, Any, Generator
16
+ import asyncio
17
+ from starlette.status import HTTP_403_FORBIDDEN
18
+ import cloudscraper
19
+ from concurrent.futures import ThreadPoolExecutor
20
+ import uvloop
21
  from fastapi.middleware.gzip import GZipMiddleware
22
  from starlette.middleware.cors import CORSMiddleware
23
+ import contextlib
24
+ import requests
25
+ # Enable uvloop for faster event loop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
27
 
28
+ # Thread pool for CPU-bound operations
29
+ executor = ThreadPoolExecutor(max_workers=16) # Increased thread count for better parallelism
 
30
 
31
+ # Load environment variables once at startup
 
32
  load_dotenv()
33
 
34
+ # API key security scheme
 
 
 
 
35
  api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
36
 
37
+ # Initialize usage tracker
38
+ from usage_tracker import UsageTracker
39
+ usage_tracker = UsageTracker()
40
 
41
+ app = FastAPI()
 
 
 
 
42
 
43
+ # Add middleware for compression and CORS
44
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
45
  app.add_middleware(
46
+ CORSMiddleware,
47
+ allow_origins=["*"],
48
  allow_credentials=True,
49
  allow_methods=["*"],
50
  allow_headers=["*"],
51
  )
52
 
53
+ # Environment variables (cached)
54
+ @lru_cache(maxsize=1)
55
+ def get_env_vars():
 
 
 
56
  return {
57
+ 'api_keys': os.getenv('API_KEYS', '').split(','),
58
  'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'),
59
  'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
60
+ 'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'),
61
+ 'secret_api_endpoint_4': "https://text.pollinations.ai/openai",
62
+ 'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'), # Added new endpoint
63
+ 'mistral_api': "https://api.mistral.ai",
64
  'mistral_key': os.getenv('MISTRAL_KEY'),
65
+ 'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
 
66
  }
67
 
68
+ # Configuration for models - use sets for faster lookups
69
+ mistral_models = {
70
+ "mistral-large-latest",
71
+ "pixtral-large-latest",
72
+ "mistral-moderation-latest",
73
+ "ministral-3b-latest",
74
+ "ministral-8b-latest",
75
+ "open-mistral-nemo",
76
+ "mistral-small-latest",
77
+ "mistral-saba-latest",
78
+ "codestral-latest"
79
  }
80
 
81
+ pollinations_models = {
82
+ "openai",
83
+ "openai-large",
84
+ "openai-xlarge",
85
+ "openai-reasoning",
86
+ "qwen-coder",
87
+ "llama",
88
+ "mistral",
89
+ "searchgpt",
90
+ "deepseek",
91
+ "claude-hybridspace",
92
+ "deepseek-r1",
93
+ "deepseek-reasoner",
94
+ "llamalight",
95
+ "gemini",
96
+ "gemini-thinking",
97
+ "hormoz",
98
+ "phi",
99
+ "phi-mini",
100
+ "openai-audio",
101
+ "llama-scaleway"
102
  }
103
 
104
+ alternate_models = { # heh, should work now
105
+ "gpt-4o",
106
+ "deepseek-v3",
107
+ "llama-3.1-8b-instruct",
108
+ "llama-3.1-sonar-small-128k-online",
109
+ "deepseek-r1-uncensored",
110
+ "tinyswallow1.5b",
111
+ "andy-3.5",
112
+ "o3-mini-low",
113
+ "hermes-3-llama-3.2-3b",
114
+ "creitin-r1",
115
+ "fluffy.1-chat",
116
+ "plutotext-1-text",
117
+ "command-a",
118
+ "claude-3-7-sonnet-20250219",
119
+ "plutogpt-3.5-turbo"
120
  }
121
 
122
+ claude_3_models = { # Models for the new endpoint
123
+ "claude-3-7-sonnet",
124
+ "claude-3-7-sonnet-thinking",
125
+ "claude 3.5 haiku",
126
+ "claude 3.5 sonnet",
127
+ "claude 3.5 haiku",
128
+ "o3-mini-medium",
129
+ "o3-mini-high",
130
+ "grok-3",
131
+ "grok-3-thinking",
132
+ "grok 2"
133
  }
134
 
135
+ # Supported image generation models
136
+ supported_image_models = {
137
+ "Flux Pro Ultra",
138
+ "grok-2-aurora",
139
+ "Flux Pro",
140
+ "Flux Pro Ultra Raw",
141
+ "Flux Dev",
142
+ "Flux Schnell",
143
+ "stable-diffusion-3-large-turbo",
144
+ "Flux Realism",
145
+ "stable-diffusion-ultra",
146
+ "dall-e-3",
147
+ "sdxl-lightning-4step"
148
  }
149
 
 
 
 
 
 
150
 
151
+ # Request payload model
152
  class Payload(BaseModel):
153
  model: str
154
+ messages: list
155
  stream: bool = False
 
 
 
 
 
156
 
157
+
158
+ # Image generation payload model
159
  class ImageGenerationPayload(BaseModel):
160
  model: str
161
  prompt: str
162
+ size: int
163
+ number: int
 
 
164
 
 
165
 
 
 
166
 
167
+ # Server status global variable
168
+ server_status = True
169
+ available_model_ids: List[str] = []
170
+
171
+ # Create a reusable httpx client pool with connection pooling
172
  @lru_cache(maxsize=1)
173
+ def get_async_client():
174
+ return httpx.AsyncClient(
175
+ timeout=60.0,
176
+ limits=httpx.Limits(max_keepalive_connections=50, max_connections=200) # Increased limits
177
+ )
178
+
179
+ # Create a cloudscraper pool
180
+ scraper_pool = []
181
+ MAX_SCRAPERS = 20 # Increased pool size
182
+
183
+
184
+ def get_scraper():
 
 
 
185
  if not scraper_pool:
 
186
  for _ in range(MAX_SCRAPERS):
 
187
  scraper_pool.append(cloudscraper.create_scraper())
 
 
 
188
 
189
+ return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS] # Simple round-robin
190
 
191
+ # API key validation - optimized to avoid string operations when possible
192
  async def verify_api_key(
193
  request: Request,
194
+ api_key: str = Security(api_key_header)
195
  ) -> bool:
196
+ # Allow bypass if the referer is from /playground or /image-playground
 
 
 
 
 
 
197
  referer = request.headers.get("referer", "")
198
+ if referer.startswith(("https://parthsadaria-lokiai.hf.space/playground",
199
+ "https://parthsadaria-lokiai.hf.space/image-playground")):
200
  return True
201
+
202
  if not api_key:
203
+ raise HTTPException(
204
+ status_code=HTTP_403_FORBIDDEN,
205
+ detail="No API key provided"
206
+ )
207
+
208
+ # Only clean if needed
209
  if api_key.startswith('Bearer '):
210
+ api_key = api_key[7:] # Remove 'Bearer ' prefix
211
+
212
+ # Get API keys from environment
213
+ valid_api_keys = get_env_vars().get('api_keys', [])
214
+ if not valid_api_keys or valid_api_keys == ['']:
215
+ raise HTTPException(
216
+ status_code=HTTP_403_FORBIDDEN,
217
+ detail="API keys not configured on server"
218
+ )
219
+
220
+ # Fast check with set operation
221
+ if api_key not in set(valid_api_keys):
222
+ raise HTTPException(
223
+ status_code=HTTP_403_FORBIDDEN,
224
+ detail="Invalid API key"
225
+ )
226
+
227
  return True
228
 
229
+ # Pre-load and cache models.json
 
230
  @lru_cache(maxsize=1)
231
+ def load_models_data():
 
 
 
 
 
 
232
  try:
233
+ file_path = Path(__file__).parent / 'models.json'
234
+ with open(file_path, 'r') as f:
235
  return json.load(f)
236
  except (FileNotFoundError, json.JSONDecodeError) as e:
237
+ print(f"Error loading models.json: {str(e)}")
238
  return []
239
 
240
+ # Async wrapper for models data
241
+ async def get_models():
242
  models_data = load_models_data()
243
  if not models_data:
244
  raise HTTPException(status_code=500, detail="Error loading available models")
245
  return models_data
246
 
247
+ # Enhanced async streaming - now with real-time SSE support
248
+ async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
249
+ # Create a streaming response channel using asyncio.Queue
250
+ queue = asyncio.Queue()
251
+
252
+ async def _fetch_search_data():
253
+ try:
254
+ headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
255
+
256
+ # Use the provided system prompt, or default to "Be Helpful and Friendly"
257
+ system_message = systemprompt or "Be Helpful and Friendly"
258
+
259
+ # Create the prompt history
260
+ prompt = [
261
+ {"role": "user", "content": query},
262
+ ]
263
+
264
+ prompt.insert(0, {"content": system_message, "role": "system"})
265
+
266
+ # Prepare the payload for the API request
267
+ payload = {
268
+ "is_vscode_extension": True,
269
+ "message_history": prompt,
270
+ "requested_model": "searchgpt",
271
+ "user_input": prompt[-1]["content"],
272
+ }
273
+
274
+ # Get endpoint from environment
275
+ secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
276
+ if not secret_api_endpoint_3:
277
+ await queue.put({"error": "Search API endpoint not configured"})
278
+ return
279
+
280
+ # Use AsyncClient for better performance
281
+ async with httpx.AsyncClient(timeout=30.0) as client:
282
+ async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response:
283
+ if response.status_code != 200:
284
+ await queue.put({"error": f"Search API returned status code {response.status_code}"})
285
+ return
286
+
287
+ # Process the streaming response in real-time
288
+ buffer = ""
289
+ async for line in response.aiter_lines():
290
+ if line.startswith("data: "):
291
+ try:
292
+ json_data = json.loads(line[6:])
293
+ content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
294
+
295
+ if content.strip():
296
+ cleaned_response = {
297
+ "created": json_data.get("created"),
298
+ "id": json_data.get("id"),
299
+ "model": "searchgpt",
300
+ "object": "chat.completion",
301
+ "choices": [
302
+ {
303
+ "message": {
304
+ "content": content
305
+ }
306
+ }
307
+ ]
308
+ }
309
+
310
+ # Send to queue immediately for streaming
311
+ await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content})
312
+ except json.JSONDecodeError:
313
+ continue
314
+
315
+ # Signal completion
316
+ await queue.put(None)
317
+
318
+ except Exception as e:
319
+ await queue.put({"error": str(e)})
320
+ await queue.put(None)
321
 
322
+ # Start the fetch process
323
+ asyncio.create_task(_fetch_search_data())
324
+
325
+ # Return the queue for consumption
326
+ return queue
327
+
328
+ # Cache for frequently accessed static files
329
  @lru_cache(maxsize=10)
330
+ def read_html_file(file_path):
 
 
 
 
 
331
  try:
332
+ with open(file_path, "r") as file:
333
  return file.read()
334
+ except FileNotFoundError:
 
335
  return None
336
 
337
+ # Basic routes
338
+ @app.get("/favicon.ico")
 
 
 
 
 
 
 
 
 
339
  async def favicon():
340
  favicon_path = Path(__file__).parent / "favicon.ico"
341
+ return FileResponse(favicon_path, media_type="image/x-icon")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
+ @app.get("/banner.jpg")
344
+ async def favicon():
345
+ favicon_path = Path(__file__).parent / "banner.jpg"
346
+ return FileResponse(favicon_path, media_type="image/x-icon")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
+ @app.get("/ping")
349
+ async def ping():
350
+ return {"message": "pong", "response_time": "0.000000 seconds"}
 
351
 
352
+ @app.get("/", response_class=HTMLResponse)
353
+ async def root():
354
+ html_content = read_html_file("index.html")
355
+ if html_content is None:
356
+ return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
357
+ return HTMLResponse(content=html_content)
358
+ @app.get("/script.js", response_class=HTMLResponse)
359
+ async def root():
360
+ html_content = read_html_file("script.js")
361
+ if html_content is None:
362
+ return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
363
+ return HTMLResponse(content=html_content)
364
+ @app.get("/style.css", response_class=HTMLResponse)
365
+ async def root():
366
+ html_content = read_html_file("style.css")
367
+ if html_content is None:
368
+ return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
369
+ return HTMLResponse(content=html_content)
370
+ @app.get("/dynamo", response_class=HTMLResponse)
371
  async def dynamic_ai_page(request: Request):
372
+ user_agent = request.headers.get('user-agent', 'Unknown User')
373
+ client_ip = request.client.host
374
+ location = f"IP: {client_ip}"
375
+
 
 
 
 
 
 
 
 
376
  prompt = f"""
377
+ Generate a dynamic HTML page for a user with the following details: with name "LOKI.AI"
 
378
  - User-Agent: {user_agent}
379
+ - Location: {location}
380
+ - Style: Cyberpunk, minimalist, or retro
381
+
382
+ Make sure the HTML is clean and includes a heading, also have cool animations a motivational message, and a cool background.
383
+ Wrap the generated HTML in triple backticks (```).
384
  """
385
+
386
  payload = {
387
+ "model": "mistral-small-latest",
388
+ "messages": [{"role": "user", "content": prompt}]
 
 
389
  }
390
+
391
  headers = {
392
+ "Authorization": "Bearer playground"
 
 
393
  }
394
+
395
+ response = requests.post("https://parthsadaria-lokiai.hf.space/chat/completions", json=payload, headers=headers)
396
+ data = response.json()
397
+
398
+ # Extract HTML from ``` blocks
399
+ html_content = re.search(r"```(.*?)```", data['choices'][0]['message']['content'], re.DOTALL)
400
+ if html_content:
401
+ html_content = html_content.group(1).strip()
402
+
403
+ # Remove the first word
404
+ if html_content:
405
+ html_content = ' '.join(html_content.split(' ')[1:])
406
+
407
+ return HTMLResponse(content=html_content)
408
+
409
+ @app.get("/playground", response_class=HTMLResponse)
410
+ async def playground():
411
+ html_content = read_html_file("playground.html")
412
+ if html_content is None:
413
+ return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
414
+ return HTMLResponse(content=html_content)
415
+
416
+ @app.get("/image-playground", response_class=HTMLResponse)
417
+ async def playground():
418
+ html_content = read_html_file("image-playground.html")
419
+ if html_content is None:
420
+ return HTMLResponse(content="<h1>image-playground.html not found</h1>", status_code=404)
421
+ return HTMLResponse(content=html_content)
422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
 
 
 
 
 
 
424
 
425
 
426
+ # VETRA
 
427
  GITHUB_BASE = "https://raw.githubusercontent.com/Parthsadaria/Vetra/main"
 
428
 
429
+ FILES = {
430
+ "html": "index.html",
431
+ "css": "style.css",
432
+ "js": "script.js"
433
+ }
434
+
435
+ async def get_github_file(filename: str) -> str:
436
  url = f"{GITHUB_BASE}/{filename}"
437
+ async with httpx.AsyncClient() as client:
 
438
  res = await client.get(url)
439
+ return res.text if res.status_code == 200 else None
 
 
 
 
 
 
 
440
 
441
+ @app.get("/vetra", response_class=HTMLResponse)
442
  async def serve_vetra():
443
+ html = await get_github_file(FILES["html"])
444
+ css = await get_github_file(FILES["css"])
445
+ js = await get_github_file(FILES["js"])
 
 
 
 
 
446
 
447
  if not html:
448
+ return HTMLResponse(content="<h1>index.html not found on GitHub</h1>", status_code=404)
 
 
 
 
 
449
 
450
+ final_html = html.replace(
451
+ "</head>",
452
+ f"<style>{css or '/* CSS not found */'}</style></head>"
453
+ ).replace(
454
+ "</body>",
455
+ f"<script>{js or '// JS not found'}</script></body>"
456
+ )
457
 
 
458
  return HTMLResponse(content=final_html)
459
 
460
 
 
 
 
 
 
 
 
 
 
461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
 
 
 
 
463
 
464
+ # Model routes
465
+ @app.get("/api/v1/models")
466
+ @app.get("/models")
467
+ async def return_models():
468
+ return await get_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
+ # Search routes with enhanced real-time streaming
471
+ @app.get("/searchgpt")
472
+ async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
 
 
 
473
  if not q:
474
  raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
475
 
 
 
476
  usage_tracker.record_request(endpoint="/searchgpt")
477
 
478
+ queue = await generate_search_async(q, systemprompt=systemprompt, stream=True)
479
 
480
  if stream:
481
  async def stream_generator():
482
+ collected_text = ""
483
  while True:
484
  item = await queue.get()
485
+ if item is None:
486
  break
487
+
488
  if "error" in item:
489
+ yield f"data: {json.dumps({'error': item['error']})}\n\n"
 
 
 
 
490
  break
491
+
492
  if "data" in item:
493
  yield item["data"]
494
+ collected_text += item.get("text", "")
 
 
495
 
496
  return StreamingResponse(
497
  stream_generator(),
498
+ media_type="text/event-stream"
 
 
 
 
 
 
499
  )
500
  else:
501
+ # For non-streaming, collect all text and return at once
502
+ collected_text = ""
503
  while True:
504
  item = await queue.get()
505
  if item is None:
506
  break
507
+
508
  if "error" in item:
509
+ raise HTTPException(status_code=500, detail=item["error"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
 
511
+ collected_text += item.get("text", "")
512
 
513
+ return JSONResponse(content={"response": collected_text})
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
 
515
 
 
 
 
 
 
 
 
 
 
 
 
 
516
 
517
+ # Enhanced streaming with direct SSE pass-through for real-time responses
518
+ header_url = os.getenv('HEADER_URL')
519
+ @app.post("/chat/completions")
520
+ @app.post("/api/v1/chat/completions")
521
+ async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
522
+ # Check server status
523
+ if not server_status:
524
+ return JSONResponse(
525
+ status_code=503,
526
+ content={"message": "Server is under maintenance. Please try again later."}
527
+ )
528
+
529
+ model_to_use = payload.model or "gpt-4o-mini"
530
 
531
+ # Validate model availability - fast lookup with set
532
+ if available_model_ids and model_to_use not in set(available_model_ids):
533
+ raise HTTPException(
534
+ status_code=400,
535
+ detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
536
+ )
537
+
538
+ # Log request without blocking
539
  asyncio.create_task(log_request(request, model_to_use))
540
  usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
541
 
542
+ # Prepare payload
543
+ payload_dict = payload.dict()
544
+ payload_dict["model"] = model_to_use
545
 
546
+ # Ensure stream is True for real-time streaming (can be overridden by client)
547
+ stream_enabled = payload_dict.get("stream", True)
548
 
549
+ # Get environment variables
550
+ env_vars = get_env_vars()
 
551
 
552
+ # Select the appropriate endpoint (fast lookup with sets)
553
  if model_to_use in mistral_models:
554
+ endpoint = env_vars['mistral_api']
555
+ custom_headers = {
556
+ "Authorization": f"Bearer {env_vars['mistral_key']}"
557
+ }
 
 
 
 
558
  elif model_to_use in pollinations_models:
559
+ endpoint = env_vars['secret_api_endpoint_4']
560
+ custom_headers = {}
 
 
 
 
561
  elif model_to_use in alternate_models:
562
+ endpoint = env_vars['secret_api_endpoint_2']
563
+ custom_headers = {}
564
+ elif model_to_use in claude_3_models: # Use the new endpoint
565
+ endpoint = env_vars['secret_api_endpoint_5']
566
+ custom_headers = {}
567
+ else:
568
+ endpoint = env_vars['secret_api_endpoint']
569
+ custom_headers = {
570
+ "Origin": header_url,
571
+ "Priority": "u=1, i",
572
+ "Referer": header_url
573
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
+ print(f"Using endpoint: {endpoint} for model: {model_to_use}")
576
 
577
+ # Improved real-time streaming handler
578
+ async def real_time_stream_generator():
 
579
  try:
580
+ async with httpx.AsyncClient(timeout=60.0) as client:
581
+ async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, headers=custom_headers) as response:
582
+ if response.status_code >= 400:
583
+ error_messages = {
584
+ 422: "Unprocessable entity. Check your payload.",
585
+ 400: "Bad request. Verify input data.",
586
+ 403: "Forbidden. You do not have access to this resource.",
587
+ 404: "The requested resource was not found.",
588
+ }
589
+ detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
590
+ raise HTTPException(status_code=response.status_code, detail=detail)
591
+
592
+ # Stream the response in real-time with minimal buffering
593
+ async for line in response.aiter_lines():
594
+ if line:
595
+ # Yield immediately for faster streaming
596
+ yield line + "\n"
 
 
 
 
 
 
 
 
 
597
  except httpx.TimeoutException:
598
+ raise HTTPException(status_code=504, detail="Request timed out")
 
 
599
  except httpx.RequestError as e:
600
+ raise HTTPException(status_code=502, detail=f"Failed to connect to upstream API: {str(e)}")
 
 
601
  except Exception as e:
602
+ if isinstance(e, HTTPException):
603
+ raise e
604
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
605
 
606
+ # Return streaming response with proper headers
607
+ if stream_enabled:
608
  return StreamingResponse(
609
+ real_time_stream_generator(),
610
  media_type="text/event-stream",
611
  headers={
612
  "Content-Type": "text/event-stream",
613
  "Cache-Control": "no-cache",
614
  "Connection": "keep-alive",
615
+ "X-Accel-Buffering": "no" # Disable proxy buffering for Nginx
616
  }
617
  )
618
  else:
619
+ # For non-streaming requests, collect the entire response
620
+ response_content = []
621
+ async for chunk in real_time_stream_generator():
622
+ response_content.append(chunk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
 
624
+ return JSONResponse(content=json.loads(''.join(response_content)))
625
 
 
626
 
627
 
628
+ # New image generation endpoint
629
+ @app.post("/images/generations")
630
+ async def create_image(payload: ImageGenerationPayload, authenticated: bool = Depends(verify_api_key)):
 
 
 
631
  """
632
+ Endpoint for generating images based on a text prompt.
633
  """
634
+ # Check server status
635
  if not server_status:
636
+ return JSONResponse(
637
+ status_code=503,
638
+ content={"message": "Server is under maintenance. Please try again later."}
639
+ )
640
 
641
+ # Validate model
642
  if payload.model not in supported_image_models:
643
  raise HTTPException(
644
  status_code=400,
645
+ detail=f"Model '{payload.model}' is not supported for image generation. Supported models are: {supported_image_models}"
646
  )
647
 
648
+ # Log the request
649
  usage_tracker.record_request(model=payload.model, endpoint="/images/generations")
650
 
651
+ # Prepare the payload for the external API
 
 
 
 
 
 
 
652
  api_payload = {
653
  "model": payload.model,
654
  "prompt": payload.prompt,
655
+ "size": payload.size,
656
+ "number": payload.number
657
  }
 
 
658
 
659
+ # Target API endpoint
660
+ target_api_url = os.getenv('NEW_IMG')
 
661
 
662
  try:
663
+ # Use a timeout for the image generation request
664
+ async with httpx.AsyncClient(timeout=60.0) as client:
665
+ response = await client.post(target_api_url, json=api_payload)
666
+
667
+ if response.status_code != 200:
668
+ error_detail = response.json().get("detail", f"Image generation failed with status code: {response.status_code}")
669
+ raise HTTPException(status_code=response.status_code, detail=error_detail)
670
 
671
+ # Return the response from the external API
672
  return JSONResponse(content=response.json())
673
 
674
  except httpx.TimeoutException:
 
675
  raise HTTPException(status_code=504, detail="Image generation request timed out.")
 
 
 
 
 
 
 
 
 
676
  except httpx.RequestError as e:
 
677
  raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}")
678
  except Exception as e:
 
679
  raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
680
 
681
 
 
682
 
683
+ # Asynchronous logging function
684
+ async def log_request(request, model):
685
+ # Get minimal data for logging
686
+ current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
687
+ ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
688
+ print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}")
 
 
 
 
 
689
 
690
+ # Cache usage statistics
691
+ @lru_cache(maxsize=10)
692
+ def get_usage_summary(days=7):
693
+ return usage_tracker.get_usage_summary(days)
694
 
695
+ @app.get("/usage")
696
  async def get_usage(days: int = 7):
697
+ """Retrieve usage statistics"""
698
+ return get_usage_summary(days)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699
 
700
+ # Generate HTML for usage page
701
+ def generate_usage_html(usage_data):
702
+ # Model Usage Table Rows
703
  model_usage_rows = "\n".join([
704
  f"""
705
  <tr>
706
  <td>{model}</td>
707
+ <td>{model_data['total_requests']}</td>
708
+ <td>{model_data['first_used']}</td>
709
+ <td>{model_data['last_used']}</td>
710
  </tr>
711
+ """ for model, model_data in usage_data['models'].items()
712
+ ])
713
 
714
+ # API Endpoint Usage Table Rows
715
  api_usage_rows = "\n".join([
716
  f"""
717
  <tr>
718
  <td>{endpoint}</td>
719
+ <td>{endpoint_data['total_requests']}</td>
720
+ <td>{endpoint_data['first_used']}</td>
721
+ <td>{endpoint_data['last_used']}</td>
722
  </tr>
723
+ """ for endpoint, endpoint_data in usage_data['api_endpoints'].items()
724
+ ])
725
 
726
+ # Daily Usage Table Rows
727
  daily_usage_rows = "\n".join([
728
+ "\n".join([
729
+ f"""
730
+ <tr>
731
+ <td>{date}</td>
732
+ <td>{entity}</td>
733
+ <td>{requests}</td>
734
+ </tr>
735
+ """ for entity, requests in date_data.items()
736
+ ]) for date, date_data in usage_data['recent_daily_usage'].items()
737
+ ])
738
 
 
 
 
739
  html_content = f"""
740
  <!DOCTYPE html>
741
  <html lang="en">
742
  <head>
743
  <meta charset="UTF-8">
 
744
  <title>Lokiai AI - Usage Statistics</title>
745
  <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet">
746
  <style>
 
747
  :root {{
748
+ --bg-dark: #0f1011;
749
+ --bg-darker: #070708;
750
+ --text-primary: #e6e6e6;
751
+ --text-secondary: #8c8c8c;
752
+ --border-color: #2c2c2c;
753
+ --accent-color: #3a6ee0;
754
  --accent-hover: #4a7ef0;
755
  }}
756
+ body {{
757
+ font-family: 'Inter', sans-serif;
758
+ background-color: var(--bg-dark);
759
+ color: var(--text-primary);
760
+ max-width: 1200px;
761
+ margin: 0 auto;
762
+ padding: 40px 20px;
763
+ line-height: 1.6;
764
+ }}
765
+ .logo {{
766
+ display: flex;
767
+ align-items: center;
768
+ justify-content: center;
769
+ margin-bottom: 30px;
770
+ }}
771
+ .logo h1 {{
772
+ font-weight: 600;
773
+ font-size: 2.5em;
774
+ color: var(--text-primary);
775
+ margin-left: 15px;
776
+ }}
777
+ .logo img {{
778
+ width: 60px;
779
+ height: 60px;
780
+ border-radius: 10px;
781
+ }}
782
+ .container {{
783
+ background-color: var(--bg-darker);
784
+ border-radius: 12px;
785
+ padding: 30px;
786
+ box-shadow: 0 15px 40px rgba(0,0,0,0.3);
787
+ border: 1px solid var(--border-color);
788
+ }}
789
+ h2, h3 {{
790
+ color: var(--text-primary);
791
+ border-bottom: 2px solid var(--border-color);
792
+ padding-bottom: 10px;
793
+ font-weight: 500;
794
+ }}
795
+ .total-requests {{
796
+ background-color: var(--accent-color);
797
+ color: white;
798
+ text-align: center;
799
+ padding: 15px;
800
+ border-radius: 8px;
801
+ margin-bottom: 30px;
802
+ font-weight: 600;
803
+ letter-spacing: -0.5px;
804
+ }}
805
+ table {{
806
+ width: 100%;
807
+ border-collapse: separate;
808
+ border-spacing: 0;
809
+ margin-bottom: 30px;
810
+ background-color: var(--bg-dark);
811
+ border-radius: 8px;
812
+ overflow: hidden;
813
+ }}
814
+ th, td {{
815
+ border: 1px solid var(--border-color);
816
+ padding: 12px;
817
+ text-align: left;
818
+ transition: background-color 0.3s ease;
819
+ }}
820
+ th {{
821
+ background-color: #1e1e1e;
822
+ color: var(--text-primary);
823
+ font-weight: 600;
824
+ text-transform: uppercase;
825
+ font-size: 0.9em;
826
+ }}
827
+ tr:nth-child(even) {{
828
+ background-color: rgba(255,255,255,0.05);
829
+ }}
830
+ tr:hover {{
831
+ background-color: rgba(62,100,255,0.1);
832
+ }}
833
+ @media (max-width: 768px) {{
834
+ .container {{
835
+ padding: 15px;
836
+ }}
837
+ table {{
838
+ font-size: 0.9em;
839
+ }}
840
+ }}
841
  </style>
842
  </head>
843
  <body>
844
  <div class="container">
845
  <div class="logo">
846
  <img src="data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjAwIiBoZWlnaHQ9IjIwMCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48cGF0aCBkPSJNMTAwIDM1TDUwIDkwaDEwMHoiIGZpbGw9IiMzYTZlZTAiLz48Y2lyY2xlIGN4PSIxMDAiIGN5PSIxNDAiIHI9IjMwIiBmaWxsPSIjM2E2ZWUwIi8+PC9zdmc+" alt="Lokai AI Logo">
847
+ <h1>Lokiai AI</h1>
848
  </div>
849
 
850
  <div class="total-requests">
851
+ Total API Requests: {usage_data['total_requests']}
852
  </div>
853
 
854
  <h2>Model Usage</h2>
855
  <table>
856
+ <tr>
857
+ <th>Model</th>
858
+ <th>Total Requests</th>
859
+ <th>First Used</th>
860
+ <th>Last Used</th>
861
+ </tr>
862
+ {model_usage_rows}
863
  </table>
864
 
865
  <h2>API Endpoint Usage</h2>
866
  <table>
867
+ <tr>
868
+ <th>Endpoint</th>
869
+ <th>Total Requests</th>
870
+ <th>First Used</th>
871
+ <th>Last Used</th>
872
+ </tr>
873
+ {api_usage_rows}
874
  </table>
875
 
876
+ <h2>Daily Usage (Last 7 Days)</h2>
877
  <table>
878
+ <tr>
879
+ <th>Date</th>
880
+ <th>Entity</th>
881
+ <th>Requests</th>
882
+ </tr>
883
+ {daily_usage_rows}
884
  </table>
885
  </div>
886
  </body>
 
888
  """
889
  return html_content
890
 
891
+ # Cache the usage page HTML
892
+ @lru_cache(maxsize=1)
893
+ def get_usage_page_html():
894
+ usage_data = get_usage_summary()
895
+ return generate_usage_html(usage_data)
896
 
897
+ @app.get("/usage/page", response_class=HTMLResponse)
898
  async def usage_page():
899
+ """Serve an HTML page showing usage statistics"""
900
+ # Use cached HTML if available, regenerate if not
901
+ html_content = get_usage_page_html()
902
+ return HTMLResponse(content=html_content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903
 
904
+ # Meme endpoint with optimized networking
905
+ @app.get("/meme")
906
  async def get_meme():
 
 
 
 
907
  try:
908
+ # Use the shared client for connection pooling
909
+ client = get_async_client()
910
+ response = await client.get("https://meme-api.com/gimme")
911
  response_data = response.json()
912
 
913
  meme_url = response_data.get("url")
914
+ if not meme_url:
915
+ raise HTTPException(status_code=404, detail="No meme found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
916
 
917
+ image_response = await client.get(meme_url, follow_redirects=True)
918
 
919
+ # Use larger chunks for streaming
920
+ async def stream_with_larger_chunks():
921
+ chunks = []
922
+ size = 0
923
+ async for chunk in image_response.aiter_bytes(chunk_size=16384):
924
+ chunks.append(chunk)
925
+ size += len(chunk)
926
 
927
+ if size >= 65536:
928
+ yield b''.join(chunks)
929
+ chunks = []
930
+ size = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
 
932
+ if chunks:
933
+ yield b''.join(chunks)
 
 
 
 
 
934
 
935
+ return StreamingResponse(
936
+ stream_with_larger_chunks(),
937
+ media_type=image_response.headers.get("content-type", "image/png"),
938
+ headers={'Cache-Control': 'max-age=3600'} # Add caching
939
+ )
940
+ except Exception:
941
+ raise HTTPException(status_code=500, detail="Failed to retrieve meme")
942
 
943
+ # Utility function for loading model IDs - optimized to run once at startup
944
+ def load_model_ids(json_file_path):
945
+ try:
946
+ with open(json_file_path, 'r') as f:
947
+ models_data = json.load(f)
948
+ # Extract 'id' from each model object and use a set for fast lookups
949
+ return [model['id'] for model in models_data if 'id' in model]
950
+ except Exception as e:
951
+ print(f"Error loading model IDs: {str(e)}")
952
+ return []
953
 
954
  @app.on_event("startup")
955
  async def startup_event():
 
956
  global available_model_ids
957
+ available_model_ids = load_model_ids("models.json")
958
+ print(f"Loaded {len(available_model_ids)} model IDs")
959
+
960
+ # Add all pollinations models to available_model_ids
961
+ available_model_ids.extend(list(pollinations_models))
962
+ # Add alternate models to available_model_ids
963
+ available_model_ids.extend(list(alternate_models))
964
+ # Add mistral models to available_model_ids
965
+ available_model_ids.extend(list(mistral_models))
966
+ # Add claude models
967
+ available_model_ids.extend(list(claude_3_models))
968
+
969
+ available_model_ids = list(set(available_model_ids)) # Remove duplicates
970
+ print(f"Total available models: {len(available_model_ids)}")
971
+
972
+ # Preload scrapers
973
+ for _ in range(MAX_SCRAPERS):
974
+ scraper_pool.append(cloudscraper.create_scraper())
975
+
976
+ # Validate critical environment variables
977
  env_vars = get_env_vars()
978
+ missing_vars = []
979
+
980
+ if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
981
+ missing_vars.append('API_KEYS')
982
+ if not env_vars['secret_api_endpoint']:
983
+ missing_vars.append('SECRET_API_ENDPOINT')
984
+ if not env_vars['secret_api_endpoint_2']:
985
+ missing_vars.append('SECRET_API_ENDPOINT_2')
986
+ if not env_vars['secret_api_endpoint_3']:
987
+ missing_vars.append('SECRET_API_ENDPOINT_3')
988
+ if not env_vars['secret_api_endpoint_4']:
989
+ missing_vars.append('SECRET_API_ENDPOINT_4')
990
+ if not env_vars['secret_api_endpoint_5']: # Check the new endpoint
991
+ missing_vars.append('SECRET_API_ENDPOINT_5')
992
+ if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids):
993
+ missing_vars.append('MISTRAL_API')
994
+ if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
995
+ missing_vars.append('MISTRAL_KEY')
996
+
997
+ if missing_vars:
998
+ print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
999
+ print("Some functionality may be limited.")
1000
+
1001
+ print("Server started successfully!")
1002
 
1003
  @app.on_event("shutdown")
1004
  async def shutdown_event():
1005
+ # Close the httpx client
 
 
 
1006
  client = get_async_client()
1007
  await client.aclose()
 
 
 
 
 
1008
 
1009
+ # Clear scraper pool
1010
  scraper_pool.clear()
 
1011
 
1012
  # Persist usage data
1013
+ usage_tracker.save_data()
 
 
 
 
 
 
 
1014
 
1015
+ print("Server shutdown complete!")
1016
 
1017
+ # Health check endpoint
1018
+ # Health check endpoint
1019
+ @app.get("/health")
1020
+ async def health_check():
1021
+ """Health check endpoint for monitoring"""
1022
+ env_vars = get_env_vars()
1023
+ missing_critical_vars = []
1024
+
1025
+ # Check critical environment variables
1026
+ if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
1027
+ missing_critical_vars.append('API_KEYS')
1028
+ if not env_vars['secret_api_endpoint']:
1029
+ missing_critical_vars.append('SECRET_API_ENDPOINT')
1030
+ if not env_vars['secret_api_endpoint_2']:
1031
+ missing_critical_vars.append('SECRET_API_ENDPOINT_2')
1032
+ if not env_vars['secret_api_endpoint_3']:
1033
+ missing_critical_vars.append('SECRET_API_ENDPOINT_3')
1034
+ if not env_vars['secret_api_endpoint_4']:
1035
+ missing_critical_vars.append('SECRET_API_ENDPOINT_4')
1036
+ if not env_vars['secret_api_endpoint_5']: # Check the new endpoint
1037
+ missing_critical_vars.append('SECRET_API_ENDPOINT_5')
1038
+ if not env_vars['mistral_api']:
1039
+ missing_critical_vars.append('MISTRAL_API')
1040
+ if not env_vars['mistral_key']:
1041
+ missing_critical_vars.append('MISTRAL_KEY')
1042
+
1043
+ health_status = {
1044
+ "status": "healthy" if not missing_critical_vars else "unhealthy",
1045
+ "missing_env_vars": missing_critical_vars,
1046
+ "server_status": server_status,
1047
+ "message": "Everything's lit! 🚀" if not missing_critical_vars else "Uh oh, some env vars are missing. 😬"
1048
+ }
1049
+ return JSONResponse(content=health_status)
1050
 
 
 
 
 
1051
  if __name__ == "__main__":
1052
  import uvicorn
1053
+ uvicorn.run(app, host="0.0.0.0", port=7860)