Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -14,6 +14,16 @@ from typing import Optional, Dict, List, Any, Generator
|
|
14 |
import asyncio
|
15 |
from starlette.status import HTTP_403_FORBIDDEN
|
16 |
import cloudscraper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# Load environment variables once at startup
|
19 |
load_dotenv()
|
@@ -27,6 +37,16 @@ usage_tracker = UsageTracker()
|
|
27 |
|
28 |
app = FastAPI()
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# Environment variables (cached)
|
31 |
@lru_cache(maxsize=1)
|
32 |
def get_env_vars():
|
@@ -41,8 +61,8 @@ def get_env_vars():
|
|
41 |
'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
|
42 |
}
|
43 |
|
44 |
-
# Configuration for models
|
45 |
-
mistral_models =
|
46 |
"mistral-large-latest",
|
47 |
"pixtral-large-latest",
|
48 |
"mistral-moderation-latest",
|
@@ -52,7 +72,7 @@ mistral_models = [
|
|
52 |
"mistral-small-latest",
|
53 |
"mistral-saba-latest",
|
54 |
"codestral-latest"
|
55 |
-
|
56 |
|
57 |
alternate_models = {
|
58 |
"gpt-4o-mini",
|
@@ -78,23 +98,37 @@ class Payload(BaseModel):
|
|
78 |
server_status = True
|
79 |
available_model_ids: List[str] = []
|
80 |
|
81 |
-
# Create a reusable httpx client
|
82 |
@lru_cache(maxsize=1)
|
83 |
def get_async_client():
|
84 |
-
return httpx.AsyncClient(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
|
88 |
if not api_key:
|
89 |
raise HTTPException(
|
90 |
status_code=HTTP_403_FORBIDDEN,
|
91 |
detail="No API key provided"
|
92 |
)
|
93 |
-
|
94 |
-
#
|
95 |
if api_key.startswith('Bearer '):
|
96 |
api_key = api_key[7:] # Remove 'Bearer ' prefix
|
97 |
-
|
98 |
# Get API keys from environment
|
99 |
valid_api_keys = get_env_vars()['api_keys']
|
100 |
if not valid_api_keys or valid_api_keys == ['']:
|
@@ -102,96 +136,120 @@ async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
|
|
102 |
status_code=HTTP_403_FORBIDDEN,
|
103 |
detail="API keys not configured on server"
|
104 |
)
|
105 |
-
|
106 |
-
#
|
107 |
-
if api_key not in valid_api_keys:
|
108 |
raise HTTPException(
|
109 |
status_code=HTTP_403_FORBIDDEN,
|
110 |
detail="Invalid API key"
|
111 |
)
|
112 |
-
|
113 |
return True
|
114 |
|
115 |
-
#
|
116 |
@lru_cache(maxsize=1)
|
117 |
-
|
118 |
try:
|
119 |
file_path = Path(__file__).parent / 'models.json'
|
120 |
with open(file_path, 'r') as f:
|
121 |
return json.load(f)
|
122 |
except (FileNotFoundError, json.JSONDecodeError) as e:
|
123 |
-
# Log the error but don't expose the exact error to users
|
124 |
print(f"Error loading models.json: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
raise HTTPException(status_code=500, detail="Error loading available models")
|
|
|
126 |
|
127 |
-
# Searcher function with optimized streaming
|
128 |
-
def
|
129 |
-
|
130 |
-
|
131 |
-
# Use the provided system prompt, or default to "Be Helpful and Friendly"
|
132 |
-
system_message = systemprompt or "Be Helpful and Friendly"
|
133 |
-
|
134 |
-
# Create the prompt history with the user query and system message
|
135 |
-
prompt = [
|
136 |
-
{"role": "user", "content": query},
|
137 |
-
]
|
138 |
-
|
139 |
-
prompt.insert(0, {"content": system_message, "role": "system"})
|
140 |
-
|
141 |
-
# Prepare the payload for the API request
|
142 |
-
payload = {
|
143 |
-
"is_vscode_extension": True,
|
144 |
-
"message_history": prompt,
|
145 |
-
"requested_model": "searchgpt",
|
146 |
-
"user_input": prompt[-1]["content"],
|
147 |
-
}
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
}
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
192 |
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
# Basic routes
|
197 |
@app.get("/favicon.ico")
|
@@ -201,29 +259,21 @@ async def favicon():
|
|
201 |
|
202 |
@app.get("/ping")
|
203 |
async def ping():
|
204 |
-
|
205 |
-
response_time = (datetime.datetime.now() - start_time).total_seconds()
|
206 |
-
return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
|
207 |
|
208 |
@app.get("/", response_class=HTMLResponse)
|
209 |
async def root():
|
210 |
-
|
211 |
-
|
212 |
-
with open(file_path, "r") as file:
|
213 |
-
html_content = file.read()
|
214 |
-
return HTMLResponse(content=html_content)
|
215 |
-
except FileNotFoundError:
|
216 |
return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
|
|
|
217 |
|
218 |
@app.get("/playground", response_class=HTMLResponse)
|
219 |
async def playground():
|
220 |
-
|
221 |
-
|
222 |
-
with open(file_path, "r") as file:
|
223 |
-
html_content = file.read()
|
224 |
-
return HTMLResponse(content=html_content)
|
225 |
-
except FileNotFoundError:
|
226 |
return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
|
|
|
227 |
|
228 |
# Model routes
|
229 |
@app.get("/api/v1/models")
|
@@ -239,15 +289,20 @@ async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optio
|
|
239 |
|
240 |
usage_tracker.record_request(endpoint="/searchgpt")
|
241 |
|
|
|
|
|
242 |
if stream:
|
|
|
|
|
|
|
|
|
243 |
return StreamingResponse(
|
244 |
-
|
245 |
media_type="text/event-stream"
|
246 |
)
|
247 |
else:
|
248 |
-
# For non-streaming,
|
249 |
-
|
250 |
-
return JSONResponse(content={"response": response_text})
|
251 |
|
252 |
# Chat completion endpoint
|
253 |
@app.post("/chat/completions")
|
@@ -260,15 +315,17 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
260 |
content={"message": "Server is under maintenance. Please try again later."}
|
261 |
)
|
262 |
|
263 |
-
model_to_use = payload.model
|
264 |
|
265 |
-
# Validate model availability
|
266 |
-
if available_model_ids and model_to_use not in available_model_ids:
|
267 |
raise HTTPException(
|
268 |
status_code=400,
|
269 |
detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
|
270 |
)
|
271 |
|
|
|
|
|
272 |
usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
|
273 |
|
274 |
# Prepare payload
|
@@ -278,7 +335,7 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
278 |
# Get environment variables
|
279 |
env_vars = get_env_vars()
|
280 |
|
281 |
-
# Select the appropriate endpoint
|
282 |
if model_to_use in mistral_models:
|
283 |
endpoint = env_vars['mistral_api']
|
284 |
custom_headers = {
|
@@ -291,13 +348,8 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
291 |
endpoint = env_vars['secret_api_endpoint']
|
292 |
custom_headers = {}
|
293 |
|
294 |
-
#
|
295 |
-
|
296 |
-
ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
|
297 |
-
print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model_to_use}")
|
298 |
-
|
299 |
-
# Create scraper for each connection to avoid concurrency issues
|
300 |
-
scraper = cloudscraper.create_scraper()
|
301 |
|
302 |
async def stream_generator(payload_dict):
|
303 |
try:
|
@@ -320,10 +372,25 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
320 |
detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
|
321 |
raise HTTPException(status_code=response.status_code, detail=detail)
|
322 |
|
323 |
-
# Stream response lines to the client
|
|
|
|
|
|
|
|
|
324 |
for line in response.iter_lines():
|
325 |
if line:
|
326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
|
328 |
except Exception as e:
|
329 |
# Use a generic error message that doesn't expose internal details
|
@@ -331,7 +398,14 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
331 |
|
332 |
return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
|
333 |
|
334 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
@app.api_route("/images/generations", methods=["GET", "POST"])
|
336 |
async def generate_image(
|
337 |
prompt: Optional[str] = None,
|
@@ -391,35 +465,53 @@ async def generate_image(
|
|
391 |
params['enhance'] = str(enhance).lower()
|
392 |
|
393 |
try:
|
394 |
-
|
395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
response.iter_bytes(),
|
417 |
-
media_type=content_type,
|
418 |
-
headers={
|
419 |
-
'Cache-Control': 'no-cache',
|
420 |
-
'Pragma': 'no-cache'
|
421 |
-
}
|
422 |
-
)
|
423 |
|
424 |
except httpx.TimeoutException:
|
425 |
raise HTTPException(status_code=504, detail="Image generation request timed out")
|
@@ -428,37 +520,57 @@ async def generate_image(
|
|
428 |
except Exception:
|
429 |
raise HTTPException(status_code=500, detail="Unexpected error during image generation")
|
430 |
|
431 |
-
#
|
432 |
@app.get("/meme")
|
433 |
async def get_meme():
|
434 |
try:
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
except Exception:
|
449 |
raise HTTPException(status_code=500, detail="Failed to retrieve meme")
|
450 |
|
|
|
|
|
|
|
|
|
|
|
451 |
@app.get("/usage")
|
452 |
async def get_usage(days: int = 7):
|
453 |
"""Retrieve usage statistics"""
|
454 |
-
return
|
455 |
|
456 |
-
|
457 |
-
|
458 |
-
"""Serve an HTML page showing usage statistics"""
|
459 |
-
# Retrieve usage data
|
460 |
-
usage_data = usage_tracker.get_usage_summary()
|
461 |
-
|
462 |
# Model Usage Table Rows
|
463 |
model_usage_rows = "\n".join([
|
464 |
f"""
|
@@ -646,16 +758,28 @@ async def usage_page():
|
|
646 |
</body>
|
647 |
</html>
|
648 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
return HTMLResponse(content=html_content)
|
650 |
|
651 |
-
# Utility function for loading model IDs
|
652 |
def load_model_ids(json_file_path):
|
653 |
try:
|
654 |
with open(json_file_path, 'r') as f:
|
655 |
models_data = json.load(f)
|
656 |
-
# Extract 'id' from each model object
|
657 |
-
|
658 |
-
return model_ids
|
659 |
except Exception as e:
|
660 |
print(f"Error loading model IDs: {str(e)}")
|
661 |
return []
|
@@ -665,7 +789,10 @@ async def startup_event():
|
|
665 |
global available_model_ids
|
666 |
available_model_ids = load_model_ids("models.json")
|
667 |
print(f"Loaded {len(available_model_ids)} model IDs")
|
668 |
-
|
|
|
|
|
|
|
669 |
|
670 |
# Validate critical environment variables
|
671 |
env_vars = get_env_vars()
|
@@ -680,7 +807,19 @@ async def startup_event():
|
|
680 |
|
681 |
if missing_vars:
|
682 |
print(f"WARNING: The following required environment variables are missing: {', '.join(missing_vars)}")
|
|
|
|
|
683 |
|
684 |
if __name__ == "__main__":
|
685 |
import uvicorn
|
686 |
-
uvicorn.run(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
import asyncio
|
15 |
from starlette.status import HTTP_403_FORBIDDEN
|
16 |
import cloudscraper
|
17 |
+
from concurrent.futures import ThreadPoolExecutor
|
18 |
+
import uvloop
|
19 |
+
from fastapi.middleware.gzip import GZipMiddleware
|
20 |
+
from starlette.middleware.cors import CORSMiddleware
|
21 |
+
|
22 |
+
# Enable uvloop for faster event loop
|
23 |
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
24 |
+
|
25 |
+
# Thread pool for CPU-bound operations
|
26 |
+
executor = ThreadPoolExecutor(max_workers=8)
|
27 |
|
28 |
# Load environment variables once at startup
|
29 |
load_dotenv()
|
|
|
37 |
|
38 |
app = FastAPI()
|
39 |
|
40 |
+
# Add middleware for compression and CORS
|
41 |
+
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
42 |
+
app.add_middleware(
|
43 |
+
CORSMiddleware,
|
44 |
+
allow_origins=["*"],
|
45 |
+
allow_credentials=True,
|
46 |
+
allow_methods=["*"],
|
47 |
+
allow_headers=["*"],
|
48 |
+
)
|
49 |
+
|
50 |
# Environment variables (cached)
|
51 |
@lru_cache(maxsize=1)
|
52 |
def get_env_vars():
|
|
|
61 |
'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
|
62 |
}
|
63 |
|
64 |
+
# Configuration for models - use sets for faster lookups
|
65 |
+
mistral_models = {
|
66 |
"mistral-large-latest",
|
67 |
"pixtral-large-latest",
|
68 |
"mistral-moderation-latest",
|
|
|
72 |
"mistral-small-latest",
|
73 |
"mistral-saba-latest",
|
74 |
"codestral-latest"
|
75 |
+
}
|
76 |
|
77 |
alternate_models = {
|
78 |
"gpt-4o-mini",
|
|
|
98 |
server_status = True
|
99 |
available_model_ids: List[str] = []
|
100 |
|
101 |
+
# Create a reusable httpx client pool with connection pooling
|
102 |
@lru_cache(maxsize=1)
|
103 |
def get_async_client():
|
104 |
+
return httpx.AsyncClient(
|
105 |
+
timeout=60.0,
|
106 |
+
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100)
|
107 |
+
)
|
108 |
+
|
109 |
+
# Create a cloudscraper pool
|
110 |
+
scraper_pool = []
|
111 |
+
MAX_SCRAPERS = 10
|
112 |
|
113 |
+
def get_scraper():
|
114 |
+
if not scraper_pool:
|
115 |
+
for _ in range(MAX_SCRAPERS):
|
116 |
+
scraper_pool.append(cloudscraper.create_scraper())
|
117 |
+
|
118 |
+
return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS] # Simple round-robin
|
119 |
+
|
120 |
+
# API key validation - optimized to avoid string operations when possible
|
121 |
async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
|
122 |
if not api_key:
|
123 |
raise HTTPException(
|
124 |
status_code=HTTP_403_FORBIDDEN,
|
125 |
detail="No API key provided"
|
126 |
)
|
127 |
+
|
128 |
+
# Only clean if needed
|
129 |
if api_key.startswith('Bearer '):
|
130 |
api_key = api_key[7:] # Remove 'Bearer ' prefix
|
131 |
+
|
132 |
# Get API keys from environment
|
133 |
valid_api_keys = get_env_vars()['api_keys']
|
134 |
if not valid_api_keys or valid_api_keys == ['']:
|
|
|
136 |
status_code=HTTP_403_FORBIDDEN,
|
137 |
detail="API keys not configured on server"
|
138 |
)
|
139 |
+
|
140 |
+
# Fast check with set operation
|
141 |
+
if api_key not in set(valid_api_keys):
|
142 |
raise HTTPException(
|
143 |
status_code=HTTP_403_FORBIDDEN,
|
144 |
detail="Invalid API key"
|
145 |
)
|
146 |
+
|
147 |
return True
|
148 |
|
149 |
+
# Pre-load and cache models.json
|
150 |
@lru_cache(maxsize=1)
|
151 |
+
def load_models_data():
|
152 |
try:
|
153 |
file_path = Path(__file__).parent / 'models.json'
|
154 |
with open(file_path, 'r') as f:
|
155 |
return json.load(f)
|
156 |
except (FileNotFoundError, json.JSONDecodeError) as e:
|
|
|
157 |
print(f"Error loading models.json: {str(e)}")
|
158 |
+
return []
|
159 |
+
|
160 |
+
# Async wrapper for models data
|
161 |
+
async def get_models():
|
162 |
+
models_data = load_models_data()
|
163 |
+
if not models_data:
|
164 |
raise HTTPException(status_code=500, detail="Error loading available models")
|
165 |
+
return models_data
|
166 |
|
167 |
+
# Searcher function with optimized streaming - moved to a separate thread
|
168 |
+
async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
|
169 |
+
loop = asyncio.get_running_loop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
+
def _generate_search():
|
172 |
+
headers = {"User-Agent": ""}
|
173 |
+
|
174 |
+
# Use the provided system prompt, or default to "Be Helpful and Friendly"
|
175 |
+
system_message = systemprompt or "Be Helpful and Friendly"
|
176 |
+
|
177 |
+
# Create the prompt history with the user query and system message
|
178 |
+
prompt = [
|
179 |
+
{"role": "user", "content": query},
|
180 |
+
]
|
181 |
+
|
182 |
+
prompt.insert(0, {"content": system_message, "role": "system"})
|
183 |
+
|
184 |
+
# Prepare the payload for the API request
|
185 |
+
payload = {
|
186 |
+
"is_vscode_extension": True,
|
187 |
+
"message_history": prompt,
|
188 |
+
"requested_model": "searchgpt",
|
189 |
+
"user_input": prompt[-1]["content"],
|
190 |
+
}
|
191 |
+
|
192 |
+
# Get endpoint from environment
|
193 |
+
secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
|
194 |
+
if not secret_api_endpoint_3:
|
195 |
+
raise ValueError("Search API endpoint not configured")
|
196 |
+
|
197 |
+
# Send the request to the chat endpoint using a scraper from the pool
|
198 |
+
response = get_scraper().post(
|
199 |
+
secret_api_endpoint_3,
|
200 |
+
headers=headers,
|
201 |
+
json=payload,
|
202 |
+
stream=True
|
203 |
+
)
|
204 |
+
|
205 |
+
result = []
|
206 |
+
streaming_text = ""
|
207 |
+
|
208 |
+
# Process the streaming response
|
209 |
+
for value in response.iter_lines(decode_unicode=True):
|
210 |
+
if value.startswith("data: "):
|
211 |
+
try:
|
212 |
+
json_modified_value = json.loads(value[6:])
|
213 |
+
content = json_modified_value.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
214 |
+
|
215 |
+
if content.strip(): # Only process non-empty content
|
216 |
+
cleaned_response = {
|
217 |
+
"created": json_modified_value.get("created"),
|
218 |
+
"id": json_modified_value.get("id"),
|
219 |
+
"model": "searchgpt",
|
220 |
+
"object": "chat.completion",
|
221 |
+
"choices": [
|
222 |
+
{
|
223 |
+
"message": {
|
224 |
+
"content": content
|
225 |
+
}
|
226 |
}
|
227 |
+
]
|
228 |
+
}
|
229 |
+
|
230 |
+
if stream:
|
231 |
+
result.append(f"data: {json.dumps(cleaned_response)}\n\n")
|
232 |
+
|
233 |
+
streaming_text += content
|
234 |
+
except json.JSONDecodeError:
|
235 |
+
continue
|
236 |
+
|
237 |
+
if not stream:
|
238 |
+
result.append(streaming_text)
|
239 |
+
|
240 |
+
return result
|
241 |
|
242 |
+
# Run in thread pool to avoid blocking the event loop
|
243 |
+
return await loop.run_in_executor(executor, _generate_search)
|
244 |
+
|
245 |
+
# Cache for frequently accessed static files
|
246 |
+
@lru_cache(maxsize=10)
|
247 |
+
def read_html_file(file_path):
|
248 |
+
try:
|
249 |
+
with open(file_path, "r") as file:
|
250 |
+
return file.read()
|
251 |
+
except FileNotFoundError:
|
252 |
+
return None
|
253 |
|
254 |
# Basic routes
|
255 |
@app.get("/favicon.ico")
|
|
|
259 |
|
260 |
@app.get("/ping")
|
261 |
async def ping():
|
262 |
+
return {"message": "pong", "response_time": "0.000000 seconds"}
|
|
|
|
|
263 |
|
264 |
@app.get("/", response_class=HTMLResponse)
|
265 |
async def root():
|
266 |
+
html_content = read_html_file("index.html")
|
267 |
+
if html_content is None:
|
|
|
|
|
|
|
|
|
268 |
return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
|
269 |
+
return HTMLResponse(content=html_content)
|
270 |
|
271 |
@app.get("/playground", response_class=HTMLResponse)
|
272 |
async def playground():
|
273 |
+
html_content = read_html_file("playground.html")
|
274 |
+
if html_content is None:
|
|
|
|
|
|
|
|
|
275 |
return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
|
276 |
+
return HTMLResponse(content=html_content)
|
277 |
|
278 |
# Model routes
|
279 |
@app.get("/api/v1/models")
|
|
|
289 |
|
290 |
usage_tracker.record_request(endpoint="/searchgpt")
|
291 |
|
292 |
+
result = await generate_search_async(q, systemprompt=systemprompt, stream=stream)
|
293 |
+
|
294 |
if stream:
|
295 |
+
async def stream_generator():
|
296 |
+
for chunk in result:
|
297 |
+
yield chunk
|
298 |
+
|
299 |
return StreamingResponse(
|
300 |
+
stream_generator(),
|
301 |
media_type="text/event-stream"
|
302 |
)
|
303 |
else:
|
304 |
+
# For non-streaming, return the collected text
|
305 |
+
return JSONResponse(content={"response": result[0] if result else ""})
|
|
|
306 |
|
307 |
# Chat completion endpoint
|
308 |
@app.post("/chat/completions")
|
|
|
315 |
content={"message": "Server is under maintenance. Please try again later."}
|
316 |
)
|
317 |
|
318 |
+
model_to_use = payload.model or "gpt-4o-mini"
|
319 |
|
320 |
+
# Validate model availability - fast lookup with set
|
321 |
+
if available_model_ids and model_to_use not in set(available_model_ids):
|
322 |
raise HTTPException(
|
323 |
status_code=400,
|
324 |
detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
|
325 |
)
|
326 |
|
327 |
+
# Log request without blocking
|
328 |
+
asyncio.create_task(log_request(request, model_to_use))
|
329 |
usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
|
330 |
|
331 |
# Prepare payload
|
|
|
335 |
# Get environment variables
|
336 |
env_vars = get_env_vars()
|
337 |
|
338 |
+
# Select the appropriate endpoint (fast lookup with sets)
|
339 |
if model_to_use in mistral_models:
|
340 |
endpoint = env_vars['mistral_api']
|
341 |
custom_headers = {
|
|
|
348 |
endpoint = env_vars['secret_api_endpoint']
|
349 |
custom_headers = {}
|
350 |
|
351 |
+
# Get a scraper from the pool
|
352 |
+
scraper = get_scraper()
|
|
|
|
|
|
|
|
|
|
|
353 |
|
354 |
async def stream_generator(payload_dict):
|
355 |
try:
|
|
|
372 |
detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
|
373 |
raise HTTPException(status_code=response.status_code, detail=detail)
|
374 |
|
375 |
+
# Stream response lines to the client - use buffer for efficiency
|
376 |
+
buffer = []
|
377 |
+
buffer_size = 0
|
378 |
+
max_buffer = 8192 # 8KB buffer
|
379 |
+
|
380 |
for line in response.iter_lines():
|
381 |
if line:
|
382 |
+
decoded = line.decode('utf-8') + "\n"
|
383 |
+
buffer.append(decoded)
|
384 |
+
buffer_size += len(decoded)
|
385 |
+
|
386 |
+
if buffer_size >= max_buffer:
|
387 |
+
yield ''.join(buffer)
|
388 |
+
buffer = []
|
389 |
+
buffer_size = 0
|
390 |
+
|
391 |
+
# Flush remaining buffer
|
392 |
+
if buffer:
|
393 |
+
yield ''.join(buffer)
|
394 |
|
395 |
except Exception as e:
|
396 |
# Use a generic error message that doesn't expose internal details
|
|
|
398 |
|
399 |
return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
|
400 |
|
401 |
+
# Asynchronous logging function
|
402 |
+
async def log_request(request, model):
|
403 |
+
# Get minimal data for logging
|
404 |
+
current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
|
405 |
+
ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
|
406 |
+
print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}")
|
407 |
+
|
408 |
+
# Image generation endpoint - optimized to use connection pool
|
409 |
@app.api_route("/images/generations", methods=["GET", "POST"])
|
410 |
async def generate_image(
|
411 |
prompt: Optional[str] = None,
|
|
|
465 |
params['enhance'] = str(enhance).lower()
|
466 |
|
467 |
try:
|
468 |
+
# Use the shared httpx client for connection pooling
|
469 |
+
client = get_async_client()
|
470 |
+
response = await client.get(url, params=params, follow_redirects=True)
|
471 |
+
|
472 |
+
# Check for various error conditions
|
473 |
+
if response.status_code != 200:
|
474 |
+
error_messages = {
|
475 |
+
404: "Image generation service not found",
|
476 |
+
400: "Invalid parameters provided to image service",
|
477 |
+
429: "Too many requests to image service",
|
478 |
+
}
|
479 |
+
detail = error_messages.get(response.status_code, f"Image generation failed with status code {response.status_code}")
|
480 |
+
raise HTTPException(status_code=response.status_code, detail=detail)
|
481 |
+
|
482 |
+
# Verify content type
|
483 |
+
content_type = response.headers.get('content-type', '')
|
484 |
+
if not content_type.startswith('image/'):
|
485 |
+
raise HTTPException(
|
486 |
+
status_code=500,
|
487 |
+
detail="Unexpected content type received from image service"
|
488 |
+
)
|
489 |
|
490 |
+
# Use larger chunks for streaming for better performance
|
491 |
+
async def stream_with_larger_chunks():
|
492 |
+
chunks = []
|
493 |
+
size = 0
|
494 |
+
async for chunk in response.aiter_bytes(chunk_size=16384): # Use 16KB chunks
|
495 |
+
chunks.append(chunk)
|
496 |
+
size += len(chunk)
|
497 |
+
|
498 |
+
if size >= 65536: # Yield every 64KB
|
499 |
+
yield b''.join(chunks)
|
500 |
+
chunks = []
|
501 |
+
size = 0
|
502 |
+
|
503 |
+
if chunks:
|
504 |
+
yield b''.join(chunks)
|
505 |
|
506 |
+
return StreamingResponse(
|
507 |
+
stream_with_larger_chunks(),
|
508 |
+
media_type=content_type,
|
509 |
+
headers={
|
510 |
+
'Cache-Control': 'no-cache, no-store, must-revalidate',
|
511 |
+
'Pragma': 'no-cache',
|
512 |
+
'Expires': '0'
|
513 |
+
}
|
514 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
|
516 |
except httpx.TimeoutException:
|
517 |
raise HTTPException(status_code=504, detail="Image generation request timed out")
|
|
|
520 |
except Exception:
|
521 |
raise HTTPException(status_code=500, detail="Unexpected error during image generation")
|
522 |
|
523 |
+
# Meme endpoint with optimized networking
|
524 |
@app.get("/meme")
|
525 |
async def get_meme():
|
526 |
try:
|
527 |
+
# Use the shared client for connection pooling
|
528 |
+
client = get_async_client()
|
529 |
+
response = await client.get("https://meme-api.com/gimme")
|
530 |
+
response_data = response.json()
|
531 |
+
|
532 |
+
meme_url = response_data.get("url")
|
533 |
+
if not meme_url:
|
534 |
+
raise HTTPException(status_code=404, detail="No meme found")
|
535 |
+
|
536 |
+
image_response = await client.get(meme_url, follow_redirects=True)
|
537 |
+
|
538 |
+
# Use larger chunks for streaming
|
539 |
+
async def stream_with_larger_chunks():
|
540 |
+
chunks = []
|
541 |
+
size = 0
|
542 |
+
async for chunk in image_response.aiter_bytes(chunk_size=16384):
|
543 |
+
chunks.append(chunk)
|
544 |
+
size += len(chunk)
|
545 |
+
|
546 |
+
if size >= 65536:
|
547 |
+
yield b''.join(chunks)
|
548 |
+
chunks = []
|
549 |
+
size = 0
|
550 |
+
|
551 |
+
if chunks:
|
552 |
+
yield b''.join(chunks)
|
553 |
+
|
554 |
+
return StreamingResponse(
|
555 |
+
stream_with_larger_chunks(),
|
556 |
+
media_type=image_response.headers.get("content-type", "image/png"),
|
557 |
+
headers={'Cache-Control': 'max-age=3600'} # Add caching
|
558 |
+
)
|
559 |
except Exception:
|
560 |
raise HTTPException(status_code=500, detail="Failed to retrieve meme")
|
561 |
|
562 |
+
# Cache usage statistics
|
563 |
+
@lru_cache(maxsize=10)
|
564 |
+
def get_usage_summary(days=7):
|
565 |
+
return usage_tracker.get_usage_summary(days)
|
566 |
+
|
567 |
@app.get("/usage")
|
568 |
async def get_usage(days: int = 7):
|
569 |
"""Retrieve usage statistics"""
|
570 |
+
return get_usage_summary(days)
|
571 |
|
572 |
+
# Generate HTML for usage page
|
573 |
+
def generate_usage_html(usage_data):
|
|
|
|
|
|
|
|
|
574 |
# Model Usage Table Rows
|
575 |
model_usage_rows = "\n".join([
|
576 |
f"""
|
|
|
758 |
</body>
|
759 |
</html>
|
760 |
"""
|
761 |
+
return html_content
|
762 |
+
|
763 |
+
# Cache the usage page HTML
|
764 |
+
@lru_cache(maxsize=1)
|
765 |
+
def get_usage_page_html():
|
766 |
+
usage_data = get_usage_summary()
|
767 |
+
return generate_usage_html(usage_data)
|
768 |
+
|
769 |
+
@app.get("/usage/page", response_class=HTMLResponse)
|
770 |
+
async def usage_page():
|
771 |
+
"""Serve an HTML page showing usage statistics"""
|
772 |
+
# Use cached HTML if available, regenerate if not
|
773 |
+
html_content = get_usage_page_html()
|
774 |
return HTMLResponse(content=html_content)
|
775 |
|
776 |
+
# Utility function for loading model IDs - optimized to run once at startup
|
777 |
def load_model_ids(json_file_path):
|
778 |
try:
|
779 |
with open(json_file_path, 'r') as f:
|
780 |
models_data = json.load(f)
|
781 |
+
# Extract 'id' from each model object and use a set for fast lookups
|
782 |
+
return [model['id'] for model in models_data if 'id' in model]
|
|
|
783 |
except Exception as e:
|
784 |
print(f"Error loading model IDs: {str(e)}")
|
785 |
return []
|
|
|
789 |
global available_model_ids
|
790 |
available_model_ids = load_model_ids("models.json")
|
791 |
print(f"Loaded {len(available_model_ids)} model IDs")
|
792 |
+
|
793 |
+
# Preload scrapers
|
794 |
+
for _ in range(MAX_SCRAPERS):
|
795 |
+
scraper_pool.append(cloudscraper.create_scraper())
|
796 |
|
797 |
# Validate critical environment variables
|
798 |
env_vars = get_env_vars()
|
|
|
807 |
|
808 |
if missing_vars:
|
809 |
print(f"WARNING: The following required environment variables are missing: {', '.join(missing_vars)}")
|
810 |
+
|
811 |
+
print("API started successfully with high-performance optimizations")
|
812 |
|
813 |
if __name__ == "__main__":
|
814 |
import uvicorn
|
815 |
+
uvicorn.run(
|
816 |
+
app,
|
817 |
+
host="0.0.0.0",
|
818 |
+
port=8000,
|
819 |
+
workers=4, # Multiple workers for better CPU utilization
|
820 |
+
loop="uvloop", # Use uvloop for faster async operations
|
821 |
+
http="httptools", # Faster HTTP parsing
|
822 |
+
log_level="warning", # Reduce logging overhead
|
823 |
+
limit_concurrency=100, # Limit concurrent connections
|
824 |
+
timeout_keep_alive=5 # Reduce idle connection time
|
825 |
+
)
|