ParthSadaria commited on
Commit
c3d5a54
·
verified ·
1 Parent(s): 62457ce

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +234 -246
main.py CHANGED
@@ -1,111 +1,47 @@
1
  import os
2
  from dotenv import load_dotenv
3
- from fastapi import FastAPI, HTTPException, Request
4
- from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
 
5
  from pydantic import BaseModel
6
  import httpx
7
- import hashlib
8
  from functools import lru_cache
9
- from pathlib import Path # Import Path from pathlib
10
- import requests
11
- import re
12
- import cloudscraper
13
  import json
14
- from typing import Optional
15
  import datetime
16
  import time
17
- from usage_tracker import UsageTracker
18
- from starlette.middleware.base import BaseHTTPMiddleware
19
- from collections import defaultdict
20
- from fastapi import Security #new
21
- from fastapi import Depends
22
- from fastapi.security import APIKeyHeader
23
- from starlette.exceptions import HTTPException
24
  from starlette.status import HTTP_403_FORBIDDEN
 
25
 
26
- # API key header scheme
27
- api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
28
 
29
- # Function to validate API key
30
- async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
31
- if not api_key:
32
- raise HTTPException(
33
- status_code=HTTP_403_FORBIDDEN,
34
- detail="No API key provided"
35
- )
36
-
37
- # Clean the API key by removing 'Bearer ' if present
38
- if api_key.startswith('Bearer '):
39
- api_key = api_key[7:] # Remove 'Bearer ' prefix
40
-
41
- # Get API keys from environment
42
- api_keys_str = os.getenv('API_KEYS')
43
- if not api_keys_str:
44
- raise HTTPException(
45
- status_code=HTTP_403_FORBIDDEN,
46
- detail="API keys not configured on server"
47
- )
48
-
49
- valid_api_keys = api_keys_str.split(',')
50
-
51
- # Check if the provided key is valid
52
- if api_key not in valid_api_keys:
53
- raise HTTPException(
54
- status_code=HTTP_403_FORBIDDEN,
55
- detail="Invalid API key"
56
- )
57
-
58
- return True
59
- class RateLimitMiddleware(BaseHTTPMiddleware):
60
- def __init__(self, app, requests_per_second: int = 2):
61
- super().__init__(app)
62
- self.requests_per_second = requests_per_second
63
- self.last_request_time = defaultdict(float)
64
- self.tokens = defaultdict(lambda: requests_per_second)
65
- self.last_update = defaultdict(float)
66
-
67
- async def dispatch(self, request: Request, call_next):
68
- client_ip = request.client.host
69
- current_time = time.time()
70
-
71
- # Update tokens
72
- time_passed = current_time - self.last_update[client_ip]
73
- self.last_update[client_ip] = current_time
74
- self.tokens[client_ip] = min(
75
- self.requests_per_second,
76
- self.tokens[client_ip] + time_passed * self.requests_per_second
77
- )
78
-
79
- # Check if request can be processed
80
- if self.tokens[client_ip] < 1:
81
- return JSONResponse(
82
- status_code=429,
83
- content={
84
- "detail": "Too many requests. Please try again later.",
85
- "retry_after": round((1 - self.tokens[client_ip]) / self.requests_per_second)
86
- }
87
- )
88
-
89
- # Consume a token
90
- self.tokens[client_ip] -= 1
91
-
92
- # Process the request
93
- response = await call_next(request)
94
- return response
95
 
 
 
96
  usage_tracker = UsageTracker()
97
- load_dotenv() #idk why this shi
98
 
99
  app = FastAPI()
100
- app.add_middleware(RateLimitMiddleware, requests_per_second=2)
101
-
102
- # Get API keys and secret endpoint from environment variables
103
- # valid_api_keys = api_keys_str.split(',') if api_keys_str else []
104
- secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
105
- secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
106
- secret_api_endpoint_3 = os.getenv('SECRET_API_ENDPOINT_3') # New endpoint for searchgpt
107
- mistral_api = "https://api.mistral.ai"
108
- mistral_key = os.getenv('MISTRAL_KEY')
 
 
 
 
 
 
 
109
  mistral_models = [
110
  "mistral-large-latest",
111
  "pixtral-large-latest",
@@ -118,14 +54,6 @@ mistral_models = [
118
  "codestral-latest"
119
  ]
120
 
121
- image_endpoint = os.getenv("IMAGE_ENDPOINT")
122
- ENDPOINT_ORIGIN = os.getenv('ENDPOINT_ORIGIN')
123
-
124
- # Validate if the main secret API endpoints are set
125
- if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoint_3:
126
- raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
127
-
128
- # Define models that should use the secondary endpoint
129
  alternate_models = {
130
  "gpt-4o-mini",
131
  "deepseek-v3",
@@ -140,19 +68,64 @@ alternate_models = {
140
  "hermes-3-llama-3.2-3b"
141
  }
142
 
143
- available_model_ids = []
144
  class Payload(BaseModel):
145
  model: str
146
  messages: list
147
  stream: bool = False
148
-
149
- @app.get("/favicon.ico")
150
- async def favicon():
151
- # The favicon.ico file is in the same directory as the app
152
- favicon_path = Path(__file__).parent / "favicon.ico"
153
- return FileResponse(favicon_path, media_type="image/x-icon")
154
-
155
- def generate_search(query: str, systemprompt: Optional[str] = None, stream: bool = True) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  headers = {"User-Agent": ""}
157
 
158
  # Use the provided system prompt, or default to "Be Helpful and Friendly"
@@ -173,8 +146,18 @@ def generate_search(query: str, systemprompt: Optional[str] = None, stream: bool
173
  "user_input": prompt[-1]["content"],
174
  }
175
 
 
 
 
 
 
176
  # Send the request to the chat endpoint
177
- response = requests.post(secret_api_endpoint_3, headers=headers, json=payload, stream=True)
 
 
 
 
 
178
 
179
  streaming_text = ""
180
 
@@ -210,31 +193,21 @@ def generate_search(query: str, systemprompt: Optional[str] = None, stream: bool
210
  if not stream:
211
  yield streaming_text
212
 
 
 
 
 
 
 
213
  @app.get("/ping")
214
  async def ping():
215
  start_time = datetime.datetime.now()
216
  response_time = (datetime.datetime.now() - start_time).total_seconds()
217
  return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
218
-
219
- @app.get("/searchgpt")
220
- async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
221
- if not q:
222
- raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
223
- usage_tracker.record_request(endpoint="/searchgpt")
224
- if stream:
225
- return StreamingResponse(
226
- generate_search(q, systemprompt=systemprompt, stream=True),
227
- media_type="text/event-stream"
228
- )
229
- else:
230
- # For non-streaming, collect the text and return as JSON response
231
- response_text = "".join([chunk for chunk in generate_search(q, systemprompt=systemprompt, stream=False)])
232
- return JSONResponse(content={"response": response_text})
233
  @app.get("/", response_class=HTMLResponse)
234
  async def root():
235
- # Open and read the content of index.html (in the same folder as the app)
236
  file_path = "index.html"
237
-
238
  try:
239
  with open(file_path, "r") as file:
240
  html_content = file.read()
@@ -242,29 +215,55 @@ async def root():
242
  except FileNotFoundError:
243
  return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
244
 
245
- async def get_models():
 
 
246
  try:
247
- # Load the models from models.json in the same folder
248
- file_path = Path(__file__).parent / 'models.json'
249
- with open(file_path, 'r') as f:
250
- return json.load(f)
251
  except FileNotFoundError:
252
- raise HTTPException(status_code=404, detail="models.json not found")
253
- except json.JSONDecodeError:
254
- raise HTTPException(status_code=500, detail="Error decoding models.json")
255
- @app.get("api/v1/models")
256
  @app.get("/models")
257
  async def return_models():
258
  return await get_models()
259
- server_status = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  @app.post("/chat/completions")
261
  @app.post("/api/v1/chat/completions")
262
  async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
263
  # Check server status
 
 
 
 
 
 
264
  model_to_use = payload.model if payload.model else "gpt-4o-mini"
265
 
266
  # Validate model availability
267
- if model_to_use not in available_model_ids:
268
  raise HTTPException(
269
  status_code=400,
270
  detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
@@ -276,31 +275,28 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
276
  payload_dict = payload.dict()
277
  payload_dict["model"] = model_to_use
278
 
 
 
 
279
  # Select the appropriate endpoint
280
  if model_to_use in mistral_models:
281
- endpoint = mistral_api
282
  custom_headers = {
283
- "Authorization": f"Bearer {mistral_key}"
284
  }
285
  elif model_to_use in alternate_models:
286
- endpoint = secret_api_endpoint_2
287
  custom_headers = {}
288
  else:
289
- endpoint = secret_api_endpoint
290
  custom_headers = {}
291
 
292
- # Current time and IP logging
293
  current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
294
- aaip = request.client.host
295
- print(f"Time: {current_time}, {aaip}, {model_to_use}, server status: {server_status}")
296
- print(payload_dict)
297
-
298
- if not server_status:
299
- return JSONResponse(
300
- status_code=503,
301
- content={"message": "Server is under maintenance. Please try again later."}
302
- )
303
 
 
304
  scraper = cloudscraper.create_scraper()
305
 
306
  async def stream_generator(payload_dict):
@@ -314,55 +310,52 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
314
  )
315
 
316
  # Handle response errors
317
- if response.status_code == 422:
318
- raise HTTPException(status_code=422, detail="Unprocessable entity. Check your payload.")
319
- elif response.status_code == 400:
320
- raise HTTPException(status_code=400, detail="Bad request. Verify input data.")
321
- elif response.status_code == 403:
322
- raise HTTPException(status_code=403, detail="Forbidden. You do not have access to this resource.")
323
- elif response.status_code == 404:
324
- raise HTTPException(status_code=404, detail="The requested resource was not found.")
325
- elif response.status_code >= 500:
326
- raise HTTPException(status_code=500, detail="Server error. Try again later.")
327
 
328
  # Stream response lines to the client
329
  for line in response.iter_lines():
330
  if line:
331
  yield line.decode('utf-8') + "\n"
332
 
333
- except requests.exceptions.RequestException as req_err:
334
- print(response.text)
335
- raise HTTPException(status_code=500, detail=f"Request failed: {req_err}")
336
  except Exception as e:
337
- print(response.text)
338
- raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
339
 
340
  return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
341
- # Remove the duplicated endpoint and combine the functionality
342
- @app.api_route("/images/generations", methods=["GET", "POST"]) # Support both GET and POST
 
343
  async def generate_image(
344
  prompt: Optional[str] = None,
345
- model: str = "flux", # Default model
346
  seed: Optional[int] = None,
347
  width: Optional[int] = None,
348
  height: Optional[int] = None,
349
  nologo: Optional[bool] = True,
350
  private: Optional[bool] = None,
351
  enhance: Optional[bool] = None,
352
- request: Request = None, # Access raw POST data
353
  authenticated: bool = Depends(verify_api_key)
354
  ):
355
- """
356
- Generate an image using the Image Generation API.
357
- """
358
  # Validate the image endpoint
 
359
  if not image_endpoint:
360
  raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.")
 
361
  usage_tracker.record_request(endpoint="/images/generations")
 
362
  # Handle GET and POST prompts
363
  if request.method == "POST":
364
  try:
365
- body = await request.json() # Parse JSON body
366
  prompt = body.get("prompt", "").strip()
367
  if not prompt:
368
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
@@ -377,12 +370,11 @@ async def generate_image(
377
  encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt')
378
 
379
  # Construct the URL with the encoded prompt
380
- base_url = image_endpoint.rstrip('/') # Remove trailing slash if present
381
  url = f"{base_url}/{encoded_prompt}"
382
 
383
  # Prepare query parameters with validation
384
  params = {}
385
-
386
  if model and isinstance(model, str):
387
  params['model'] = model
388
  if seed is not None and isinstance(seed, int):
@@ -399,29 +391,25 @@ async def generate_image(
399
  params['enhance'] = str(enhance).lower()
400
 
401
  try:
402
- timeout = httpx.Timeout(60.0) # Set a reasonable timeout
403
- async with httpx.AsyncClient(timeout=timeout) as client:
404
  response = await client.get(url, params=params, follow_redirects=True)
405
 
406
  # Check for various error conditions
407
- if response.status_code == 404:
408
- raise HTTPException(status_code=404, detail="Image generation service not found")
409
- elif response.status_code == 400:
410
- raise HTTPException(status_code=400, detail="Invalid parameters provided to image service")
411
- elif response.status_code == 429:
412
- raise HTTPException(status_code=429, detail="Too many requests to image service")
413
- elif response.status_code != 200:
414
- raise HTTPException(
415
- status_code=response.status_code,
416
- detail=f"Image generation failed with status code {response.status_code}"
417
- )
418
 
419
  # Verify content type
420
  content_type = response.headers.get('content-type', '')
421
  if not content_type.startswith('image/'):
422
  raise HTTPException(
423
  status_code=500,
424
- detail=f"Unexpected content type received: {content_type}"
425
  )
426
 
427
  return StreamingResponse(
@@ -435,35 +423,31 @@ async def generate_image(
435
 
436
  except httpx.TimeoutException:
437
  raise HTTPException(status_code=504, detail="Image generation request timed out")
438
- except httpx.RequestError as e:
439
- raise HTTPException(status_code=500, detail=f"Failed to contact image service: {str(e)}")
440
- except Exception as e:
441
- raise HTTPException(status_code=500, detail=f"Unexpected error during image generation: {str(e)}")
442
- @app.get("/playground", response_class=HTMLResponse)
443
- async def playground():
444
- # Open and read the content of playground.html (in the same folder as the app)
445
- file_path = "playground.html"
446
 
 
 
 
447
  try:
448
- with open(file_path, "r") as file:
449
- html_content = file.read()
450
- return HTMLResponse(content=html_content)
451
- except FileNotFoundError:
452
- return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
453
-
454
- def load_model_ids(json_file_path):
455
- try:
456
- with open(json_file_path, 'r') as f:
457
- models_data = json.load(f)
458
- # Extract 'id' from each model object
459
- model_ids = [model['id'] for model in models_data if 'id' in model]
460
- return model_ids
461
- except FileNotFoundError:
462
- print("Error: models.json file not found.")
463
- return []
464
- except json.JSONDecodeError:
465
- print("Error: Invalid JSON format in models.json.")
466
- return []
467
  @app.get("/usage")
468
  async def get_usage(days: int = 7):
469
  """Retrieve usage statistics"""
@@ -474,6 +458,7 @@ async def usage_page():
474
  """Serve an HTML page showing usage statistics"""
475
  # Retrieve usage data
476
  usage_data = usage_tracker.get_usage_summary()
 
477
  # Model Usage Table Rows
478
  model_usage_rows = "\n".join([
479
  f"""
@@ -485,6 +470,7 @@ async def usage_page():
485
  </tr>
486
  """ for model, model_data in usage_data['models'].items()
487
  ])
 
488
  # API Endpoint Usage Table Rows
489
  api_usage_rows = "\n".join([
490
  f"""
@@ -496,6 +482,7 @@ async def usage_page():
496
  </tr>
497
  """ for endpoint, endpoint_data in usage_data['api_endpoints'].items()
498
  ])
 
499
  # Daily Usage Table Rows
500
  daily_usage_rows = "\n".join([
501
  "\n".join([
@@ -660,39 +647,40 @@ async def usage_page():
660
  </html>
661
  """
662
  return HTMLResponse(content=html_content)
663
- @app.get("/meme")
664
- async def get_meme():
665
- try:
666
- response = requests.get("https://meme-api.com/gimme")
667
- response_data = response.json()
668
-
669
- meme_url = response_data.get("url")
670
-
671
- if meme_url:
672
- def stream_image():
673
- with requests.get(meme_url, stream=True) as image_response:
674
- for chunk in image_response.iter_content(chunk_size=1024):
675
- yield chunk
676
-
677
- return StreamingResponse(stream_image(), media_type="image/png")
678
- else:
679
- raise HTTPException(status_code=404, detail="No mimi found :(")
680
 
 
 
 
 
 
 
 
 
681
  except Exception as e:
682
- raise HTTPException(status_code=500, detail=str(e))
683
-
 
684
  @app.on_event("startup")
685
  async def startup_event():
686
  global available_model_ids
687
  available_model_ids = load_model_ids("models.json")
688
- print(f"Loaded model IDs: {available_model_ids}")
689
- print("API endpoints:")
690
- print("GET /")
691
- print("GET /models")
692
- print("GET /searchgpt")
693
- print("POST /chat/completions")
694
- print("GET /images/generations")
695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
  if __name__ == "__main__":
697
  import uvicorn
698
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import os
2
  from dotenv import load_dotenv
3
+ from fastapi import FastAPI, HTTPException, Request, Depends, Security
4
+ from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
5
+ from fastapi.security import APIKeyHeader
6
  from pydantic import BaseModel
7
  import httpx
 
8
  from functools import lru_cache
9
+ from pathlib import Path
 
 
 
10
  import json
 
11
  import datetime
12
  import time
13
+ from typing import Optional, Dict, List, Any, Generator
14
+ import asyncio
 
 
 
 
 
15
  from starlette.status import HTTP_403_FORBIDDEN
16
+ import cloudscraper
17
 
18
+ # Load environment variables once at startup
19
+ load_dotenv()
20
 
21
+ # API key security scheme
22
+ api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Initialize usage tracker
25
+ from usage_tracker import UsageTracker
26
  usage_tracker = UsageTracker()
 
27
 
28
  app = FastAPI()
29
+
30
+ # Environment variables (cached)
31
+ @lru_cache(maxsize=1)
32
+ def get_env_vars():
33
+ return {
34
+ 'api_keys': os.getenv('API_KEYS', '').split(','),
35
+ 'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'),
36
+ 'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
37
+ 'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'),
38
+ 'mistral_api': "https://api.mistral.ai",
39
+ 'mistral_key': os.getenv('MISTRAL_KEY'),
40
+ 'image_endpoint': os.getenv("IMAGE_ENDPOINT"),
41
+ 'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
42
+ }
43
+
44
+ # Configuration for models
45
  mistral_models = [
46
  "mistral-large-latest",
47
  "pixtral-large-latest",
 
54
  "codestral-latest"
55
  ]
56
 
 
 
 
 
 
 
 
 
57
  alternate_models = {
58
  "gpt-4o-mini",
59
  "deepseek-v3",
 
68
  "hermes-3-llama-3.2-3b"
69
  }
70
 
71
+ # Request payload model
72
  class Payload(BaseModel):
73
  model: str
74
  messages: list
75
  stream: bool = False
76
+
77
+ # Server status global variable
78
+ server_status = True
79
+ available_model_ids: List[str] = []
80
+
81
+ # Create a reusable httpx client
82
+ @lru_cache(maxsize=1)
83
+ def get_async_client():
84
+ return httpx.AsyncClient(timeout=60.0)
85
+
86
+ # API key validation
87
+ async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
88
+ if not api_key:
89
+ raise HTTPException(
90
+ status_code=HTTP_403_FORBIDDEN,
91
+ detail="No API key provided"
92
+ )
93
+
94
+ # Clean the API key by removing 'Bearer ' if present
95
+ if api_key.startswith('Bearer '):
96
+ api_key = api_key[7:] # Remove 'Bearer ' prefix
97
+
98
+ # Get API keys from environment
99
+ valid_api_keys = get_env_vars()['api_keys']
100
+ if not valid_api_keys or valid_api_keys == ['']:
101
+ raise HTTPException(
102
+ status_code=HTTP_403_FORBIDDEN,
103
+ detail="API keys not configured on server"
104
+ )
105
+
106
+ # Check if the provided key is valid
107
+ if api_key not in valid_api_keys:
108
+ raise HTTPException(
109
+ status_code=HTTP_403_FORBIDDEN,
110
+ detail="Invalid API key"
111
+ )
112
+
113
+ return True
114
+
115
+ # Cache for models.json
116
+ @lru_cache(maxsize=1)
117
+ async def get_models():
118
+ try:
119
+ file_path = Path(__file__).parent / 'models.json'
120
+ with open(file_path, 'r') as f:
121
+ return json.load(f)
122
+ except (FileNotFoundError, json.JSONDecodeError) as e:
123
+ # Log the error but don't expose the exact error to users
124
+ print(f"Error loading models.json: {str(e)}")
125
+ raise HTTPException(status_code=500, detail="Error loading available models")
126
+
127
+ # Searcher function with optimized streaming
128
+ def generate_search(query: str, systemprompt: Optional[str] = None, stream: bool = True) -> Generator[str, None, None]:
129
  headers = {"User-Agent": ""}
130
 
131
  # Use the provided system prompt, or default to "Be Helpful and Friendly"
 
146
  "user_input": prompt[-1]["content"],
147
  }
148
 
149
+ # Get endpoint from environment
150
+ secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
151
+ if not secret_api_endpoint_3:
152
+ raise ValueError("Search API endpoint not configured")
153
+
154
  # Send the request to the chat endpoint
155
+ response = cloudscraper.create_scraper().post(
156
+ secret_api_endpoint_3,
157
+ headers=headers,
158
+ json=payload,
159
+ stream=True
160
+ )
161
 
162
  streaming_text = ""
163
 
 
193
  if not stream:
194
  yield streaming_text
195
 
196
+ # Basic routes
197
+ @app.get("/favicon.ico")
198
+ async def favicon():
199
+ favicon_path = Path(__file__).parent / "favicon.ico"
200
+ return FileResponse(favicon_path, media_type="image/x-icon")
201
+
202
  @app.get("/ping")
203
  async def ping():
204
  start_time = datetime.datetime.now()
205
  response_time = (datetime.datetime.now() - start_time).total_seconds()
206
  return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
207
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  @app.get("/", response_class=HTMLResponse)
209
  async def root():
 
210
  file_path = "index.html"
 
211
  try:
212
  with open(file_path, "r") as file:
213
  html_content = file.read()
 
215
  except FileNotFoundError:
216
  return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
217
 
218
+ @app.get("/playground", response_class=HTMLResponse)
219
+ async def playground():
220
+ file_path = "playground.html"
221
  try:
222
+ with open(file_path, "r") as file:
223
+ html_content = file.read()
224
+ return HTMLResponse(content=html_content)
 
225
  except FileNotFoundError:
226
+ return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
227
+
228
+ # Model routes
229
+ @app.get("/api/v1/models")
230
  @app.get("/models")
231
  async def return_models():
232
  return await get_models()
233
+
234
+ # Search routes
235
+ @app.get("/searchgpt")
236
+ async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
237
+ if not q:
238
+ raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
239
+
240
+ usage_tracker.record_request(endpoint="/searchgpt")
241
+
242
+ if stream:
243
+ return StreamingResponse(
244
+ generate_search(q, systemprompt=systemprompt, stream=True),
245
+ media_type="text/event-stream"
246
+ )
247
+ else:
248
+ # For non-streaming, collect the text and return as JSON response
249
+ response_text = "".join([chunk for chunk in generate_search(q, systemprompt=systemprompt, stream=False)])
250
+ return JSONResponse(content={"response": response_text})
251
+
252
+ # Chat completion endpoint
253
  @app.post("/chat/completions")
254
  @app.post("/api/v1/chat/completions")
255
  async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
256
  # Check server status
257
+ if not server_status:
258
+ return JSONResponse(
259
+ status_code=503,
260
+ content={"message": "Server is under maintenance. Please try again later."}
261
+ )
262
+
263
  model_to_use = payload.model if payload.model else "gpt-4o-mini"
264
 
265
  # Validate model availability
266
+ if available_model_ids and model_to_use not in available_model_ids:
267
  raise HTTPException(
268
  status_code=400,
269
  detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
 
275
  payload_dict = payload.dict()
276
  payload_dict["model"] = model_to_use
277
 
278
+ # Get environment variables
279
+ env_vars = get_env_vars()
280
+
281
  # Select the appropriate endpoint
282
  if model_to_use in mistral_models:
283
+ endpoint = env_vars['mistral_api']
284
  custom_headers = {
285
+ "Authorization": f"Bearer {env_vars['mistral_key']}"
286
  }
287
  elif model_to_use in alternate_models:
288
+ endpoint = env_vars['secret_api_endpoint_2']
289
  custom_headers = {}
290
  else:
291
+ endpoint = env_vars['secret_api_endpoint']
292
  custom_headers = {}
293
 
294
+ # Current time and IP logging (with minimal data)
295
  current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
296
+ ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
297
+ print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model_to_use}")
 
 
 
 
 
 
 
298
 
299
+ # Create scraper for each connection to avoid concurrency issues
300
  scraper = cloudscraper.create_scraper()
301
 
302
  async def stream_generator(payload_dict):
 
310
  )
311
 
312
  # Handle response errors
313
+ if response.status_code >= 400:
314
+ error_messages = {
315
+ 422: "Unprocessable entity. Check your payload.",
316
+ 400: "Bad request. Verify input data.",
317
+ 403: "Forbidden. You do not have access to this resource.",
318
+ 404: "The requested resource was not found.",
319
+ }
320
+ detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
321
+ raise HTTPException(status_code=response.status_code, detail=detail)
 
322
 
323
  # Stream response lines to the client
324
  for line in response.iter_lines():
325
  if line:
326
  yield line.decode('utf-8') + "\n"
327
 
 
 
 
328
  except Exception as e:
329
+ # Use a generic error message that doesn't expose internal details
330
+ raise HTTPException(status_code=500, detail="An error occurred while processing your request")
331
 
332
  return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
333
+
334
+ # Image generation endpoint
335
+ @app.api_route("/images/generations", methods=["GET", "POST"])
336
  async def generate_image(
337
  prompt: Optional[str] = None,
338
+ model: str = "flux",
339
  seed: Optional[int] = None,
340
  width: Optional[int] = None,
341
  height: Optional[int] = None,
342
  nologo: Optional[bool] = True,
343
  private: Optional[bool] = None,
344
  enhance: Optional[bool] = None,
345
+ request: Request = None,
346
  authenticated: bool = Depends(verify_api_key)
347
  ):
 
 
 
348
  # Validate the image endpoint
349
+ image_endpoint = get_env_vars()['image_endpoint']
350
  if not image_endpoint:
351
  raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.")
352
+
353
  usage_tracker.record_request(endpoint="/images/generations")
354
+
355
  # Handle GET and POST prompts
356
  if request.method == "POST":
357
  try:
358
+ body = await request.json()
359
  prompt = body.get("prompt", "").strip()
360
  if not prompt:
361
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
 
370
  encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt')
371
 
372
  # Construct the URL with the encoded prompt
373
+ base_url = image_endpoint.rstrip('/')
374
  url = f"{base_url}/{encoded_prompt}"
375
 
376
  # Prepare query parameters with validation
377
  params = {}
 
378
  if model and isinstance(model, str):
379
  params['model'] = model
380
  if seed is not None and isinstance(seed, int):
 
391
  params['enhance'] = str(enhance).lower()
392
 
393
  try:
394
+ async with httpx.AsyncClient(timeout=60.0) as client:
 
395
  response = await client.get(url, params=params, follow_redirects=True)
396
 
397
  # Check for various error conditions
398
+ if response.status_code != 200:
399
+ error_messages = {
400
+ 404: "Image generation service not found",
401
+ 400: "Invalid parameters provided to image service",
402
+ 429: "Too many requests to image service",
403
+ }
404
+ detail = error_messages.get(response.status_code, f"Image generation failed with status code {response.status_code}")
405
+ raise HTTPException(status_code=response.status_code, detail=detail)
 
 
 
406
 
407
  # Verify content type
408
  content_type = response.headers.get('content-type', '')
409
  if not content_type.startswith('image/'):
410
  raise HTTPException(
411
  status_code=500,
412
+ detail="Unexpected content type received from image service"
413
  )
414
 
415
  return StreamingResponse(
 
423
 
424
  except httpx.TimeoutException:
425
  raise HTTPException(status_code=504, detail="Image generation request timed out")
426
+ except httpx.RequestError:
427
+ raise HTTPException(status_code=500, detail="Failed to contact image service")
428
+ except Exception:
429
+ raise HTTPException(status_code=500, detail="Unexpected error during image generation")
 
 
 
 
430
 
431
+ # Usage statistics
432
+ @app.get("/meme")
433
+ async def get_meme():
434
  try:
435
+ async with httpx.AsyncClient() as client:
436
+ response = await client.get("https://meme-api.com/gimme")
437
+ response_data = response.json()
438
+
439
+ meme_url = response_data.get("url")
440
+ if not meme_url:
441
+ raise HTTPException(status_code=404, detail="No meme found")
442
+
443
+ image_response = await client.get(meme_url, follow_redirects=True)
444
+ return StreamingResponse(
445
+ image_response.iter_bytes(),
446
+ media_type=image_response.headers.get("content-type", "image/png")
447
+ )
448
+ except Exception:
449
+ raise HTTPException(status_code=500, detail="Failed to retrieve meme")
450
+
 
 
 
451
  @app.get("/usage")
452
  async def get_usage(days: int = 7):
453
  """Retrieve usage statistics"""
 
458
  """Serve an HTML page showing usage statistics"""
459
  # Retrieve usage data
460
  usage_data = usage_tracker.get_usage_summary()
461
+
462
  # Model Usage Table Rows
463
  model_usage_rows = "\n".join([
464
  f"""
 
470
  </tr>
471
  """ for model, model_data in usage_data['models'].items()
472
  ])
473
+
474
  # API Endpoint Usage Table Rows
475
  api_usage_rows = "\n".join([
476
  f"""
 
482
  </tr>
483
  """ for endpoint, endpoint_data in usage_data['api_endpoints'].items()
484
  ])
485
+
486
  # Daily Usage Table Rows
487
  daily_usage_rows = "\n".join([
488
  "\n".join([
 
647
  </html>
648
  """
649
  return HTMLResponse(content=html_content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
 
651
+ # Utility function for loading model IDs
652
+ def load_model_ids(json_file_path):
653
+ try:
654
+ with open(json_file_path, 'r') as f:
655
+ models_data = json.load(f)
656
+ # Extract 'id' from each model object
657
+ model_ids = [model['id'] for model in models_data if 'id' in model]
658
+ return model_ids
659
  except Exception as e:
660
+ print(f"Error loading model IDs: {str(e)}")
661
+ return []
662
+
663
  @app.on_event("startup")
664
  async def startup_event():
665
  global available_model_ids
666
  available_model_ids = load_model_ids("models.json")
667
+ print(f"Loaded {len(available_model_ids)} model IDs")
668
+ print("API started successfully")
 
 
 
 
 
669
 
670
+ # Validate critical environment variables
671
+ env_vars = get_env_vars()
672
+ missing_vars = []
673
+
674
+ if not env_vars['secret_api_endpoint']:
675
+ missing_vars.append('SECRET_API_ENDPOINT')
676
+ if not env_vars['secret_api_endpoint_2']:
677
+ missing_vars.append('SECRET_API_ENDPOINT_2')
678
+ if not env_vars['secret_api_endpoint_3']:
679
+ missing_vars.append('SECRET_API_ENDPOINT_3')
680
+
681
+ if missing_vars:
682
+ print(f"WARNING: The following required environment variables are missing: {', '.join(missing_vars)}")
683
+
684
  if __name__ == "__main__":
685
  import uvicorn
686
+ uvicorn.run(app, host="0.0.0.0", port=8000)