ParthSadaria commited on
Commit
dc21031
·
verified ·
1 Parent(s): 005b429

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +43 -6
main.py CHANGED
@@ -17,7 +17,44 @@ import time
17
  from usage_tracker import UsageTracker
18
  from starlette.middleware.base import BaseHTTPMiddleware
19
  from collections import defaultdict
 
 
 
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class RateLimitMiddleware(BaseHTTPMiddleware):
22
  def __init__(self, app, requests_per_second: int = 2):
23
  super().__init__(app)
@@ -62,7 +99,6 @@ app = FastAPI()
62
  app.add_middleware(RateLimitMiddleware, requests_per_second=2)
63
 
64
  # Get API keys and secret endpoint from environment variables
65
- api_keys_str = os.getenv('API_KEYS') #deprecated -_-
66
  valid_api_keys = api_keys_str.split(',') if api_keys_str else []
67
  secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
68
  secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
@@ -75,7 +111,7 @@ if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoi
75
  raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
76
 
77
  # Define models that should use the secondary endpoint
78
- alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
79
 
80
  available_model_ids = []
81
  class Payload(BaseModel):
@@ -154,7 +190,7 @@ async def ping():
154
  return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
155
 
156
  @app.get("/searchgpt")
157
- async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
158
  if not q:
159
  raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
160
  usage_tracker.record_request(endpoint="/searchgpt")
@@ -191,12 +227,12 @@ async def get_models():
191
  raise HTTPException(status_code=500, detail="Error decoding models.json")
192
  @app.get("api/v1/models")
193
  @app.get("/models")
194
- async def fetch_models():
195
  return await get_models()
196
  server_status = True
197
  @app.post("/chat/completions")
198
  @app.post("api/v1/chat/completions")
199
- async def get_completion(payload: Payload, request: Request):
200
  # Check server status
201
 
202
 
@@ -216,7 +252,7 @@ async def get_completion(payload: Payload, request: Request):
216
  payload_dict["model"] = model_to_use
217
  # payload_dict["stream"] = payload_dict.get("stream", False)
218
  # Select the appropriate endpoint
219
- endpoint = secret_api_endpoint_2 if model_to_use in alternate_models else secret_api_endpoint
220
 
221
  # Current time and IP logging
222
  current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
@@ -286,6 +322,7 @@ async def generate_image(
286
  private: Optional[bool] = None,
287
  enhance: Optional[bool] = None,
288
  request: Request = None, # Access raw POST data
 
289
  ):
290
  """
291
  Generate an image using the Image Generation API.
 
17
  from usage_tracker import UsageTracker
18
  from starlette.middleware.base import BaseHTTPMiddleware
19
  from collections import defaultdict
20
+ from fastapi import Security #new
21
+ from fastapi.security import APIKeyHeader
22
+ from starlette.exceptions import HTTPException
23
+ from starlette.status import HTTP_403_FORBIDDEN
24
 
25
+ # API key header scheme
26
+ api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
27
+
28
+ # Function to validate API key
29
+ async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
30
+ if not api_key:
31
+ raise HTTPException(
32
+ status_code=HTTP_403_FORBIDDEN,
33
+ detail="No API key provided"
34
+ )
35
+
36
+ # Clean the API key by removing 'Bearer ' if present
37
+ if api_key.startswith('Bearer '):
38
+ api_key = api_key[7:] # Remove 'Bearer ' prefix
39
+
40
+ # Get API keys from environment
41
+ api_keys_str = os.getenv('API_KEYS')
42
+ if not api_keys_str:
43
+ raise HTTPException(
44
+ status_code=HTTP_403_FORBIDDEN,
45
+ detail="API keys not configured on server"
46
+ )
47
+
48
+ valid_api_keys = api_keys_str.split(',')
49
+
50
+ # Check if the provided key is valid
51
+ if api_key not in valid_api_keys:
52
+ raise HTTPException(
53
+ status_code=HTTP_403_FORBIDDEN,
54
+ detail="Invalid API key"
55
+ )
56
+
57
+ return True
58
  class RateLimitMiddleware(BaseHTTPMiddleware):
59
  def __init__(self, app, requests_per_second: int = 2):
60
  super().__init__(app)
 
99
  app.add_middleware(RateLimitMiddleware, requests_per_second=2)
100
 
101
  # Get API keys and secret endpoint from environment variables
 
102
  valid_api_keys = api_keys_str.split(',') if api_keys_str else []
103
  secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
104
  secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
 
111
  raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
112
 
113
  # Define models that should use the secondary endpoint
114
+ # alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
115
 
116
  available_model_ids = []
117
  class Payload(BaseModel):
 
190
  return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
191
 
192
  @app.get("/searchgpt")
193
+ async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None,authenticated: bool = Depends(verify_api_key)):
194
  if not q:
195
  raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
196
  usage_tracker.record_request(endpoint="/searchgpt")
 
227
  raise HTTPException(status_code=500, detail="Error decoding models.json")
228
  @app.get("api/v1/models")
229
  @app.get("/models")
230
+ async def fetch_models(authenticated: bool = Depends(verify_api_key)):
231
  return await get_models()
232
  server_status = True
233
  @app.post("/chat/completions")
234
  @app.post("api/v1/chat/completions")
235
+ async def get_completion(payload: Payload, request: Request,authenticated: bool = Depends(verify_api_key)):
236
  # Check server status
237
 
238
 
 
252
  payload_dict["model"] = model_to_use
253
  # payload_dict["stream"] = payload_dict.get("stream", False)
254
  # Select the appropriate endpoint
255
+ endpoint = secret_api_endpoint
256
 
257
  # Current time and IP logging
258
  current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
 
322
  private: Optional[bool] = None,
323
  enhance: Optional[bool] = None,
324
  request: Request = None, # Access raw POST data
325
+ authenticated: bool = Depends(verify_api_key)
326
  ):
327
  """
328
  Generate an image using the Image Generation API.