Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -1,886 +1,1056 @@
|
|
1 |
import os
|
2 |
import re
|
3 |
-
from dotenv import load_dotenv
|
4 |
-
from fastapi import FastAPI, HTTPException, Request, Depends, Security
|
5 |
-
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
|
6 |
-
from fastapi.security import APIKeyHeader
|
7 |
-
from pydantic import BaseModel
|
8 |
-
import httpx
|
9 |
-
from functools import lru_cache
|
10 |
-
from pathlib import Path
|
11 |
import json
|
12 |
import datetime
|
13 |
import time
|
14 |
-
import threading
|
15 |
-
from typing import Optional, Dict, List, Any, Generator
|
16 |
import asyncio
|
17 |
-
|
18 |
-
import
|
|
|
|
|
19 |
from concurrent.futures import ThreadPoolExecutor
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
from fastapi.middleware.gzip import GZipMiddleware
|
22 |
from starlette.middleware.cors import CORSMiddleware
|
23 |
-
import
|
24 |
-
import requests
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
27 |
|
28 |
-
#
|
29 |
-
|
|
|
30 |
|
31 |
-
#
|
|
|
32 |
load_dotenv()
|
33 |
|
34 |
-
#
|
|
|
|
|
|
|
|
|
35 |
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
36 |
|
37 |
-
#
|
38 |
-
from usage_tracker import UsageTracker
|
39 |
-
usage_tracker = UsageTracker()
|
40 |
|
41 |
-
app = FastAPI(
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
#
|
44 |
-
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
45 |
app.add_middleware(
|
46 |
-
CORSMiddleware,
|
47 |
-
allow_origins=["*"],
|
48 |
allow_credentials=True,
|
49 |
allow_methods=["*"],
|
50 |
allow_headers=["*"],
|
51 |
)
|
52 |
|
53 |
-
# Environment
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
56 |
return {
|
57 |
-
'api_keys': os.getenv('API_KEYS', '').split(','),
|
58 |
'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'),
|
59 |
'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
|
60 |
-
'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'),
|
61 |
-
'secret_api_endpoint_4': "https://text.pollinations.ai/openai",
|
62 |
-
'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'),
|
63 |
-
'mistral_api': "https://api.mistral.ai",
|
64 |
'mistral_key': os.getenv('MISTRAL_KEY'),
|
65 |
-
'
|
|
|
66 |
}
|
67 |
|
68 |
-
#
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
"mistral-moderation-latest",
|
73 |
-
"ministral-3b-latest",
|
74 |
-
"
|
75 |
-
"open-mistral-nemo",
|
76 |
-
"mistral-small-latest",
|
77 |
-
"mistral-saba-latest",
|
78 |
-
"codestral-latest"
|
79 |
}
|
80 |
|
81 |
-
pollinations_models = {
|
82 |
-
"openai",
|
83 |
-
"
|
84 |
-
"
|
85 |
-
"openai-
|
86 |
-
"qwen-coder",
|
87 |
-
"llama",
|
88 |
-
"mistral",
|
89 |
-
"searchgpt",
|
90 |
-
"deepseek",
|
91 |
-
"claude-hybridspace",
|
92 |
-
"deepseek-r1",
|
93 |
-
"deepseek-reasoner",
|
94 |
-
"llamalight",
|
95 |
-
"gemini",
|
96 |
-
"gemini-thinking",
|
97 |
-
"hormoz",
|
98 |
-
"phi",
|
99 |
-
"phi-mini",
|
100 |
-
"openai-audio",
|
101 |
-
"llama-scaleway"
|
102 |
}
|
103 |
|
104 |
-
alternate_models = {
|
105 |
-
"gpt-4o",
|
106 |
-
"deepseek-
|
107 |
-
"llama-3.1-
|
108 |
-
"
|
109 |
-
"deepseek-r1-uncensored",
|
110 |
-
"tinyswallow1.5b",
|
111 |
-
"andy-3.5",
|
112 |
-
"o3-mini-low",
|
113 |
-
"hermes-3-llama-3.2-3b",
|
114 |
-
"creitin-r1",
|
115 |
-
"fluffy.1-chat",
|
116 |
-
"plutotext-1-text",
|
117 |
-
"command-a",
|
118 |
-
"claude-3-7-sonnet-20250219",
|
119 |
-
"plutogpt-3.5-turbo"
|
120 |
}
|
121 |
|
122 |
-
claude_3_models = {
|
123 |
-
"claude-3-7-sonnet",
|
124 |
-
"claude
|
125 |
-
"
|
126 |
-
"claude 3.5 sonnet",
|
127 |
-
"claude 3.5 haiku",
|
128 |
-
"o3-mini-medium",
|
129 |
-
"o3-mini-high",
|
130 |
-
"grok-3",
|
131 |
-
"grok-3-thinking",
|
132 |
-
"grok 2"
|
133 |
}
|
134 |
|
135 |
-
|
136 |
-
|
137 |
-
"Flux
|
138 |
-
"
|
139 |
-
"Flux Pro",
|
140 |
-
"Flux Pro Ultra Raw",
|
141 |
-
"Flux Dev",
|
142 |
-
"Flux Schnell",
|
143 |
-
"stable-diffusion-3-large-turbo",
|
144 |
-
"Flux Realism",
|
145 |
-
"stable-diffusion-ultra",
|
146 |
-
"dall-e-3",
|
147 |
-
"sdxl-lightning-4step"
|
148 |
}
|
149 |
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
-
# Request payload model
|
152 |
class Payload(BaseModel):
|
153 |
model: str
|
154 |
-
messages:
|
155 |
stream: bool = False
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
-
|
158 |
-
# Image generation payload model
|
159 |
class ImageGenerationPayload(BaseModel):
|
160 |
model: str
|
161 |
prompt: str
|
162 |
-
size:
|
163 |
-
|
|
|
|
|
164 |
|
|
|
165 |
|
|
|
|
|
166 |
|
167 |
-
#
|
168 |
-
|
169 |
-
available_model_ids: List[str] = []
|
170 |
-
|
171 |
-
# Create a reusable httpx client pool with connection pooling
|
172 |
@lru_cache(maxsize=1)
|
173 |
-
def get_async_client():
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
)
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
185 |
if not scraper_pool:
|
|
|
186 |
for _ in range(MAX_SCRAPERS):
|
|
|
187 |
scraper_pool.append(cloudscraper.create_scraper())
|
|
|
|
|
|
|
188 |
|
189 |
-
|
190 |
|
191 |
-
# API key validation - optimized to avoid string operations when possible
|
192 |
async def verify_api_key(
|
193 |
request: Request,
|
194 |
-
api_key: str = Security(api_key_header)
|
195 |
) -> bool:
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
referer = request.headers.get("referer", "")
|
198 |
-
if referer.startswith(("
|
199 |
-
|
200 |
return True
|
201 |
-
|
202 |
if not api_key:
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
# Only clean if needed
|
209 |
if api_key.startswith('Bearer '):
|
210 |
-
api_key = api_key[7:]
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
)
|
219 |
-
|
220 |
-
|
221 |
-
if api_key not in set(valid_api_keys):
|
222 |
-
raise HTTPException(
|
223 |
-
status_code=HTTP_403_FORBIDDEN,
|
224 |
-
detail="Invalid API key"
|
225 |
-
)
|
226 |
-
|
227 |
return True
|
228 |
|
229 |
-
#
|
|
|
230 |
@lru_cache(maxsize=1)
|
231 |
-
def load_models_data():
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
try:
|
233 |
-
|
234 |
-
with open(file_path, 'r') as f:
|
235 |
return json.load(f)
|
236 |
except (FileNotFoundError, json.JSONDecodeError) as e:
|
237 |
-
|
238 |
return []
|
239 |
|
240 |
-
|
241 |
-
|
242 |
models_data = load_models_data()
|
243 |
if not models_data:
|
244 |
raise HTTPException(status_code=500, detail="Error loading available models")
|
245 |
return models_data
|
246 |
|
247 |
-
#
|
248 |
-
async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
|
249 |
-
# Create a streaming response channel using asyncio.Queue
|
250 |
-
queue = asyncio.Queue()
|
251 |
-
|
252 |
-
async def _fetch_search_data():
|
253 |
-
try:
|
254 |
-
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
|
255 |
-
|
256 |
-
# Use the provided system prompt, or default to "Be Helpful and Friendly"
|
257 |
-
system_message = systemprompt or "Be Helpful and Friendly"
|
258 |
-
|
259 |
-
# Create the prompt history
|
260 |
-
prompt = [
|
261 |
-
{"role": "user", "content": query},
|
262 |
-
]
|
263 |
-
|
264 |
-
prompt.insert(0, {"content": system_message, "role": "system"})
|
265 |
-
|
266 |
-
# Prepare the payload for the API request
|
267 |
-
payload = {
|
268 |
-
"is_vscode_extension": True,
|
269 |
-
"message_history": prompt,
|
270 |
-
"requested_model": "searchgpt",
|
271 |
-
"user_input": prompt[-1]["content"],
|
272 |
-
}
|
273 |
-
|
274 |
-
# Get endpoint from environment
|
275 |
-
secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
|
276 |
-
if not secret_api_endpoint_3:
|
277 |
-
await queue.put({"error": "Search API endpoint not configured"})
|
278 |
-
return
|
279 |
-
|
280 |
-
# Use AsyncClient for better performance
|
281 |
-
async with httpx.AsyncClient(timeout=30.0) as client:
|
282 |
-
async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response:
|
283 |
-
if response.status_code != 200:
|
284 |
-
await queue.put({"error": f"Search API returned status code {response.status_code}"})
|
285 |
-
return
|
286 |
-
|
287 |
-
# Process the streaming response in real-time
|
288 |
-
buffer = ""
|
289 |
-
async for line in response.aiter_lines():
|
290 |
-
if line.startswith("data: "):
|
291 |
-
try:
|
292 |
-
json_data = json.loads(line[6:])
|
293 |
-
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
294 |
-
|
295 |
-
if content.strip():
|
296 |
-
cleaned_response = {
|
297 |
-
"created": json_data.get("created"),
|
298 |
-
"id": json_data.get("id"),
|
299 |
-
"model": "searchgpt",
|
300 |
-
"object": "chat.completion",
|
301 |
-
"choices": [
|
302 |
-
{
|
303 |
-
"message": {
|
304 |
-
"content": content
|
305 |
-
}
|
306 |
-
}
|
307 |
-
]
|
308 |
-
}
|
309 |
-
|
310 |
-
# Send to queue immediately for streaming
|
311 |
-
await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content})
|
312 |
-
except json.JSONDecodeError:
|
313 |
-
continue
|
314 |
-
|
315 |
-
# Signal completion
|
316 |
-
await queue.put(None)
|
317 |
-
|
318 |
-
except Exception as e:
|
319 |
-
await queue.put({"error": str(e)})
|
320 |
-
await queue.put(None)
|
321 |
|
322 |
-
|
323 |
-
asyncio.create_task(_fetch_search_data())
|
324 |
-
|
325 |
-
# Return the queue for consumption
|
326 |
-
return queue
|
327 |
-
|
328 |
-
# Cache for frequently accessed static files
|
329 |
@lru_cache(maxsize=10)
|
330 |
-
def
|
|
|
|
|
|
|
|
|
|
|
331 |
try:
|
332 |
-
with open(
|
333 |
return file.read()
|
334 |
-
except
|
|
|
335 |
return None
|
336 |
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
342 |
|
343 |
-
|
344 |
-
async def favicon():
|
345 |
-
favicon_path = Path(__file__).parent / "banner.jpg"
|
346 |
-
return FileResponse(favicon_path, media_type="image/x-icon")
|
347 |
|
348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
async def ping():
|
350 |
-
|
|
|
351 |
|
352 |
-
@app.get("/", response_class=HTMLResponse)
|
353 |
async def root():
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
@app.get("/style.css", response_class=
|
365 |
-
async def
|
366 |
-
|
367 |
-
if
|
368 |
-
return
|
369 |
-
return
|
370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
async def dynamic_ai_page(request: Request):
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
prompt = f"""
|
377 |
-
Generate a dynamic HTML page for a user with the following details:
|
|
|
378 |
- User-Agent: {user_agent}
|
379 |
-
- Location: {location}
|
380 |
-
- Style: Cyberpunk, minimalist,
|
381 |
-
|
382 |
-
|
383 |
-
Wrap the generated HTML in triple backticks (```).
|
384 |
"""
|
385 |
-
|
386 |
payload = {
|
387 |
-
"model": "mistral-small-latest",
|
388 |
-
"messages": [{"role": "user", "content": prompt}]
|
|
|
|
|
389 |
}
|
390 |
-
|
391 |
headers = {
|
392 |
-
|
|
|
|
|
393 |
}
|
394 |
-
|
395 |
-
response = requests.post("https://parthsadaria-lokiai.hf.space/chat/completions", json=payload, headers=headers)
|
396 |
-
data = response.json()
|
397 |
-
|
398 |
-
# Extract HTML from ``` blocks
|
399 |
-
html_content = re.search(r"```(.*?)```", data['choices'][0]['message']['content'], re.DOTALL)
|
400 |
-
if html_content:
|
401 |
-
html_content = html_content.group(1).strip()
|
402 |
-
|
403 |
-
# Remove the first word
|
404 |
-
if html_content:
|
405 |
-
html_content = ' '.join(html_content.split(' ')[1:])
|
406 |
-
|
407 |
-
return HTMLResponse(content=html_content)
|
408 |
-
|
409 |
-
@app.get("/playground", response_class=HTMLResponse)
|
410 |
-
async def playground():
|
411 |
-
html_content = read_html_file("playground.html")
|
412 |
-
if html_content is None:
|
413 |
-
return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
|
414 |
-
return HTMLResponse(content=html_content)
|
415 |
-
|
416 |
-
@app.get("/image-playground", response_class=HTMLResponse)
|
417 |
-
async def playground():
|
418 |
-
html_content = read_html_file("image-playground.html")
|
419 |
-
if html_content is None:
|
420 |
-
return HTMLResponse(content="<h1>image-playground.html not found</h1>", status_code=404)
|
421 |
-
return HTMLResponse(content=html_content)
|
422 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
|
|
|
|
|
|
|
424 |
|
|
|
|
|
|
|
425 |
|
426 |
-
|
427 |
-
GITHUB_BASE = "https://raw.githubusercontent.com/Parthsadaria/Vetra/main"
|
428 |
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
|
435 |
-
async def get_github_file(filename: str) -> str:
|
|
|
436 |
url = f"{GITHUB_BASE}/{filename}"
|
437 |
-
|
|
|
438 |
res = await client.get(url)
|
439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
|
441 |
-
@app.get("/vetra", response_class=HTMLResponse)
|
442 |
async def serve_vetra():
|
443 |
-
|
444 |
-
|
445 |
-
|
|
|
|
|
|
|
|
|
|
|
446 |
|
447 |
if not html:
|
448 |
-
|
|
|
|
|
|
|
|
|
|
|
449 |
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
).replace(
|
454 |
-
"</body>",
|
455 |
-
f"<script>{js or '// JS not found'}</script></body>"
|
456 |
-
)
|
457 |
|
|
|
458 |
return HTMLResponse(content=final_html)
|
459 |
|
460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
|
|
|
|
|
|
|
|
|
|
|
463 |
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
async def return_models():
|
468 |
-
return await get_models()
|
469 |
|
470 |
-
|
471 |
-
|
472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
if not q:
|
474 |
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
|
475 |
|
|
|
|
|
476 |
usage_tracker.record_request(endpoint="/searchgpt")
|
477 |
|
478 |
-
queue = await generate_search_async(q, systemprompt=systemprompt
|
479 |
|
480 |
if stream:
|
481 |
async def stream_generator():
|
482 |
-
|
483 |
while True:
|
484 |
item = await queue.get()
|
485 |
-
if item is None:
|
486 |
break
|
487 |
-
|
488 |
if "error" in item:
|
489 |
-
|
|
|
|
|
|
|
|
|
490 |
break
|
491 |
-
|
492 |
if "data" in item:
|
493 |
yield item["data"]
|
494 |
-
|
|
|
|
|
495 |
|
496 |
return StreamingResponse(
|
497 |
stream_generator(),
|
498 |
-
media_type="text/event-stream"
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
)
|
500 |
else:
|
501 |
-
#
|
502 |
-
|
503 |
while True:
|
504 |
item = await queue.get()
|
505 |
if item is None:
|
506 |
break
|
507 |
-
|
508 |
if "error" in item:
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
|
516 |
|
517 |
-
#
|
518 |
-
|
519 |
-
@app.post("/chat/completions")
|
520 |
-
|
521 |
-
|
522 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
if not server_status:
|
524 |
-
|
525 |
-
status_code=503,
|
526 |
-
content={"message": "Server is under maintenance. Please try again later."}
|
527 |
-
)
|
528 |
|
529 |
-
model_to_use = payload.model or "gpt-4o-mini"
|
530 |
|
531 |
-
#
|
532 |
-
if available_model_ids and model_to_use not in
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
|
538 |
-
|
|
|
539 |
asyncio.create_task(log_request(request, model_to_use))
|
540 |
usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
|
541 |
|
542 |
-
# Prepare payload
|
543 |
-
payload_dict = payload.dict()
|
544 |
-
payload_dict["model"] = model_to_use
|
545 |
-
|
546 |
-
# Ensure stream is True for real-time streaming (can be overridden by client)
|
547 |
-
stream_enabled = payload_dict.get("stream", True)
|
548 |
|
549 |
-
# Get environment variables
|
550 |
env_vars = get_env_vars()
|
|
|
|
|
|
|
|
|
|
|
551 |
|
552 |
-
# Select the appropriate endpoint (fast lookup with sets)
|
553 |
if model_to_use in mistral_models:
|
554 |
-
endpoint = env_vars
|
555 |
-
|
556 |
-
|
557 |
-
|
|
|
|
|
|
|
|
|
558 |
elif model_to_use in pollinations_models:
|
559 |
-
endpoint = env_vars
|
560 |
-
|
|
|
|
|
|
|
|
|
561 |
elif model_to_use in alternate_models:
|
562 |
-
endpoint = env_vars
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
endpoint = env_vars
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
|
575 |
-
|
|
|
576 |
|
577 |
-
|
578 |
-
|
|
|
|
|
|
|
579 |
try:
|
580 |
-
async with
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
except httpx.TimeoutException:
|
598 |
-
|
|
|
|
|
599 |
except httpx.RequestError as e:
|
600 |
-
|
|
|
|
|
601 |
except Exception as e:
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
|
606 |
-
|
607 |
-
if stream_enabled:
|
608 |
return StreamingResponse(
|
609 |
-
|
610 |
media_type="text/event-stream",
|
611 |
headers={
|
612 |
"Content-Type": "text/event-stream",
|
613 |
"Cache-Control": "no-cache",
|
614 |
"Connection": "keep-alive",
|
615 |
-
"X-Accel-Buffering": "no"
|
616 |
}
|
617 |
)
|
618 |
else:
|
619 |
-
#
|
620 |
-
|
621 |
-
|
622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
623 |
|
624 |
-
return JSONResponse(content=json.loads(''.join(response_content)))
|
625 |
|
|
|
626 |
|
627 |
|
628 |
-
#
|
629 |
-
@app.post("/images/generations")
|
630 |
-
async def create_image(
|
|
|
|
|
|
|
631 |
"""
|
632 |
-
|
633 |
"""
|
634 |
-
# Check server status
|
635 |
if not server_status:
|
636 |
-
|
637 |
-
status_code=503,
|
638 |
-
content={"message": "Server is under maintenance. Please try again later."}
|
639 |
-
)
|
640 |
|
641 |
-
# Validate model
|
642 |
if payload.model not in supported_image_models:
|
643 |
raise HTTPException(
|
644 |
status_code=400,
|
645 |
-
detail=f"Model '{payload.model}' is not supported for image generation.
|
646 |
)
|
647 |
|
648 |
-
# Log the request
|
649 |
usage_tracker.record_request(model=payload.model, endpoint="/images/generations")
|
650 |
|
651 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
652 |
api_payload = {
|
653 |
"model": payload.model,
|
654 |
"prompt": payload.prompt,
|
655 |
-
"
|
656 |
-
"
|
657 |
}
|
|
|
|
|
658 |
|
659 |
-
# Target API endpoint
|
660 |
-
target_api_url = os.getenv('NEW_IMG')
|
661 |
|
662 |
-
|
663 |
-
|
664 |
-
async with httpx.AsyncClient(timeout=60.0) as client:
|
665 |
-
response = await client.post(target_api_url, json=api_payload)
|
666 |
|
667 |
-
|
668 |
-
|
669 |
-
|
|
|
|
|
670 |
|
671 |
-
# Return the response from the
|
672 |
return JSONResponse(content=response.json())
|
673 |
|
674 |
except httpx.TimeoutException:
|
|
|
675 |
raise HTTPException(status_code=504, detail="Image generation request timed out.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
except httpx.RequestError as e:
|
|
|
677 |
raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}")
|
678 |
except Exception as e:
|
|
|
679 |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
|
680 |
|
681 |
|
|
|
682 |
|
683 |
-
|
684 |
-
|
685 |
-
#
|
686 |
-
|
687 |
-
|
688 |
-
|
|
|
|
|
|
|
|
|
|
|
689 |
|
690 |
-
# Cache usage statistics
|
691 |
-
@lru_cache(maxsize=10)
|
692 |
-
def get_usage_summary(days=7):
|
693 |
-
return usage_tracker.get_usage_summary(days)
|
694 |
|
695 |
-
@app.get("/usage")
|
696 |
async def get_usage(days: int = 7):
|
697 |
-
"""
|
698 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
699 |
|
700 |
-
# Generate HTML for usage page
|
701 |
-
def generate_usage_html(usage_data):
|
702 |
-
# Model Usage Table Rows
|
703 |
model_usage_rows = "\n".join([
|
704 |
f"""
|
705 |
<tr>
|
706 |
<td>{model}</td>
|
707 |
-
<td>{model_data
|
708 |
-
<td>{model_data
|
709 |
-
<td>{model_data
|
710 |
</tr>
|
711 |
-
""" for model, model_data in
|
712 |
-
])
|
713 |
|
714 |
-
# API Endpoint Usage Table Rows
|
715 |
api_usage_rows = "\n".join([
|
716 |
f"""
|
717 |
<tr>
|
718 |
<td>{endpoint}</td>
|
719 |
-
<td>{endpoint_data
|
720 |
-
<td>{endpoint_data
|
721 |
-
<td>{endpoint_data
|
722 |
</tr>
|
723 |
-
""" for endpoint, endpoint_data in
|
724 |
-
])
|
725 |
|
726 |
-
# Daily Usage Table Rows
|
727 |
daily_usage_rows = "\n".join([
|
728 |
-
"
|
729 |
-
|
730 |
-
<
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
])
|
738 |
|
|
|
|
|
|
|
739 |
html_content = f"""
|
740 |
<!DOCTYPE html>
|
741 |
<html lang="en">
|
742 |
<head>
|
743 |
<meta charset="UTF-8">
|
|
|
744 |
<title>Lokiai AI - Usage Statistics</title>
|
745 |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet">
|
746 |
<style>
|
|
|
747 |
:root {{
|
748 |
-
--bg-dark: #0f1011;
|
749 |
-
--
|
750 |
-
--text-primary: #e6e6e6;
|
751 |
-
--text-secondary: #8c8c8c;
|
752 |
-
--border-color: #2c2c2c;
|
753 |
-
--accent-color: #3a6ee0;
|
754 |
--accent-hover: #4a7ef0;
|
755 |
}}
|
756 |
-
body {{
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
}}
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
margin-bottom: 30px;
|
770 |
-
}}
|
771 |
-
.logo h1 {{
|
772 |
-
font-weight: 600;
|
773 |
-
font-size: 2.5em;
|
774 |
-
color: var(--text-primary);
|
775 |
-
margin-left: 15px;
|
776 |
-
}}
|
777 |
-
.logo img {{
|
778 |
-
width: 60px;
|
779 |
-
height: 60px;
|
780 |
-
border-radius: 10px;
|
781 |
-
}}
|
782 |
-
.container {{
|
783 |
-
background-color: var(--bg-darker);
|
784 |
-
border-radius: 12px;
|
785 |
-
padding: 30px;
|
786 |
-
box-shadow: 0 15px 40px rgba(0,0,0,0.3);
|
787 |
-
border: 1px solid var(--border-color);
|
788 |
-
}}
|
789 |
-
h2, h3 {{
|
790 |
-
color: var(--text-primary);
|
791 |
-
border-bottom: 2px solid var(--border-color);
|
792 |
-
padding-bottom: 10px;
|
793 |
-
font-weight: 500;
|
794 |
-
}}
|
795 |
-
.total-requests {{
|
796 |
-
background-color: var(--accent-color);
|
797 |
-
color: white;
|
798 |
-
text-align: center;
|
799 |
-
padding: 15px;
|
800 |
-
border-radius: 8px;
|
801 |
-
margin-bottom: 30px;
|
802 |
-
font-weight: 600;
|
803 |
-
letter-spacing: -0.5px;
|
804 |
-
}}
|
805 |
-
table {{
|
806 |
-
width: 100%;
|
807 |
-
border-collapse: separate;
|
808 |
-
border-spacing: 0;
|
809 |
-
margin-bottom: 30px;
|
810 |
-
background-color: var(--bg-dark);
|
811 |
-
border-radius: 8px;
|
812 |
-
overflow: hidden;
|
813 |
-
}}
|
814 |
-
th, td {{
|
815 |
-
border: 1px solid var(--border-color);
|
816 |
-
padding: 12px;
|
817 |
-
text-align: left;
|
818 |
-
transition: background-color 0.3s ease;
|
819 |
-
}}
|
820 |
-
th {{
|
821 |
-
background-color: #1e1e1e;
|
822 |
-
color: var(--text-primary);
|
823 |
-
font-weight: 600;
|
824 |
-
text-transform: uppercase;
|
825 |
-
font-size: 0.9em;
|
826 |
-
}}
|
827 |
-
tr:nth-child(even) {{
|
828 |
-
background-color: rgba(255,255,255,0.05);
|
829 |
-
}}
|
830 |
-
tr:hover {{
|
831 |
-
background-color: rgba(62,100,255,0.1);
|
832 |
-
}}
|
833 |
-
@media (max-width: 768px) {{
|
834 |
-
.container {{
|
835 |
-
padding: 15px;
|
836 |
-
}}
|
837 |
-
table {{
|
838 |
-
font-size: 0.9em;
|
839 |
-
}}
|
840 |
-
}}
|
841 |
</style>
|
842 |
</head>
|
843 |
<body>
|
844 |
<div class="container">
|
845 |
<div class="logo">
|
846 |
<img src="data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjAwIiBoZWlnaHQ9IjIwMCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48cGF0aCBkPSJNMTAwIDM1TDUwIDkwaDEwMHoiIGZpbGw9IiMzYTZlZTAiLz48Y2lyY2xlIGN4PSIxMDAiIGN5PSIxNDAiIHI9IjMwIiBmaWxsPSIjM2E2ZWUwIi8+PC9zdmc+" alt="Lokai AI Logo">
|
847 |
-
<h1>Lokiai AI</h1>
|
848 |
</div>
|
849 |
|
850 |
<div class="total-requests">
|
851 |
-
Total API Requests: {
|
852 |
</div>
|
853 |
|
854 |
<h2>Model Usage</h2>
|
855 |
<table>
|
856 |
-
<tr>
|
857 |
-
|
858 |
-
<th>Total Requests</th>
|
859 |
-
<th>First Used</th>
|
860 |
-
<th>Last Used</th>
|
861 |
-
</tr>
|
862 |
-
{model_usage_rows}
|
863 |
</table>
|
864 |
|
865 |
<h2>API Endpoint Usage</h2>
|
866 |
<table>
|
867 |
-
<tr>
|
868 |
-
|
869 |
-
<th>Total Requests</th>
|
870 |
-
<th>First Used</th>
|
871 |
-
<th>Last Used</th>
|
872 |
-
</tr>
|
873 |
-
{api_usage_rows}
|
874 |
</table>
|
875 |
|
876 |
-
<h2>Daily Usage (Last 7 Days)</h2>
|
877 |
<table>
|
878 |
-
|
879 |
-
|
880 |
-
<th>Entity</th>
|
881 |
-
<th>Requests</th>
|
882 |
-
</tr>
|
883 |
-
{daily_usage_rows}
|
884 |
</table>
|
885 |
</div>
|
886 |
</body>
|
@@ -888,166 +1058,199 @@ def generate_usage_html(usage_data):
|
|
888 |
"""
|
889 |
return html_content
|
890 |
|
891 |
-
#
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
return generate_usage_html(usage_data)
|
896 |
|
897 |
-
@app.get("/usage/page", response_class=HTMLResponse)
|
898 |
async def usage_page():
|
899 |
-
"""
|
900 |
-
|
901 |
-
|
902 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
903 |
|
904 |
-
# Meme
|
905 |
-
@app.get("/meme")
|
906 |
async def get_meme():
|
|
|
|
|
|
|
|
|
907 |
try:
|
908 |
-
|
909 |
-
|
910 |
-
response
|
911 |
response_data = response.json()
|
912 |
|
913 |
meme_url = response_data.get("url")
|
914 |
-
if not meme_url:
|
915 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
916 |
|
917 |
-
image_response = await client.get(meme_url, follow_redirects=True)
|
918 |
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
size += len(chunk)
|
926 |
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
931 |
|
932 |
-
|
933 |
-
|
|
|
|
|
|
|
|
|
|
|
934 |
|
935 |
-
return StreamingResponse(
|
936 |
-
stream_with_larger_chunks(),
|
937 |
-
media_type=image_response.headers.get("content-type", "image/png"),
|
938 |
-
headers={'Cache-Control': 'max-age=3600'} # Add caching
|
939 |
-
)
|
940 |
-
except Exception:
|
941 |
-
raise HTTPException(status_code=500, detail="Failed to retrieve meme")
|
942 |
|
943 |
-
#
|
944 |
-
def load_model_ids(json_file_path):
|
945 |
-
try:
|
946 |
-
with open(json_file_path, 'r') as f:
|
947 |
-
models_data = json.load(f)
|
948 |
-
# Extract 'id' from each model object and use a set for fast lookups
|
949 |
-
return [model['id'] for model in models_data if 'id' in model]
|
950 |
-
except Exception as e:
|
951 |
-
print(f"Error loading model IDs: {str(e)}")
|
952 |
-
return []
|
953 |
|
954 |
@app.on_event("startup")
|
955 |
async def startup_event():
|
|
|
956 |
global available_model_ids
|
957 |
-
|
958 |
-
|
959 |
-
|
960 |
-
|
961 |
-
|
962 |
-
|
963 |
-
|
964 |
-
|
965 |
-
|
966 |
-
|
967 |
-
|
968 |
-
|
969 |
-
|
970 |
-
|
971 |
-
|
972 |
-
#
|
973 |
-
|
974 |
-
|
975 |
-
|
976 |
-
# Validate critical environment variables
|
977 |
env_vars = get_env_vars()
|
978 |
-
|
979 |
-
|
980 |
-
|
981 |
-
|
982 |
-
|
983 |
-
|
984 |
-
|
985 |
-
|
986 |
-
|
987 |
-
missing_vars.append('SECRET_API_ENDPOINT_3')
|
988 |
-
if not env_vars['secret_api_endpoint_4']:
|
989 |
-
missing_vars.append('SECRET_API_ENDPOINT_4')
|
990 |
-
if not env_vars['secret_api_endpoint_5']: # Check the new endpoint
|
991 |
-
missing_vars.append('SECRET_API_ENDPOINT_5')
|
992 |
-
if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids):
|
993 |
-
missing_vars.append('MISTRAL_API')
|
994 |
-
if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
|
995 |
-
missing_vars.append('MISTRAL_KEY')
|
996 |
-
|
997 |
-
if missing_vars:
|
998 |
-
print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
|
999 |
-
print("Some functionality may be limited.")
|
1000 |
-
|
1001 |
-
print("Server started successfully!")
|
1002 |
|
1003 |
@app.on_event("shutdown")
|
1004 |
async def shutdown_event():
|
1005 |
-
|
|
|
|
|
|
|
1006 |
client = get_async_client()
|
1007 |
await client.aclose()
|
|
|
|
|
|
|
|
|
|
|
1008 |
|
1009 |
-
# Clear scraper pool
|
1010 |
scraper_pool.clear()
|
|
|
1011 |
|
1012 |
# Persist usage data
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
1020 |
-
|
1021 |
-
"""Health check endpoint for monitoring"""
|
1022 |
-
env_vars = get_env_vars()
|
1023 |
-
missing_critical_vars = []
|
1024 |
|
1025 |
-
|
1026 |
-
if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
|
1027 |
-
missing_critical_vars.append('API_KEYS')
|
1028 |
-
if not env_vars['secret_api_endpoint']:
|
1029 |
-
missing_critical_vars.append('SECRET_API_ENDPOINT')
|
1030 |
-
if not env_vars['secret_api_endpoint_2']:
|
1031 |
-
missing_critical_vars.append('SECRET_API_ENDPOINT_2')
|
1032 |
-
if not env_vars['secret_api_endpoint_3']:
|
1033 |
-
missing_critical_vars.append('SECRET_API_ENDPOINT_3')
|
1034 |
-
if not env_vars['secret_api_endpoint_4']:
|
1035 |
-
missing_critical_vars.append('SECRET_API_ENDPOINT_4')
|
1036 |
-
if not env_vars['secret_api_endpoint_5']: # Check the new endpoint
|
1037 |
-
missing_critical_vars.append('SECRET_API_ENDPOINT_5')
|
1038 |
-
if not env_vars['mistral_api']:
|
1039 |
-
missing_critical_vars.append('MISTRAL_API')
|
1040 |
-
if not env_vars['mistral_key']:
|
1041 |
-
missing_critical_vars.append('MISTRAL_KEY')
|
1042 |
|
1043 |
-
health_status = {
|
1044 |
-
"status": "healthy" if not missing_critical_vars else "unhealthy",
|
1045 |
-
"missing_env_vars": missing_critical_vars,
|
1046 |
-
"server_status": server_status,
|
1047 |
-
"message": "Everything's lit! 🚀" if not missing_critical_vars else "Uh oh, some env vars are missing. 😬"
|
1048 |
-
}
|
1049 |
-
return JSONResponse(content=health_status)
|
1050 |
|
|
|
|
|
|
|
|
|
1051 |
if __name__ == "__main__":
|
1052 |
import uvicorn
|
1053 |
-
|
|
|
|
|
|
1 |
import os
|
2 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import json
|
4 |
import datetime
|
5 |
import time
|
|
|
|
|
6 |
import asyncio
|
7 |
+
import logging
|
8 |
+
from pathlib import Path
|
9 |
+
from functools import lru_cache
|
10 |
+
from typing import Optional, Dict, List, Any, Generator, Set
|
11 |
from concurrent.futures import ThreadPoolExecutor
|
12 |
+
|
13 |
+
# Third-party libraries (ensure these are in requirements.txt)
|
14 |
+
from dotenv import load_dotenv
|
15 |
+
from fastapi import FastAPI, HTTPException, Request, Depends, Security, Response
|
16 |
+
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
|
17 |
+
from fastapi.security import APIKeyHeader
|
18 |
+
from pydantic import BaseModel
|
19 |
+
import httpx
|
20 |
+
import uvloop # Use uvloop for performance
|
21 |
from fastapi.middleware.gzip import GZipMiddleware
|
22 |
from starlette.middleware.cors import CORSMiddleware
|
23 |
+
import cloudscraper # For bypassing Cloudflare, potentially unreliable
|
24 |
+
import requests # For synchronous requests like in /dynamo
|
25 |
+
|
26 |
+
# HF Space Note: Ensure usage_tracker.py is in your repository
|
27 |
+
try:
|
28 |
+
from usage_tracker import UsageTracker
|
29 |
+
usage_tracker = UsageTracker()
|
30 |
+
except ImportError:
|
31 |
+
print("Warning: usage_tracker.py not found. Usage tracking will be disabled.")
|
32 |
+
# Create a dummy tracker if the file is missing
|
33 |
+
class DummyUsageTracker:
|
34 |
+
def record_request(self, *args, **kwargs): pass
|
35 |
+
def get_usage_summary(self, *args, **kwargs): return {}
|
36 |
+
def save_data(self, *args, **kwargs): pass
|
37 |
+
usage_tracker = DummyUsageTracker()
|
38 |
+
|
39 |
+
|
40 |
+
# --- Configuration & Setup ---
|
41 |
+
|
42 |
+
# HF Space Note: uvloop can improve performance in I/O bound tasks common in web apps.
|
43 |
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
44 |
|
45 |
+
# HF Space Note: Adjust max_workers based on your HF Space resources (CPU).
|
46 |
+
# Higher tiers allow more workers. Start lower (e.g., 4) for free tier.
|
47 |
+
executor = ThreadPoolExecutor(max_workers=8)
|
48 |
|
49 |
+
# HF Space Note: load_dotenv() is useful for local dev but HF Spaces use Secrets.
|
50 |
+
# os.getenv will automatically pick up secrets set in the HF Space settings.
|
51 |
load_dotenv()
|
52 |
|
53 |
+
# Logging setup
|
54 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
55 |
+
logger = logging.getLogger(__name__)
|
56 |
+
|
57 |
+
# API key security
|
58 |
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
59 |
|
60 |
+
# --- FastAPI App Initialization ---
|
|
|
|
|
61 |
|
62 |
+
app = FastAPI(
|
63 |
+
title="LokiAI API",
|
64 |
+
description="API Proxy for various AI models with usage tracking and streaming.",
|
65 |
+
version="1.0.0"
|
66 |
+
)
|
67 |
|
68 |
+
# Middleware
|
69 |
+
app.add_middleware(GZipMiddleware, minimum_size=1000) # Compress large responses
|
70 |
app.add_middleware(
|
71 |
+
CORSMiddleware, # Allow cross-origin requests (useful for web playgrounds)
|
72 |
+
allow_origins=["*"], # Or restrict to specific origins
|
73 |
allow_credentials=True,
|
74 |
allow_methods=["*"],
|
75 |
allow_headers=["*"],
|
76 |
)
|
77 |
|
78 |
+
# --- Environment Variables & Model Config ---
|
79 |
+
|
80 |
+
@lru_cache(maxsize=1) # Cache environment variables
|
81 |
+
def get_env_vars() -> Dict[str, Any]:
|
82 |
+
"""Loads and returns essential environment variables."""
|
83 |
+
# HF Space Note: Set these as Secrets in your Hugging Face Space settings.
|
84 |
return {
|
85 |
+
'api_keys': set(filter(None, os.getenv('API_KEYS', '').split(','))), # Use set for faster lookup
|
86 |
'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'),
|
87 |
'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
|
88 |
+
'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'), # Search endpoint
|
89 |
+
'secret_api_endpoint_4': os.getenv('SECRET_API_ENDPOINT_4', "https://text.pollinations.ai/openai"), # Pollinations
|
90 |
+
'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'), # Claude 3 endpoint
|
91 |
+
'mistral_api': os.getenv('MISTRAL_API', "https://api.mistral.ai"),
|
92 |
'mistral_key': os.getenv('MISTRAL_KEY'),
|
93 |
+
'new_img_endpoint': os.getenv('NEW_IMG'), # Image generation endpoint
|
94 |
+
'hf_space_url': os.getenv('HF_SPACE_URL', 'https://your-space-name.hf.space') # HF Space Note: Set this! Used for Referer/Origin checks.
|
95 |
}
|
96 |
|
97 |
+
# Model sets for fast lookups
|
98 |
+
# HF Space Note: Consider moving these large sets to a separate config file (e.g., config.py or models_config.json)
|
99 |
+
# for better organization if they grow larger.
|
100 |
+
mistral_models: Set[str] = {
|
101 |
+
"mistral-large-latest", "pixtral-large-latest", "mistral-moderation-latest",
|
102 |
+
"ministral-3b-latest", "ministral-8b-latest", "open-mistral-nemo",
|
103 |
+
"mistral-small-latest", "mistral-saba-latest", "codestral-latest"
|
|
|
|
|
|
|
|
|
104 |
}
|
105 |
|
106 |
+
pollinations_models: Set[str] = {
|
107 |
+
"openai", "openai-large", "openai-xlarge", "openai-reasoning", "qwen-coder",
|
108 |
+
"llama", "mistral", "searchgpt", "deepseek", "claude-hybridspace",
|
109 |
+
"deepseek-r1", "deepseek-reasoner", "llamalight", "gemini", "gemini-thinking",
|
110 |
+
"hormoz", "phi", "phi-mini", "openai-audio", "llama-scaleway"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
}
|
112 |
|
113 |
+
alternate_models: Set[str] = {
|
114 |
+
"gpt-4o", "deepseek-v3", "llama-3.1-8b-instruct", "llama-3.1-sonar-small-128k-online",
|
115 |
+
"deepseek-r1-uncensored", "tinyswallow1.5b", "andy-3.5", "o3-mini-low",
|
116 |
+
"hermes-3-llama-3.2-3b", "creitin-r1", "fluffy.1-chat", "plutotext-1-text",
|
117 |
+
"command-a", "claude-3-7-sonnet-20250219", "plutogpt-3.5-turbo"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
}
|
119 |
|
120 |
+
claude_3_models: Set[str] = {
|
121 |
+
"claude-3-7-sonnet", "claude-3-7-sonnet-thinking", "claude 3.5 haiku",
|
122 |
+
"claude 3.5 sonnet", "claude 3.5 haiku", "o3-mini-medium", "o3-mini-high",
|
123 |
+
"grok-3", "grok-3-thinking", "grok 2"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
}
|
125 |
|
126 |
+
supported_image_models: Set[str] = {
|
127 |
+
"Flux Pro Ultra", "grok-2-aurora", "Flux Pro", "Flux Pro Ultra Raw", "Flux Dev",
|
128 |
+
"Flux Schnell", "stable-diffusion-3-large-turbo", "Flux Realism",
|
129 |
+
"stable-diffusion-ultra", "dall-e-3", "sdxl-lightning-4step"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
}
|
131 |
|
132 |
+
# --- Pydantic Models ---
|
133 |
+
|
134 |
+
class Message(BaseModel):
|
135 |
+
role: str
|
136 |
+
content: Any # Allow content to be string or potentially list for multimodal models
|
137 |
|
|
|
138 |
class Payload(BaseModel):
|
139 |
model: str
|
140 |
+
messages: List[Message]
|
141 |
stream: bool = False
|
142 |
+
# Add other potential OpenAI compatible parameters with defaults
|
143 |
+
max_tokens: Optional[int] = None
|
144 |
+
temperature: Optional[float] = None
|
145 |
+
top_p: Optional[float] = None
|
146 |
+
# ... add others as needed
|
147 |
|
|
|
|
|
148 |
class ImageGenerationPayload(BaseModel):
|
149 |
model: str
|
150 |
prompt: str
|
151 |
+
size: Optional[str] = "1024x1024" # Default size, make optional if API allows
|
152 |
+
n: Optional[int] = 1 # Number of images, OpenAI uses 'n'
|
153 |
+
# HF Space Note: Ensure these parameter names match the target NEW_IMG endpoint API
|
154 |
+
# Renaming from 'number' to 'n' and 'size' type hint correction.
|
155 |
|
156 |
+
# --- Global State & Clients ---
|
157 |
|
158 |
+
server_status: bool = True # For maintenance mode
|
159 |
+
available_model_ids: List[str] = [] # Loaded at startup
|
160 |
|
161 |
+
# HF Space Note: Reusable HTTP client with connection pooling is crucial for performance.
|
162 |
+
# Adjust limits based on expected load and HF Space resources.
|
|
|
|
|
|
|
163 |
@lru_cache(maxsize=1)
|
164 |
+
def get_async_client() -> httpx.AsyncClient:
|
165 |
+
"""Returns a cached instance of httpx.AsyncClient."""
|
166 |
+
# HF Space Note: Timeouts are important to prevent hanging requests.
|
167 |
+
# Keepalive connections reduce handshake overhead.
|
168 |
+
timeout = httpx.Timeout(30.0, connect=10.0) # 30s total, 10s connect
|
169 |
+
limits = httpx.Limits(max_keepalive_connections=20, max_connections=100)
|
170 |
+
return httpx.AsyncClient(timeout=timeout, limits=limits, follow_redirects=True)
|
171 |
+
|
172 |
+
# HF Space Note: cloudscraper pool. Be mindful of potential rate limits or blocks.
|
173 |
+
# Consider alternatives if this becomes unreliable.
|
174 |
+
scraper_pool: List[cloudscraper.CloudScraper] = []
|
175 |
+
MAX_SCRAPERS = 10 # Reduced pool size for potentially lower resource usage
|
176 |
+
|
177 |
+
def get_scraper() -> cloudscraper.CloudScraper:
|
178 |
+
"""Gets a cloudscraper instance from the pool."""
|
179 |
if not scraper_pool:
|
180 |
+
logger.info(f"Initializing {MAX_SCRAPERS} cloudscraper instances...")
|
181 |
for _ in range(MAX_SCRAPERS):
|
182 |
+
# HF Space Note: Scraper creation can be slow, doing it upfront is good.
|
183 |
scraper_pool.append(cloudscraper.create_scraper())
|
184 |
+
logger.info("Cloudscraper pool initialized.")
|
185 |
+
# Simple round-robin selection
|
186 |
+
return scraper_pool[int(time.monotonic() * 1000) % MAX_SCRAPERS]
|
187 |
|
188 |
+
# --- Security & Authentication ---
|
189 |
|
|
|
190 |
async def verify_api_key(
|
191 |
request: Request,
|
192 |
+
api_key: Optional[str] = Security(api_key_header)
|
193 |
) -> bool:
|
194 |
+
"""Verifies the provided API key against environment variables."""
|
195 |
+
env_vars = get_env_vars()
|
196 |
+
valid_api_keys = env_vars.get('api_keys', set())
|
197 |
+
hf_space_url = env_vars.get('hf_space_url', '')
|
198 |
+
|
199 |
+
# Allow bypass if the referer is from the known HF Space playground URLs
|
200 |
+
# HF Space Note: Make HF_SPACE_URL a secret for flexibility.
|
201 |
referer = request.headers.get("referer", "")
|
202 |
+
if hf_space_url and referer.startswith((f"{hf_space_url}/playground", f"{hf_space_url}/image-playground")):
|
203 |
+
logger.debug(f"API Key check bypassed for referer: {referer}")
|
204 |
return True
|
205 |
+
|
206 |
if not api_key:
|
207 |
+
logger.warning("API Key missing.")
|
208 |
+
raise HTTPException(status_code=403, detail="Not authenticated: No API key provided")
|
209 |
+
|
210 |
+
# Clean 'Bearer ' prefix if present
|
|
|
|
|
211 |
if api_key.startswith('Bearer '):
|
212 |
+
api_key = api_key[7:]
|
213 |
+
|
214 |
+
if not valid_api_keys:
|
215 |
+
logger.error("API keys are not configured on the server (API_KEYS secret missing?).")
|
216 |
+
raise HTTPException(status_code=500, detail="Server configuration error: API keys not set")
|
217 |
+
|
218 |
+
if api_key not in valid_api_keys:
|
219 |
+
logger.warning(f"Invalid API key received: {api_key[:4]}...") # Log prefix only
|
220 |
+
raise HTTPException(status_code=403, detail="Not authenticated: Invalid API key")
|
221 |
+
|
222 |
+
logger.debug("API Key verified successfully.")
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
return True
|
224 |
|
225 |
+
# --- Model & File Loading ---
|
226 |
+
|
227 |
@lru_cache(maxsize=1)
|
228 |
+
def load_models_data() -> List[Dict]:
|
229 |
+
"""Loads model data from models.json."""
|
230 |
+
# HF Space Note: Ensure models.json is in the root of your HF Space repo.
|
231 |
+
models_file = Path(__file__).parent / 'models.json'
|
232 |
+
if not models_file.is_file():
|
233 |
+
logger.error("models.json not found!")
|
234 |
+
return []
|
235 |
try:
|
236 |
+
with open(models_file, 'r') as f:
|
|
|
237 |
return json.load(f)
|
238 |
except (FileNotFoundError, json.JSONDecodeError) as e:
|
239 |
+
logger.error(f"Error loading models.json: {e}")
|
240 |
return []
|
241 |
|
242 |
+
async def get_models() -> List[Dict]:
|
243 |
+
"""Async wrapper to get models data."""
|
244 |
models_data = load_models_data()
|
245 |
if not models_data:
|
246 |
raise HTTPException(status_code=500, detail="Error loading available models")
|
247 |
return models_data
|
248 |
|
249 |
+
# --- Static File Serving ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
+
# HF Space Note: Cache frequently accessed static files in memory.
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
@lru_cache(maxsize=10)
|
253 |
+
def read_static_file(file_path: str) -> Optional[str]:
|
254 |
+
"""Reads a static file, caching the result."""
|
255 |
+
full_path = Path(__file__).parent / file_path
|
256 |
+
if not full_path.is_file():
|
257 |
+
logger.warning(f"Static file not found: {file_path}")
|
258 |
+
return None
|
259 |
try:
|
260 |
+
with open(full_path, "r", encoding="utf-8") as file:
|
261 |
return file.read()
|
262 |
+
except Exception as e:
|
263 |
+
logger.error(f"Error reading static file {file_path}: {e}")
|
264 |
return None
|
265 |
|
266 |
+
async def serve_static_html(file_path: str) -> HTMLResponse:
|
267 |
+
"""Serves a static HTML file."""
|
268 |
+
content = read_static_file(file_path)
|
269 |
+
if content is None:
|
270 |
+
return HTMLResponse(content=f"<h1>Error: {file_path} not found</h1>", status_code=404)
|
271 |
+
return HTMLResponse(content=content)
|
272 |
|
273 |
+
# --- API Endpoints ---
|
|
|
|
|
|
|
274 |
|
275 |
+
# Basic Routes & Static Files
|
276 |
+
@app.get("/favicon.ico", include_in_schema=False)
|
277 |
+
async def favicon():
|
278 |
+
favicon_path = Path(__file__).parent / "favicon.ico"
|
279 |
+
if favicon_path.is_file():
|
280 |
+
return FileResponse(favicon_path, media_type="image/vnd.microsoft.icon")
|
281 |
+
raise HTTPException(status_code=404, detail="favicon.ico not found")
|
282 |
+
|
283 |
+
@app.get("/banner.jpg", include_in_schema=False)
|
284 |
+
async def banner():
|
285 |
+
banner_path = Path(__file__).parent / "banner.jpg"
|
286 |
+
if banner_path.is_file():
|
287 |
+
return FileResponse(banner_path, media_type="image/jpeg") # Assuming JPEG
|
288 |
+
raise HTTPException(status_code=404, detail="banner.jpg not found")
|
289 |
+
|
290 |
+
@app.get("/ping", tags=["Utility"])
|
291 |
async def ping():
|
292 |
+
"""Simple health check endpoint."""
|
293 |
+
return {"message": "pong"}
|
294 |
|
295 |
+
@app.get("/", response_class=HTMLResponse, tags=["Frontend"])
|
296 |
async def root():
|
297 |
+
"""Serves the main index HTML page."""
|
298 |
+
return await serve_static_html("index.html")
|
299 |
+
|
300 |
+
@app.get("/script.js", response_class=Response, tags=["Frontend"], include_in_schema=False)
|
301 |
+
async def script_js():
|
302 |
+
content = read_static_file("script.js")
|
303 |
+
if content is None:
|
304 |
+
return Response(content="/* script.js not found */", status_code=404, media_type="application/javascript")
|
305 |
+
return Response(content=content, media_type="application/javascript")
|
306 |
+
|
307 |
+
@app.get("/style.css", response_class=Response, tags=["Frontend"], include_in_schema=False)
|
308 |
+
async def style_css():
|
309 |
+
content = read_static_file("style.css")
|
310 |
+
if content is None:
|
311 |
+
return Response(content="/* style.css not found */", status_code=404, media_type="text/css")
|
312 |
+
return Response(content=content, media_type="text/css")
|
313 |
+
|
314 |
+
@app.get("/playground", response_class=HTMLResponse, tags=["Frontend"])
|
315 |
+
async def playground():
|
316 |
+
"""Serves the chat playground HTML page."""
|
317 |
+
return await serve_static_html("playground.html")
|
318 |
+
|
319 |
+
@app.get("/image-playground", response_class=HTMLResponse, tags=["Frontend"])
|
320 |
+
async def image_playground():
|
321 |
+
"""Serves the image playground HTML page."""
|
322 |
+
return await serve_static_html("image-playground.html")
|
323 |
+
|
324 |
+
# Dynamic Page Example
|
325 |
+
@app.get("/dynamo", response_class=HTMLResponse, tags=["Examples"])
|
326 |
async def dynamic_ai_page(request: Request):
|
327 |
+
"""Generates a dynamic HTML page using an AI model (example)."""
|
328 |
+
# HF Space Note: This uses a hardcoded URL to *itself* if running in the space.
|
329 |
+
# Ensure the HF_SPACE_URL secret is set correctly.
|
330 |
+
env_vars = get_env_vars()
|
331 |
+
hf_space_url = env_vars.get('hf_space_url', '')
|
332 |
+
if not hf_space_url:
|
333 |
+
raise HTTPException(status_code=500, detail="HF_SPACE_URL environment variable not set.")
|
334 |
+
|
335 |
+
user_agent = request.headers.get('user-agent', 'Unknown')
|
336 |
+
client_ip = request.client.host if request.client else "Unknown"
|
337 |
+
location = f"IP: {client_ip}" # Basic IP, location requires GeoIP lookup (extra dependency)
|
338 |
+
|
339 |
prompt = f"""
|
340 |
+
Generate a cool, dynamic HTML page for a user with the following details:
|
341 |
+
- App Name: "LokiAI"
|
342 |
- User-Agent: {user_agent}
|
343 |
+
- Location Info: {location}
|
344 |
+
- Style: Cyberpunk aesthetic, minimalist layout, maybe some retro touches.
|
345 |
+
- Content: Include a heading, a short motivational or witty message, and perhaps a subtle animation. Use inline CSS for styling within a <style> tag.
|
346 |
+
- Output: Provide ONLY the raw HTML code, starting with <!DOCTYPE html>. Do not wrap it in backticks or add explanations.
|
|
|
347 |
"""
|
348 |
+
|
349 |
payload = {
|
350 |
+
"model": "mistral-small-latest", # Or another capable model
|
351 |
+
"messages": [{"role": "user", "content": prompt}],
|
352 |
+
"max_tokens": 1000,
|
353 |
+
"temperature": 0.7
|
354 |
}
|
|
|
355 |
headers = {
|
356 |
+
# HF Space Note: Use the space's own URL and a valid API key if required by your setup.
|
357 |
+
# Here, we assume the playground key bypass works or use a dedicated internal key.
|
358 |
+
"Authorization": f"Bearer {list(env_vars['api_keys'])[0] if env_vars['api_keys'] else 'dummy-key'}" # Use first key or dummy
|
359 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
+
try:
|
362 |
+
# HF Space Note: Use the async client for internal requests too.
|
363 |
+
client = get_async_client()
|
364 |
+
api_url = f"{hf_space_url}/chat/completions" # Call own endpoint
|
365 |
+
response = await client.post(api_url, json=payload, headers=headers)
|
366 |
+
response.raise_for_status() # Raise exception for bad status codes
|
367 |
+
data = response.json()
|
368 |
+
|
369 |
+
html_content = data.get('choices', [{}])[0].get('message', {}).get('content', '')
|
370 |
|
371 |
+
# Basic cleanup (remove potential markdown backticks if model adds them)
|
372 |
+
html_content = re.sub(r"^```html\s*", "", html_content, flags=re.IGNORECASE)
|
373 |
+
html_content = re.sub(r"\s*```$", "", html_content)
|
374 |
|
375 |
+
if not html_content.strip().lower().startswith("<!doctype html"):
|
376 |
+
logger.warning("Dynamo page generation might be incomplete or malformed.")
|
377 |
+
# Optionally return a fallback static page here
|
378 |
|
379 |
+
return HTMLResponse(content=html_content)
|
|
|
380 |
|
381 |
+
except httpx.HTTPStatusError as e:
|
382 |
+
logger.error(f"Error calling self API for /dynamo: {e.response.status_code} - {e.response.text}")
|
383 |
+
raise HTTPException(status_code=502, detail=f"Failed to generate dynamic content: Upstream API error {e.response.status_code}")
|
384 |
+
except Exception as e:
|
385 |
+
logger.error(f"Unexpected error in /dynamo: {e}", exc_info=True)
|
386 |
+
raise HTTPException(status_code=500, detail="Failed to generate dynamic content due to an internal error.")
|
387 |
+
|
388 |
+
|
389 |
+
# Vetra Example (Fetching from GitHub)
|
390 |
+
# HF Space Note: Ensure outbound requests to raw.githubusercontent.com are allowed.
|
391 |
+
GITHUB_BASE = "https://raw.githubusercontent.com/Parthsadaria/Vetra/main"
|
392 |
+
VETRA_FILES = {"html": "index.html", "css": "style.css", "js": "script.js"}
|
393 |
|
394 |
+
async def get_github_file(filename: str) -> Optional[str]:
|
395 |
+
"""Fetches a file from the Vetra GitHub repo."""
|
396 |
url = f"{GITHUB_BASE}/{filename}"
|
397 |
+
try:
|
398 |
+
client = get_async_client()
|
399 |
res = await client.get(url)
|
400 |
+
res.raise_for_status()
|
401 |
+
return res.text
|
402 |
+
except httpx.RequestError as e:
|
403 |
+
logger.error(f"Error fetching GitHub file {url}: {e}")
|
404 |
+
return None
|
405 |
+
except httpx.HTTPStatusError as e:
|
406 |
+
logger.error(f"GitHub file {url} returned status {e.response.status_code}")
|
407 |
+
return None
|
408 |
|
409 |
+
@app.get("/vetra", response_class=HTMLResponse, tags=["Examples"])
|
410 |
async def serve_vetra():
|
411 |
+
"""Serves the Vetra application by fetching components from GitHub."""
|
412 |
+
logger.info("Fetching Vetra files from GitHub...")
|
413 |
+
# Fetch files concurrently
|
414 |
+
html_task = asyncio.create_task(get_github_file(VETRA_FILES["html"]))
|
415 |
+
css_task = asyncio.create_task(get_github_file(VETRA_FILES["css"]))
|
416 |
+
js_task = asyncio.create_task(get_github_file(VETRA_FILES["js"]))
|
417 |
+
|
418 |
+
html, css, js = await asyncio.gather(html_task, css_task, js_task)
|
419 |
|
420 |
if not html:
|
421 |
+
logger.error("Failed to fetch Vetra index.html")
|
422 |
+
return HTMLResponse(content="<h1>Error: Could not load Vetra application (HTML missing)</h1>", status_code=502)
|
423 |
+
|
424 |
+
# Inject CSS and JS into HTML
|
425 |
+
css_content = f"<style>{css or '/* CSS failed to load */'}</style>"
|
426 |
+
js_content = f"<script>{js or '// JS failed to load'}</script>"
|
427 |
|
428 |
+
# Inject carefully before closing tags
|
429 |
+
final_html = html.replace("</head>", f"{css_content}\n</head>", 1)
|
430 |
+
final_html = final_html.replace("</body>", f"{js_content}\n</body>", 1)
|
|
|
|
|
|
|
|
|
431 |
|
432 |
+
logger.info("Successfully served Vetra application.")
|
433 |
return HTMLResponse(content=final_html)
|
434 |
|
435 |
|
436 |
+
# Model Info Endpoint
|
437 |
+
@app.get("/api/v1/models", tags=["Models"])
|
438 |
+
@app.get("/models", tags=["Models"])
|
439 |
+
async def return_models():
|
440 |
+
"""Returns the list of available models loaded from models.json."""
|
441 |
+
# HF Space Note: This endpoint now relies on models.json being present.
|
442 |
+
# It no longer dynamically adds models defined only in the script's sets.
|
443 |
+
# Ensure models.json is comprehensive or adjust startup logic if needed.
|
444 |
+
return await get_models()
|
445 |
|
446 |
+
# Search Endpoint (using cloudscraper)
|
447 |
+
# HF Space Note: This uses cloudscraper which might be blocked or require updates.
|
448 |
+
# Consider replacing with a more stable search API if possible.
|
449 |
+
async def generate_search_async(query: str, systemprompt: Optional[str] = None) -> asyncio.Queue:
|
450 |
+
"""Performs search using the configured backend and streams results."""
|
451 |
+
queue = asyncio.Queue()
|
452 |
+
env_vars = get_env_vars()
|
453 |
+
search_endpoint = env_vars.get('secret_api_endpoint_3')
|
454 |
|
455 |
+
async def _fetch_search_data():
|
456 |
+
if not search_endpoint:
|
457 |
+
await queue.put({"error": "Search API endpoint (SECRET_API_ENDPOINT_3) not configured"})
|
458 |
+
await queue.put(None) # Signal end
|
459 |
+
return
|
460 |
|
461 |
+
try:
|
462 |
+
scraper = get_scraper() # Get a scraper instance from the pool
|
463 |
+
loop = asyncio.get_running_loop()
|
|
|
|
|
464 |
|
465 |
+
system_message = systemprompt or "You are a helpful search assistant."
|
466 |
+
messages = [
|
467 |
+
{"role": "system", "content": system_message},
|
468 |
+
{"role": "user", "content": query},
|
469 |
+
]
|
470 |
+
payload = {
|
471 |
+
"model": "searchgpt", # Assuming the endpoint expects this model name
|
472 |
+
"messages": messages,
|
473 |
+
"stream": True # Explicitly request streaming from backend
|
474 |
+
}
|
475 |
+
headers = {"User-Agent": "Mozilla/5.0"} # Standard user agent
|
476 |
+
|
477 |
+
# HF Space Note: Run synchronous scraper call in executor thread
|
478 |
+
response = await loop.run_in_executor(
|
479 |
+
executor,
|
480 |
+
scraper.post,
|
481 |
+
search_endpoint,
|
482 |
+
json=payload,
|
483 |
+
headers=headers,
|
484 |
+
stream=True # Request streaming from requests library perspective
|
485 |
+
)
|
486 |
+
|
487 |
+
response.raise_for_status()
|
488 |
+
|
489 |
+
# Process SSE stream
|
490 |
+
# HF Space Note: Iterating lines on the response directly can be blocking if not handled carefully.
|
491 |
+
# Using iter_lines with decode_unicode=True is generally safe.
|
492 |
+
for line in response.iter_lines(decode_unicode=True):
|
493 |
+
if line.startswith("data: "):
|
494 |
+
try:
|
495 |
+
data_str = line[6:]
|
496 |
+
if data_str.strip() == "[DONE]": # Check for OpenAI style completion
|
497 |
+
break
|
498 |
+
json_data = json.loads(data_str)
|
499 |
+
# Assuming OpenAI compatible streaming format
|
500 |
+
delta = json_data.get("choices", [{}])[0].get("delta", {})
|
501 |
+
content = delta.get("content")
|
502 |
+
if content:
|
503 |
+
# Reconstruct OpenAI-like SSE chunk
|
504 |
+
chunk = {
|
505 |
+
"id": json_data.get("id"),
|
506 |
+
"object": "chat.completion.chunk",
|
507 |
+
"created": json_data.get("created", int(time.time())),
|
508 |
+
"model": "searchgpt",
|
509 |
+
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}]
|
510 |
+
}
|
511 |
+
await queue.put({"data": f"data: {json.dumps(chunk)}\n\n", "text": content})
|
512 |
+
# Check for finish reason
|
513 |
+
finish_reason = json_data.get("choices", [{}])[0].get("finish_reason")
|
514 |
+
if finish_reason:
|
515 |
+
chunk = {
|
516 |
+
"id": json_data.get("id"),
|
517 |
+
"object": "chat.completion.chunk",
|
518 |
+
"created": json_data.get("created", int(time.time())),
|
519 |
+
"model": "searchgpt",
|
520 |
+
"choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}]
|
521 |
+
}
|
522 |
+
await queue.put({"data": f"data: {json.dumps(chunk)}\n\n", "text": ""})
|
523 |
+
break # Stop processing after finish reason
|
524 |
+
|
525 |
+
except json.JSONDecodeError:
|
526 |
+
logger.warning(f"Failed to decode JSON from search stream: {line}")
|
527 |
+
continue
|
528 |
+
except Exception as e:
|
529 |
+
logger.error(f"Error processing search stream chunk: {e}", exc_info=True)
|
530 |
+
await queue.put({"error": f"Error processing stream: {e}"})
|
531 |
+
break # Stop on processing error
|
532 |
+
|
533 |
+
except requests.exceptions.RequestException as e:
|
534 |
+
logger.error(f"Search request failed: {e}")
|
535 |
+
await queue.put({"error": f"Search request failed: {e}"})
|
536 |
+
except Exception as e:
|
537 |
+
logger.error(f"Unexpected error during search: {e}", exc_info=True)
|
538 |
+
await queue.put({"error": f"An unexpected error occurred during search: {e}"})
|
539 |
+
finally:
|
540 |
+
await queue.put(None) # Signal completion
|
541 |
+
|
542 |
+
|
543 |
+
asyncio.create_task(_fetch_search_data())
|
544 |
+
return queue
|
545 |
+
|
546 |
+
@app.get("/searchgpt", tags=["Search"])
|
547 |
+
async def search_gpt(q: str, stream: bool = True, systemprompt: Optional[str] = None):
|
548 |
+
"""
|
549 |
+
Performs a search using the backend search model and streams results.
|
550 |
+
Pass `stream=false` to get the full response at once.
|
551 |
+
"""
|
552 |
if not q:
|
553 |
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
|
554 |
|
555 |
+
# HF Space Note: Ensure usage_tracker is thread-safe if used across async/sync boundaries.
|
556 |
+
# The dummy tracker used when the module isn't found is safe.
|
557 |
usage_tracker.record_request(endpoint="/searchgpt")
|
558 |
|
559 |
+
queue = await generate_search_async(q, systemprompt=systemprompt)
|
560 |
|
561 |
if stream:
|
562 |
async def stream_generator():
|
563 |
+
full_response_text = "" # Keep track for non-streaming case if needed
|
564 |
while True:
|
565 |
item = await queue.get()
|
566 |
+
if item is None: # End of stream signal
|
567 |
break
|
|
|
568 |
if "error" in item:
|
569 |
+
# HF Space Note: Log errors server-side, return generic error to client for security.
|
570 |
+
logger.error(f"Search stream error: {item['error']}")
|
571 |
+
# Send an error event in the stream
|
572 |
+
error_event = {"error": {"message": "Search failed.", "code": 500}}
|
573 |
+
yield f"data: {json.dumps(error_event)}\n\n"
|
574 |
break
|
|
|
575 |
if "data" in item:
|
576 |
yield item["data"]
|
577 |
+
full_response_text += item.get("text", "")
|
578 |
+
# Optionally yield a [DONE] message if backend doesn't guarantee it
|
579 |
+
# yield "data: [DONE]\n\n"
|
580 |
|
581 |
return StreamingResponse(
|
582 |
stream_generator(),
|
583 |
+
media_type="text/event-stream",
|
584 |
+
headers={
|
585 |
+
"Content-Type": "text/event-stream",
|
586 |
+
"Cache-Control": "no-cache",
|
587 |
+
"Connection": "keep-alive",
|
588 |
+
"X-Accel-Buffering": "no" # Crucial for Nginx/proxies in HF Spaces
|
589 |
+
}
|
590 |
)
|
591 |
else:
|
592 |
+
# Collect full response for non-streaming request
|
593 |
+
full_response_text = ""
|
594 |
while True:
|
595 |
item = await queue.get()
|
596 |
if item is None:
|
597 |
break
|
|
|
598 |
if "error" in item:
|
599 |
+
logger.error(f"Search non-stream error: {item['error']}")
|
600 |
+
raise HTTPException(status_code=502, detail=f"Search failed: {item['error']}")
|
601 |
+
full_response_text += item.get("text", "")
|
602 |
+
|
603 |
+
# Mimic OpenAI non-streaming response structure
|
604 |
+
return JSONResponse(content={
|
605 |
+
"id": f"search-{int(time.time())}",
|
606 |
+
"object": "chat.completion",
|
607 |
+
"created": int(time.time()),
|
608 |
+
"model": "searchgpt",
|
609 |
+
"choices": [{
|
610 |
+
"index": 0,
|
611 |
+
"message": {
|
612 |
+
"role": "assistant",
|
613 |
+
"content": full_response_text,
|
614 |
+
},
|
615 |
+
"finish_reason": "stop",
|
616 |
+
}],
|
617 |
+
"usage": { # Note: Token usage is unknown here
|
618 |
+
"prompt_tokens": None,
|
619 |
+
"completion_tokens": None,
|
620 |
+
"total_tokens": None,
|
621 |
+
}
|
622 |
+
})
|
623 |
|
624 |
|
625 |
+
# Main Chat Completions Proxy
|
626 |
+
@app.post("/api/v1/chat/completions", tags=["Chat Completions"])
|
627 |
+
@app.post("/chat/completions", tags=["Chat Completions"])
|
628 |
+
async def get_completion(
|
629 |
+
payload: Payload,
|
630 |
+
request: Request,
|
631 |
+
authenticated: bool = Depends(verify_api_key) # Apply authentication
|
632 |
+
):
|
633 |
+
"""
|
634 |
+
Proxies chat completion requests to the appropriate backend API based on the model.
|
635 |
+
Supports streaming (SSE).
|
636 |
+
"""
|
637 |
if not server_status:
|
638 |
+
raise HTTPException(status_code=503, detail="Server is under maintenance.")
|
|
|
|
|
|
|
639 |
|
640 |
+
model_to_use = payload.model or "gpt-4o-mini" # Default model
|
641 |
|
642 |
+
# HF Space Note: Check against models loaded at startup.
|
643 |
+
if available_model_ids and model_to_use not in available_model_ids:
|
644 |
+
logger.warning(f"Requested model '{model_to_use}' not in available list.")
|
645 |
+
# Check if it's a known category even if not explicitly in models.json
|
646 |
+
known_categories = mistral_models | pollinations_models | alternate_models | claude_3_models
|
647 |
+
if model_to_use not in known_categories:
|
648 |
+
raise HTTPException(
|
649 |
+
status_code=400,
|
650 |
+
detail=f"Model '{model_to_use}' is not available or recognized. Check /models."
|
651 |
+
)
|
652 |
+
else:
|
653 |
+
logger.info(f"Allowing known category model '{model_to_use}' despite not being in models.json.")
|
654 |
|
655 |
+
|
656 |
+
# Log request asynchronously
|
657 |
asyncio.create_task(log_request(request, model_to_use))
|
658 |
usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
|
659 |
|
660 |
+
# Prepare payload for the target API
|
661 |
+
payload_dict = payload.dict(exclude_none=True) # Exclude None values
|
662 |
+
payload_dict["model"] = model_to_use # Ensure model is set
|
|
|
|
|
|
|
663 |
|
|
|
664 |
env_vars = get_env_vars()
|
665 |
+
hf_space_url = env_vars.get('hf_space_url', '') # Needed for Referer/Origin
|
666 |
+
|
667 |
+
# Determine target endpoint and headers
|
668 |
+
endpoint = None
|
669 |
+
custom_headers = {}
|
670 |
|
|
|
671 |
if model_to_use in mistral_models:
|
672 |
+
endpoint = env_vars.get('mistral_api')
|
673 |
+
api_key = env_vars.get('mistral_key')
|
674 |
+
if not endpoint or not api_key:
|
675 |
+
raise HTTPException(status_code=500, detail="Mistral API endpoint or key not configured.")
|
676 |
+
custom_headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "application/json"}
|
677 |
+
# Mistral specific adjustments if needed
|
678 |
+
# payload_dict.pop('system', None) # Example: if Mistral doesn't use 'system' role
|
679 |
+
|
680 |
elif model_to_use in pollinations_models:
|
681 |
+
endpoint = env_vars.get('secret_api_endpoint_4')
|
682 |
+
if not endpoint:
|
683 |
+
raise HTTPException(status_code=500, detail="Pollinations API endpoint (SECRET_API_ENDPOINT_4) not configured.")
|
684 |
+
# Pollinations might need specific headers? Add them here.
|
685 |
+
custom_headers = {"Content-Type": "application/json"}
|
686 |
+
|
687 |
elif model_to_use in alternate_models:
|
688 |
+
endpoint = env_vars.get('secret_api_endpoint_2')
|
689 |
+
if not endpoint:
|
690 |
+
raise HTTPException(status_code=500, detail="Alternate API endpoint (SECRET_API_ENDPOINT_2) not configured.")
|
691 |
+
custom_headers = {"Content-Type": "application/json"}
|
692 |
+
|
693 |
+
elif model_to_use in claude_3_models:
|
694 |
+
endpoint = env_vars.get('secret_api_endpoint_5')
|
695 |
+
if not endpoint:
|
696 |
+
raise HTTPException(status_code=500, detail="Claude 3 API endpoint (SECRET_API_ENDPOINT_5) not configured.")
|
697 |
+
custom_headers = {"Content-Type": "application/json"}
|
698 |
+
# Claude specific headers (like anthropic-version) might be needed
|
699 |
+
# custom_headers["anthropic-version"] = "2023-06-01"
|
700 |
+
|
701 |
+
else: # Default endpoint
|
702 |
+
endpoint = env_vars.get('secret_api_endpoint')
|
703 |
+
if not endpoint:
|
704 |
+
raise HTTPException(status_code=500, detail="Default API endpoint (SECRET_API_ENDPOINT) not configured.")
|
705 |
+
# Default endpoint might need Origin/Referer
|
706 |
+
if hf_space_url:
|
707 |
+
custom_headers = {
|
708 |
+
"Origin": hf_space_url,
|
709 |
+
"Referer": hf_space_url,
|
710 |
+
"Content-Type": "application/json"
|
711 |
+
}
|
712 |
+
else:
|
713 |
+
custom_headers = {"Content-Type": "application/json"}
|
714 |
+
|
715 |
|
716 |
+
target_url = f"{endpoint.rstrip('/')}/v1/chat/completions" # Assume OpenAI compatible path
|
717 |
+
logger.info(f"Proxying request for model '{model_to_use}' to endpoint: {endpoint}")
|
718 |
|
719 |
+
client = get_async_client()
|
720 |
+
|
721 |
+
async def stream_generator():
|
722 |
+
"""Generator for streaming the response."""
|
723 |
+
nonlocal target_url # Allow modification if needed
|
724 |
try:
|
725 |
+
async with client.stream("POST", target_url, json=payload_dict, headers=custom_headers) as response:
|
726 |
+
# Check for initial errors before streaming
|
727 |
+
if response.status_code >= 400:
|
728 |
+
error_body = await response.aread()
|
729 |
+
logger.error(f"Upstream API error: {response.status_code} - {error_body.decode()}")
|
730 |
+
# Try to parse error detail from upstream
|
731 |
+
detail = f"Upstream API error: {response.status_code}"
|
732 |
+
try:
|
733 |
+
error_json = json.loads(error_body)
|
734 |
+
detail = error_json.get('error', {}).get('message', detail)
|
735 |
+
except json.JSONDecodeError:
|
736 |
+
pass
|
737 |
+
# Send error as SSE event
|
738 |
+
error_event = {"error": {"message": detail, "code": response.status_code}}
|
739 |
+
yield f"data: {json.dumps(error_event)}\n\n"
|
740 |
+
return # Stop generation
|
741 |
+
|
742 |
+
# Stream the response line by line
|
743 |
+
async for line in response.aiter_lines():
|
744 |
+
if line:
|
745 |
+
# Pass through the data directly
|
746 |
+
yield line + "\n"
|
747 |
+
# Ensure stream is properly closed, yield [DONE] if backend doesn't
|
748 |
+
# Some backends might not send [DONE], uncomment if needed
|
749 |
+
# yield "data: [DONE]\n\n"
|
750 |
+
|
751 |
except httpx.TimeoutException:
|
752 |
+
logger.error(f"Request to {target_url} timed out.")
|
753 |
+
error_event = {"error": {"message": "Request timed out", "code": 504}}
|
754 |
+
yield f"data: {json.dumps(error_event)}\n\n"
|
755 |
except httpx.RequestError as e:
|
756 |
+
logger.error(f"Failed to connect to upstream API {target_url}: {e}")
|
757 |
+
error_event = {"error": {"message": f"Upstream connection error: {e}", "code": 502}}
|
758 |
+
yield f"data: {json.dumps(error_event)}\n\n"
|
759 |
except Exception as e:
|
760 |
+
logger.error(f"An unexpected error occurred during streaming proxy: {e}", exc_info=True)
|
761 |
+
error_event = {"error": {"message": f"Internal server error: {e}", "code": 500}}
|
762 |
+
yield f"data: {json.dumps(error_event)}\n\n"
|
763 |
|
764 |
+
if payload.stream:
|
|
|
765 |
return StreamingResponse(
|
766 |
+
stream_generator(),
|
767 |
media_type="text/event-stream",
|
768 |
headers={
|
769 |
"Content-Type": "text/event-stream",
|
770 |
"Cache-Control": "no-cache",
|
771 |
"Connection": "keep-alive",
|
772 |
+
"X-Accel-Buffering": "no" # Essential for HF Spaces proxying SSE
|
773 |
}
|
774 |
)
|
775 |
else:
|
776 |
+
# Handle non-streaming request by collecting the streamed chunks
|
777 |
+
full_response_content = ""
|
778 |
+
final_json_response = None
|
779 |
+
async for line in stream_generator():
|
780 |
+
if line.startswith("data: "):
|
781 |
+
data_str = line[6:].strip()
|
782 |
+
if data_str == "[DONE]":
|
783 |
+
break
|
784 |
+
try:
|
785 |
+
chunk = json.loads(data_str)
|
786 |
+
# Check for error chunk
|
787 |
+
if "error" in chunk:
|
788 |
+
logger.error(f"Received error during non-stream collection: {chunk['error']}")
|
789 |
+
raise HTTPException(status_code=chunk['error'].get('code', 502), detail=chunk['error'].get('message', 'Upstream API error'))
|
790 |
+
|
791 |
+
# Accumulate content from delta
|
792 |
+
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
793 |
+
content = delta.get("content")
|
794 |
+
if content:
|
795 |
+
full_response_content += content
|
796 |
+
|
797 |
+
# Store the last chunk structure to reconstruct the final response
|
798 |
+
# We assume the last chunk contains necessary info like id, model, etc.
|
799 |
+
# but we overwrite the choices/message part.
|
800 |
+
final_json_response = chunk # Keep the structure
|
801 |
+
# Check for finish reason
|
802 |
+
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason")
|
803 |
+
if finish_reason:
|
804 |
+
break # Stop collecting
|
805 |
+
|
806 |
+
except json.JSONDecodeError:
|
807 |
+
logger.warning(f"Could not decode JSON chunk in non-stream mode: {data_str}")
|
808 |
+
except Exception as e:
|
809 |
+
logger.error(f"Error processing chunk in non-stream mode: {e}")
|
810 |
+
raise HTTPException(status_code=500, detail="Error processing response stream.")
|
811 |
+
|
812 |
+
if final_json_response is None:
|
813 |
+
# Handle cases where no valid data chunks were received
|
814 |
+
logger.error("No valid response chunks received for non-streaming request.")
|
815 |
+
raise HTTPException(status_code=502, detail="Failed to get valid response from upstream API.")
|
816 |
+
|
817 |
+
|
818 |
+
# Reconstruct OpenAI-like non-streaming response
|
819 |
+
final_response_obj = {
|
820 |
+
"id": final_json_response.get("id", f"chatcmpl-{int(time.time())}"),
|
821 |
+
"object": "chat.completion",
|
822 |
+
"created": final_json_response.get("created", int(time.time())),
|
823 |
+
"model": model_to_use, # Use the requested model
|
824 |
+
"choices": [{
|
825 |
+
"index": 0,
|
826 |
+
"message": {
|
827 |
+
"role": "assistant",
|
828 |
+
"content": full_response_content,
|
829 |
+
},
|
830 |
+
"finish_reason": final_json_response.get("choices", [{}])[0].get("finish_reason", "stop"), # Get finish reason from last chunk
|
831 |
+
}],
|
832 |
+
"usage": { # Token usage might be in the last chunk for some APIs, otherwise unknown
|
833 |
+
"prompt_tokens": None,
|
834 |
+
"completion_tokens": None,
|
835 |
+
"total_tokens": None,
|
836 |
+
}
|
837 |
+
}
|
838 |
+
# Attempt to extract usage if present in the (potentially non-standard) final chunk
|
839 |
+
usage_data = final_json_response.get("usage")
|
840 |
+
if isinstance(usage_data, dict):
|
841 |
+
final_response_obj["usage"].update(usage_data)
|
842 |
|
|
|
843 |
|
844 |
+
return JSONResponse(content=final_response_obj)
|
845 |
|
846 |
|
847 |
+
# Image Generation Endpoint
|
848 |
+
@app.post("/images/generations", tags=["Image Generation"])
|
849 |
+
async def create_image(
|
850 |
+
payload: ImageGenerationPayload,
|
851 |
+
authenticated: bool = Depends(verify_api_key)
|
852 |
+
):
|
853 |
"""
|
854 |
+
Generates images based on a text prompt using the configured backend.
|
855 |
"""
|
|
|
856 |
if not server_status:
|
857 |
+
raise HTTPException(status_code=503, detail="Server is under maintenance.")
|
|
|
|
|
|
|
858 |
|
|
|
859 |
if payload.model not in supported_image_models:
|
860 |
raise HTTPException(
|
861 |
status_code=400,
|
862 |
+
detail=f"Model '{payload.model}' is not supported for image generation. Supported: {', '.join(supported_image_models)}"
|
863 |
)
|
864 |
|
|
|
865 |
usage_tracker.record_request(model=payload.model, endpoint="/images/generations")
|
866 |
|
867 |
+
env_vars = get_env_vars()
|
868 |
+
target_api_url = env_vars.get('new_img_endpoint')
|
869 |
+
if not target_api_url:
|
870 |
+
raise HTTPException(status_code=500, detail="Image generation endpoint (NEW_IMG) not configured.")
|
871 |
+
|
872 |
+
# Prepare payload for the target API (adjust keys if needed)
|
873 |
+
# HF Space Note: Ensure the keys match the actual API expected by NEW_IMG endpoint.
|
874 |
+
# Assuming it's OpenAI compatible here.
|
875 |
api_payload = {
|
876 |
"model": payload.model,
|
877 |
"prompt": payload.prompt,
|
878 |
+
"n": payload.n,
|
879 |
+
"size": payload.size
|
880 |
}
|
881 |
+
# Remove None values the target API might not like
|
882 |
+
api_payload = {k: v for k, v in api_payload.items() if v is not None}
|
883 |
|
|
|
|
|
884 |
|
885 |
+
logger.info(f"Requesting image generation for model '{payload.model}' from {target_api_url}")
|
886 |
+
client = get_async_client()
|
|
|
|
|
887 |
|
888 |
+
try:
|
889 |
+
# HF Space Note: Image generation can take time, use a longer timeout if needed.
|
890 |
+
# Consider making this truly async if the backend supports webhooks or polling.
|
891 |
+
response = await client.post(target_api_url, json=api_payload, timeout=120.0) # 2 min timeout
|
892 |
+
response.raise_for_status() # Raise HTTP errors
|
893 |
|
894 |
+
# Return the exact response from the backend
|
895 |
return JSONResponse(content=response.json())
|
896 |
|
897 |
except httpx.TimeoutException:
|
898 |
+
logger.error(f"Image generation request to {target_api_url} timed out.")
|
899 |
raise HTTPException(status_code=504, detail="Image generation request timed out.")
|
900 |
+
except httpx.HTTPStatusError as e:
|
901 |
+
logger.error(f"Image generation API error: {e.response.status_code} - {e.response.text}")
|
902 |
+
detail = f"Image generation failed: Upstream API error {e.response.status_code}"
|
903 |
+
try:
|
904 |
+
err_json = e.response.json()
|
905 |
+
detail = err_json.get('error', {}).get('message', detail)
|
906 |
+
except json.JSONDecodeError:
|
907 |
+
pass
|
908 |
+
raise HTTPException(status_code=e.response.status_code, detail=detail)
|
909 |
except httpx.RequestError as e:
|
910 |
+
logger.error(f"Error connecting to image generation service {target_api_url}: {e}")
|
911 |
raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}")
|
912 |
except Exception as e:
|
913 |
+
logger.error(f"Unexpected error during image generation: {e}", exc_info=True)
|
914 |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}")
|
915 |
|
916 |
|
917 |
+
# --- Utility & Admin Endpoints ---
|
918 |
|
919 |
+
async def log_request(request: Request, model: Optional[str] = None):
|
920 |
+
"""Logs basic request information asynchronously."""
|
921 |
+
# HF Space Note: Avoid logging sensitive info like full IP or headers unless necessary.
|
922 |
+
# Hashing IP provides some privacy.
|
923 |
+
client_host = request.client.host if request.client else "unknown"
|
924 |
+
ip_hash = hash(client_host) % 10000
|
925 |
+
timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
|
926 |
+
log_message = f"Timestamp: {timestamp}, IP Hash: {ip_hash}, Method: {request.method}, Path: {request.url.path}"
|
927 |
+
if model:
|
928 |
+
log_message += f", Model: {model}"
|
929 |
+
logger.info(log_message)
|
930 |
|
|
|
|
|
|
|
|
|
931 |
|
932 |
+
@app.get("/usage", tags=["Admin"])
|
933 |
async def get_usage(days: int = 7):
|
934 |
+
"""Retrieves aggregated usage statistics."""
|
935 |
+
# HF Space Note: Ensure usage_tracker methods are efficient, especially get_usage_summary.
|
936 |
+
# Caching might be needed if it becomes slow.
|
937 |
+
if days <= 0:
|
938 |
+
raise HTTPException(status_code=400, detail="Number of days must be positive.")
|
939 |
+
try:
|
940 |
+
# Run potentially CPU-bound summary generation in executor
|
941 |
+
loop = asyncio.get_running_loop()
|
942 |
+
summary = await loop.run_in_executor(executor, usage_tracker.get_usage_summary, days)
|
943 |
+
return summary
|
944 |
+
except Exception as e:
|
945 |
+
logger.error(f"Error retrieving usage statistics: {e}", exc_info=True)
|
946 |
+
raise HTTPException(status_code=500, detail="Failed to retrieve usage statistics.")
|
947 |
+
|
948 |
+
# HF Space Note: Generating HTML dynamically can be resource-intensive.
|
949 |
+
# Consider caching the generated HTML or serving a static page updated periodically.
|
950 |
+
def generate_usage_html(usage_data: Dict) -> str:
|
951 |
+
"""Generates an HTML report from usage data."""
|
952 |
+
# (Keep the HTML generation logic as provided in the original file)
|
953 |
+
# ... (rest of the HTML generation code from the original file) ...
|
954 |
+
# Ensure this function handles potentially missing keys gracefully
|
955 |
+
models_usage = usage_data.get('models', {})
|
956 |
+
endpoints_usage = usage_data.get('api_endpoints', {})
|
957 |
+
daily_usage = usage_data.get('recent_daily_usage', {})
|
958 |
+
total_requests = usage_data.get('total_requests', 0)
|
959 |
|
|
|
|
|
|
|
960 |
model_usage_rows = "\n".join([
|
961 |
f"""
|
962 |
<tr>
|
963 |
<td>{model}</td>
|
964 |
+
<td>{model_data.get('total_requests', 'N/A')}</td>
|
965 |
+
<td>{model_data.get('first_used', 'N/A')}</td>
|
966 |
+
<td>{model_data.get('last_used', 'N/A')}</td>
|
967 |
</tr>
|
968 |
+
""" for model, model_data in models_usage.items()
|
969 |
+
]) if models_usage else "<tr><td colspan='4'>No model usage data</td></tr>"
|
970 |
|
|
|
971 |
api_usage_rows = "\n".join([
|
972 |
f"""
|
973 |
<tr>
|
974 |
<td>{endpoint}</td>
|
975 |
+
<td>{endpoint_data.get('total_requests', 'N/A')}</td>
|
976 |
+
<td>{endpoint_data.get('first_used', 'N/A')}</td>
|
977 |
+
<td>{endpoint_data.get('last_used', 'N/A')}</td>
|
978 |
</tr>
|
979 |
+
""" for endpoint, endpoint_data in endpoints_usage.items()
|
980 |
+
]) if endpoints_usage else "<tr><td colspan='4'>No API endpoint usage data</td></tr>"
|
981 |
|
|
|
982 |
daily_usage_rows = "\n".join([
|
983 |
+
f"""
|
984 |
+
<tr>
|
985 |
+
<td>{date}</td>
|
986 |
+
<td>{entity}</td>
|
987 |
+
<td>{requests}</td>
|
988 |
+
</tr>
|
989 |
+
"""
|
990 |
+
for date, date_data in daily_usage.items()
|
991 |
+
for entity, requests in date_data.items()
|
992 |
+
]) if daily_usage else "<tr><td colspan='3'>No daily usage data</td></tr>"
|
993 |
|
994 |
+
|
995 |
+
# HF Space Note: Using f-string for large HTML is okay, but consider template engines (Jinja2)
|
996 |
+
# for more complex pages. Ensure CSS/JS are either inline or served via separate endpoints.
|
997 |
html_content = f"""
|
998 |
<!DOCTYPE html>
|
999 |
<html lang="en">
|
1000 |
<head>
|
1001 |
<meta charset="UTF-8">
|
1002 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
1003 |
<title>Lokiai AI - Usage Statistics</title>
|
1004 |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet">
|
1005 |
<style>
|
1006 |
+
/* (Keep the CSS styles as provided in the original file) */
|
1007 |
:root {{
|
1008 |
+
--bg-dark: #0f1011; --bg-darker: #070708; --text-primary: #e6e6e6;
|
1009 |
+
--text-secondary: #8c8c8c; --border-color: #2c2c2c; --accent-color: #3a6ee0;
|
|
|
|
|
|
|
|
|
1010 |
--accent-hover: #4a7ef0;
|
1011 |
}}
|
1012 |
+
body {{ font-family: 'Inter', sans-serif; background-color: var(--bg-dark); color: var(--text-primary); max-width: 1200px; margin: 0 auto; padding: 40px 20px; line-height: 1.6; }}
|
1013 |
+
.logo {{ display: flex; align-items: center; justify-content: center; margin-bottom: 30px; }}
|
1014 |
+
.logo h1 {{ font-weight: 600; font-size: 2.5em; color: var(--text-primary); margin-left: 15px; }}
|
1015 |
+
.logo img {{ width: 60px; height: 60px; border-radius: 10px; }}
|
1016 |
+
.container {{ background-color: var(--bg-darker); border-radius: 12px; padding: 30px; box-shadow: 0 15px 40px rgba(0,0,0,0.3); border: 1px solid var(--border-color); }}
|
1017 |
+
h2, h3 {{ color: var(--text-primary); border-bottom: 2px solid var(--border-color); padding-bottom: 10px; font-weight: 500; }}
|
1018 |
+
.total-requests {{ background-color: var(--accent-color); color: white; text-align: center; padding: 15px; border-radius: 8px; margin-bottom: 30px; font-weight: 600; letter-spacing: -0.5px; }}
|
1019 |
+
table {{ width: 100%; border-collapse: separate; border-spacing: 0; margin-bottom: 30px; background-color: var(--bg-dark); border-radius: 8px; overflow: hidden; }}
|
1020 |
+
th, td {{ border: 1px solid var(--border-color); padding: 12px; text-align: left; transition: background-color 0.3s ease; }}
|
1021 |
+
th {{ background-color: #1e1e1e; color: var(--text-primary); font-weight: 600; text-transform: uppercase; font-size: 0.9em; }}
|
1022 |
+
tr:nth-child(even) {{ background-color: rgba(255,255,255,0.05); }}
|
1023 |
+
tr:hover {{ background-color: rgba(62,100,255,0.1); }}
|
1024 |
+
@media (max-width: 768px) {{ .container {{ padding: 15px; }} table {{ font-size: 0.9em; }} }}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1025 |
</style>
|
1026 |
</head>
|
1027 |
<body>
|
1028 |
<div class="container">
|
1029 |
<div class="logo">
|
1030 |
<img src="data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjAwIiBoZWlnaHQ9IjIwMCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48cGF0aCBkPSJNMTAwIDM1TDUwIDkwaDEwMHoiIGZpbGw9IiMzYTZlZTAiLz48Y2lyY2xlIGN4PSIxMDAiIGN5PSIxNDAiIHI9IjMwIiBmaWxsPSIjM2E2ZWUwIi8+PC9zdmc+" alt="Lokai AI Logo">
|
1031 |
+
<h1>Lokiai AI Usage</h1>
|
1032 |
</div>
|
1033 |
|
1034 |
<div class="total-requests">
|
1035 |
+
Total API Requests Recorded: {total_requests}
|
1036 |
</div>
|
1037 |
|
1038 |
<h2>Model Usage</h2>
|
1039 |
<table>
|
1040 |
+
<thead><tr><th>Model</th><th>Total Requests</th><th>First Used</th><th>Last Used</th></tr></thead>
|
1041 |
+
<tbody>{model_usage_rows}</tbody>
|
|
|
|
|
|
|
|
|
|
|
1042 |
</table>
|
1043 |
|
1044 |
<h2>API Endpoint Usage</h2>
|
1045 |
<table>
|
1046 |
+
<thead><tr><th>Endpoint</th><th>Total Requests</th><th>First Used</th><th>Last Used</th></tr></thead>
|
1047 |
+
<tbody>{api_usage_rows}</tbody>
|
|
|
|
|
|
|
|
|
|
|
1048 |
</table>
|
1049 |
|
1050 |
+
<h2>Daily Usage (Last {usage_data.get('days_analyzed', 7)} Days)</h2>
|
1051 |
<table>
|
1052 |
+
<thead><tr><th>Date</th><th>Entity (Model/Endpoint)</th><th>Requests</th></tr></thead>
|
1053 |
+
<tbody>{daily_usage_rows}</tbody>
|
|
|
|
|
|
|
|
|
1054 |
</table>
|
1055 |
</div>
|
1056 |
</body>
|
|
|
1058 |
"""
|
1059 |
return html_content
|
1060 |
|
1061 |
+
# HF Space Note: Caching the generated HTML page can save resources.
|
1062 |
+
# Invalidate cache periodically or when usage data changes significantly.
|
1063 |
+
usage_html_cache = {"content": None, "timestamp": 0}
|
1064 |
+
CACHE_DURATION = 300 # Cache usage page for 5 minutes
|
|
|
1065 |
|
1066 |
+
@app.get("/usage/page", response_class=HTMLResponse, tags=["Admin"])
|
1067 |
async def usage_page():
|
1068 |
+
"""Serves an HTML page showing usage statistics."""
|
1069 |
+
now = time.monotonic()
|
1070 |
+
if usage_html_cache["content"] and (now - usage_html_cache["timestamp"] < CACHE_DURATION):
|
1071 |
+
logger.info("Serving cached usage page.")
|
1072 |
+
return HTMLResponse(content=usage_html_cache["content"])
|
1073 |
+
|
1074 |
+
logger.info("Generating fresh usage page.")
|
1075 |
+
try:
|
1076 |
+
# Run potentially slow parts in executor
|
1077 |
+
loop = asyncio.get_running_loop()
|
1078 |
+
usage_data = await loop.run_in_executor(executor, usage_tracker.get_usage_summary, 7) # Get data for 7 days
|
1079 |
+
html_content = await loop.run_in_executor(executor, generate_usage_html, usage_data)
|
1080 |
+
|
1081 |
+
# Update cache
|
1082 |
+
usage_html_cache["content"] = html_content
|
1083 |
+
usage_html_cache["timestamp"] = now
|
1084 |
+
|
1085 |
+
return HTMLResponse(content=html_content)
|
1086 |
+
except Exception as e:
|
1087 |
+
logger.error(f"Failed to generate usage page: {e}", exc_info=True)
|
1088 |
+
# Serve stale cache if available, otherwise error
|
1089 |
+
if usage_html_cache["content"]:
|
1090 |
+
logger.warning("Serving stale usage page due to generation error.")
|
1091 |
+
return HTMLResponse(content=usage_html_cache["content"])
|
1092 |
+
else:
|
1093 |
+
raise HTTPException(status_code=500, detail="Failed to generate usage statistics page.")
|
1094 |
+
|
1095 |
|
1096 |
+
# Meme Endpoint
|
1097 |
+
@app.get("/meme", tags=["Fun"])
|
1098 |
async def get_meme():
|
1099 |
+
"""Fetches a random meme and streams the image."""
|
1100 |
+
# HF Space Note: Ensure meme-api.com is accessible from the HF Space network.
|
1101 |
+
client = get_async_client()
|
1102 |
+
meme_api_url = "https://meme-api.com/gimme"
|
1103 |
try:
|
1104 |
+
logger.info("Fetching meme info...")
|
1105 |
+
response = await client.get(meme_api_url)
|
1106 |
+
response.raise_for_status()
|
1107 |
response_data = response.json()
|
1108 |
|
1109 |
meme_url = response_data.get("url")
|
1110 |
+
if not meme_url or not isinstance(meme_url, str):
|
1111 |
+
logger.error(f"Invalid meme URL received from API: {meme_url}")
|
1112 |
+
raise HTTPException(status_code=502, detail="Failed to get valid meme URL from API.")
|
1113 |
+
|
1114 |
+
logger.info(f"Fetching meme image: {meme_url}")
|
1115 |
+
# Use streaming request for the image itself
|
1116 |
+
async with client.stream("GET", meme_url) as image_response:
|
1117 |
+
image_response.raise_for_status() # Check if image URL is valid
|
1118 |
+
|
1119 |
+
# Get content type, default to image/png
|
1120 |
+
media_type = image_response.headers.get("content-type", "image/png")
|
1121 |
+
if not media_type.startswith("image/"):
|
1122 |
+
logger.warning(f"Unexpected content type '{media_type}' for meme URL: {meme_url}")
|
1123 |
+
# You might want to reject non-image types
|
1124 |
+
# raise HTTPException(status_code=502, detail="Meme URL did not return an image.")
|
1125 |
+
|
1126 |
+
|
1127 |
+
# Stream the image content directly
|
1128 |
+
return StreamingResponse(
|
1129 |
+
image_response.aiter_bytes(),
|
1130 |
+
media_type=media_type,
|
1131 |
+
headers={'Cache-Control': 'no-cache'} # Don't cache the meme itself heavily
|
1132 |
+
)
|
1133 |
+
|
1134 |
+
except httpx.HTTPStatusError as e:
|
1135 |
+
logger.error(f"HTTP error fetching meme ({e.request.url}): {e.response.status_code}")
|
1136 |
+
raise HTTPException(status_code=502, detail=f"Failed to fetch meme (HTTP {e.response.status_code})")
|
1137 |
+
except httpx.RequestError as e:
|
1138 |
+
logger.error(f"Network error fetching meme ({e.request.url}): {e}")
|
1139 |
+
raise HTTPException(status_code=502, detail="Failed to fetch meme (Network Error)")
|
1140 |
+
except Exception as e:
|
1141 |
+
logger.error(f"Unexpected error fetching meme: {e}", exc_info=True)
|
1142 |
+
raise HTTPException(status_code=500, detail="Failed to retrieve meme due to an internal error.")
|
1143 |
|
|
|
1144 |
|
1145 |
+
# Health Check Endpoint
|
1146 |
+
@app.get("/health", tags=["Utility"])
|
1147 |
+
async def health_check():
|
1148 |
+
"""Provides a health check status, including missing critical configurations."""
|
1149 |
+
env_vars = get_env_vars()
|
1150 |
+
missing_critical_vars = []
|
|
|
1151 |
|
1152 |
+
# Define critical vars needed for core functionality
|
1153 |
+
critical_vars = [
|
1154 |
+
'api_keys', 'secret_api_endpoint', 'secret_api_endpoint_2',
|
1155 |
+
'secret_api_endpoint_3', 'secret_api_endpoint_4', 'secret_api_endpoint_5',
|
1156 |
+
'new_img_endpoint', 'hf_space_url'
|
1157 |
+
]
|
1158 |
+
# Conditionally critical vars
|
1159 |
+
if any(model in mistral_models for model in available_model_ids):
|
1160 |
+
critical_vars.extend(['mistral_api', 'mistral_key'])
|
1161 |
+
|
1162 |
+
for var_name in critical_vars:
|
1163 |
+
value = env_vars.get(var_name)
|
1164 |
+
# Check for None or empty strings/lists/sets
|
1165 |
+
if value is None or (isinstance(value, (str, list, set)) and not value):
|
1166 |
+
missing_critical_vars.append(var_name)
|
1167 |
+
|
1168 |
+
is_healthy = not missing_critical_vars and server_status
|
1169 |
+
status_code = 200 if is_healthy else 503 # Service Unavailable if unhealthy
|
1170 |
|
1171 |
+
health_status = {
|
1172 |
+
"status": "healthy" if is_healthy else "unhealthy",
|
1173 |
+
"server_mode": "online" if server_status else "maintenance",
|
1174 |
+
"missing_critical_env_vars": missing_critical_vars,
|
1175 |
+
"details": "All critical configurations seem okay. Ready to roll! 🚀" if is_healthy else "Service issues detected. Check missing env vars or server status. 🛠️"
|
1176 |
+
}
|
1177 |
+
return JSONResponse(content=health_status, status_code=status_code)
|
1178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1179 |
|
1180 |
+
# --- Startup and Shutdown Events ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1181 |
|
1182 |
@app.on_event("startup")
|
1183 |
async def startup_event():
|
1184 |
+
"""Tasks to run when the application starts."""
|
1185 |
global available_model_ids
|
1186 |
+
logger.info("Application startup sequence initiated...")
|
1187 |
+
|
1188 |
+
# Load models from JSON
|
1189 |
+
models_from_file = load_models_data()
|
1190 |
+
model_ids_from_file = {model['id'] for model in models_from_file if 'id' in model}
|
1191 |
+
|
1192 |
+
# Combine models from file and predefined sets
|
1193 |
+
predefined_model_sets = mistral_models | pollinations_models | alternate_models | claude_3_models
|
1194 |
+
all_model_ids = model_ids_from_file.union(predefined_model_sets)
|
1195 |
+
available_model_ids = sorted(list(all_model_ids)) # Keep as sorted list
|
1196 |
+
|
1197 |
+
logger.info(f"Loaded {len(model_ids_from_file)} models from models.json.")
|
1198 |
+
logger.info(f"Total {len(available_model_ids)} unique models available.")
|
1199 |
+
|
1200 |
+
# Initialize scraper pool (can take time)
|
1201 |
+
# HF Space Note: Run potentially blocking I/O in executor during startup
|
1202 |
+
loop = asyncio.get_running_loop()
|
1203 |
+
await loop.run_in_executor(executor, get_scraper) # This initializes the pool
|
1204 |
+
|
1205 |
+
# Validate critical environment variables and log warnings
|
1206 |
env_vars = get_env_vars()
|
1207 |
+
logger.info("Checking critical environment variables (Secrets)...")
|
1208 |
+
await health_check() # Run health check logic to log warnings
|
1209 |
+
|
1210 |
+
# Pre-connect async client? Optional, httpx handles connections on demand.
|
1211 |
+
# client = get_async_client()
|
1212 |
+
# await client.get("https://www.google.com") # Example warm-up call
|
1213 |
+
|
1214 |
+
logger.info("Startup complete. Server is ready to accept requests.")
|
1215 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1216 |
|
1217 |
@app.on_event("shutdown")
|
1218 |
async def shutdown_event():
|
1219 |
+
"""Tasks to run when the application shuts down."""
|
1220 |
+
logger.info("Application shutdown sequence initiated...")
|
1221 |
+
|
1222 |
+
# Close the httpx client gracefully
|
1223 |
client = get_async_client()
|
1224 |
await client.aclose()
|
1225 |
+
logger.info("HTTP client closed.")
|
1226 |
+
|
1227 |
+
# Shutdown the thread pool executor
|
1228 |
+
executor.shutdown(wait=True)
|
1229 |
+
logger.info("Thread pool executor shut down.")
|
1230 |
|
1231 |
+
# Clear scraper pool (optional, resources will be reclaimed anyway)
|
1232 |
scraper_pool.clear()
|
1233 |
+
logger.info("Scraper pool cleared.")
|
1234 |
|
1235 |
# Persist usage data
|
1236 |
+
# HF Space Note: Ensure file system is writable if saving locally.
|
1237 |
+
# Consider using HF Datasets or external DB for persistent storage.
|
1238 |
+
try:
|
1239 |
+
logger.info("Saving usage data...")
|
1240 |
+
usage_tracker.save_data()
|
1241 |
+
logger.info("Usage data saved.")
|
1242 |
+
except Exception as e:
|
1243 |
+
logger.error(f"Failed to save usage data during shutdown: {e}")
|
|
|
|
|
|
|
1244 |
|
1245 |
+
logger.info("Shutdown complete.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1247 |
|
1248 |
+
# --- Main Execution Block ---
|
1249 |
+
# HF Space Note: This block is mainly for local testing.
|
1250 |
+
# HF Spaces usually run the app using `uvicorn main:app --host 0.0.0.0 --port 7860` (or similar)
|
1251 |
+
# defined in the README metadata or a Procfile.
|
1252 |
if __name__ == "__main__":
|
1253 |
import uvicorn
|
1254 |
+
logger.info("Starting server locally with uvicorn...")
|
1255 |
+
# HF Space Note: Port 7860 is the default for HF Spaces. Host 0.0.0.0 is required.
|
1256 |
+
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
|