Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -1,111 +1,47 @@
|
|
1 |
import os
|
2 |
from dotenv import load_dotenv
|
3 |
-
from fastapi import FastAPI, HTTPException, Request
|
4 |
-
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
|
|
|
5 |
from pydantic import BaseModel
|
6 |
import httpx
|
7 |
-
import hashlib
|
8 |
from functools import lru_cache
|
9 |
-
from pathlib import Path
|
10 |
-
import requests
|
11 |
-
import re
|
12 |
-
import cloudscraper
|
13 |
import json
|
14 |
-
from typing import Optional
|
15 |
import datetime
|
16 |
import time
|
17 |
-
from
|
18 |
-
|
19 |
-
from collections import defaultdict
|
20 |
-
from fastapi import Security #new
|
21 |
-
from fastapi import Depends
|
22 |
-
from fastapi.security import APIKeyHeader
|
23 |
-
from starlette.exceptions import HTTPException
|
24 |
from starlette.status import HTTP_403_FORBIDDEN
|
|
|
25 |
|
26 |
-
#
|
27 |
-
|
28 |
|
29 |
-
#
|
30 |
-
|
31 |
-
if not api_key:
|
32 |
-
raise HTTPException(
|
33 |
-
status_code=HTTP_403_FORBIDDEN,
|
34 |
-
detail="No API key provided"
|
35 |
-
)
|
36 |
-
|
37 |
-
# Clean the API key by removing 'Bearer ' if present
|
38 |
-
if api_key.startswith('Bearer '):
|
39 |
-
api_key = api_key[7:] # Remove 'Bearer ' prefix
|
40 |
-
|
41 |
-
# Get API keys from environment
|
42 |
-
api_keys_str = os.getenv('API_KEYS')
|
43 |
-
if not api_keys_str:
|
44 |
-
raise HTTPException(
|
45 |
-
status_code=HTTP_403_FORBIDDEN,
|
46 |
-
detail="API keys not configured on server"
|
47 |
-
)
|
48 |
-
|
49 |
-
valid_api_keys = api_keys_str.split(',')
|
50 |
-
|
51 |
-
# Check if the provided key is valid
|
52 |
-
if api_key not in valid_api_keys:
|
53 |
-
raise HTTPException(
|
54 |
-
status_code=HTTP_403_FORBIDDEN,
|
55 |
-
detail="Invalid API key"
|
56 |
-
)
|
57 |
-
|
58 |
-
return True
|
59 |
-
class RateLimitMiddleware(BaseHTTPMiddleware):
|
60 |
-
def __init__(self, app, requests_per_second: int = 2):
|
61 |
-
super().__init__(app)
|
62 |
-
self.requests_per_second = requests_per_second
|
63 |
-
self.last_request_time = defaultdict(float)
|
64 |
-
self.tokens = defaultdict(lambda: requests_per_second)
|
65 |
-
self.last_update = defaultdict(float)
|
66 |
-
|
67 |
-
async def dispatch(self, request: Request, call_next):
|
68 |
-
client_ip = request.client.host
|
69 |
-
current_time = time.time()
|
70 |
-
|
71 |
-
# Update tokens
|
72 |
-
time_passed = current_time - self.last_update[client_ip]
|
73 |
-
self.last_update[client_ip] = current_time
|
74 |
-
self.tokens[client_ip] = min(
|
75 |
-
self.requests_per_second,
|
76 |
-
self.tokens[client_ip] + time_passed * self.requests_per_second
|
77 |
-
)
|
78 |
-
|
79 |
-
# Check if request can be processed
|
80 |
-
if self.tokens[client_ip] < 1:
|
81 |
-
return JSONResponse(
|
82 |
-
status_code=429,
|
83 |
-
content={
|
84 |
-
"detail": "Too many requests. Please try again later.",
|
85 |
-
"retry_after": round((1 - self.tokens[client_ip]) / self.requests_per_second)
|
86 |
-
}
|
87 |
-
)
|
88 |
-
|
89 |
-
# Consume a token
|
90 |
-
self.tokens[client_ip] -= 1
|
91 |
-
|
92 |
-
# Process the request
|
93 |
-
response = await call_next(request)
|
94 |
-
return response
|
95 |
|
|
|
|
|
96 |
usage_tracker = UsageTracker()
|
97 |
-
load_dotenv() #idk why this shi
|
98 |
|
99 |
app = FastAPI()
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
mistral_models = [
|
110 |
"mistral-large-latest",
|
111 |
"pixtral-large-latest",
|
@@ -118,14 +54,6 @@ mistral_models = [
|
|
118 |
"codestral-latest"
|
119 |
]
|
120 |
|
121 |
-
image_endpoint = os.getenv("IMAGE_ENDPOINT")
|
122 |
-
ENDPOINT_ORIGIN = os.getenv('ENDPOINT_ORIGIN')
|
123 |
-
|
124 |
-
# Validate if the main secret API endpoints are set
|
125 |
-
if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoint_3:
|
126 |
-
raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
|
127 |
-
|
128 |
-
# Define models that should use the secondary endpoint
|
129 |
alternate_models = {
|
130 |
"gpt-4o-mini",
|
131 |
"deepseek-v3",
|
@@ -140,19 +68,64 @@ alternate_models = {
|
|
140 |
"hermes-3-llama-3.2-3b"
|
141 |
}
|
142 |
|
143 |
-
|
144 |
class Payload(BaseModel):
|
145 |
model: str
|
146 |
messages: list
|
147 |
stream: bool = False
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
headers = {"User-Agent": ""}
|
157 |
|
158 |
# Use the provided system prompt, or default to "Be Helpful and Friendly"
|
@@ -173,8 +146,18 @@ def generate_search(query: str, systemprompt: Optional[str] = None, stream: bool
|
|
173 |
"user_input": prompt[-1]["content"],
|
174 |
}
|
175 |
|
|
|
|
|
|
|
|
|
|
|
176 |
# Send the request to the chat endpoint
|
177 |
-
response =
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
streaming_text = ""
|
180 |
|
@@ -210,31 +193,21 @@ def generate_search(query: str, systemprompt: Optional[str] = None, stream: bool
|
|
210 |
if not stream:
|
211 |
yield streaming_text
|
212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
@app.get("/ping")
|
214 |
async def ping():
|
215 |
start_time = datetime.datetime.now()
|
216 |
response_time = (datetime.datetime.now() - start_time).total_seconds()
|
217 |
return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
|
218 |
-
|
219 |
-
@app.get("/searchgpt")
|
220 |
-
async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
|
221 |
-
if not q:
|
222 |
-
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
|
223 |
-
usage_tracker.record_request(endpoint="/searchgpt")
|
224 |
-
if stream:
|
225 |
-
return StreamingResponse(
|
226 |
-
generate_search(q, systemprompt=systemprompt, stream=True),
|
227 |
-
media_type="text/event-stream"
|
228 |
-
)
|
229 |
-
else:
|
230 |
-
# For non-streaming, collect the text and return as JSON response
|
231 |
-
response_text = "".join([chunk for chunk in generate_search(q, systemprompt=systemprompt, stream=False)])
|
232 |
-
return JSONResponse(content={"response": response_text})
|
233 |
@app.get("/", response_class=HTMLResponse)
|
234 |
async def root():
|
235 |
-
# Open and read the content of index.html (in the same folder as the app)
|
236 |
file_path = "index.html"
|
237 |
-
|
238 |
try:
|
239 |
with open(file_path, "r") as file:
|
240 |
html_content = file.read()
|
@@ -242,29 +215,55 @@ async def root():
|
|
242 |
except FileNotFoundError:
|
243 |
return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
|
244 |
|
245 |
-
|
|
|
|
|
246 |
try:
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
return json.load(f)
|
251 |
except FileNotFoundError:
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
@app.get("api/v1/models")
|
256 |
@app.get("/models")
|
257 |
async def return_models():
|
258 |
return await get_models()
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
@app.post("/chat/completions")
|
261 |
@app.post("/api/v1/chat/completions")
|
262 |
async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
|
263 |
# Check server status
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
model_to_use = payload.model if payload.model else "gpt-4o-mini"
|
265 |
|
266 |
# Validate model availability
|
267 |
-
if model_to_use not in available_model_ids:
|
268 |
raise HTTPException(
|
269 |
status_code=400,
|
270 |
detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
|
@@ -276,31 +275,28 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
276 |
payload_dict = payload.dict()
|
277 |
payload_dict["model"] = model_to_use
|
278 |
|
|
|
|
|
|
|
279 |
# Select the appropriate endpoint
|
280 |
if model_to_use in mistral_models:
|
281 |
-
endpoint = mistral_api
|
282 |
custom_headers = {
|
283 |
-
"Authorization": f"Bearer {mistral_key}"
|
284 |
}
|
285 |
elif model_to_use in alternate_models:
|
286 |
-
endpoint = secret_api_endpoint_2
|
287 |
custom_headers = {}
|
288 |
else:
|
289 |
-
endpoint = secret_api_endpoint
|
290 |
custom_headers = {}
|
291 |
|
292 |
-
# Current time and IP logging
|
293 |
current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
|
294 |
-
|
295 |
-
print(f"Time: {current_time},
|
296 |
-
print(payload_dict)
|
297 |
-
|
298 |
-
if not server_status:
|
299 |
-
return JSONResponse(
|
300 |
-
status_code=503,
|
301 |
-
content={"message": "Server is under maintenance. Please try again later."}
|
302 |
-
)
|
303 |
|
|
|
304 |
scraper = cloudscraper.create_scraper()
|
305 |
|
306 |
async def stream_generator(payload_dict):
|
@@ -314,55 +310,52 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
314 |
)
|
315 |
|
316 |
# Handle response errors
|
317 |
-
if response.status_code
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
raise HTTPException(status_code=500, detail="Server error. Try again later.")
|
327 |
|
328 |
# Stream response lines to the client
|
329 |
for line in response.iter_lines():
|
330 |
if line:
|
331 |
yield line.decode('utf-8') + "\n"
|
332 |
|
333 |
-
except requests.exceptions.RequestException as req_err:
|
334 |
-
print(response.text)
|
335 |
-
raise HTTPException(status_code=500, detail=f"Request failed: {req_err}")
|
336 |
except Exception as e:
|
337 |
-
|
338 |
-
raise HTTPException(status_code=500, detail=
|
339 |
|
340 |
return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
|
341 |
-
|
342 |
-
|
|
|
343 |
async def generate_image(
|
344 |
prompt: Optional[str] = None,
|
345 |
-
model: str = "flux",
|
346 |
seed: Optional[int] = None,
|
347 |
width: Optional[int] = None,
|
348 |
height: Optional[int] = None,
|
349 |
nologo: Optional[bool] = True,
|
350 |
private: Optional[bool] = None,
|
351 |
enhance: Optional[bool] = None,
|
352 |
-
request: Request = None,
|
353 |
authenticated: bool = Depends(verify_api_key)
|
354 |
):
|
355 |
-
"""
|
356 |
-
Generate an image using the Image Generation API.
|
357 |
-
"""
|
358 |
# Validate the image endpoint
|
|
|
359 |
if not image_endpoint:
|
360 |
raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.")
|
|
|
361 |
usage_tracker.record_request(endpoint="/images/generations")
|
|
|
362 |
# Handle GET and POST prompts
|
363 |
if request.method == "POST":
|
364 |
try:
|
365 |
-
body = await request.json()
|
366 |
prompt = body.get("prompt", "").strip()
|
367 |
if not prompt:
|
368 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
@@ -377,12 +370,11 @@ async def generate_image(
|
|
377 |
encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt')
|
378 |
|
379 |
# Construct the URL with the encoded prompt
|
380 |
-
base_url = image_endpoint.rstrip('/')
|
381 |
url = f"{base_url}/{encoded_prompt}"
|
382 |
|
383 |
# Prepare query parameters with validation
|
384 |
params = {}
|
385 |
-
|
386 |
if model and isinstance(model, str):
|
387 |
params['model'] = model
|
388 |
if seed is not None and isinstance(seed, int):
|
@@ -399,29 +391,25 @@ async def generate_image(
|
|
399 |
params['enhance'] = str(enhance).lower()
|
400 |
|
401 |
try:
|
402 |
-
|
403 |
-
async with httpx.AsyncClient(timeout=timeout) as client:
|
404 |
response = await client.get(url, params=params, follow_redirects=True)
|
405 |
|
406 |
# Check for various error conditions
|
407 |
-
if response.status_code
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
raise HTTPException(
|
415 |
-
status_code=response.status_code,
|
416 |
-
detail=f"Image generation failed with status code {response.status_code}"
|
417 |
-
)
|
418 |
|
419 |
# Verify content type
|
420 |
content_type = response.headers.get('content-type', '')
|
421 |
if not content_type.startswith('image/'):
|
422 |
raise HTTPException(
|
423 |
status_code=500,
|
424 |
-
detail=
|
425 |
)
|
426 |
|
427 |
return StreamingResponse(
|
@@ -435,35 +423,31 @@ async def generate_image(
|
|
435 |
|
436 |
except httpx.TimeoutException:
|
437 |
raise HTTPException(status_code=504, detail="Image generation request timed out")
|
438 |
-
except httpx.RequestError
|
439 |
-
raise HTTPException(status_code=500, detail=
|
440 |
-
except Exception
|
441 |
-
raise HTTPException(status_code=500, detail=
|
442 |
-
@app.get("/playground", response_class=HTMLResponse)
|
443 |
-
async def playground():
|
444 |
-
# Open and read the content of playground.html (in the same folder as the app)
|
445 |
-
file_path = "playground.html"
|
446 |
|
|
|
|
|
|
|
447 |
try:
|
448 |
-
with
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
except
|
462 |
-
|
463 |
-
|
464 |
-
except json.JSONDecodeError:
|
465 |
-
print("Error: Invalid JSON format in models.json.")
|
466 |
-
return []
|
467 |
@app.get("/usage")
|
468 |
async def get_usage(days: int = 7):
|
469 |
"""Retrieve usage statistics"""
|
@@ -474,6 +458,7 @@ async def usage_page():
|
|
474 |
"""Serve an HTML page showing usage statistics"""
|
475 |
# Retrieve usage data
|
476 |
usage_data = usage_tracker.get_usage_summary()
|
|
|
477 |
# Model Usage Table Rows
|
478 |
model_usage_rows = "\n".join([
|
479 |
f"""
|
@@ -485,6 +470,7 @@ async def usage_page():
|
|
485 |
</tr>
|
486 |
""" for model, model_data in usage_data['models'].items()
|
487 |
])
|
|
|
488 |
# API Endpoint Usage Table Rows
|
489 |
api_usage_rows = "\n".join([
|
490 |
f"""
|
@@ -496,6 +482,7 @@ async def usage_page():
|
|
496 |
</tr>
|
497 |
""" for endpoint, endpoint_data in usage_data['api_endpoints'].items()
|
498 |
])
|
|
|
499 |
# Daily Usage Table Rows
|
500 |
daily_usage_rows = "\n".join([
|
501 |
"\n".join([
|
@@ -660,39 +647,40 @@ async def usage_page():
|
|
660 |
</html>
|
661 |
"""
|
662 |
return HTMLResponse(content=html_content)
|
663 |
-
@app.get("/meme")
|
664 |
-
async def get_meme():
|
665 |
-
try:
|
666 |
-
response = requests.get("https://meme-api.com/gimme")
|
667 |
-
response_data = response.json()
|
668 |
-
|
669 |
-
meme_url = response_data.get("url")
|
670 |
-
|
671 |
-
if meme_url:
|
672 |
-
def stream_image():
|
673 |
-
with requests.get(meme_url, stream=True) as image_response:
|
674 |
-
for chunk in image_response.iter_content(chunk_size=1024):
|
675 |
-
yield chunk
|
676 |
-
|
677 |
-
return StreamingResponse(stream_image(), media_type="image/png")
|
678 |
-
else:
|
679 |
-
raise HTTPException(status_code=404, detail="No mimi found :(")
|
680 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
681 |
except Exception as e:
|
682 |
-
|
683 |
-
|
|
|
684 |
@app.on_event("startup")
|
685 |
async def startup_event():
|
686 |
global available_model_ids
|
687 |
available_model_ids = load_model_ids("models.json")
|
688 |
-
print(f"Loaded model IDs
|
689 |
-
print("API
|
690 |
-
print("GET /")
|
691 |
-
print("GET /models")
|
692 |
-
print("GET /searchgpt")
|
693 |
-
print("POST /chat/completions")
|
694 |
-
print("GET /images/generations")
|
695 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
696 |
if __name__ == "__main__":
|
697 |
import uvicorn
|
698 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
1 |
import os
|
2 |
from dotenv import load_dotenv
|
3 |
+
from fastapi import FastAPI, HTTPException, Request, Depends, Security
|
4 |
+
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
|
5 |
+
from fastapi.security import APIKeyHeader
|
6 |
from pydantic import BaseModel
|
7 |
import httpx
|
|
|
8 |
from functools import lru_cache
|
9 |
+
from pathlib import Path
|
|
|
|
|
|
|
10 |
import json
|
|
|
11 |
import datetime
|
12 |
import time
|
13 |
+
from typing import Optional, Dict, List, Any, Generator
|
14 |
+
import asyncio
|
|
|
|
|
|
|
|
|
|
|
15 |
from starlette.status import HTTP_403_FORBIDDEN
|
16 |
+
import cloudscraper
|
17 |
|
18 |
+
# Load environment variables once at startup
|
19 |
+
load_dotenv()
|
20 |
|
21 |
+
# API key security scheme
|
22 |
+
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
# Initialize usage tracker
|
25 |
+
from usage_tracker import UsageTracker
|
26 |
usage_tracker = UsageTracker()
|
|
|
27 |
|
28 |
app = FastAPI()
|
29 |
+
|
30 |
+
# Environment variables (cached)
|
31 |
+
@lru_cache(maxsize=1)
|
32 |
+
def get_env_vars():
|
33 |
+
return {
|
34 |
+
'api_keys': os.getenv('API_KEYS', '').split(','),
|
35 |
+
'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'),
|
36 |
+
'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
|
37 |
+
'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'),
|
38 |
+
'mistral_api': "https://api.mistral.ai",
|
39 |
+
'mistral_key': os.getenv('MISTRAL_KEY'),
|
40 |
+
'image_endpoint': os.getenv("IMAGE_ENDPOINT"),
|
41 |
+
'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
|
42 |
+
}
|
43 |
+
|
44 |
+
# Configuration for models
|
45 |
mistral_models = [
|
46 |
"mistral-large-latest",
|
47 |
"pixtral-large-latest",
|
|
|
54 |
"codestral-latest"
|
55 |
]
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
alternate_models = {
|
58 |
"gpt-4o-mini",
|
59 |
"deepseek-v3",
|
|
|
68 |
"hermes-3-llama-3.2-3b"
|
69 |
}
|
70 |
|
71 |
+
# Request payload model
|
72 |
class Payload(BaseModel):
|
73 |
model: str
|
74 |
messages: list
|
75 |
stream: bool = False
|
76 |
+
|
77 |
+
# Server status global variable
|
78 |
+
server_status = True
|
79 |
+
available_model_ids: List[str] = []
|
80 |
+
|
81 |
+
# Create a reusable httpx client
|
82 |
+
@lru_cache(maxsize=1)
|
83 |
+
def get_async_client():
|
84 |
+
return httpx.AsyncClient(timeout=60.0)
|
85 |
+
|
86 |
+
# API key validation
|
87 |
+
async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
|
88 |
+
if not api_key:
|
89 |
+
raise HTTPException(
|
90 |
+
status_code=HTTP_403_FORBIDDEN,
|
91 |
+
detail="No API key provided"
|
92 |
+
)
|
93 |
+
|
94 |
+
# Clean the API key by removing 'Bearer ' if present
|
95 |
+
if api_key.startswith('Bearer '):
|
96 |
+
api_key = api_key[7:] # Remove 'Bearer ' prefix
|
97 |
+
|
98 |
+
# Get API keys from environment
|
99 |
+
valid_api_keys = get_env_vars()['api_keys']
|
100 |
+
if not valid_api_keys or valid_api_keys == ['']:
|
101 |
+
raise HTTPException(
|
102 |
+
status_code=HTTP_403_FORBIDDEN,
|
103 |
+
detail="API keys not configured on server"
|
104 |
+
)
|
105 |
+
|
106 |
+
# Check if the provided key is valid
|
107 |
+
if api_key not in valid_api_keys:
|
108 |
+
raise HTTPException(
|
109 |
+
status_code=HTTP_403_FORBIDDEN,
|
110 |
+
detail="Invalid API key"
|
111 |
+
)
|
112 |
+
|
113 |
+
return True
|
114 |
+
|
115 |
+
# Cache for models.json
|
116 |
+
@lru_cache(maxsize=1)
|
117 |
+
async def get_models():
|
118 |
+
try:
|
119 |
+
file_path = Path(__file__).parent / 'models.json'
|
120 |
+
with open(file_path, 'r') as f:
|
121 |
+
return json.load(f)
|
122 |
+
except (FileNotFoundError, json.JSONDecodeError) as e:
|
123 |
+
# Log the error but don't expose the exact error to users
|
124 |
+
print(f"Error loading models.json: {str(e)}")
|
125 |
+
raise HTTPException(status_code=500, detail="Error loading available models")
|
126 |
+
|
127 |
+
# Searcher function with optimized streaming
|
128 |
+
def generate_search(query: str, systemprompt: Optional[str] = None, stream: bool = True) -> Generator[str, None, None]:
|
129 |
headers = {"User-Agent": ""}
|
130 |
|
131 |
# Use the provided system prompt, or default to "Be Helpful and Friendly"
|
|
|
146 |
"user_input": prompt[-1]["content"],
|
147 |
}
|
148 |
|
149 |
+
# Get endpoint from environment
|
150 |
+
secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
|
151 |
+
if not secret_api_endpoint_3:
|
152 |
+
raise ValueError("Search API endpoint not configured")
|
153 |
+
|
154 |
# Send the request to the chat endpoint
|
155 |
+
response = cloudscraper.create_scraper().post(
|
156 |
+
secret_api_endpoint_3,
|
157 |
+
headers=headers,
|
158 |
+
json=payload,
|
159 |
+
stream=True
|
160 |
+
)
|
161 |
|
162 |
streaming_text = ""
|
163 |
|
|
|
193 |
if not stream:
|
194 |
yield streaming_text
|
195 |
|
196 |
+
# Basic routes
|
197 |
+
@app.get("/favicon.ico")
|
198 |
+
async def favicon():
|
199 |
+
favicon_path = Path(__file__).parent / "favicon.ico"
|
200 |
+
return FileResponse(favicon_path, media_type="image/x-icon")
|
201 |
+
|
202 |
@app.get("/ping")
|
203 |
async def ping():
|
204 |
start_time = datetime.datetime.now()
|
205 |
response_time = (datetime.datetime.now() - start_time).total_seconds()
|
206 |
return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
|
207 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
@app.get("/", response_class=HTMLResponse)
|
209 |
async def root():
|
|
|
210 |
file_path = "index.html"
|
|
|
211 |
try:
|
212 |
with open(file_path, "r") as file:
|
213 |
html_content = file.read()
|
|
|
215 |
except FileNotFoundError:
|
216 |
return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
|
217 |
|
218 |
+
@app.get("/playground", response_class=HTMLResponse)
|
219 |
+
async def playground():
|
220 |
+
file_path = "playground.html"
|
221 |
try:
|
222 |
+
with open(file_path, "r") as file:
|
223 |
+
html_content = file.read()
|
224 |
+
return HTMLResponse(content=html_content)
|
|
|
225 |
except FileNotFoundError:
|
226 |
+
return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
|
227 |
+
|
228 |
+
# Model routes
|
229 |
+
@app.get("/api/v1/models")
|
230 |
@app.get("/models")
|
231 |
async def return_models():
|
232 |
return await get_models()
|
233 |
+
|
234 |
+
# Search routes
|
235 |
+
@app.get("/searchgpt")
|
236 |
+
async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
|
237 |
+
if not q:
|
238 |
+
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
|
239 |
+
|
240 |
+
usage_tracker.record_request(endpoint="/searchgpt")
|
241 |
+
|
242 |
+
if stream:
|
243 |
+
return StreamingResponse(
|
244 |
+
generate_search(q, systemprompt=systemprompt, stream=True),
|
245 |
+
media_type="text/event-stream"
|
246 |
+
)
|
247 |
+
else:
|
248 |
+
# For non-streaming, collect the text and return as JSON response
|
249 |
+
response_text = "".join([chunk for chunk in generate_search(q, systemprompt=systemprompt, stream=False)])
|
250 |
+
return JSONResponse(content={"response": response_text})
|
251 |
+
|
252 |
+
# Chat completion endpoint
|
253 |
@app.post("/chat/completions")
|
254 |
@app.post("/api/v1/chat/completions")
|
255 |
async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
|
256 |
# Check server status
|
257 |
+
if not server_status:
|
258 |
+
return JSONResponse(
|
259 |
+
status_code=503,
|
260 |
+
content={"message": "Server is under maintenance. Please try again later."}
|
261 |
+
)
|
262 |
+
|
263 |
model_to_use = payload.model if payload.model else "gpt-4o-mini"
|
264 |
|
265 |
# Validate model availability
|
266 |
+
if available_model_ids and model_to_use not in available_model_ids:
|
267 |
raise HTTPException(
|
268 |
status_code=400,
|
269 |
detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
|
|
|
275 |
payload_dict = payload.dict()
|
276 |
payload_dict["model"] = model_to_use
|
277 |
|
278 |
+
# Get environment variables
|
279 |
+
env_vars = get_env_vars()
|
280 |
+
|
281 |
# Select the appropriate endpoint
|
282 |
if model_to_use in mistral_models:
|
283 |
+
endpoint = env_vars['mistral_api']
|
284 |
custom_headers = {
|
285 |
+
"Authorization": f"Bearer {env_vars['mistral_key']}"
|
286 |
}
|
287 |
elif model_to_use in alternate_models:
|
288 |
+
endpoint = env_vars['secret_api_endpoint_2']
|
289 |
custom_headers = {}
|
290 |
else:
|
291 |
+
endpoint = env_vars['secret_api_endpoint']
|
292 |
custom_headers = {}
|
293 |
|
294 |
+
# Current time and IP logging (with minimal data)
|
295 |
current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
|
296 |
+
ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
|
297 |
+
print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model_to_use}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
299 |
+
# Create scraper for each connection to avoid concurrency issues
|
300 |
scraper = cloudscraper.create_scraper()
|
301 |
|
302 |
async def stream_generator(payload_dict):
|
|
|
310 |
)
|
311 |
|
312 |
# Handle response errors
|
313 |
+
if response.status_code >= 400:
|
314 |
+
error_messages = {
|
315 |
+
422: "Unprocessable entity. Check your payload.",
|
316 |
+
400: "Bad request. Verify input data.",
|
317 |
+
403: "Forbidden. You do not have access to this resource.",
|
318 |
+
404: "The requested resource was not found.",
|
319 |
+
}
|
320 |
+
detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
|
321 |
+
raise HTTPException(status_code=response.status_code, detail=detail)
|
|
|
322 |
|
323 |
# Stream response lines to the client
|
324 |
for line in response.iter_lines():
|
325 |
if line:
|
326 |
yield line.decode('utf-8') + "\n"
|
327 |
|
|
|
|
|
|
|
328 |
except Exception as e:
|
329 |
+
# Use a generic error message that doesn't expose internal details
|
330 |
+
raise HTTPException(status_code=500, detail="An error occurred while processing your request")
|
331 |
|
332 |
return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
|
333 |
+
|
334 |
+
# Image generation endpoint
|
335 |
+
@app.api_route("/images/generations", methods=["GET", "POST"])
|
336 |
async def generate_image(
|
337 |
prompt: Optional[str] = None,
|
338 |
+
model: str = "flux",
|
339 |
seed: Optional[int] = None,
|
340 |
width: Optional[int] = None,
|
341 |
height: Optional[int] = None,
|
342 |
nologo: Optional[bool] = True,
|
343 |
private: Optional[bool] = None,
|
344 |
enhance: Optional[bool] = None,
|
345 |
+
request: Request = None,
|
346 |
authenticated: bool = Depends(verify_api_key)
|
347 |
):
|
|
|
|
|
|
|
348 |
# Validate the image endpoint
|
349 |
+
image_endpoint = get_env_vars()['image_endpoint']
|
350 |
if not image_endpoint:
|
351 |
raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.")
|
352 |
+
|
353 |
usage_tracker.record_request(endpoint="/images/generations")
|
354 |
+
|
355 |
# Handle GET and POST prompts
|
356 |
if request.method == "POST":
|
357 |
try:
|
358 |
+
body = await request.json()
|
359 |
prompt = body.get("prompt", "").strip()
|
360 |
if not prompt:
|
361 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
|
|
370 |
encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt')
|
371 |
|
372 |
# Construct the URL with the encoded prompt
|
373 |
+
base_url = image_endpoint.rstrip('/')
|
374 |
url = f"{base_url}/{encoded_prompt}"
|
375 |
|
376 |
# Prepare query parameters with validation
|
377 |
params = {}
|
|
|
378 |
if model and isinstance(model, str):
|
379 |
params['model'] = model
|
380 |
if seed is not None and isinstance(seed, int):
|
|
|
391 |
params['enhance'] = str(enhance).lower()
|
392 |
|
393 |
try:
|
394 |
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
|
395 |
response = await client.get(url, params=params, follow_redirects=True)
|
396 |
|
397 |
# Check for various error conditions
|
398 |
+
if response.status_code != 200:
|
399 |
+
error_messages = {
|
400 |
+
404: "Image generation service not found",
|
401 |
+
400: "Invalid parameters provided to image service",
|
402 |
+
429: "Too many requests to image service",
|
403 |
+
}
|
404 |
+
detail = error_messages.get(response.status_code, f"Image generation failed with status code {response.status_code}")
|
405 |
+
raise HTTPException(status_code=response.status_code, detail=detail)
|
|
|
|
|
|
|
406 |
|
407 |
# Verify content type
|
408 |
content_type = response.headers.get('content-type', '')
|
409 |
if not content_type.startswith('image/'):
|
410 |
raise HTTPException(
|
411 |
status_code=500,
|
412 |
+
detail="Unexpected content type received from image service"
|
413 |
)
|
414 |
|
415 |
return StreamingResponse(
|
|
|
423 |
|
424 |
except httpx.TimeoutException:
|
425 |
raise HTTPException(status_code=504, detail="Image generation request timed out")
|
426 |
+
except httpx.RequestError:
|
427 |
+
raise HTTPException(status_code=500, detail="Failed to contact image service")
|
428 |
+
except Exception:
|
429 |
+
raise HTTPException(status_code=500, detail="Unexpected error during image generation")
|
|
|
|
|
|
|
|
|
430 |
|
431 |
+
# Usage statistics
|
432 |
+
@app.get("/meme")
|
433 |
+
async def get_meme():
|
434 |
try:
|
435 |
+
async with httpx.AsyncClient() as client:
|
436 |
+
response = await client.get("https://meme-api.com/gimme")
|
437 |
+
response_data = response.json()
|
438 |
+
|
439 |
+
meme_url = response_data.get("url")
|
440 |
+
if not meme_url:
|
441 |
+
raise HTTPException(status_code=404, detail="No meme found")
|
442 |
+
|
443 |
+
image_response = await client.get(meme_url, follow_redirects=True)
|
444 |
+
return StreamingResponse(
|
445 |
+
image_response.iter_bytes(),
|
446 |
+
media_type=image_response.headers.get("content-type", "image/png")
|
447 |
+
)
|
448 |
+
except Exception:
|
449 |
+
raise HTTPException(status_code=500, detail="Failed to retrieve meme")
|
450 |
+
|
|
|
|
|
|
|
451 |
@app.get("/usage")
|
452 |
async def get_usage(days: int = 7):
|
453 |
"""Retrieve usage statistics"""
|
|
|
458 |
"""Serve an HTML page showing usage statistics"""
|
459 |
# Retrieve usage data
|
460 |
usage_data = usage_tracker.get_usage_summary()
|
461 |
+
|
462 |
# Model Usage Table Rows
|
463 |
model_usage_rows = "\n".join([
|
464 |
f"""
|
|
|
470 |
</tr>
|
471 |
""" for model, model_data in usage_data['models'].items()
|
472 |
])
|
473 |
+
|
474 |
# API Endpoint Usage Table Rows
|
475 |
api_usage_rows = "\n".join([
|
476 |
f"""
|
|
|
482 |
</tr>
|
483 |
""" for endpoint, endpoint_data in usage_data['api_endpoints'].items()
|
484 |
])
|
485 |
+
|
486 |
# Daily Usage Table Rows
|
487 |
daily_usage_rows = "\n".join([
|
488 |
"\n".join([
|
|
|
647 |
</html>
|
648 |
"""
|
649 |
return HTMLResponse(content=html_content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
650 |
|
651 |
+
# Utility function for loading model IDs
|
652 |
+
def load_model_ids(json_file_path):
|
653 |
+
try:
|
654 |
+
with open(json_file_path, 'r') as f:
|
655 |
+
models_data = json.load(f)
|
656 |
+
# Extract 'id' from each model object
|
657 |
+
model_ids = [model['id'] for model in models_data if 'id' in model]
|
658 |
+
return model_ids
|
659 |
except Exception as e:
|
660 |
+
print(f"Error loading model IDs: {str(e)}")
|
661 |
+
return []
|
662 |
+
|
663 |
@app.on_event("startup")
|
664 |
async def startup_event():
|
665 |
global available_model_ids
|
666 |
available_model_ids = load_model_ids("models.json")
|
667 |
+
print(f"Loaded {len(available_model_ids)} model IDs")
|
668 |
+
print("API started successfully")
|
|
|
|
|
|
|
|
|
|
|
669 |
|
670 |
+
# Validate critical environment variables
|
671 |
+
env_vars = get_env_vars()
|
672 |
+
missing_vars = []
|
673 |
+
|
674 |
+
if not env_vars['secret_api_endpoint']:
|
675 |
+
missing_vars.append('SECRET_API_ENDPOINT')
|
676 |
+
if not env_vars['secret_api_endpoint_2']:
|
677 |
+
missing_vars.append('SECRET_API_ENDPOINT_2')
|
678 |
+
if not env_vars['secret_api_endpoint_3']:
|
679 |
+
missing_vars.append('SECRET_API_ENDPOINT_3')
|
680 |
+
|
681 |
+
if missing_vars:
|
682 |
+
print(f"WARNING: The following required environment variables are missing: {', '.join(missing_vars)}")
|
683 |
+
|
684 |
if __name__ == "__main__":
|
685 |
import uvicorn
|
686 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|