ParthSadaria commited on
Commit
396b35b
·
verified ·
1 Parent(s): c3d5a54

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +307 -168
main.py CHANGED
@@ -14,6 +14,16 @@ from typing import Optional, Dict, List, Any, Generator
14
  import asyncio
15
  from starlette.status import HTTP_403_FORBIDDEN
16
  import cloudscraper
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Load environment variables once at startup
19
  load_dotenv()
@@ -27,6 +37,16 @@ usage_tracker = UsageTracker()
27
 
28
  app = FastAPI()
29
 
 
 
 
 
 
 
 
 
 
 
30
  # Environment variables (cached)
31
  @lru_cache(maxsize=1)
32
  def get_env_vars():
@@ -41,8 +61,8 @@ def get_env_vars():
41
  'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
42
  }
43
 
44
- # Configuration for models
45
- mistral_models = [
46
  "mistral-large-latest",
47
  "pixtral-large-latest",
48
  "mistral-moderation-latest",
@@ -52,7 +72,7 @@ mistral_models = [
52
  "mistral-small-latest",
53
  "mistral-saba-latest",
54
  "codestral-latest"
55
- ]
56
 
57
  alternate_models = {
58
  "gpt-4o-mini",
@@ -78,23 +98,37 @@ class Payload(BaseModel):
78
  server_status = True
79
  available_model_ids: List[str] = []
80
 
81
- # Create a reusable httpx client
82
  @lru_cache(maxsize=1)
83
  def get_async_client():
84
- return httpx.AsyncClient(timeout=60.0)
 
 
 
 
 
 
 
85
 
86
- # API key validation
 
 
 
 
 
 
 
87
  async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
88
  if not api_key:
89
  raise HTTPException(
90
  status_code=HTTP_403_FORBIDDEN,
91
  detail="No API key provided"
92
  )
93
-
94
- # Clean the API key by removing 'Bearer ' if present
95
  if api_key.startswith('Bearer '):
96
  api_key = api_key[7:] # Remove 'Bearer ' prefix
97
-
98
  # Get API keys from environment
99
  valid_api_keys = get_env_vars()['api_keys']
100
  if not valid_api_keys or valid_api_keys == ['']:
@@ -102,96 +136,120 @@ async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
102
  status_code=HTTP_403_FORBIDDEN,
103
  detail="API keys not configured on server"
104
  )
105
-
106
- # Check if the provided key is valid
107
- if api_key not in valid_api_keys:
108
  raise HTTPException(
109
  status_code=HTTP_403_FORBIDDEN,
110
  detail="Invalid API key"
111
  )
112
-
113
  return True
114
 
115
- # Cache for models.json
116
  @lru_cache(maxsize=1)
117
- async def get_models():
118
  try:
119
  file_path = Path(__file__).parent / 'models.json'
120
  with open(file_path, 'r') as f:
121
  return json.load(f)
122
  except (FileNotFoundError, json.JSONDecodeError) as e:
123
- # Log the error but don't expose the exact error to users
124
  print(f"Error loading models.json: {str(e)}")
 
 
 
 
 
 
125
  raise HTTPException(status_code=500, detail="Error loading available models")
 
126
 
127
- # Searcher function with optimized streaming
128
- def generate_search(query: str, systemprompt: Optional[str] = None, stream: bool = True) -> Generator[str, None, None]:
129
- headers = {"User-Agent": ""}
130
-
131
- # Use the provided system prompt, or default to "Be Helpful and Friendly"
132
- system_message = systemprompt or "Be Helpful and Friendly"
133
-
134
- # Create the prompt history with the user query and system message
135
- prompt = [
136
- {"role": "user", "content": query},
137
- ]
138
-
139
- prompt.insert(0, {"content": system_message, "role": "system"})
140
-
141
- # Prepare the payload for the API request
142
- payload = {
143
- "is_vscode_extension": True,
144
- "message_history": prompt,
145
- "requested_model": "searchgpt",
146
- "user_input": prompt[-1]["content"],
147
- }
148
 
149
- # Get endpoint from environment
150
- secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
151
- if not secret_api_endpoint_3:
152
- raise ValueError("Search API endpoint not configured")
153
-
154
- # Send the request to the chat endpoint
155
- response = cloudscraper.create_scraper().post(
156
- secret_api_endpoint_3,
157
- headers=headers,
158
- json=payload,
159
- stream=True
160
- )
161
-
162
- streaming_text = ""
163
-
164
- # Process the streaming response
165
- for value in response.iter_lines(decode_unicode=True):
166
- if value.startswith("data: "):
167
- try:
168
- json_modified_value = json.loads(value[6:])
169
- content = json_modified_value.get("choices", [{}])[0].get("delta", {}).get("content", "")
170
-
171
- if content.strip(): # Only process non-empty content
172
- cleaned_response = {
173
- "created": json_modified_value.get("created"),
174
- "id": json_modified_value.get("id"),
175
- "model": "searchgpt",
176
- "object": "chat.completion",
177
- "choices": [
178
- {
179
- "message": {
180
- "content": content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  }
182
- }
183
- ]
184
- }
185
-
186
- if stream:
187
- yield f"data: {json.dumps(cleaned_response)}\n\n"
188
-
189
- streaming_text += content
190
- except json.JSONDecodeError:
191
- continue
 
 
 
 
192
 
193
- if not stream:
194
- yield streaming_text
 
 
 
 
 
 
 
 
 
195
 
196
  # Basic routes
197
  @app.get("/favicon.ico")
@@ -201,29 +259,21 @@ async def favicon():
201
 
202
  @app.get("/ping")
203
  async def ping():
204
- start_time = datetime.datetime.now()
205
- response_time = (datetime.datetime.now() - start_time).total_seconds()
206
- return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
207
 
208
  @app.get("/", response_class=HTMLResponse)
209
  async def root():
210
- file_path = "index.html"
211
- try:
212
- with open(file_path, "r") as file:
213
- html_content = file.read()
214
- return HTMLResponse(content=html_content)
215
- except FileNotFoundError:
216
  return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
 
217
 
218
  @app.get("/playground", response_class=HTMLResponse)
219
  async def playground():
220
- file_path = "playground.html"
221
- try:
222
- with open(file_path, "r") as file:
223
- html_content = file.read()
224
- return HTMLResponse(content=html_content)
225
- except FileNotFoundError:
226
  return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
 
227
 
228
  # Model routes
229
  @app.get("/api/v1/models")
@@ -239,15 +289,20 @@ async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optio
239
 
240
  usage_tracker.record_request(endpoint="/searchgpt")
241
 
 
 
242
  if stream:
 
 
 
 
243
  return StreamingResponse(
244
- generate_search(q, systemprompt=systemprompt, stream=True),
245
  media_type="text/event-stream"
246
  )
247
  else:
248
- # For non-streaming, collect the text and return as JSON response
249
- response_text = "".join([chunk for chunk in generate_search(q, systemprompt=systemprompt, stream=False)])
250
- return JSONResponse(content={"response": response_text})
251
 
252
  # Chat completion endpoint
253
  @app.post("/chat/completions")
@@ -260,15 +315,17 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
260
  content={"message": "Server is under maintenance. Please try again later."}
261
  )
262
 
263
- model_to_use = payload.model if payload.model else "gpt-4o-mini"
264
 
265
- # Validate model availability
266
- if available_model_ids and model_to_use not in available_model_ids:
267
  raise HTTPException(
268
  status_code=400,
269
  detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
270
  )
271
 
 
 
272
  usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
273
 
274
  # Prepare payload
@@ -278,7 +335,7 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
278
  # Get environment variables
279
  env_vars = get_env_vars()
280
 
281
- # Select the appropriate endpoint
282
  if model_to_use in mistral_models:
283
  endpoint = env_vars['mistral_api']
284
  custom_headers = {
@@ -291,13 +348,8 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
291
  endpoint = env_vars['secret_api_endpoint']
292
  custom_headers = {}
293
 
294
- # Current time and IP logging (with minimal data)
295
- current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
296
- ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
297
- print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model_to_use}")
298
-
299
- # Create scraper for each connection to avoid concurrency issues
300
- scraper = cloudscraper.create_scraper()
301
 
302
  async def stream_generator(payload_dict):
303
  try:
@@ -320,10 +372,25 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
320
  detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
321
  raise HTTPException(status_code=response.status_code, detail=detail)
322
 
323
- # Stream response lines to the client
 
 
 
 
324
  for line in response.iter_lines():
325
  if line:
326
- yield line.decode('utf-8') + "\n"
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  except Exception as e:
329
  # Use a generic error message that doesn't expose internal details
@@ -331,7 +398,14 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
331
 
332
  return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
333
 
334
- # Image generation endpoint
 
 
 
 
 
 
 
335
  @app.api_route("/images/generations", methods=["GET", "POST"])
336
  async def generate_image(
337
  prompt: Optional[str] = None,
@@ -391,35 +465,53 @@ async def generate_image(
391
  params['enhance'] = str(enhance).lower()
392
 
393
  try:
394
- async with httpx.AsyncClient(timeout=60.0) as client:
395
- response = await client.get(url, params=params, follow_redirects=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
- # Check for various error conditions
398
- if response.status_code != 200:
399
- error_messages = {
400
- 404: "Image generation service not found",
401
- 400: "Invalid parameters provided to image service",
402
- 429: "Too many requests to image service",
403
- }
404
- detail = error_messages.get(response.status_code, f"Image generation failed with status code {response.status_code}")
405
- raise HTTPException(status_code=response.status_code, detail=detail)
 
 
 
 
 
 
406
 
407
- # Verify content type
408
- content_type = response.headers.get('content-type', '')
409
- if not content_type.startswith('image/'):
410
- raise HTTPException(
411
- status_code=500,
412
- detail="Unexpected content type received from image service"
413
- )
414
-
415
- return StreamingResponse(
416
- response.iter_bytes(),
417
- media_type=content_type,
418
- headers={
419
- 'Cache-Control': 'no-cache',
420
- 'Pragma': 'no-cache'
421
- }
422
- )
423
 
424
  except httpx.TimeoutException:
425
  raise HTTPException(status_code=504, detail="Image generation request timed out")
@@ -428,37 +520,57 @@ async def generate_image(
428
  except Exception:
429
  raise HTTPException(status_code=500, detail="Unexpected error during image generation")
430
 
431
- # Usage statistics
432
  @app.get("/meme")
433
  async def get_meme():
434
  try:
435
- async with httpx.AsyncClient() as client:
436
- response = await client.get("https://meme-api.com/gimme")
437
- response_data = response.json()
438
-
439
- meme_url = response_data.get("url")
440
- if not meme_url:
441
- raise HTTPException(status_code=404, detail="No meme found")
442
-
443
- image_response = await client.get(meme_url, follow_redirects=True)
444
- return StreamingResponse(
445
- image_response.iter_bytes(),
446
- media_type=image_response.headers.get("content-type", "image/png")
447
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  except Exception:
449
  raise HTTPException(status_code=500, detail="Failed to retrieve meme")
450
 
 
 
 
 
 
451
  @app.get("/usage")
452
  async def get_usage(days: int = 7):
453
  """Retrieve usage statistics"""
454
- return usage_tracker.get_usage_summary(days)
455
 
456
- @app.get("/usage/page", response_class=HTMLResponse)
457
- async def usage_page():
458
- """Serve an HTML page showing usage statistics"""
459
- # Retrieve usage data
460
- usage_data = usage_tracker.get_usage_summary()
461
-
462
  # Model Usage Table Rows
463
  model_usage_rows = "\n".join([
464
  f"""
@@ -646,16 +758,28 @@ async def usage_page():
646
  </body>
647
  </html>
648
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
649
  return HTMLResponse(content=html_content)
650
 
651
- # Utility function for loading model IDs
652
  def load_model_ids(json_file_path):
653
  try:
654
  with open(json_file_path, 'r') as f:
655
  models_data = json.load(f)
656
- # Extract 'id' from each model object
657
- model_ids = [model['id'] for model in models_data if 'id' in model]
658
- return model_ids
659
  except Exception as e:
660
  print(f"Error loading model IDs: {str(e)}")
661
  return []
@@ -665,7 +789,10 @@ async def startup_event():
665
  global available_model_ids
666
  available_model_ids = load_model_ids("models.json")
667
  print(f"Loaded {len(available_model_ids)} model IDs")
668
- print("API started successfully")
 
 
 
669
 
670
  # Validate critical environment variables
671
  env_vars = get_env_vars()
@@ -680,7 +807,19 @@ async def startup_event():
680
 
681
  if missing_vars:
682
  print(f"WARNING: The following required environment variables are missing: {', '.join(missing_vars)}")
 
 
683
 
684
  if __name__ == "__main__":
685
  import uvicorn
686
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
14
  import asyncio
15
  from starlette.status import HTTP_403_FORBIDDEN
16
  import cloudscraper
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ import uvloop
19
+ from fastapi.middleware.gzip import GZipMiddleware
20
+ from starlette.middleware.cors import CORSMiddleware
21
+
22
+ # Enable uvloop for faster event loop
23
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
24
+
25
+ # Thread pool for CPU-bound operations
26
+ executor = ThreadPoolExecutor(max_workers=8)
27
 
28
  # Load environment variables once at startup
29
  load_dotenv()
 
37
 
38
  app = FastAPI()
39
 
40
+ # Add middleware for compression and CORS
41
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
42
+ app.add_middleware(
43
+ CORSMiddleware,
44
+ allow_origins=["*"],
45
+ allow_credentials=True,
46
+ allow_methods=["*"],
47
+ allow_headers=["*"],
48
+ )
49
+
50
  # Environment variables (cached)
51
  @lru_cache(maxsize=1)
52
  def get_env_vars():
 
61
  'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
62
  }
63
 
64
+ # Configuration for models - use sets for faster lookups
65
+ mistral_models = {
66
  "mistral-large-latest",
67
  "pixtral-large-latest",
68
  "mistral-moderation-latest",
 
72
  "mistral-small-latest",
73
  "mistral-saba-latest",
74
  "codestral-latest"
75
+ }
76
 
77
  alternate_models = {
78
  "gpt-4o-mini",
 
98
  server_status = True
99
  available_model_ids: List[str] = []
100
 
101
+ # Create a reusable httpx client pool with connection pooling
102
  @lru_cache(maxsize=1)
103
  def get_async_client():
104
+ return httpx.AsyncClient(
105
+ timeout=60.0,
106
+ limits=httpx.Limits(max_keepalive_connections=20, max_connections=100)
107
+ )
108
+
109
+ # Create a cloudscraper pool
110
+ scraper_pool = []
111
+ MAX_SCRAPERS = 10
112
 
113
+ def get_scraper():
114
+ if not scraper_pool:
115
+ for _ in range(MAX_SCRAPERS):
116
+ scraper_pool.append(cloudscraper.create_scraper())
117
+
118
+ return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS] # Simple round-robin
119
+
120
+ # API key validation - optimized to avoid string operations when possible
121
  async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
122
  if not api_key:
123
  raise HTTPException(
124
  status_code=HTTP_403_FORBIDDEN,
125
  detail="No API key provided"
126
  )
127
+
128
+ # Only clean if needed
129
  if api_key.startswith('Bearer '):
130
  api_key = api_key[7:] # Remove 'Bearer ' prefix
131
+
132
  # Get API keys from environment
133
  valid_api_keys = get_env_vars()['api_keys']
134
  if not valid_api_keys or valid_api_keys == ['']:
 
136
  status_code=HTTP_403_FORBIDDEN,
137
  detail="API keys not configured on server"
138
  )
139
+
140
+ # Fast check with set operation
141
+ if api_key not in set(valid_api_keys):
142
  raise HTTPException(
143
  status_code=HTTP_403_FORBIDDEN,
144
  detail="Invalid API key"
145
  )
146
+
147
  return True
148
 
149
+ # Pre-load and cache models.json
150
  @lru_cache(maxsize=1)
151
+ def load_models_data():
152
  try:
153
  file_path = Path(__file__).parent / 'models.json'
154
  with open(file_path, 'r') as f:
155
  return json.load(f)
156
  except (FileNotFoundError, json.JSONDecodeError) as e:
 
157
  print(f"Error loading models.json: {str(e)}")
158
+ return []
159
+
160
+ # Async wrapper for models data
161
+ async def get_models():
162
+ models_data = load_models_data()
163
+ if not models_data:
164
  raise HTTPException(status_code=500, detail="Error loading available models")
165
+ return models_data
166
 
167
+ # Searcher function with optimized streaming - moved to a separate thread
168
+ async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
169
+ loop = asyncio.get_running_loop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ def _generate_search():
172
+ headers = {"User-Agent": ""}
173
+
174
+ # Use the provided system prompt, or default to "Be Helpful and Friendly"
175
+ system_message = systemprompt or "Be Helpful and Friendly"
176
+
177
+ # Create the prompt history with the user query and system message
178
+ prompt = [
179
+ {"role": "user", "content": query},
180
+ ]
181
+
182
+ prompt.insert(0, {"content": system_message, "role": "system"})
183
+
184
+ # Prepare the payload for the API request
185
+ payload = {
186
+ "is_vscode_extension": True,
187
+ "message_history": prompt,
188
+ "requested_model": "searchgpt",
189
+ "user_input": prompt[-1]["content"],
190
+ }
191
+
192
+ # Get endpoint from environment
193
+ secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
194
+ if not secret_api_endpoint_3:
195
+ raise ValueError("Search API endpoint not configured")
196
+
197
+ # Send the request to the chat endpoint using a scraper from the pool
198
+ response = get_scraper().post(
199
+ secret_api_endpoint_3,
200
+ headers=headers,
201
+ json=payload,
202
+ stream=True
203
+ )
204
+
205
+ result = []
206
+ streaming_text = ""
207
+
208
+ # Process the streaming response
209
+ for value in response.iter_lines(decode_unicode=True):
210
+ if value.startswith("data: "):
211
+ try:
212
+ json_modified_value = json.loads(value[6:])
213
+ content = json_modified_value.get("choices", [{}])[0].get("delta", {}).get("content", "")
214
+
215
+ if content.strip(): # Only process non-empty content
216
+ cleaned_response = {
217
+ "created": json_modified_value.get("created"),
218
+ "id": json_modified_value.get("id"),
219
+ "model": "searchgpt",
220
+ "object": "chat.completion",
221
+ "choices": [
222
+ {
223
+ "message": {
224
+ "content": content
225
+ }
226
  }
227
+ ]
228
+ }
229
+
230
+ if stream:
231
+ result.append(f"data: {json.dumps(cleaned_response)}\n\n")
232
+
233
+ streaming_text += content
234
+ except json.JSONDecodeError:
235
+ continue
236
+
237
+ if not stream:
238
+ result.append(streaming_text)
239
+
240
+ return result
241
 
242
+ # Run in thread pool to avoid blocking the event loop
243
+ return await loop.run_in_executor(executor, _generate_search)
244
+
245
+ # Cache for frequently accessed static files
246
+ @lru_cache(maxsize=10)
247
+ def read_html_file(file_path):
248
+ try:
249
+ with open(file_path, "r") as file:
250
+ return file.read()
251
+ except FileNotFoundError:
252
+ return None
253
 
254
  # Basic routes
255
  @app.get("/favicon.ico")
 
259
 
260
  @app.get("/ping")
261
  async def ping():
262
+ return {"message": "pong", "response_time": "0.000000 seconds"}
 
 
263
 
264
  @app.get("/", response_class=HTMLResponse)
265
  async def root():
266
+ html_content = read_html_file("index.html")
267
+ if html_content is None:
 
 
 
 
268
  return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
269
+ return HTMLResponse(content=html_content)
270
 
271
  @app.get("/playground", response_class=HTMLResponse)
272
  async def playground():
273
+ html_content = read_html_file("playground.html")
274
+ if html_content is None:
 
 
 
 
275
  return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
276
+ return HTMLResponse(content=html_content)
277
 
278
  # Model routes
279
  @app.get("/api/v1/models")
 
289
 
290
  usage_tracker.record_request(endpoint="/searchgpt")
291
 
292
+ result = await generate_search_async(q, systemprompt=systemprompt, stream=stream)
293
+
294
  if stream:
295
+ async def stream_generator():
296
+ for chunk in result:
297
+ yield chunk
298
+
299
  return StreamingResponse(
300
+ stream_generator(),
301
  media_type="text/event-stream"
302
  )
303
  else:
304
+ # For non-streaming, return the collected text
305
+ return JSONResponse(content={"response": result[0] if result else ""})
 
306
 
307
  # Chat completion endpoint
308
  @app.post("/chat/completions")
 
315
  content={"message": "Server is under maintenance. Please try again later."}
316
  )
317
 
318
+ model_to_use = payload.model or "gpt-4o-mini"
319
 
320
+ # Validate model availability - fast lookup with set
321
+ if available_model_ids and model_to_use not in set(available_model_ids):
322
  raise HTTPException(
323
  status_code=400,
324
  detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
325
  )
326
 
327
+ # Log request without blocking
328
+ asyncio.create_task(log_request(request, model_to_use))
329
  usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
330
 
331
  # Prepare payload
 
335
  # Get environment variables
336
  env_vars = get_env_vars()
337
 
338
+ # Select the appropriate endpoint (fast lookup with sets)
339
  if model_to_use in mistral_models:
340
  endpoint = env_vars['mistral_api']
341
  custom_headers = {
 
348
  endpoint = env_vars['secret_api_endpoint']
349
  custom_headers = {}
350
 
351
+ # Get a scraper from the pool
352
+ scraper = get_scraper()
 
 
 
 
 
353
 
354
  async def stream_generator(payload_dict):
355
  try:
 
372
  detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
373
  raise HTTPException(status_code=response.status_code, detail=detail)
374
 
375
+ # Stream response lines to the client - use buffer for efficiency
376
+ buffer = []
377
+ buffer_size = 0
378
+ max_buffer = 8192 # 8KB buffer
379
+
380
  for line in response.iter_lines():
381
  if line:
382
+ decoded = line.decode('utf-8') + "\n"
383
+ buffer.append(decoded)
384
+ buffer_size += len(decoded)
385
+
386
+ if buffer_size >= max_buffer:
387
+ yield ''.join(buffer)
388
+ buffer = []
389
+ buffer_size = 0
390
+
391
+ # Flush remaining buffer
392
+ if buffer:
393
+ yield ''.join(buffer)
394
 
395
  except Exception as e:
396
  # Use a generic error message that doesn't expose internal details
 
398
 
399
  return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
400
 
401
+ # Asynchronous logging function
402
+ async def log_request(request, model):
403
+ # Get minimal data for logging
404
+ current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
405
+ ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
406
+ print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}")
407
+
408
+ # Image generation endpoint - optimized to use connection pool
409
  @app.api_route("/images/generations", methods=["GET", "POST"])
410
  async def generate_image(
411
  prompt: Optional[str] = None,
 
465
  params['enhance'] = str(enhance).lower()
466
 
467
  try:
468
+ # Use the shared httpx client for connection pooling
469
+ client = get_async_client()
470
+ response = await client.get(url, params=params, follow_redirects=True)
471
+
472
+ # Check for various error conditions
473
+ if response.status_code != 200:
474
+ error_messages = {
475
+ 404: "Image generation service not found",
476
+ 400: "Invalid parameters provided to image service",
477
+ 429: "Too many requests to image service",
478
+ }
479
+ detail = error_messages.get(response.status_code, f"Image generation failed with status code {response.status_code}")
480
+ raise HTTPException(status_code=response.status_code, detail=detail)
481
+
482
+ # Verify content type
483
+ content_type = response.headers.get('content-type', '')
484
+ if not content_type.startswith('image/'):
485
+ raise HTTPException(
486
+ status_code=500,
487
+ detail="Unexpected content type received from image service"
488
+ )
489
 
490
+ # Use larger chunks for streaming for better performance
491
+ async def stream_with_larger_chunks():
492
+ chunks = []
493
+ size = 0
494
+ async for chunk in response.aiter_bytes(chunk_size=16384): # Use 16KB chunks
495
+ chunks.append(chunk)
496
+ size += len(chunk)
497
+
498
+ if size >= 65536: # Yield every 64KB
499
+ yield b''.join(chunks)
500
+ chunks = []
501
+ size = 0
502
+
503
+ if chunks:
504
+ yield b''.join(chunks)
505
 
506
+ return StreamingResponse(
507
+ stream_with_larger_chunks(),
508
+ media_type=content_type,
509
+ headers={
510
+ 'Cache-Control': 'no-cache, no-store, must-revalidate',
511
+ 'Pragma': 'no-cache',
512
+ 'Expires': '0'
513
+ }
514
+ )
 
 
 
 
 
 
 
515
 
516
  except httpx.TimeoutException:
517
  raise HTTPException(status_code=504, detail="Image generation request timed out")
 
520
  except Exception:
521
  raise HTTPException(status_code=500, detail="Unexpected error during image generation")
522
 
523
+ # Meme endpoint with optimized networking
524
  @app.get("/meme")
525
  async def get_meme():
526
  try:
527
+ # Use the shared client for connection pooling
528
+ client = get_async_client()
529
+ response = await client.get("https://meme-api.com/gimme")
530
+ response_data = response.json()
531
+
532
+ meme_url = response_data.get("url")
533
+ if not meme_url:
534
+ raise HTTPException(status_code=404, detail="No meme found")
535
+
536
+ image_response = await client.get(meme_url, follow_redirects=True)
537
+
538
+ # Use larger chunks for streaming
539
+ async def stream_with_larger_chunks():
540
+ chunks = []
541
+ size = 0
542
+ async for chunk in image_response.aiter_bytes(chunk_size=16384):
543
+ chunks.append(chunk)
544
+ size += len(chunk)
545
+
546
+ if size >= 65536:
547
+ yield b''.join(chunks)
548
+ chunks = []
549
+ size = 0
550
+
551
+ if chunks:
552
+ yield b''.join(chunks)
553
+
554
+ return StreamingResponse(
555
+ stream_with_larger_chunks(),
556
+ media_type=image_response.headers.get("content-type", "image/png"),
557
+ headers={'Cache-Control': 'max-age=3600'} # Add caching
558
+ )
559
  except Exception:
560
  raise HTTPException(status_code=500, detail="Failed to retrieve meme")
561
 
562
+ # Cache usage statistics
563
+ @lru_cache(maxsize=10)
564
+ def get_usage_summary(days=7):
565
+ return usage_tracker.get_usage_summary(days)
566
+
567
  @app.get("/usage")
568
  async def get_usage(days: int = 7):
569
  """Retrieve usage statistics"""
570
+ return get_usage_summary(days)
571
 
572
+ # Generate HTML for usage page
573
+ def generate_usage_html(usage_data):
 
 
 
 
574
  # Model Usage Table Rows
575
  model_usage_rows = "\n".join([
576
  f"""
 
758
  </body>
759
  </html>
760
  """
761
+ return html_content
762
+
763
+ # Cache the usage page HTML
764
+ @lru_cache(maxsize=1)
765
+ def get_usage_page_html():
766
+ usage_data = get_usage_summary()
767
+ return generate_usage_html(usage_data)
768
+
769
+ @app.get("/usage/page", response_class=HTMLResponse)
770
+ async def usage_page():
771
+ """Serve an HTML page showing usage statistics"""
772
+ # Use cached HTML if available, regenerate if not
773
+ html_content = get_usage_page_html()
774
  return HTMLResponse(content=html_content)
775
 
776
+ # Utility function for loading model IDs - optimized to run once at startup
777
  def load_model_ids(json_file_path):
778
  try:
779
  with open(json_file_path, 'r') as f:
780
  models_data = json.load(f)
781
+ # Extract 'id' from each model object and use a set for fast lookups
782
+ return [model['id'] for model in models_data if 'id' in model]
 
783
  except Exception as e:
784
  print(f"Error loading model IDs: {str(e)}")
785
  return []
 
789
  global available_model_ids
790
  available_model_ids = load_model_ids("models.json")
791
  print(f"Loaded {len(available_model_ids)} model IDs")
792
+
793
+ # Preload scrapers
794
+ for _ in range(MAX_SCRAPERS):
795
+ scraper_pool.append(cloudscraper.create_scraper())
796
 
797
  # Validate critical environment variables
798
  env_vars = get_env_vars()
 
807
 
808
  if missing_vars:
809
  print(f"WARNING: The following required environment variables are missing: {', '.join(missing_vars)}")
810
+
811
+ print("API started successfully with high-performance optimizations")
812
 
813
  if __name__ == "__main__":
814
  import uvicorn
815
+ uvicorn.run(
816
+ app,
817
+ host="0.0.0.0",
818
+ port=8000,
819
+ workers=4, # Multiple workers for better CPU utilization
820
+ loop="uvloop", # Use uvloop for faster async operations
821
+ http="httptools", # Faster HTTP parsing
822
+ log_level="warning", # Reduce logging overhead
823
+ limit_concurrency=100, # Limit concurrent connections
824
+ timeout_keep_alive=5 # Reduce idle connection time
825
+ )