ParthSadaria commited on
Commit
74e8abb
·
verified ·
1 Parent(s): 342885c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +960 -757
main.py CHANGED
@@ -1,886 +1,1056 @@
1
  import os
2
  import re
3
- from dotenv import load_dotenv
4
- from fastapi import FastAPI, HTTPException, Request, Depends, Security
5
- from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
6
- from fastapi.security import APIKeyHeader
7
- from pydantic import BaseModel
8
- import httpx
9
- from functools import lru_cache
10
- from pathlib import Path
11
  import json
12
  import datetime
13
  import time
14
- import threading
15
- from typing import Optional, Dict, List, Any, Generator
16
  import asyncio
17
- from starlette.status import HTTP_403_FORBIDDEN
18
- import cloudscraper
 
 
19
  from concurrent.futures import ThreadPoolExecutor
20
- import uvloop
 
 
 
 
 
 
 
 
21
  from fastapi.middleware.gzip import GZipMiddleware
22
  from starlette.middleware.cors import CORSMiddleware
23
- import contextlib
24
- import requests
25
- # Enable uvloop for faster event loop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
27
 
28
- # Thread pool for CPU-bound operations
29
- executor = ThreadPoolExecutor(max_workers=16) # Increased thread count for better parallelism
 
30
 
31
- # Load environment variables once at startup
 
32
  load_dotenv()
33
 
34
- # API key security scheme
 
 
 
 
35
  api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
36
 
37
- # Initialize usage tracker
38
- from usage_tracker import UsageTracker
39
- usage_tracker = UsageTracker()
40
 
41
- app = FastAPI()
 
 
 
 
42
 
43
- # Add middleware for compression and CORS
44
- app.add_middleware(GZipMiddleware, minimum_size=1000)
45
  app.add_middleware(
46
- CORSMiddleware,
47
- allow_origins=["*"],
48
  allow_credentials=True,
49
  allow_methods=["*"],
50
  allow_headers=["*"],
51
  )
52
 
53
- # Environment variables (cached)
54
- @lru_cache(maxsize=1)
55
- def get_env_vars():
 
 
 
56
  return {
57
- 'api_keys': os.getenv('API_KEYS', '').split(','),
58
  'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'),
59
  'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
60
- 'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'),
61
- 'secret_api_endpoint_4': "https://text.pollinations.ai/openai",
62
- 'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'), # Added new endpoint
63
- 'mistral_api': "https://api.mistral.ai",
64
  'mistral_key': os.getenv('MISTRAL_KEY'),
65
- 'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
 
66
  }
67
 
68
- # Configuration for models - use sets for faster lookups
69
- mistral_models = {
70
- "mistral-large-latest",
71
- "pixtral-large-latest",
72
- "mistral-moderation-latest",
73
- "ministral-3b-latest",
74
- "ministral-8b-latest",
75
- "open-mistral-nemo",
76
- "mistral-small-latest",
77
- "mistral-saba-latest",
78
- "codestral-latest"
79
  }
80
 
81
- pollinations_models = {
82
- "openai",
83
- "openai-large",
84
- "openai-xlarge",
85
- "openai-reasoning",
86
- "qwen-coder",
87
- "llama",
88
- "mistral",
89
- "searchgpt",
90
- "deepseek",
91
- "claude-hybridspace",
92
- "deepseek-r1",
93
- "deepseek-reasoner",
94
- "llamalight",
95
- "gemini",
96
- "gemini-thinking",
97
- "hormoz",
98
- "phi",
99
- "phi-mini",
100
- "openai-audio",
101
- "llama-scaleway"
102
  }
103
 
104
- alternate_models = { # heh, should work now
105
- "gpt-4o",
106
- "deepseek-v3",
107
- "llama-3.1-8b-instruct",
108
- "llama-3.1-sonar-small-128k-online",
109
- "deepseek-r1-uncensored",
110
- "tinyswallow1.5b",
111
- "andy-3.5",
112
- "o3-mini-low",
113
- "hermes-3-llama-3.2-3b",
114
- "creitin-r1",
115
- "fluffy.1-chat",
116
- "plutotext-1-text",
117
- "command-a",
118
- "claude-3-7-sonnet-20250219",
119
- "plutogpt-3.5-turbo"
120
  }
121
 
122
- claude_3_models = { # Models for the new endpoint
123
- "claude-3-7-sonnet",
124
- "claude-3-7-sonnet-thinking",
125
- "claude 3.5 haiku",
126
- "claude 3.5 sonnet",
127
- "claude 3.5 haiku",
128
- "o3-mini-medium",
129
- "o3-mini-high",
130
- "grok-3",
131
- "grok-3-thinking",
132
- "grok 2"
133
  }
134
 
135
- # Supported image generation models
136
- supported_image_models = {
137
- "Flux Pro Ultra",
138
- "grok-2-aurora",
139
- "Flux Pro",
140
- "Flux Pro Ultra Raw",
141
- "Flux Dev",
142
- "Flux Schnell",
143
- "stable-diffusion-3-large-turbo",
144
- "Flux Realism",
145
- "stable-diffusion-ultra",
146
- "dall-e-3",
147
- "sdxl-lightning-4step"
148
  }
149
 
 
 
 
 
 
150
 
151
- # Request payload model
152
  class Payload(BaseModel):
153
  model: str
154
- messages: list
155
  stream: bool = False
 
 
 
 
 
156
 
157
-
158
- # Image generation payload model
159
  class ImageGenerationPayload(BaseModel):
160
  model: str
161
  prompt: str
162
- size: int
163
- number: int
 
 
164
 
 
165
 
 
 
166
 
167
- # Server status global variable
168
- server_status = True
169
- available_model_ids: List[str] = []
170
-
171
- # Create a reusable httpx client pool with connection pooling
172
  @lru_cache(maxsize=1)
173
- def get_async_client():
174
- return httpx.AsyncClient(
175
- timeout=60.0,
176
- limits=httpx.Limits(max_keepalive_connections=50, max_connections=200) # Increased limits
177
- )
178
-
179
- # Create a cloudscraper pool
180
- scraper_pool = []
181
- MAX_SCRAPERS = 20 # Increased pool size
182
-
183
-
184
- def get_scraper():
 
 
 
185
  if not scraper_pool:
 
186
  for _ in range(MAX_SCRAPERS):
 
187
  scraper_pool.append(cloudscraper.create_scraper())
 
 
 
188
 
189
- return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS] # Simple round-robin
190
 
191
- # API key validation - optimized to avoid string operations when possible
192
  async def verify_api_key(
193
  request: Request,
194
- api_key: str = Security(api_key_header)
195
  ) -> bool:
196
- # Allow bypass if the referer is from /playground or /image-playground
 
 
 
 
 
 
197
  referer = request.headers.get("referer", "")
198
- if referer.startswith(("https://parthsadaria-lokiai.hf.space/playground",
199
- "https://parthsadaria-lokiai.hf.space/image-playground")):
200
  return True
201
-
202
  if not api_key:
203
- raise HTTPException(
204
- status_code=HTTP_403_FORBIDDEN,
205
- detail="No API key provided"
206
- )
207
-
208
- # Only clean if needed
209
  if api_key.startswith('Bearer '):
210
- api_key = api_key[7:] # Remove 'Bearer ' prefix
211
-
212
- # Get API keys from environment
213
- valid_api_keys = get_env_vars().get('api_keys', [])
214
- if not valid_api_keys or valid_api_keys == ['']:
215
- raise HTTPException(
216
- status_code=HTTP_403_FORBIDDEN,
217
- detail="API keys not configured on server"
218
- )
219
-
220
- # Fast check with set operation
221
- if api_key not in set(valid_api_keys):
222
- raise HTTPException(
223
- status_code=HTTP_403_FORBIDDEN,
224
- detail="Invalid API key"
225
- )
226
-
227
  return True
228
 
229
- # Pre-load and cache models.json
 
230
  @lru_cache(maxsize=1)
231
- def load_models_data():
 
 
 
 
 
 
232
  try:
233
- file_path = Path(__file__).parent / 'models.json'
234
- with open(file_path, 'r') as f:
235
  return json.load(f)
236
  except (FileNotFoundError, json.JSONDecodeError) as e:
237
- print(f"Error loading models.json: {str(e)}")
238
  return []
239
 
240
- # Async wrapper for models data
241
- async def get_models():
242
  models_data = load_models_data()
243
  if not models_data:
244
  raise HTTPException(status_code=500, detail="Error loading available models")
245
  return models_data
246
 
247
- # Enhanced async streaming - now with real-time SSE support
248
- async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
249
- # Create a streaming response channel using asyncio.Queue
250
- queue = asyncio.Queue()
251
-
252
- async def _fetch_search_data():
253
- try:
254
- headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
255
-
256
- # Use the provided system prompt, or default to "Be Helpful and Friendly"
257
- system_message = systemprompt or "Be Helpful and Friendly"
258
-
259
- # Create the prompt history
260
- prompt = [
261
- {"role": "user", "content": query},
262
- ]
263
-
264
- prompt.insert(0, {"content": system_message, "role": "system"})
265
-
266
- # Prepare the payload for the API request
267
- payload = {
268
- "is_vscode_extension": True,
269
- "message_history": prompt,
270
- "requested_model": "searchgpt",
271
- "user_input": prompt[-1]["content"],
272
- }
273
-
274
- # Get endpoint from environment
275
- secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
276
- if not secret_api_endpoint_3:
277
- await queue.put({"error": "Search API endpoint not configured"})
278
- return
279
-
280
- # Use AsyncClient for better performance
281
- async with httpx.AsyncClient(timeout=30.0) as client:
282
- async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response:
283
- if response.status_code != 200:
284
- await queue.put({"error": f"Search API returned status code {response.status_code}"})
285
- return
286
-
287
- # Process the streaming response in real-time
288
- buffer = ""
289
- async for line in response.aiter_lines():
290
- if line.startswith("data: "):
291
- try:
292
- json_data = json.loads(line[6:])
293
- content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
294
-
295
- if content.strip():
296
- cleaned_response = {
297
- "created": json_data.get("created"),
298
- "id": json_data.get("id"),
299
- "model": "searchgpt",
300
- "object": "chat.completion",
301
- "choices": [
302
- {
303
- "message": {
304
- "content": content
305
- }
306
- }
307
- ]
308
- }
309
-
310
- # Send to queue immediately for streaming
311
- await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content})
312
- except json.JSONDecodeError:
313
- continue
314
-
315
- # Signal completion
316
- await queue.put(None)
317
-
318
- except Exception as e:
319
- await queue.put({"error": str(e)})
320
- await queue.put(None)
321
 
322
- # Start the fetch process
323
- asyncio.create_task(_fetch_search_data())
324
-
325
- # Return the queue for consumption
326
- return queue
327
-
328
- # Cache for frequently accessed static files
329
  @lru_cache(maxsize=10)
330
- def read_html_file(file_path):
 
 
 
 
 
331
  try:
332
- with open(file_path, "r") as file:
333
  return file.read()
334
- except FileNotFoundError:
 
335
  return None
336
 
337
- # Basic routes
338
- @app.get("/favicon.ico")
339
- async def favicon():
340
- favicon_path = Path(__file__).parent / "favicon.ico"
341
- return FileResponse(favicon_path, media_type="image/x-icon")
 
342
 
343
- @app.get("/banner.jpg")
344
- async def favicon():
345
- favicon_path = Path(__file__).parent / "banner.jpg"
346
- return FileResponse(favicon_path, media_type="image/x-icon")
347
 
348
- @app.get("/ping")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  async def ping():
350
- return {"message": "pong", "response_time": "0.000000 seconds"}
 
351
 
352
- @app.get("/", response_class=HTMLResponse)
353
  async def root():
354
- html_content = read_html_file("index.html")
355
- if html_content is None:
356
- return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
357
- return HTMLResponse(content=html_content)
358
- @app.get("/script.js", response_class=HTMLResponse)
359
- async def root():
360
- html_content = read_html_file("script.js")
361
- if html_content is None:
362
- return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
363
- return HTMLResponse(content=html_content)
364
- @app.get("/style.css", response_class=HTMLResponse)
365
- async def root():
366
- html_content = read_html_file("style.css")
367
- if html_content is None:
368
- return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
369
- return HTMLResponse(content=html_content)
370
- @app.get("/dynamo", response_class=HTMLResponse)
 
 
 
 
 
 
 
 
 
 
 
 
371
  async def dynamic_ai_page(request: Request):
372
- user_agent = request.headers.get('user-agent', 'Unknown User')
373
- client_ip = request.client.host
374
- location = f"IP: {client_ip}"
375
-
 
 
 
 
 
 
 
 
376
  prompt = f"""
377
- Generate a dynamic HTML page for a user with the following details: with name "LOKI.AI"
 
378
  - User-Agent: {user_agent}
379
- - Location: {location}
380
- - Style: Cyberpunk, minimalist, or retro
381
-
382
- Make sure the HTML is clean and includes a heading, also have cool animations a motivational message, and a cool background.
383
- Wrap the generated HTML in triple backticks (```).
384
  """
385
-
386
  payload = {
387
- "model": "mistral-small-latest",
388
- "messages": [{"role": "user", "content": prompt}]
 
 
389
  }
390
-
391
  headers = {
392
- "Authorization": "Bearer playground"
 
 
393
  }
394
-
395
- response = requests.post("https://parthsadaria-lokiai.hf.space/chat/completions", json=payload, headers=headers)
396
- data = response.json()
397
-
398
- # Extract HTML from ``` blocks
399
- html_content = re.search(r"```(.*?)```", data['choices'][0]['message']['content'], re.DOTALL)
400
- if html_content:
401
- html_content = html_content.group(1).strip()
402
-
403
- # Remove the first word
404
- if html_content:
405
- html_content = ' '.join(html_content.split(' ')[1:])
406
-
407
- return HTMLResponse(content=html_content)
408
-
409
- @app.get("/playground", response_class=HTMLResponse)
410
- async def playground():
411
- html_content = read_html_file("playground.html")
412
- if html_content is None:
413
- return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
414
- return HTMLResponse(content=html_content)
415
-
416
- @app.get("/image-playground", response_class=HTMLResponse)
417
- async def playground():
418
- html_content = read_html_file("image-playground.html")
419
- if html_content is None:
420
- return HTMLResponse(content="<h1>image-playground.html not found</h1>", status_code=404)
421
- return HTMLResponse(content=html_content)
422
 
 
 
 
 
 
 
 
 
 
423
 
 
 
 
424
 
 
 
 
425
 
426
- # VETRA
427
- GITHUB_BASE = "https://raw.githubusercontent.com/Parthsadaria/Vetra/main"
428
 
429
- FILES = {
430
- "html": "index.html",
431
- "css": "style.css",
432
- "js": "script.js"
433
- }
 
 
 
 
 
 
 
434
 
435
- async def get_github_file(filename: str) -> str:
 
436
  url = f"{GITHUB_BASE}/{filename}"
437
- async with httpx.AsyncClient() as client:
 
438
  res = await client.get(url)
439
- return res.text if res.status_code == 200 else None
 
 
 
 
 
 
 
440
 
441
- @app.get("/vetra", response_class=HTMLResponse)
442
  async def serve_vetra():
443
- html = await get_github_file(FILES["html"])
444
- css = await get_github_file(FILES["css"])
445
- js = await get_github_file(FILES["js"])
 
 
 
 
 
446
 
447
  if not html:
448
- return HTMLResponse(content="<h1>index.html not found on GitHub</h1>", status_code=404)
 
 
 
 
 
449
 
450
- final_html = html.replace(
451
- "</head>",
452
- f"<style>{css or '/* CSS not found */'}</style></head>"
453
- ).replace(
454
- "</body>",
455
- f"<script>{js or '// JS not found'}</script></body>"
456
- )
457
 
 
458
  return HTMLResponse(content=final_html)
459
 
460
 
 
 
 
 
 
 
 
 
 
461
 
 
 
 
 
 
 
 
 
462
 
 
 
 
 
 
463
 
464
- # Model routes
465
- @app.get("/api/v1/models")
466
- @app.get("/models")
467
- async def return_models():
468
- return await get_models()
469
 
470
- # Search routes with enhanced real-time streaming
471
- @app.get("/searchgpt")
472
- async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  if not q:
474
  raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
475
 
 
 
476
  usage_tracker.record_request(endpoint="/searchgpt")
477
 
478
- queue = await generate_search_async(q, systemprompt=systemprompt, stream=True)
479
 
480
  if stream:
481
  async def stream_generator():
482
- collected_text = ""
483
  while True:
484
  item = await queue.get()
485
- if item is None:
486
  break
487
-
488
  if "error" in item:
489
- yield f"data: {json.dumps({'error': item['error']})}\n\n"
 
 
 
 
490
  break
491
-
492
  if "data" in item:
493
  yield item["data"]
494
- collected_text += item.get("text", "")
 
 
495
 
496
  return StreamingResponse(
497
  stream_generator(),
498
- media_type="text/event-stream"
 
 
 
 
 
 
499
  )
500
  else:
501
- # For non-streaming, collect all text and return at once
502
- collected_text = ""
503
  while True:
504
  item = await queue.get()
505
  if item is None:
506
  break
507
-
508
  if "error" in item:
509
- raise HTTPException(status_code=500, detail=item["error"])
510
-
511
- collected_text += item.get("text", "")
512
-
513
- return JSONResponse(content={"response": collected_text})
514
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
516
 
517
- # Enhanced streaming with direct SSE pass-through for real-time responses
518
- header_url = os.getenv('HEADER_URL')
519
- @app.post("/chat/completions")
520
- @app.post("/api/v1/chat/completions")
521
- async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
522
- # Check server status
 
 
 
 
 
 
523
  if not server_status:
524
- return JSONResponse(
525
- status_code=503,
526
- content={"message": "Server is under maintenance. Please try again later."}
527
- )
528
 
529
- model_to_use = payload.model or "gpt-4o-mini"
530
 
531
- # Validate model availability - fast lookup with set
532
- if available_model_ids and model_to_use not in set(available_model_ids):
533
- raise HTTPException(
534
- status_code=400,
535
- detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
536
- )
 
 
 
 
 
 
537
 
538
- # Log request without blocking
 
539
  asyncio.create_task(log_request(request, model_to_use))
540
  usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
541
 
542
- # Prepare payload
543
- payload_dict = payload.dict()
544
- payload_dict["model"] = model_to_use
545
-
546
- # Ensure stream is True for real-time streaming (can be overridden by client)
547
- stream_enabled = payload_dict.get("stream", True)
548
 
549
- # Get environment variables
550
  env_vars = get_env_vars()
 
 
 
 
 
551
 
552
- # Select the appropriate endpoint (fast lookup with sets)
553
  if model_to_use in mistral_models:
554
- endpoint = env_vars['mistral_api']
555
- custom_headers = {
556
- "Authorization": f"Bearer {env_vars['mistral_key']}"
557
- }
 
 
 
 
558
  elif model_to_use in pollinations_models:
559
- endpoint = env_vars['secret_api_endpoint_4']
560
- custom_headers = {}
 
 
 
 
561
  elif model_to_use in alternate_models:
562
- endpoint = env_vars['secret_api_endpoint_2']
563
- custom_headers = {}
564
- elif model_to_use in claude_3_models: # Use the new endpoint
565
- endpoint = env_vars['secret_api_endpoint_5']
566
- custom_headers = {}
567
- else:
568
- endpoint = env_vars['secret_api_endpoint']
569
- custom_headers = {
570
- "Origin": header_url,
571
- "Priority": "u=1, i",
572
- "Referer": header_url
573
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
- print(f"Using endpoint: {endpoint} for model: {model_to_use}")
 
576
 
577
- # Improved real-time streaming handler
578
- async def real_time_stream_generator():
 
 
 
579
  try:
580
- async with httpx.AsyncClient(timeout=60.0) as client:
581
- async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, headers=custom_headers) as response:
582
- if response.status_code >= 400:
583
- error_messages = {
584
- 422: "Unprocessable entity. Check your payload.",
585
- 400: "Bad request. Verify input data.",
586
- 403: "Forbidden. You do not have access to this resource.",
587
- 404: "The requested resource was not found.",
588
- }
589
- detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
590
- raise HTTPException(status_code=response.status_code, detail=detail)
591
-
592
- # Stream the response in real-time with minimal buffering
593
- async for line in response.aiter_lines():
594
- if line:
595
- # Yield immediately for faster streaming
596
- yield line + "\n"
 
 
 
 
 
 
 
 
 
597
  except httpx.TimeoutException:
598
- raise HTTPException(status_code=504, detail="Request timed out")
 
 
599
  except httpx.RequestError as e:
600
- raise HTTPException(status_code=502, detail=f"Failed to connect to upstream API: {str(e)}")
 
 
601
  except Exception as e:
602
- if isinstance(e, HTTPException):
603
- raise e
604
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
605
 
606
- # Return streaming response with proper headers
607
- if stream_enabled:
608
  return StreamingResponse(
609
- real_time_stream_generator(),
610
  media_type="text/event-stream",
611
  headers={
612
  "Content-Type": "text/event-stream",
613
  "Cache-Control": "no-cache",
614
  "Connection": "keep-alive",
615
- "X-Accel-Buffering": "no" # Disable proxy buffering for Nginx
616
  }
617
  )
618
  else:
619
- # For non-streaming requests, collect the entire response
620
- response_content = []
621
- async for chunk in real_time_stream_generator():
622
- response_content.append(chunk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
 
624
- return JSONResponse(content=json.loads(''.join(response_content)))
625
 
 
626
 
627
 
628
- # New image generation endpoint
629
- @app.post("/images/generations")
630
- async def create_image(payload: ImageGenerationPayload, authenticated: bool = Depends(verify_api_key)):
 
 
 
631
  """
632
- Endpoint for generating images based on a text prompt.
633
  """
634
- # Check server status
635
  if not server_status:
636
- return JSONResponse(
637
- status_code=503,
638
- content={"message": "Server is under maintenance. Please try again later."}
639
- )
640
 
641
- # Validate model
642
  if payload.model not in supported_image_models:
643
  raise HTTPException(
644
  status_code=400,
645
- detail=f"Model '{payload.model}' is not supported for image generation. Supported models are: {supported_image_models}"
646
  )
647
 
648
- # Log the request
649
  usage_tracker.record_request(model=payload.model, endpoint="/images/generations")
650
 
651
- # Prepare the payload for the external API
 
 
 
 
 
 
 
652
  api_payload = {
653
  "model": payload.model,
654
  "prompt": payload.prompt,
655
- "size": payload.size,
656
- "number": payload.number
657
  }
 
 
658
 
659
- # Target API endpoint
660
- target_api_url = os.getenv('NEW_IMG')
661
 
662
- try:
663
- # Use a timeout for the image generation request
664
- async with httpx.AsyncClient(timeout=60.0) as client:
665
- response = await client.post(target_api_url, json=api_payload)
666
 
667
- if response.status_code != 200:
668
- error_detail = response.json().get("detail", f"Image generation failed with status code: {response.status_code}")
669
- raise HTTPException(status_code=response.status_code, detail=error_detail)
 
 
670
 
671
- # Return the response from the external API
672
  return JSONResponse(content=response.json())
673
 
674
  except httpx.TimeoutException:
 
675
  raise HTTPException(status_code=504, detail="Image generation request timed out.")
 
 
 
 
 
 
 
 
 
676
  except httpx.RequestError as e:
 
677
  raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}")
678
  except Exception as e:
 
679
  raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
680
 
681
 
 
682
 
683
- # Asynchronous logging function
684
- async def log_request(request, model):
685
- # Get minimal data for logging
686
- current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
687
- ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
688
- print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}")
 
 
 
 
 
689
 
690
- # Cache usage statistics
691
- @lru_cache(maxsize=10)
692
- def get_usage_summary(days=7):
693
- return usage_tracker.get_usage_summary(days)
694
 
695
- @app.get("/usage")
696
  async def get_usage(days: int = 7):
697
- """Retrieve usage statistics"""
698
- return get_usage_summary(days)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699
 
700
- # Generate HTML for usage page
701
- def generate_usage_html(usage_data):
702
- # Model Usage Table Rows
703
  model_usage_rows = "\n".join([
704
  f"""
705
  <tr>
706
  <td>{model}</td>
707
- <td>{model_data['total_requests']}</td>
708
- <td>{model_data['first_used']}</td>
709
- <td>{model_data['last_used']}</td>
710
  </tr>
711
- """ for model, model_data in usage_data['models'].items()
712
- ])
713
 
714
- # API Endpoint Usage Table Rows
715
  api_usage_rows = "\n".join([
716
  f"""
717
  <tr>
718
  <td>{endpoint}</td>
719
- <td>{endpoint_data['total_requests']}</td>
720
- <td>{endpoint_data['first_used']}</td>
721
- <td>{endpoint_data['last_used']}</td>
722
  </tr>
723
- """ for endpoint, endpoint_data in usage_data['api_endpoints'].items()
724
- ])
725
 
726
- # Daily Usage Table Rows
727
  daily_usage_rows = "\n".join([
728
- "\n".join([
729
- f"""
730
- <tr>
731
- <td>{date}</td>
732
- <td>{entity}</td>
733
- <td>{requests}</td>
734
- </tr>
735
- """ for entity, requests in date_data.items()
736
- ]) for date, date_data in usage_data['recent_daily_usage'].items()
737
- ])
738
 
 
 
 
739
  html_content = f"""
740
  <!DOCTYPE html>
741
  <html lang="en">
742
  <head>
743
  <meta charset="UTF-8">
 
744
  <title>Lokiai AI - Usage Statistics</title>
745
  <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet">
746
  <style>
 
747
  :root {{
748
- --bg-dark: #0f1011;
749
- --bg-darker: #070708;
750
- --text-primary: #e6e6e6;
751
- --text-secondary: #8c8c8c;
752
- --border-color: #2c2c2c;
753
- --accent-color: #3a6ee0;
754
  --accent-hover: #4a7ef0;
755
  }}
756
- body {{
757
- font-family: 'Inter', sans-serif;
758
- background-color: var(--bg-dark);
759
- color: var(--text-primary);
760
- max-width: 1200px;
761
- margin: 0 auto;
762
- padding: 40px 20px;
763
- line-height: 1.6;
764
- }}
765
- .logo {{
766
- display: flex;
767
- align-items: center;
768
- justify-content: center;
769
- margin-bottom: 30px;
770
- }}
771
- .logo h1 {{
772
- font-weight: 600;
773
- font-size: 2.5em;
774
- color: var(--text-primary);
775
- margin-left: 15px;
776
- }}
777
- .logo img {{
778
- width: 60px;
779
- height: 60px;
780
- border-radius: 10px;
781
- }}
782
- .container {{
783
- background-color: var(--bg-darker);
784
- border-radius: 12px;
785
- padding: 30px;
786
- box-shadow: 0 15px 40px rgba(0,0,0,0.3);
787
- border: 1px solid var(--border-color);
788
- }}
789
- h2, h3 {{
790
- color: var(--text-primary);
791
- border-bottom: 2px solid var(--border-color);
792
- padding-bottom: 10px;
793
- font-weight: 500;
794
- }}
795
- .total-requests {{
796
- background-color: var(--accent-color);
797
- color: white;
798
- text-align: center;
799
- padding: 15px;
800
- border-radius: 8px;
801
- margin-bottom: 30px;
802
- font-weight: 600;
803
- letter-spacing: -0.5px;
804
- }}
805
- table {{
806
- width: 100%;
807
- border-collapse: separate;
808
- border-spacing: 0;
809
- margin-bottom: 30px;
810
- background-color: var(--bg-dark);
811
- border-radius: 8px;
812
- overflow: hidden;
813
- }}
814
- th, td {{
815
- border: 1px solid var(--border-color);
816
- padding: 12px;
817
- text-align: left;
818
- transition: background-color 0.3s ease;
819
- }}
820
- th {{
821
- background-color: #1e1e1e;
822
- color: var(--text-primary);
823
- font-weight: 600;
824
- text-transform: uppercase;
825
- font-size: 0.9em;
826
- }}
827
- tr:nth-child(even) {{
828
- background-color: rgba(255,255,255,0.05);
829
- }}
830
- tr:hover {{
831
- background-color: rgba(62,100,255,0.1);
832
- }}
833
- @media (max-width: 768px) {{
834
- .container {{
835
- padding: 15px;
836
- }}
837
- table {{
838
- font-size: 0.9em;
839
- }}
840
- }}
841
  </style>
842
  </head>
843
  <body>
844
  <div class="container">
845
  <div class="logo">
846
  <img src="data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjAwIiBoZWlnaHQ9IjIwMCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48cGF0aCBkPSJNMTAwIDM1TDUwIDkwaDEwMHoiIGZpbGw9IiMzYTZlZTAiLz48Y2lyY2xlIGN4PSIxMDAiIGN5PSIxNDAiIHI9IjMwIiBmaWxsPSIjM2E2ZWUwIi8+PC9zdmc+" alt="Lokai AI Logo">
847
- <h1>Lokiai AI</h1>
848
  </div>
849
 
850
  <div class="total-requests">
851
- Total API Requests: {usage_data['total_requests']}
852
  </div>
853
 
854
  <h2>Model Usage</h2>
855
  <table>
856
- <tr>
857
- <th>Model</th>
858
- <th>Total Requests</th>
859
- <th>First Used</th>
860
- <th>Last Used</th>
861
- </tr>
862
- {model_usage_rows}
863
  </table>
864
 
865
  <h2>API Endpoint Usage</h2>
866
  <table>
867
- <tr>
868
- <th>Endpoint</th>
869
- <th>Total Requests</th>
870
- <th>First Used</th>
871
- <th>Last Used</th>
872
- </tr>
873
- {api_usage_rows}
874
  </table>
875
 
876
- <h2>Daily Usage (Last 7 Days)</h2>
877
  <table>
878
- <tr>
879
- <th>Date</th>
880
- <th>Entity</th>
881
- <th>Requests</th>
882
- </tr>
883
- {daily_usage_rows}
884
  </table>
885
  </div>
886
  </body>
@@ -888,166 +1058,199 @@ def generate_usage_html(usage_data):
888
  """
889
  return html_content
890
 
891
- # Cache the usage page HTML
892
- @lru_cache(maxsize=1)
893
- def get_usage_page_html():
894
- usage_data = get_usage_summary()
895
- return generate_usage_html(usage_data)
896
 
897
- @app.get("/usage/page", response_class=HTMLResponse)
898
  async def usage_page():
899
- """Serve an HTML page showing usage statistics"""
900
- # Use cached HTML if available, regenerate if not
901
- html_content = get_usage_page_html()
902
- return HTMLResponse(content=html_content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903
 
904
- # Meme endpoint with optimized networking
905
- @app.get("/meme")
906
  async def get_meme():
 
 
 
 
907
  try:
908
- # Use the shared client for connection pooling
909
- client = get_async_client()
910
- response = await client.get("https://meme-api.com/gimme")
911
  response_data = response.json()
912
 
913
  meme_url = response_data.get("url")
914
- if not meme_url:
915
- raise HTTPException(status_code=404, detail="No meme found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
916
 
917
- image_response = await client.get(meme_url, follow_redirects=True)
918
 
919
- # Use larger chunks for streaming
920
- async def stream_with_larger_chunks():
921
- chunks = []
922
- size = 0
923
- async for chunk in image_response.aiter_bytes(chunk_size=16384):
924
- chunks.append(chunk)
925
- size += len(chunk)
926
 
927
- if size >= 65536:
928
- yield b''.join(chunks)
929
- chunks = []
930
- size = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
 
932
- if chunks:
933
- yield b''.join(chunks)
 
 
 
 
 
934
 
935
- return StreamingResponse(
936
- stream_with_larger_chunks(),
937
- media_type=image_response.headers.get("content-type", "image/png"),
938
- headers={'Cache-Control': 'max-age=3600'} # Add caching
939
- )
940
- except Exception:
941
- raise HTTPException(status_code=500, detail="Failed to retrieve meme")
942
 
943
- # Utility function for loading model IDs - optimized to run once at startup
944
- def load_model_ids(json_file_path):
945
- try:
946
- with open(json_file_path, 'r') as f:
947
- models_data = json.load(f)
948
- # Extract 'id' from each model object and use a set for fast lookups
949
- return [model['id'] for model in models_data if 'id' in model]
950
- except Exception as e:
951
- print(f"Error loading model IDs: {str(e)}")
952
- return []
953
 
954
  @app.on_event("startup")
955
  async def startup_event():
 
956
  global available_model_ids
957
- available_model_ids = load_model_ids("models.json")
958
- print(f"Loaded {len(available_model_ids)} model IDs")
959
-
960
- # Add all pollinations models to available_model_ids
961
- available_model_ids.extend(list(pollinations_models))
962
- # Add alternate models to available_model_ids
963
- available_model_ids.extend(list(alternate_models))
964
- # Add mistral models to available_model_ids
965
- available_model_ids.extend(list(mistral_models))
966
- # Add claude models
967
- available_model_ids.extend(list(claude_3_models))
968
-
969
- available_model_ids = list(set(available_model_ids)) # Remove duplicates
970
- print(f"Total available models: {len(available_model_ids)}")
971
-
972
- # Preload scrapers
973
- for _ in range(MAX_SCRAPERS):
974
- scraper_pool.append(cloudscraper.create_scraper())
975
-
976
- # Validate critical environment variables
977
  env_vars = get_env_vars()
978
- missing_vars = []
979
-
980
- if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
981
- missing_vars.append('API_KEYS')
982
- if not env_vars['secret_api_endpoint']:
983
- missing_vars.append('SECRET_API_ENDPOINT')
984
- if not env_vars['secret_api_endpoint_2']:
985
- missing_vars.append('SECRET_API_ENDPOINT_2')
986
- if not env_vars['secret_api_endpoint_3']:
987
- missing_vars.append('SECRET_API_ENDPOINT_3')
988
- if not env_vars['secret_api_endpoint_4']:
989
- missing_vars.append('SECRET_API_ENDPOINT_4')
990
- if not env_vars['secret_api_endpoint_5']: # Check the new endpoint
991
- missing_vars.append('SECRET_API_ENDPOINT_5')
992
- if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids):
993
- missing_vars.append('MISTRAL_API')
994
- if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
995
- missing_vars.append('MISTRAL_KEY')
996
-
997
- if missing_vars:
998
- print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
999
- print("Some functionality may be limited.")
1000
-
1001
- print("Server started successfully!")
1002
 
1003
  @app.on_event("shutdown")
1004
  async def shutdown_event():
1005
- # Close the httpx client
 
 
 
1006
  client = get_async_client()
1007
  await client.aclose()
 
 
 
 
 
1008
 
1009
- # Clear scraper pool
1010
  scraper_pool.clear()
 
1011
 
1012
  # Persist usage data
1013
- usage_tracker.save_data()
1014
-
1015
- print("Server shutdown complete!")
1016
-
1017
- # Health check endpoint
1018
- # Health check endpoint
1019
- @app.get("/health")
1020
- async def health_check():
1021
- """Health check endpoint for monitoring"""
1022
- env_vars = get_env_vars()
1023
- missing_critical_vars = []
1024
 
1025
- # Check critical environment variables
1026
- if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
1027
- missing_critical_vars.append('API_KEYS')
1028
- if not env_vars['secret_api_endpoint']:
1029
- missing_critical_vars.append('SECRET_API_ENDPOINT')
1030
- if not env_vars['secret_api_endpoint_2']:
1031
- missing_critical_vars.append('SECRET_API_ENDPOINT_2')
1032
- if not env_vars['secret_api_endpoint_3']:
1033
- missing_critical_vars.append('SECRET_API_ENDPOINT_3')
1034
- if not env_vars['secret_api_endpoint_4']:
1035
- missing_critical_vars.append('SECRET_API_ENDPOINT_4')
1036
- if not env_vars['secret_api_endpoint_5']: # Check the new endpoint
1037
- missing_critical_vars.append('SECRET_API_ENDPOINT_5')
1038
- if not env_vars['mistral_api']:
1039
- missing_critical_vars.append('MISTRAL_API')
1040
- if not env_vars['mistral_key']:
1041
- missing_critical_vars.append('MISTRAL_KEY')
1042
 
1043
- health_status = {
1044
- "status": "healthy" if not missing_critical_vars else "unhealthy",
1045
- "missing_env_vars": missing_critical_vars,
1046
- "server_status": server_status,
1047
- "message": "Everything's lit! 🚀" if not missing_critical_vars else "Uh oh, some env vars are missing. 😬"
1048
- }
1049
- return JSONResponse(content=health_status)
1050
 
 
 
 
 
1051
  if __name__ == "__main__":
1052
  import uvicorn
1053
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
1
  import os
2
  import re
 
 
 
 
 
 
 
 
3
  import json
4
  import datetime
5
  import time
 
 
6
  import asyncio
7
+ import logging
8
+ from pathlib import Path
9
+ from functools import lru_cache
10
+ from typing import Optional, Dict, List, Any, Generator, Set
11
  from concurrent.futures import ThreadPoolExecutor
12
+
13
+ # Third-party libraries (ensure these are in requirements.txt)
14
+ from dotenv import load_dotenv
15
+ from fastapi import FastAPI, HTTPException, Request, Depends, Security, Response
16
+ from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
17
+ from fastapi.security import APIKeyHeader
18
+ from pydantic import BaseModel
19
+ import httpx
20
+ import uvloop # Use uvloop for performance
21
  from fastapi.middleware.gzip import GZipMiddleware
22
  from starlette.middleware.cors import CORSMiddleware
23
+ import cloudscraper # For bypassing Cloudflare, potentially unreliable
24
+ import requests # For synchronous requests like in /dynamo
25
+
26
+ # HF Space Note: Ensure usage_tracker.py is in your repository
27
+ try:
28
+ from usage_tracker import UsageTracker
29
+ usage_tracker = UsageTracker()
30
+ except ImportError:
31
+ print("Warning: usage_tracker.py not found. Usage tracking will be disabled.")
32
+ # Create a dummy tracker if the file is missing
33
+ class DummyUsageTracker:
34
+ def record_request(self, *args, **kwargs): pass
35
+ def get_usage_summary(self, *args, **kwargs): return {}
36
+ def save_data(self, *args, **kwargs): pass
37
+ usage_tracker = DummyUsageTracker()
38
+
39
+
40
+ # --- Configuration & Setup ---
41
+
42
+ # HF Space Note: uvloop can improve performance in I/O bound tasks common in web apps.
43
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
44
 
45
+ # HF Space Note: Adjust max_workers based on your HF Space resources (CPU).
46
+ # Higher tiers allow more workers. Start lower (e.g., 4) for free tier.
47
+ executor = ThreadPoolExecutor(max_workers=8)
48
 
49
+ # HF Space Note: load_dotenv() is useful for local dev but HF Spaces use Secrets.
50
+ # os.getenv will automatically pick up secrets set in the HF Space settings.
51
  load_dotenv()
52
 
53
+ # Logging setup
54
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
55
+ logger = logging.getLogger(__name__)
56
+
57
+ # API key security
58
  api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
59
 
60
+ # --- FastAPI App Initialization ---
 
 
61
 
62
+ app = FastAPI(
63
+ title="LokiAI API",
64
+ description="API Proxy for various AI models with usage tracking and streaming.",
65
+ version="1.0.0"
66
+ )
67
 
68
+ # Middleware
69
+ app.add_middleware(GZipMiddleware, minimum_size=1000) # Compress large responses
70
  app.add_middleware(
71
+ CORSMiddleware, # Allow cross-origin requests (useful for web playgrounds)
72
+ allow_origins=["*"], # Or restrict to specific origins
73
  allow_credentials=True,
74
  allow_methods=["*"],
75
  allow_headers=["*"],
76
  )
77
 
78
+ # --- Environment Variables & Model Config ---
79
+
80
+ @lru_cache(maxsize=1) # Cache environment variables
81
+ def get_env_vars() -> Dict[str, Any]:
82
+ """Loads and returns essential environment variables."""
83
+ # HF Space Note: Set these as Secrets in your Hugging Face Space settings.
84
  return {
85
+ 'api_keys': set(filter(None, os.getenv('API_KEYS', '').split(','))), # Use set for faster lookup
86
  'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'),
87
  'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
88
+ 'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'), # Search endpoint
89
+ 'secret_api_endpoint_4': os.getenv('SECRET_API_ENDPOINT_4', "https://text.pollinations.ai/openai"), # Pollinations
90
+ 'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'), # Claude 3 endpoint
91
+ 'mistral_api': os.getenv('MISTRAL_API', "https://api.mistral.ai"),
92
  'mistral_key': os.getenv('MISTRAL_KEY'),
93
+ 'new_img_endpoint': os.getenv('NEW_IMG'), # Image generation endpoint
94
+ 'hf_space_url': os.getenv('HF_SPACE_URL', 'https://your-space-name.hf.space') # HF Space Note: Set this! Used for Referer/Origin checks.
95
  }
96
 
97
+ # Model sets for fast lookups
98
+ # HF Space Note: Consider moving these large sets to a separate config file (e.g., config.py or models_config.json)
99
+ # for better organization if they grow larger.
100
+ mistral_models: Set[str] = {
101
+ "mistral-large-latest", "pixtral-large-latest", "mistral-moderation-latest",
102
+ "ministral-3b-latest", "ministral-8b-latest", "open-mistral-nemo",
103
+ "mistral-small-latest", "mistral-saba-latest", "codestral-latest"
 
 
 
 
104
  }
105
 
106
+ pollinations_models: Set[str] = {
107
+ "openai", "openai-large", "openai-xlarge", "openai-reasoning", "qwen-coder",
108
+ "llama", "mistral", "searchgpt", "deepseek", "claude-hybridspace",
109
+ "deepseek-r1", "deepseek-reasoner", "llamalight", "gemini", "gemini-thinking",
110
+ "hormoz", "phi", "phi-mini", "openai-audio", "llama-scaleway"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  }
112
 
113
+ alternate_models: Set[str] = {
114
+ "gpt-4o", "deepseek-v3", "llama-3.1-8b-instruct", "llama-3.1-sonar-small-128k-online",
115
+ "deepseek-r1-uncensored", "tinyswallow1.5b", "andy-3.5", "o3-mini-low",
116
+ "hermes-3-llama-3.2-3b", "creitin-r1", "fluffy.1-chat", "plutotext-1-text",
117
+ "command-a", "claude-3-7-sonnet-20250219", "plutogpt-3.5-turbo"
 
 
 
 
 
 
 
 
 
 
 
118
  }
119
 
120
+ claude_3_models: Set[str] = {
121
+ "claude-3-7-sonnet", "claude-3-7-sonnet-thinking", "claude 3.5 haiku",
122
+ "claude 3.5 sonnet", "claude 3.5 haiku", "o3-mini-medium", "o3-mini-high",
123
+ "grok-3", "grok-3-thinking", "grok 2"
 
 
 
 
 
 
 
124
  }
125
 
126
+ supported_image_models: Set[str] = {
127
+ "Flux Pro Ultra", "grok-2-aurora", "Flux Pro", "Flux Pro Ultra Raw", "Flux Dev",
128
+ "Flux Schnell", "stable-diffusion-3-large-turbo", "Flux Realism",
129
+ "stable-diffusion-ultra", "dall-e-3", "sdxl-lightning-4step"
 
 
 
 
 
 
 
 
 
130
  }
131
 
132
+ # --- Pydantic Models ---
133
+
134
+ class Message(BaseModel):
135
+ role: str
136
+ content: Any # Allow content to be string or potentially list for multimodal models
137
 
 
138
  class Payload(BaseModel):
139
  model: str
140
+ messages: List[Message]
141
  stream: bool = False
142
+ # Add other potential OpenAI compatible parameters with defaults
143
+ max_tokens: Optional[int] = None
144
+ temperature: Optional[float] = None
145
+ top_p: Optional[float] = None
146
+ # ... add others as needed
147
 
 
 
148
  class ImageGenerationPayload(BaseModel):
149
  model: str
150
  prompt: str
151
+ size: Optional[str] = "1024x1024" # Default size, make optional if API allows
152
+ n: Optional[int] = 1 # Number of images, OpenAI uses 'n'
153
+ # HF Space Note: Ensure these parameter names match the target NEW_IMG endpoint API
154
+ # Renaming from 'number' to 'n' and 'size' type hint correction.
155
 
156
+ # --- Global State & Clients ---
157
 
158
+ server_status: bool = True # For maintenance mode
159
+ available_model_ids: List[str] = [] # Loaded at startup
160
 
161
+ # HF Space Note: Reusable HTTP client with connection pooling is crucial for performance.
162
+ # Adjust limits based on expected load and HF Space resources.
 
 
 
163
  @lru_cache(maxsize=1)
164
+ def get_async_client() -> httpx.AsyncClient:
165
+ """Returns a cached instance of httpx.AsyncClient."""
166
+ # HF Space Note: Timeouts are important to prevent hanging requests.
167
+ # Keepalive connections reduce handshake overhead.
168
+ timeout = httpx.Timeout(30.0, connect=10.0) # 30s total, 10s connect
169
+ limits = httpx.Limits(max_keepalive_connections=20, max_connections=100)
170
+ return httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True)
171
+
172
+ # HF Space Note: cloudscraper pool. Be mindful of potential rate limits or blocks.
173
+ # Consider alternatives if this becomes unreliable.
174
+ scraper_pool: List[cloudscraper.CloudScraper] = []
175
+ MAX_SCRAPERS = 10 # Reduced pool size for potentially lower resource usage
176
+
177
+ def get_scraper() -> cloudscraper.CloudScraper:
178
+ """Gets a cloudscraper instance from the pool."""
179
  if not scraper_pool:
180
+ logger.info(f"Initializing {MAX_SCRAPERS} cloudscraper instances...")
181
  for _ in range(MAX_SCRAPERS):
182
+ # HF Space Note: Scraper creation can be slow, doing it upfront is good.
183
  scraper_pool.append(cloudscraper.create_scraper())
184
+ logger.info("Cloudscraper pool initialized.")
185
+ # Simple round-robin selection
186
+ return scraper_pool[int(time.monotonic() * 1000) % MAX_SCRAPERS]
187
 
188
+ # --- Security & Authentication ---
189
 
 
190
  async def verify_api_key(
191
  request: Request,
192
+ api_key: Optional[str] = Security(api_key_header)
193
  ) -> bool:
194
+ """Verifies the provided API key against environment variables."""
195
+ env_vars = get_env_vars()
196
+ valid_api_keys = env_vars.get('api_keys', set())
197
+ hf_space_url = env_vars.get('hf_space_url', '')
198
+
199
+ # Allow bypass if the referer is from the known HF Space playground URLs
200
+ # HF Space Note: Make HF_SPACE_URL a secret for flexibility.
201
  referer = request.headers.get("referer", "")
202
+ if hf_space_url and referer.startswith((f"{hf_space_url}/playground", f"{hf_space_url}/image-playground")):
203
+ logger.debug(f"API Key check bypassed for referer: {referer}")
204
  return True
205
+
206
  if not api_key:
207
+ logger.warning("API Key missing.")
208
+ raise HTTPException(status_code=403, detail="Not authenticated: No API key provided")
209
+
210
+ # Clean 'Bearer ' prefix if present
 
 
211
  if api_key.startswith('Bearer '):
212
+ api_key = api_key[7:]
213
+
214
+ if not valid_api_keys:
215
+ logger.error("API keys are not configured on the server (API_KEYS secret missing?).")
216
+ raise HTTPException(status_code=500, detail="Server configuration error: API keys not set")
217
+
218
+ if api_key not in valid_api_keys:
219
+ logger.warning(f"Invalid API key received: {api_key[:4]}...") # Log prefix only
220
+ raise HTTPException(status_code=403, detail="Not authenticated: Invalid API key")
221
+
222
+ logger.debug("API Key verified successfully.")
 
 
 
 
 
 
223
  return True
224
 
225
+ # --- Model & File Loading ---
226
+
227
  @lru_cache(maxsize=1)
228
+ def load_models_data() -> List[Dict]:
229
+ """Loads model data from models.json."""
230
+ # HF Space Note: Ensure models.json is in the root of your HF Space repo.
231
+ models_file = Path(__file__).parent / 'models.json'
232
+ if not models_file.is_file():
233
+ logger.error("models.json not found!")
234
+ return []
235
  try:
236
+ with open(models_file, 'r') as f:
 
237
  return json.load(f)
238
  except (FileNotFoundError, json.JSONDecodeError) as e:
239
+ logger.error(f"Error loading models.json: {e}")
240
  return []
241
 
242
+ async def get_models() -> List[Dict]:
243
+ """Async wrapper to get models data."""
244
  models_data = load_models_data()
245
  if not models_data:
246
  raise HTTPException(status_code=500, detail="Error loading available models")
247
  return models_data
248
 
249
+ # --- Static File Serving ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
+ # HF Space Note: Cache frequently accessed static files in memory.
 
 
 
 
 
 
252
  @lru_cache(maxsize=10)
253
+ def read_static_file(file_path: str) -> Optional[str]:
254
+ """Reads a static file, caching the result."""
255
+ full_path = Path(__file__).parent / file_path
256
+ if not full_path.is_file():
257
+ logger.warning(f"Static file not found: {file_path}")
258
+ return None
259
  try:
260
+ with open(full_path, "r", encoding="utf-8") as file:
261
  return file.read()
262
+ except Exception as e:
263
+ logger.error(f"Error reading static file {file_path}: {e}")
264
  return None
265
 
266
+ async def serve_static_html(file_path: str) -> HTMLResponse:
267
+ """Serves a static HTML file."""
268
+ content = read_static_file(file_path)
269
+ if content is None:
270
+ return HTMLResponse(content=f"<h1>Error: {file_path} not found</h1>", status_code=404)
271
+ return HTMLResponse(content=content)
272
 
273
+ # --- API Endpoints ---
 
 
 
274
 
275
+ # Basic Routes & Static Files
276
+ @app.get("/favicon.ico", include_in_schema=False)
277
+ async def favicon():
278
+ favicon_path = Path(__file__).parent / "favicon.ico"
279
+ if favicon_path.is_file():
280
+ return FileResponse(favicon_path, media_type="image/vnd.microsoft.icon")
281
+ raise HTTPException(status_code=404, detail="favicon.ico not found")
282
+
283
+ @app.get("/banner.jpg", include_in_schema=False)
284
+ async def banner():
285
+ banner_path = Path(__file__).parent / "banner.jpg"
286
+ if banner_path.is_file():
287
+ return FileResponse(banner_path, media_type="image/jpeg") # Assuming JPEG
288
+ raise HTTPException(status_code=404, detail="banner.jpg not found")
289
+
290
+ @app.get("/ping", tags=["Utility"])
291
  async def ping():
292
+ """Simple health check endpoint."""
293
+ return {"message": "pong"}
294
 
295
+ @app.get("/", response_class=HTMLResponse, tags=["Frontend"])
296
  async def root():
297
+ """Serves the main index HTML page."""
298
+ return await serve_static_html("index.html")
299
+
300
+ @app.get("/script.js", response_class=Response, tags=["Frontend"], include_in_schema=False)
301
+ async def script_js():
302
+ content = read_static_file("script.js")
303
+ if content is None:
304
+ return Response(content="/* script.js not found */", status_code=404, media_type="application/javascript")
305
+ return Response(content=content, media_type="application/javascript")
306
+
307
+ @app.get("/style.css", response_class=Response, tags=["Frontend"], include_in_schema=False)
308
+ async def style_css():
309
+ content = read_static_file("style.css")
310
+ if content is None:
311
+ return Response(content="/* style.css not found */", status_code=404, media_type="text/css")
312
+ return Response(content=content, media_type="text/css")
313
+
314
+ @app.get("/playground", response_class=HTMLResponse, tags=["Frontend"])
315
+ async def playground():
316
+ """Serves the chat playground HTML page."""
317
+ return await serve_static_html("playground.html")
318
+
319
+ @app.get("/image-playground", response_class=HTMLResponse, tags=["Frontend"])
320
+ async def image_playground():
321
+ """Serves the image playground HTML page."""
322
+ return await serve_static_html("image-playground.html")
323
+
324
+ # Dynamic Page Example
325
+ @app.get("/dynamo", response_class=HTMLResponse, tags=["Examples"])
326
  async def dynamic_ai_page(request: Request):
327
+ """Generates a dynamic HTML page using an AI model (example)."""
328
+ # HF Space Note: This uses a hardcoded URL to *itself* if running in the space.
329
+ # Ensure the HF_SPACE_URL secret is set correctly.
330
+ env_vars = get_env_vars()
331
+ hf_space_url = env_vars.get('hf_space_url', '')
332
+ if not hf_space_url:
333
+ raise HTTPException(status_code=500, detail="HF_SPACE_URL environment variable not set.")
334
+
335
+ user_agent = request.headers.get('user-agent', 'Unknown')
336
+ client_ip = request.client.host if request.client else "Unknown"
337
+ location = f"IP: {client_ip}" # Basic IP, location requires GeoIP lookup (extra dependency)
338
+
339
  prompt = f"""
340
+ Generate a cool, dynamic HTML page for a user with the following details:
341
+ - App Name: "LokiAI"
342
  - User-Agent: {user_agent}
343
+ - Location Info: {location}
344
+ - Style: Cyberpunk aesthetic, minimalist layout, maybe some retro touches.
345
+ - Content: Include a heading, a short motivational or witty message, and perhaps a subtle animation. Use inline CSS for styling within a <style> tag.
346
+ - Output: Provide ONLY the raw HTML code, starting with <!DOCTYPE html>. Do not wrap it in backticks or add explanations.
 
347
  """
348
+
349
  payload = {
350
+ "model": "mistral-small-latest", # Or another capable model
351
+ "messages": [{"role": "user", "content": prompt}],
352
+ "max_tokens": 1000,
353
+ "temperature": 0.7
354
  }
 
355
  headers = {
356
+ # HF Space Note: Use the space's own URL and a valid API key if required by your setup.
357
+ # Here, we assume the playground key bypass works or use a dedicated internal key.
358
+ "Authorization": f"Bearer {list(env_vars['api_keys'])[0] if env_vars['api_keys'] else 'dummy-key'}" # Use first key or dummy
359
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
+ try:
362
+ # HF Space Note: Use the async client for internal requests too.
363
+ client = get_async_client()
364
+ api_url = f"{hf_space_url}/chat/completions" # Call own endpoint
365
+ response = await client.post(api_url, json=payload, headers=headers)
366
+ response.raise_for_status() # Raise exception for bad status codes
367
+ data = response.json()
368
+
369
+ html_content = data.get('choices', [{}])[0].get('message', {}).get('content', '')
370
 
371
+ # Basic cleanup (remove potential markdown backticks if model adds them)
372
+ html_content = re.sub(r"^```html\s*", "", html_content, flags=re.IGNORECASE)
373
+ html_content = re.sub(r"\s*```$", "", html_content)
374
 
375
+ if not html_content.strip().lower().startswith("<!doctype html"):
376
+ logger.warning("Dynamo page generation might be incomplete or malformed.")
377
+ # Optionally return a fallback static page here
378
 
379
+ return HTMLResponse(content=html_content)
 
380
 
381
+ except httpx.HTTPStatusError as e:
382
+ logger.error(f"Error calling self API for /dynamo: {e.response.status_code} - {e.response.text}")
383
+ raise HTTPException(status_code=502, detail=f"Failed to generate dynamic content: Upstream API error {e.response.status_code}")
384
+ except Exception as e:
385
+ logger.error(f"Unexpected error in /dynamo: {e}", exc_info=True)
386
+ raise HTTPException(status_code=500, detail="Failed to generate dynamic content due to an internal error.")
387
+
388
+
389
+ # Vetra Example (Fetching from GitHub)
390
+ # HF Space Note: Ensure outbound requests to raw.githubusercontent.com are allowed.
391
+ GITHUB_BASE = "https://raw.githubusercontent.com/Parthsadaria/Vetra/main"
392
+ VETRA_FILES = {"html": "index.html", "css": "style.css", "js": "script.js"}
393
 
394
+ async def get_github_file(filename: str) -> Optional[str]:
395
+ """Fetches a file from the Vetra GitHub repo."""
396
  url = f"{GITHUB_BASE}/{filename}"
397
+ try:
398
+ client = get_async_client()
399
  res = await client.get(url)
400
+ res.raise_for_status()
401
+ return res.text
402
+ except httpx.RequestError as e:
403
+ logger.error(f"Error fetching GitHub file {url}: {e}")
404
+ return None
405
+ except httpx.HTTPStatusError as e:
406
+ logger.error(f"GitHub file {url} returned status {e.response.status_code}")
407
+ return None
408
 
409
+ @app.get("/vetra", response_class=HTMLResponse, tags=["Examples"])
410
  async def serve_vetra():
411
+ """Serves the Vetra application by fetching components from GitHub."""
412
+ logger.info("Fetching Vetra files from GitHub...")
413
+ # Fetch files concurrently
414
+ html_task = asyncio.create_task(get_github_file(VETRA_FILES["html"]))
415
+ css_task = asyncio.create_task(get_github_file(VETRA_FILES["css"]))
416
+ js_task = asyncio.create_task(get_github_file(VETRA_FILES["js"]))
417
+
418
+ html, css, js = await asyncio.gather(html_task, css_task, js_task)
419
 
420
  if not html:
421
+ logger.error("Failed to fetch Vetra index.html")
422
+ return HTMLResponse(content="<h1>Error: Could not load Vetra application (HTML missing)</h1>", status_code=502)
423
+
424
+ # Inject CSS and JS into HTML
425
+ css_content = f"<style>{css or '/* CSS failed to load */'}</style>"
426
+ js_content = f"<script>{js or '// JS failed to load'}</script>"
427
 
428
+ # Inject carefully before closing tags
429
+ final_html = html.replace("</head>", f"{css_content}\n</head>", 1)
430
+ final_html = final_html.replace("</body>", f"{js_content}\n</body>", 1)
 
 
 
 
431
 
432
+ logger.info("Successfully served Vetra application.")
433
  return HTMLResponse(content=final_html)
434
 
435
 
436
+ # Model Info Endpoint
437
+ @app.get("/api/v1/models", tags=["Models"])
438
+ @app.get("/models", tags=["Models"])
439
+ async def return_models():
440
+ """Returns the list of available models loaded from models.json."""
441
+ # HF Space Note: This endpoint now relies on models.json being present.
442
+ # It no longer dynamically adds models defined only in the script's sets.
443
+ # Ensure models.json is comprehensive or adjust startup logic if needed.
444
+ return await get_models()
445
 
446
+ # Search Endpoint (using cloudscraper)
447
+ # HF Space Note: This uses cloudscraper which might be blocked or require updates.
448
+ # Consider replacing with a more stable search API if possible.
449
+ async def generate_search_async(query: str, systemprompt: Optional[str] = None) -> asyncio.Queue:
450
+ """Performs search using the configured backend and streams results."""
451
+ queue = asyncio.Queue()
452
+ env_vars = get_env_vars()
453
+ search_endpoint = env_vars.get('secret_api_endpoint_3')
454
 
455
+ async def _fetch_search_data():
456
+ if not search_endpoint:
457
+ await queue.put({"error": "Search API endpoint (SECRET_API_ENDPOINT_3) not configured"})
458
+ await queue.put(None) # Signal end
459
+ return
460
 
461
+ try:
462
+ scraper = get_scraper() # Get a scraper instance from the pool
463
+ loop = asyncio.get_running_loop()
 
 
464
 
465
+ system_message = systemprompt or "You are a helpful search assistant."
466
+ messages = [
467
+ {"role": "system", "content": system_message},
468
+ {"role": "user", "content": query},
469
+ ]
470
+ payload = {
471
+ "model": "searchgpt", # Assuming the endpoint expects this model name
472
+ "messages": messages,
473
+ "stream": True # Explicitly request streaming from backend
474
+ }
475
+ headers = {"User-Agent": "Mozilla/5.0"} # Standard user agent
476
+
477
+ # HF Space Note: Run synchronous scraper call in executor thread
478
+ response = await loop.run_in_executor(
479
+ executor,
480
+ scraper.post,
481
+ search_endpoint,
482
+ json=payload,
483
+ headers=headers,
484
+ stream=True # Request streaming from requests library perspective
485
+ )
486
+
487
+ response.raise_for_status()
488
+
489
+ # Process SSE stream
490
+ # HF Space Note: Iterating lines on the response directly can be blocking if not handled carefully.
491
+ # Using iter_lines with decode_unicode=True is generally safe.
492
+ for line in response.iter_lines(decode_unicode=True):
493
+ if line.startswith("data: "):
494
+ try:
495
+ data_str = line[6:]
496
+ if data_str.strip() == "[DONE]": # Check for OpenAI style completion
497
+ break
498
+ json_data = json.loads(data_str)
499
+ # Assuming OpenAI compatible streaming format
500
+ delta = json_data.get("choices", [{}])[0].get("delta", {})
501
+ content = delta.get("content")
502
+ if content:
503
+ # Reconstruct OpenAI-like SSE chunk
504
+ chunk = {
505
+ "id": json_data.get("id"),
506
+ "object": "chat.completion.chunk",
507
+ "created": json_data.get("created", int(time.time())),
508
+ "model": "searchgpt",
509
+ "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}]
510
+ }
511
+ await queue.put({"data": f"data: {json.dumps(chunk)}\n\n", "text": content})
512
+ # Check for finish reason
513
+ finish_reason = json_data.get("choices", [{}])[0].get("finish_reason")
514
+ if finish_reason:
515
+ chunk = {
516
+ "id": json_data.get("id"),
517
+ "object": "chat.completion.chunk",
518
+ "created": json_data.get("created", int(time.time())),
519
+ "model": "searchgpt",
520
+ "choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}]
521
+ }
522
+ await queue.put({"data": f"data: {json.dumps(chunk)}\n\n", "text": ""})
523
+ break # Stop processing after finish reason
524
+
525
+ except json.JSONDecodeError:
526
+ logger.warning(f"Failed to decode JSON from search stream: {line}")
527
+ continue
528
+ except Exception as e:
529
+ logger.error(f"Error processing search stream chunk: {e}", exc_info=True)
530
+ await queue.put({"error": f"Error processing stream: {e}"})
531
+ break # Stop on processing error
532
+
533
+ except requests.exceptions.RequestException as e:
534
+ logger.error(f"Search request failed: {e}")
535
+ await queue.put({"error": f"Search request failed: {e}"})
536
+ except Exception as e:
537
+ logger.error(f"Unexpected error during search: {e}", exc_info=True)
538
+ await queue.put({"error": f"An unexpected error occurred during search: {e}"})
539
+ finally:
540
+ await queue.put(None) # Signal completion
541
+
542
+
543
+ asyncio.create_task(_fetch_search_data())
544
+ return queue
545
+
546
+ @app.get("/searchgpt", tags=["Search"])
547
+ async def search_gpt(q: str, stream: bool = True, systemprompt: Optional[str] = None):
548
+ """
549
+ Performs a search using the backend search model and streams results.
550
+ Pass `stream=false` to get the full response at once.
551
+ """
552
  if not q:
553
  raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
554
 
555
+ # HF Space Note: Ensure usage_tracker is thread-safe if used across async/sync boundaries.
556
+ # The dummy tracker used when the module isn't found is safe.
557
  usage_tracker.record_request(endpoint="/searchgpt")
558
 
559
+ queue = await generate_search_async(q, systemprompt=systemprompt)
560
 
561
  if stream:
562
  async def stream_generator():
563
+ full_response_text = "" # Keep track for non-streaming case if needed
564
  while True:
565
  item = await queue.get()
566
+ if item is None: # End of stream signal
567
  break
 
568
  if "error" in item:
569
+ # HF Space Note: Log errors server-side, return generic error to client for security.
570
+ logger.error(f"Search stream error: {item['error']}")
571
+ # Send an error event in the stream
572
+ error_event = {"error": {"message": "Search failed.", "code": 500}}
573
+ yield f"data: {json.dumps(error_event)}\n\n"
574
  break
 
575
  if "data" in item:
576
  yield item["data"]
577
+ full_response_text += item.get("text", "")
578
+ # Optionally yield a [DONE] message if backend doesn't guarantee it
579
+ # yield "data: [DONE]\n\n"
580
 
581
  return StreamingResponse(
582
  stream_generator(),
583
+ media_type="text/event-stream",
584
+ headers={
585
+ "Content-Type": "text/event-stream",
586
+ "Cache-Control": "no-cache",
587
+ "Connection": "keep-alive",
588
+ "X-Accel-Buffering": "no" # Crucial for Nginx/proxies in HF Spaces
589
+ }
590
  )
591
  else:
592
+ # Collect full response for non-streaming request
593
+ full_response_text = ""
594
  while True:
595
  item = await queue.get()
596
  if item is None:
597
  break
 
598
  if "error" in item:
599
+ logger.error(f"Search non-stream error: {item['error']}")
600
+ raise HTTPException(status_code=502, detail=f"Search failed: {item['error']}")
601
+ full_response_text += item.get("text", "")
602
+
603
+ # Mimic OpenAI non-streaming response structure
604
+ return JSONResponse(content={
605
+ "id": f"search-{int(time.time())}",
606
+ "object": "chat.completion",
607
+ "created": int(time.time()),
608
+ "model": "searchgpt",
609
+ "choices": [{
610
+ "index": 0,
611
+ "message": {
612
+ "role": "assistant",
613
+ "content": full_response_text,
614
+ },
615
+ "finish_reason": "stop",
616
+ }],
617
+ "usage": { # Note: Token usage is unknown here
618
+ "prompt_tokens": None,
619
+ "completion_tokens": None,
620
+ "total_tokens": None,
621
+ }
622
+ })
623
 
624
 
625
+ # Main Chat Completions Proxy
626
+ @app.post("/api/v1/chat/completions", tags=["Chat Completions"])
627
+ @app.post("/chat/completions", tags=["Chat Completions"])
628
+ async def get_completion(
629
+ payload: Payload,
630
+ request: Request,
631
+ authenticated: bool = Depends(verify_api_key) # Apply authentication
632
+ ):
633
+ """
634
+ Proxies chat completion requests to the appropriate backend API based on the model.
635
+ Supports streaming (SSE).
636
+ """
637
  if not server_status:
638
+ raise HTTPException(status_code=503, detail="Server is under maintenance.")
 
 
 
639
 
640
+ model_to_use = payload.model or "gpt-4o-mini" # Default model
641
 
642
+ # HF Space Note: Check against models loaded at startup.
643
+ if available_model_ids and model_to_use not in available_model_ids:
644
+ logger.warning(f"Requested model '{model_to_use}' not in available list.")
645
+ # Check if it's a known category even if not explicitly in models.json
646
+ known_categories = mistral_models | pollinations_models | alternate_models | claude_3_models
647
+ if model_to_use not in known_categories:
648
+ raise HTTPException(
649
+ status_code=400,
650
+ detail=f"Model '{model_to_use}' is not available or recognized. Check /models."
651
+ )
652
+ else:
653
+ logger.info(f"Allowing known category model '{model_to_use}' despite not being in models.json.")
654
 
655
+
656
+ # Log request asynchronously
657
  asyncio.create_task(log_request(request, model_to_use))
658
  usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
659
 
660
+ # Prepare payload for the target API
661
+ payload_dict = payload.dict(exclude_none=True) # Exclude None values
662
+ payload_dict["model"] = model_to_use # Ensure model is set
 
 
 
663
 
 
664
  env_vars = get_env_vars()
665
+ hf_space_url = env_vars.get('hf_space_url', '') # Needed for Referer/Origin
666
+
667
+ # Determine target endpoint and headers
668
+ endpoint = None
669
+ custom_headers = {}
670
 
 
671
  if model_to_use in mistral_models:
672
+ endpoint = env_vars.get('mistral_api')
673
+ api_key = env_vars.get('mistral_key')
674
+ if not endpoint or not api_key:
675
+ raise HTTPException(status_code=500, detail="Mistral API endpoint or key not configured.")
676
+ custom_headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "application/json"}
677
+ # Mistral specific adjustments if needed
678
+ # payload_dict.pop('system', None) # Example: if Mistral doesn't use 'system' role
679
+
680
  elif model_to_use in pollinations_models:
681
+ endpoint = env_vars.get('secret_api_endpoint_4')
682
+ if not endpoint:
683
+ raise HTTPException(status_code=500, detail="Pollinations API endpoint (SECRET_API_ENDPOINT_4) not configured.")
684
+ # Pollinations might need specific headers? Add them here.
685
+ custom_headers = {"Content-Type": "application/json"}
686
+
687
  elif model_to_use in alternate_models:
688
+ endpoint = env_vars.get('secret_api_endpoint_2')
689
+ if not endpoint:
690
+ raise HTTPException(status_code=500, detail="Alternate API endpoint (SECRET_API_ENDPOINT_2) not configured.")
691
+ custom_headers = {"Content-Type": "application/json"}
692
+
693
+ elif model_to_use in claude_3_models:
694
+ endpoint = env_vars.get('secret_api_endpoint_5')
695
+ if not endpoint:
696
+ raise HTTPException(status_code=500, detail="Claude 3 API endpoint (SECRET_API_ENDPOINT_5) not configured.")
697
+ custom_headers = {"Content-Type": "application/json"}
698
+ # Claude specific headers (like anthropic-version) might be needed
699
+ # custom_headers["anthropic-version"] = "2023-06-01"
700
+
701
+ else: # Default endpoint
702
+ endpoint = env_vars.get('secret_api_endpoint')
703
+ if not endpoint:
704
+ raise HTTPException(status_code=500, detail="Default API endpoint (SECRET_API_ENDPOINT) not configured.")
705
+ # Default endpoint might need Origin/Referer
706
+ if hf_space_url:
707
+ custom_headers = {
708
+ "Origin": hf_space_url,
709
+ "Referer": hf_space_url,
710
+ "Content-Type": "application/json"
711
+ }
712
+ else:
713
+ custom_headers = {"Content-Type": "application/json"}
714
+
715
 
716
+ target_url = f"{endpoint.rstrip('/')}/v1/chat/completions" # Assume OpenAI compatible path
717
+ logger.info(f"Proxying request for model '{model_to_use}' to endpoint: {endpoint}")
718
 
719
+ client = get_async_client()
720
+
721
+ async def stream_generator():
722
+ """Generator for streaming the response."""
723
+ nonlocal target_url # Allow modification if needed
724
  try:
725
+ async with client.stream("POST", target_url, json=payload_dict, headers=custom_headers) as response:
726
+ # Check for initial errors before streaming
727
+ if response.status_code >= 400:
728
+ error_body = await response.aread()
729
+ logger.error(f"Upstream API error: {response.status_code} - {error_body.decode()}")
730
+ # Try to parse error detail from upstream
731
+ detail = f"Upstream API error: {response.status_code}"
732
+ try:
733
+ error_json = json.loads(error_body)
734
+ detail = error_json.get('error', {}).get('message', detail)
735
+ except json.JSONDecodeError:
736
+ pass
737
+ # Send error as SSE event
738
+ error_event = {"error": {"message": detail, "code": response.status_code}}
739
+ yield f"data: {json.dumps(error_event)}\n\n"
740
+ return # Stop generation
741
+
742
+ # Stream the response line by line
743
+ async for line in response.aiter_lines():
744
+ if line:
745
+ # Pass through the data directly
746
+ yield line + "\n"
747
+ # Ensure stream is properly closed, yield [DONE] if backend doesn't
748
+ # Some backends might not send [DONE], uncomment if needed
749
+ # yield "data: [DONE]\n\n"
750
+
751
  except httpx.TimeoutException:
752
+ logger.error(f"Request to {target_url} timed out.")
753
+ error_event = {"error": {"message": "Request timed out", "code": 504}}
754
+ yield f"data: {json.dumps(error_event)}\n\n"
755
  except httpx.RequestError as e:
756
+ logger.error(f"Failed to connect to upstream API {target_url}: {e}")
757
+ error_event = {"error": {"message": f"Upstream connection error: {e}", "code": 502}}
758
+ yield f"data: {json.dumps(error_event)}\n\n"
759
  except Exception as e:
760
+ logger.error(f"An unexpected error occurred during streaming proxy: {e}", exc_info=True)
761
+ error_event = {"error": {"message": f"Internal server error: {e}", "code": 500}}
762
+ yield f"data: {json.dumps(error_event)}\n\n"
763
 
764
+ if payload.stream:
 
765
  return StreamingResponse(
766
+ stream_generator(),
767
  media_type="text/event-stream",
768
  headers={
769
  "Content-Type": "text/event-stream",
770
  "Cache-Control": "no-cache",
771
  "Connection": "keep-alive",
772
+ "X-Accel-Buffering": "no" # Essential for HF Spaces proxying SSE
773
  }
774
  )
775
  else:
776
+ # Handle non-streaming request by collecting the streamed chunks
777
+ full_response_content = ""
778
+ final_json_response = None
779
+ async for line in stream_generator():
780
+ if line.startswith("data: "):
781
+ data_str = line[6:].strip()
782
+ if data_str == "[DONE]":
783
+ break
784
+ try:
785
+ chunk = json.loads(data_str)
786
+ # Check for error chunk
787
+ if "error" in chunk:
788
+ logger.error(f"Received error during non-stream collection: {chunk['error']}")
789
+ raise HTTPException(status_code=chunk['error'].get('code', 502), detail=chunk['error'].get('message', 'Upstream API error'))
790
+
791
+ # Accumulate content from delta
792
+ delta = chunk.get("choices", [{}])[0].get("delta", {})
793
+ content = delta.get("content")
794
+ if content:
795
+ full_response_content += content
796
+
797
+ # Store the last chunk structure to reconstruct the final response
798
+ # We assume the last chunk contains necessary info like id, model, etc.
799
+ # but we overwrite the choices/message part.
800
+ final_json_response = chunk # Keep the structure
801
+ # Check for finish reason
802
+ finish_reason = chunk.get("choices", [{}])[0].get("finish_reason")
803
+ if finish_reason:
804
+ break # Stop collecting
805
+
806
+ except json.JSONDecodeError:
807
+ logger.warning(f"Could not decode JSON chunk in non-stream mode: {data_str}")
808
+ except Exception as e:
809
+ logger.error(f"Error processing chunk in non-stream mode: {e}")
810
+ raise HTTPException(status_code=500, detail="Error processing response stream.")
811
+
812
+ if final_json_response is None:
813
+ # Handle cases where no valid data chunks were received
814
+ logger.error("No valid response chunks received for non-streaming request.")
815
+ raise HTTPException(status_code=502, detail="Failed to get valid response from upstream API.")
816
+
817
+
818
+ # Reconstruct OpenAI-like non-streaming response
819
+ final_response_obj = {
820
+ "id": final_json_response.get("id", f"chatcmpl-{int(time.time())}"),
821
+ "object": "chat.completion",
822
+ "created": final_json_response.get("created", int(time.time())),
823
+ "model": model_to_use, # Use the requested model
824
+ "choices": [{
825
+ "index": 0,
826
+ "message": {
827
+ "role": "assistant",
828
+ "content": full_response_content,
829
+ },
830
+ "finish_reason": final_json_response.get("choices", [{}])[0].get("finish_reason", "stop"), # Get finish reason from last chunk
831
+ }],
832
+ "usage": { # Token usage might be in the last chunk for some APIs, otherwise unknown
833
+ "prompt_tokens": None,
834
+ "completion_tokens": None,
835
+ "total_tokens": None,
836
+ }
837
+ }
838
+ # Attempt to extract usage if present in the (potentially non-standard) final chunk
839
+ usage_data = final_json_response.get("usage")
840
+ if isinstance(usage_data, dict):
841
+ final_response_obj["usage"].update(usage_data)
842
 
 
843
 
844
+ return JSONResponse(content=final_response_obj)
845
 
846
 
847
+ # Image Generation Endpoint
848
+ @app.post("/images/generations", tags=["Image Generation"])
849
+ async def create_image(
850
+ payload: ImageGenerationPayload,
851
+ authenticated: bool = Depends(verify_api_key)
852
+ ):
853
  """
854
+ Generates images based on a text prompt using the configured backend.
855
  """
 
856
  if not server_status:
857
+ raise HTTPException(status_code=503, detail="Server is under maintenance.")
 
 
 
858
 
 
859
  if payload.model not in supported_image_models:
860
  raise HTTPException(
861
  status_code=400,
862
+ detail=f"Model '{payload.model}' is not supported for image generation. Supported: {', '.join(supported_image_models)}"
863
  )
864
 
 
865
  usage_tracker.record_request(model=payload.model, endpoint="/images/generations")
866
 
867
+ env_vars = get_env_vars()
868
+ target_api_url = env_vars.get('new_img_endpoint')
869
+ if not target_api_url:
870
+ raise HTTPException(status_code=500, detail="Image generation endpoint (NEW_IMG) not configured.")
871
+
872
+ # Prepare payload for the target API (adjust keys if needed)
873
+ # HF Space Note: Ensure the keys match the actual API expected by NEW_IMG endpoint.
874
+ # Assuming it's OpenAI compatible here.
875
  api_payload = {
876
  "model": payload.model,
877
  "prompt": payload.prompt,
878
+ "n": payload.n,
879
+ "size": payload.size
880
  }
881
+ # Remove None values the target API might not like
882
+ api_payload = {k: v for k, v in api_payload.items() if v is not None}
883
 
 
 
884
 
885
+ logger.info(f"Requesting image generation for model '{payload.model}' from {target_api_url}")
886
+ client = get_async_client()
 
 
887
 
888
+ try:
889
+ # HF Space Note: Image generation can take time, use a longer timeout if needed.
890
+ # Consider making this truly async if the backend supports webhooks or polling.
891
+ response = await client.post(target_api_url, json=api_payload, timeout=120.0) # 2 min timeout
892
+ response.raise_for_status() # Raise HTTP errors
893
 
894
+ # Return the exact response from the backend
895
  return JSONResponse(content=response.json())
896
 
897
  except httpx.TimeoutException:
898
+ logger.error(f"Image generation request to {target_api_url} timed out.")
899
  raise HTTPException(status_code=504, detail="Image generation request timed out.")
900
+ except httpx.HTTPStatusError as e:
901
+ logger.error(f"Image generation API error: {e.response.status_code} - {e.response.text}")
902
+ detail = f"Image generation failed: Upstream API error {e.response.status_code}"
903
+ try:
904
+ err_json = e.response.json()
905
+ detail = err_json.get('error', {}).get('message', detail)
906
+ except json.JSONDecodeError:
907
+ pass
908
+ raise HTTPException(status_code=e.response.status_code, detail=detail)
909
  except httpx.RequestError as e:
910
+ logger.error(f"Error connecting to image generation service {target_api_url}: {e}")
911
  raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}")
912
  except Exception as e:
913
+ logger.error(f"Unexpected error during image generation: {e}", exc_info=True)
914
  raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
915
 
916
 
917
+ # --- Utility & Admin Endpoints ---
918
 
919
+ async def log_request(request: Request, model: Optional[str] = None):
920
+ """Logs basic request information asynchronously."""
921
+ # HF Space Note: Avoid logging sensitive info like full IP or headers unless necessary.
922
+ # Hashing IP provides some privacy.
923
+ client_host = request.client.host if request.client else "unknown"
924
+ ip_hash = hash(client_host) % 10000
925
+ timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
926
+ log_message = f"Timestamp: {timestamp}, IP Hash: {ip_hash}, Method: {request.method}, Path: {request.url.path}"
927
+ if model:
928
+ log_message += f", Model: {model}"
929
+ logger.info(log_message)
930
 
 
 
 
 
931
 
932
+ @app.get("/usage", tags=["Admin"])
933
  async def get_usage(days: int = 7):
934
+ """Retrieves aggregated usage statistics."""
935
+ # HF Space Note: Ensure usage_tracker methods are efficient, especially get_usage_summary.
936
+ # Caching might be needed if it becomes slow.
937
+ if days <= 0:
938
+ raise HTTPException(status_code=400, detail="Number of days must be positive.")
939
+ try:
940
+ # Run potentially CPU-bound summary generation in executor
941
+ loop = asyncio.get_running_loop()
942
+ summary = await loop.run_in_executor(executor, usage_tracker.get_usage_summary, days)
943
+ return summary
944
+ except Exception as e:
945
+ logger.error(f"Error retrieving usage statistics: {e}", exc_info=True)
946
+ raise HTTPException(status_code=500, detail="Failed to retrieve usage statistics.")
947
+
948
+ # HF Space Note: Generating HTML dynamically can be resource-intensive.
949
+ # Consider caching the generated HTML or serving a static page updated periodically.
950
+ def generate_usage_html(usage_data: Dict) -> str:
951
+ """Generates an HTML report from usage data."""
952
+ # (Keep the HTML generation logic as provided in the original file)
953
+ # ... (rest of the HTML generation code from the original file) ...
954
+ # Ensure this function handles potentially missing keys gracefully
955
+ models_usage = usage_data.get('models', {})
956
+ endpoints_usage = usage_data.get('api_endpoints', {})
957
+ daily_usage = usage_data.get('recent_daily_usage', {})
958
+ total_requests = usage_data.get('total_requests', 0)
959
 
 
 
 
960
  model_usage_rows = "\n".join([
961
  f"""
962
  <tr>
963
  <td>{model}</td>
964
+ <td>{model_data.get('total_requests', 'N/A')}</td>
965
+ <td>{model_data.get('first_used', 'N/A')}</td>
966
+ <td>{model_data.get('last_used', 'N/A')}</td>
967
  </tr>
968
+ """ for model, model_data in models_usage.items()
969
+ ]) if models_usage else "<tr><td colspan='4'>No model usage data</td></tr>"
970
 
 
971
  api_usage_rows = "\n".join([
972
  f"""
973
  <tr>
974
  <td>{endpoint}</td>
975
+ <td>{endpoint_data.get('total_requests', 'N/A')}</td>
976
+ <td>{endpoint_data.get('first_used', 'N/A')}</td>
977
+ <td>{endpoint_data.get('last_used', 'N/A')}</td>
978
  </tr>
979
+ """ for endpoint, endpoint_data in endpoints_usage.items()
980
+ ]) if endpoints_usage else "<tr><td colspan='4'>No API endpoint usage data</td></tr>"
981
 
 
982
  daily_usage_rows = "\n".join([
983
+ f"""
984
+ <tr>
985
+ <td>{date}</td>
986
+ <td>{entity}</td>
987
+ <td>{requests}</td>
988
+ </tr>
989
+ """
990
+ for date, date_data in daily_usage.items()
991
+ for entity, requests in date_data.items()
992
+ ]) if daily_usage else "<tr><td colspan='3'>No daily usage data</td></tr>"
993
 
994
+
995
+ # HF Space Note: Using f-string for large HTML is okay, but consider template engines (Jinja2)
996
+ # for more complex pages. Ensure CSS/JS are either inline or served via separate endpoints.
997
  html_content = f"""
998
  <!DOCTYPE html>
999
  <html lang="en">
1000
  <head>
1001
  <meta charset="UTF-8">
1002
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
1003
  <title>Lokiai AI - Usage Statistics</title>
1004
  <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet">
1005
  <style>
1006
+ /* (Keep the CSS styles as provided in the original file) */
1007
  :root {{
1008
+ --bg-dark: #0f1011; --bg-darker: #070708; --text-primary: #e6e6e6;
1009
+ --text-secondary: #8c8c8c; --border-color: #2c2c2c; --accent-color: #3a6ee0;
 
 
 
 
1010
  --accent-hover: #4a7ef0;
1011
  }}
1012
+ body {{ font-family: 'Inter', sans-serif; background-color: var(--bg-dark); color: var(--text-primary); max-width: 1200px; margin: 0 auto; padding: 40px 20px; line-height: 1.6; }}
1013
+ .logo {{ display: flex; align-items: center; justify-content: center; margin-bottom: 30px; }}
1014
+ .logo h1 {{ font-weight: 600; font-size: 2.5em; color: var(--text-primary); margin-left: 15px; }}
1015
+ .logo img {{ width: 60px; height: 60px; border-radius: 10px; }}
1016
+ .container {{ background-color: var(--bg-darker); border-radius: 12px; padding: 30px; box-shadow: 0 15px 40px rgba(0,0,0,0.3); border: 1px solid var(--border-color); }}
1017
+ h2, h3 {{ color: var(--text-primary); border-bottom: 2px solid var(--border-color); padding-bottom: 10px; font-weight: 500; }}
1018
+ .total-requests {{ background-color: var(--accent-color); color: white; text-align: center; padding: 15px; border-radius: 8px; margin-bottom: 30px; font-weight: 600; letter-spacing: -0.5px; }}
1019
+ table {{ width: 100%; border-collapse: separate; border-spacing: 0; margin-bottom: 30px; background-color: var(--bg-dark); border-radius: 8px; overflow: hidden; }}
1020
+ th, td {{ border: 1px solid var(--border-color); padding: 12px; text-align: left; transition: background-color 0.3s ease; }}
1021
+ th {{ background-color: #1e1e1e; color: var(--text-primary); font-weight: 600; text-transform: uppercase; font-size: 0.9em; }}
1022
+ tr:nth-child(even) {{ background-color: rgba(255,255,255,0.05); }}
1023
+ tr:hover {{ background-color: rgba(62,100,255,0.1); }}
1024
+ @media (max-width: 768px) {{ .container {{ padding: 15px; }} table {{ font-size: 0.9em; }} }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1025
  </style>
1026
  </head>
1027
  <body>
1028
  <div class="container">
1029
  <div class="logo">
1030
  <img src="data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjAwIiBoZWlnaHQ9IjIwMCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48cGF0aCBkPSJNMTAwIDM1TDUwIDkwaDEwMHoiIGZpbGw9IiMzYTZlZTAiLz48Y2lyY2xlIGN4PSIxMDAiIGN5PSIxNDAiIHI9IjMwIiBmaWxsPSIjM2E2ZWUwIi8+PC9zdmc+" alt="Lokai AI Logo">
1031
+ <h1>Lokiai AI Usage</h1>
1032
  </div>
1033
 
1034
  <div class="total-requests">
1035
+ Total API Requests Recorded: {total_requests}
1036
  </div>
1037
 
1038
  <h2>Model Usage</h2>
1039
  <table>
1040
+ <thead><tr><th>Model</th><th>Total Requests</th><th>First Used</th><th>Last Used</th></tr></thead>
1041
+ <tbody>{model_usage_rows}</tbody>
 
 
 
 
 
1042
  </table>
1043
 
1044
  <h2>API Endpoint Usage</h2>
1045
  <table>
1046
+ <thead><tr><th>Endpoint</th><th>Total Requests</th><th>First Used</th><th>Last Used</th></tr></thead>
1047
+ <tbody>{api_usage_rows}</tbody>
 
 
 
 
 
1048
  </table>
1049
 
1050
+ <h2>Daily Usage (Last {usage_data.get('days_analyzed', 7)} Days)</h2>
1051
  <table>
1052
+ <thead><tr><th>Date</th><th>Entity (Model/Endpoint)</th><th>Requests</th></tr></thead>
1053
+ <tbody>{daily_usage_rows}</tbody>
 
 
 
 
1054
  </table>
1055
  </div>
1056
  </body>
 
1058
  """
1059
  return html_content
1060
 
1061
+ # HF Space Note: Caching the generated HTML page can save resources.
1062
+ # Invalidate cache periodically or when usage data changes significantly.
1063
+ usage_html_cache = {"content": None, "timestamp": 0}
1064
+ CACHE_DURATION = 300 # Cache usage page for 5 minutes
 
1065
 
1066
+ @app.get("/usage/page", response_class=HTMLResponse, tags=["Admin"])
1067
  async def usage_page():
1068
+ """Serves an HTML page showing usage statistics."""
1069
+ now = time.monotonic()
1070
+ if usage_html_cache["content"] and (now - usage_html_cache["timestamp"] < CACHE_DURATION):
1071
+ logger.info("Serving cached usage page.")
1072
+ return HTMLResponse(content=usage_html_cache["content"])
1073
+
1074
+ logger.info("Generating fresh usage page.")
1075
+ try:
1076
+ # Run potentially slow parts in executor
1077
+ loop = asyncio.get_running_loop()
1078
+ usage_data = await loop.run_in_executor(executor, usage_tracker.get_usage_summary, 7) # Get data for 7 days
1079
+ html_content = await loop.run_in_executor(executor, generate_usage_html, usage_data)
1080
+
1081
+ # Update cache
1082
+ usage_html_cache["content"] = html_content
1083
+ usage_html_cache["timestamp"] = now
1084
+
1085
+ return HTMLResponse(content=html_content)
1086
+ except Exception as e:
1087
+ logger.error(f"Failed to generate usage page: {e}", exc_info=True)
1088
+ # Serve stale cache if available, otherwise error
1089
+ if usage_html_cache["content"]:
1090
+ logger.warning("Serving stale usage page due to generation error.")
1091
+ return HTMLResponse(content=usage_html_cache["content"])
1092
+ else:
1093
+ raise HTTPException(status_code=500, detail="Failed to generate usage statistics page.")
1094
+
1095
 
1096
+ # Meme Endpoint
1097
+ @app.get("/meme", tags=["Fun"])
1098
  async def get_meme():
1099
+ """Fetches a random meme and streams the image."""
1100
+ # HF Space Note: Ensure meme-api.com is accessible from the HF Space network.
1101
+ client = get_async_client()
1102
+ meme_api_url = "https://meme-api.com/gimme"
1103
  try:
1104
+ logger.info("Fetching meme info...")
1105
+ response = await client.get(meme_api_url)
1106
+ response.raise_for_status()
1107
  response_data = response.json()
1108
 
1109
  meme_url = response_data.get("url")
1110
+ if not meme_url or not isinstance(meme_url, str):
1111
+ logger.error(f"Invalid meme URL received from API: {meme_url}")
1112
+ raise HTTPException(status_code=502, detail="Failed to get valid meme URL from API.")
1113
+
1114
+ logger.info(f"Fetching meme image: {meme_url}")
1115
+ # Use streaming request for the image itself
1116
+ async with client.stream("GET", meme_url) as image_response:
1117
+ image_response.raise_for_status() # Check if image URL is valid
1118
+
1119
+ # Get content type, default to image/png
1120
+ media_type = image_response.headers.get("content-type", "image/png")
1121
+ if not media_type.startswith("image/"):
1122
+ logger.warning(f"Unexpected content type '{media_type}' for meme URL: {meme_url}")
1123
+ # You might want to reject non-image types
1124
+ # raise HTTPException(status_code=502, detail="Meme URL did not return an image.")
1125
+
1126
+
1127
+ # Stream the image content directly
1128
+ return StreamingResponse(
1129
+ image_response.aiter_bytes(),
1130
+ media_type=media_type,
1131
+ headers={'Cache-Control': 'no-cache'} # Don't cache the meme itself heavily
1132
+ )
1133
+
1134
+ except httpx.HTTPStatusError as e:
1135
+ logger.error(f"HTTP error fetching meme ({e.request.url}): {e.response.status_code}")
1136
+ raise HTTPException(status_code=502, detail=f"Failed to fetch meme (HTTP {e.response.status_code})")
1137
+ except httpx.RequestError as e:
1138
+ logger.error(f"Network error fetching meme ({e.request.url}): {e}")
1139
+ raise HTTPException(status_code=502, detail="Failed to fetch meme (Network Error)")
1140
+ except Exception as e:
1141
+ logger.error(f"Unexpected error fetching meme: {e}", exc_info=True)
1142
+ raise HTTPException(status_code=500, detail="Failed to retrieve meme due to an internal error.")
1143
 
 
1144
 
1145
+ # Health Check Endpoint
1146
+ @app.get("/health", tags=["Utility"])
1147
+ async def health_check():
1148
+ """Provides a health check status, including missing critical configurations."""
1149
+ env_vars = get_env_vars()
1150
+ missing_critical_vars = []
 
1151
 
1152
+ # Define critical vars needed for core functionality
1153
+ critical_vars = [
1154
+ 'api_keys', 'secret_api_endpoint', 'secret_api_endpoint_2',
1155
+ 'secret_api_endpoint_3', 'secret_api_endpoint_4', 'secret_api_endpoint_5',
1156
+ 'new_img_endpoint', 'hf_space_url'
1157
+ ]
1158
+ # Conditionally critical vars
1159
+ if any(model in mistral_models for model in available_model_ids):
1160
+ critical_vars.extend(['mistral_api', 'mistral_key'])
1161
+
1162
+ for var_name in critical_vars:
1163
+ value = env_vars.get(var_name)
1164
+ # Check for None or empty strings/lists/sets
1165
+ if value is None or (isinstance(value, (str, list, set)) and not value):
1166
+ missing_critical_vars.append(var_name)
1167
+
1168
+ is_healthy = not missing_critical_vars and server_status
1169
+ status_code = 200 if is_healthy else 503 # Service Unavailable if unhealthy
1170
 
1171
+ health_status = {
1172
+ "status": "healthy" if is_healthy else "unhealthy",
1173
+ "server_mode": "online" if server_status else "maintenance",
1174
+ "missing_critical_env_vars": missing_critical_vars,
1175
+ "details": "All critical configurations seem okay. Ready to roll! 🚀" if is_healthy else "Service issues detected. Check missing env vars or server status. 🛠️"
1176
+ }
1177
+ return JSONResponse(content=health_status, status_code=status_code)
1178
 
 
 
 
 
 
 
 
1179
 
1180
+ # --- Startup and Shutdown Events ---
 
 
 
 
 
 
 
 
 
1181
 
1182
  @app.on_event("startup")
1183
  async def startup_event():
1184
+ """Tasks to run when the application starts."""
1185
  global available_model_ids
1186
+ logger.info("Application startup sequence initiated...")
1187
+
1188
+ # Load models from JSON
1189
+ models_from_file = load_models_data()
1190
+ model_ids_from_file = {model['id'] for model in models_from_file if 'id' in model}
1191
+
1192
+ # Combine models from file and predefined sets
1193
+ predefined_model_sets = mistral_models | pollinations_models | alternate_models | claude_3_models
1194
+ all_model_ids = model_ids_from_file.union(predefined_model_sets)
1195
+ available_model_ids = sorted(list(all_model_ids)) # Keep as sorted list
1196
+
1197
+ logger.info(f"Loaded {len(model_ids_from_file)} models from models.json.")
1198
+ logger.info(f"Total {len(available_model_ids)} unique models available.")
1199
+
1200
+ # Initialize scraper pool (can take time)
1201
+ # HF Space Note: Run potentially blocking I/O in executor during startup
1202
+ loop = asyncio.get_running_loop()
1203
+ await loop.run_in_executor(executor, get_scraper) # This initializes the pool
1204
+
1205
+ # Validate critical environment variables and log warnings
1206
  env_vars = get_env_vars()
1207
+ logger.info("Checking critical environment variables (Secrets)...")
1208
+ await health_check() # Run health check logic to log warnings
1209
+
1210
+ # Pre-connect async client? Optional, httpx handles connections on demand.
1211
+ # client = get_async_client()
1212
+ # await client.get("https://www.google.com") # Example warm-up call
1213
+
1214
+ logger.info("Startup complete. Server is ready to accept requests.")
1215
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1216
 
1217
  @app.on_event("shutdown")
1218
  async def shutdown_event():
1219
+ """Tasks to run when the application shuts down."""
1220
+ logger.info("Application shutdown sequence initiated...")
1221
+
1222
+ # Close the httpx client gracefully
1223
  client = get_async_client()
1224
  await client.aclose()
1225
+ logger.info("HTTP client closed.")
1226
+
1227
+ # Shutdown the thread pool executor
1228
+ executor.shutdown(wait=True)
1229
+ logger.info("Thread pool executor shut down.")
1230
 
1231
+ # Clear scraper pool (optional, resources will be reclaimed anyway)
1232
  scraper_pool.clear()
1233
+ logger.info("Scraper pool cleared.")
1234
 
1235
  # Persist usage data
1236
+ # HF Space Note: Ensure file system is writable if saving locally.
1237
+ # Consider using HF Datasets or external DB for persistent storage.
1238
+ try:
1239
+ logger.info("Saving usage data...")
1240
+ usage_tracker.save_data()
1241
+ logger.info("Usage data saved.")
1242
+ except Exception as e:
1243
+ logger.error(f"Failed to save usage data during shutdown: {e}")
 
 
 
1244
 
1245
+ logger.info("Shutdown complete.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1246
 
 
 
 
 
 
 
 
1247
 
1248
+ # --- Main Execution Block ---
1249
+ # HF Space Note: This block is mainly for local testing.
1250
+ # HF Spaces usually run the app using `uvicorn main:app --host 0.0.0.0 --port 7860` (or similar)
1251
+ # defined in the README metadata or a Procfile.
1252
  if __name__ == "__main__":
1253
  import uvicorn
1254
+ logger.info("Starting server locally with uvicorn...")
1255
+ # HF Space Note: Port 7860 is the default for HF Spaces. Host 0.0.0.0 is required.
1256
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")