Niansuh commited on
Commit
e813605
·
verified ·
1 Parent(s): 25032a9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +28 -54
main.py CHANGED
@@ -39,72 +39,45 @@ if AVAILABLE_MODELS:
39
  else:
40
  AVAILABLE_MODELS = [] # If empty, all models are available
41
 
42
- # Simple in-memory rate limiter
43
  rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
44
- ip_rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
45
 
46
  # Define cleanup interval and window
47
  CLEANUP_INTERVAL = 60 # seconds
48
  RATE_LIMIT_WINDOW = 60 # seconds
49
 
50
  async def cleanup_rate_limit_stores():
 
 
 
51
  while True:
52
  current_time = time.time()
53
- # Clean API key rate limit store
54
- keys_to_delete = [key for key, value in rate_limit_store.items() if current_time - value["timestamp"] > RATE_LIMIT_WINDOW * 2]
55
- for key in keys_to_delete:
56
- del rate_limit_store[key]
57
- logger.debug(f"Cleaned up rate_limit_store for API key: {key}")
58
-
59
- # Clean IP rate limit store
60
- ips_to_delete = [ip for ip, value in ip_rate_limit_store.items() if current_time - value["timestamp"] > RATE_LIMIT_WINDOW * 2]
61
  for ip in ips_to_delete:
62
- del ip_rate_limit_store[ip]
63
- logger.debug(f"Cleaned up ip_rate_limit_store for IP: {ip}")
64
-
65
  await asyncio.sleep(CLEANUP_INTERVAL)
66
 
67
- # Define rate limiter for endpoints requiring API key
68
- async def rate_limiter_with_api_key(request: Request, api_key: str = Depends(lambda request: get_api_key(request, authorization=request.headers.get('Authorization')))):
69
- client_ip = request.client.host
70
- current_time = time.time()
71
-
72
- # Rate limiting per API key
73
- window_start = rate_limit_store[api_key]["timestamp"]
74
- if current_time - window_start > RATE_LIMIT_WINDOW:
75
- rate_limit_store[api_key] = {"count": 1, "timestamp": current_time}
76
- else:
77
- if rate_limit_store[api_key]["count"] >= RATE_LIMIT:
78
- logger.warning(f"Rate limit exceeded for API key: {api_key} from IP: {client_ip}")
79
- raise HTTPException(status_code=429, detail='Rate limit exceeded for API key')
80
- rate_limit_store[api_key]["count"] += 1
81
-
82
- # Rate limiting per IP address
83
- window_start_ip = ip_rate_limit_store[client_ip]["timestamp"]
84
- if current_time - window_start_ip > RATE_LIMIT_WINDOW:
85
- ip_rate_limit_store[client_ip] = {"count": 1, "timestamp": current_time}
86
- else:
87
- if ip_rate_limit_store[client_ip]["count"] >= RATE_LIMIT:
88
- logger.warning(f"Rate limit exceeded for IP address: {client_ip}")
89
- raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address')
90
- ip_rate_limit_store[client_ip]["count"] += 1
91
-
92
- # Define rate limiter for endpoints NOT requiring API key
93
  async def rate_limiter_per_ip(request: Request):
 
 
 
94
  client_ip = request.client.host
95
  current_time = time.time()
96
 
97
- # Rate limiting per IP address
98
- window_start_ip = ip_rate_limit_store[client_ip]["timestamp"]
99
- if current_time - window_start_ip > RATE_LIMIT_WINDOW:
100
- ip_rate_limit_store[client_ip] = {"count": 1, "timestamp": current_time}
101
  else:
102
- if ip_rate_limit_store[client_ip]["count"] >= RATE_LIMIT:
103
  logger.warning(f"Rate limit exceeded for IP address: {client_ip}")
104
  raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address')
105
- ip_rate_limit_store[client_ip]["count"] += 1
106
 
107
  async def get_api_key(request: Request, authorization: str = Header(None)) -> str:
 
 
 
108
  client_ip = request.client.host
109
  if authorization is None or not authorization.startswith('Bearer '):
110
  logger.warning(f"Invalid or missing authorization header from IP: {client_ip}")
@@ -412,11 +385,17 @@ class Blackbox:
412
  # FastAPI app setup
413
  app = FastAPI()
414
 
415
- # Middleware to enhance security
 
 
 
 
 
 
416
  @app.middleware("http")
417
  async def security_middleware(request: Request, call_next):
418
  client_ip = request.client.host
419
- # Enforce that POST requests to sensitive endpoints must have a valid Content-Type
420
  if request.method == "POST" and request.url.path == "/v1/chat/completions":
421
  content_type = request.headers.get("Content-Type")
422
  if content_type != "application/json":
@@ -435,6 +414,7 @@ async def security_middleware(request: Request, call_next):
435
  response = await call_next(request)
436
  return response
437
 
 
438
  class Message(BaseModel):
439
  role: str
440
  content: str
@@ -470,7 +450,7 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
470
  "usage": None,
471
  }
472
 
473
- @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_with_api_key)])
474
  async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
475
  client_ip = req.client.host
476
  # Redact user messages only for logging purposes
@@ -621,12 +601,6 @@ async def http_exception_handler(request: Request, exc: HTTPException):
621
  },
622
  )
623
 
624
- # Add the cleanup task when the app starts
625
- @app.on_event("startup")
626
- async def startup_event():
627
- asyncio.create_task(cleanup_rate_limit_stores())
628
- logger.info("Started rate limit store cleanup task.")
629
-
630
  if __name__ == "__main__":
631
  import uvicorn
632
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
39
  else:
40
  AVAILABLE_MODELS = [] # If empty, all models are available
41
 
42
+ # Simple in-memory rate limiter based solely on IP addresses
43
  rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
 
44
 
45
  # Define cleanup interval and window
46
  CLEANUP_INTERVAL = 60 # seconds
47
  RATE_LIMIT_WINDOW = 60 # seconds
48
 
49
  async def cleanup_rate_limit_stores():
50
+ """
51
+ Periodically cleans up stale entries in the rate_limit_store to prevent memory bloat.
52
+ """
53
  while True:
54
  current_time = time.time()
55
+ ips_to_delete = [ip for ip, value in rate_limit_store.items() if current_time - value["timestamp"] > RATE_LIMIT_WINDOW * 2]
 
 
 
 
 
 
 
56
  for ip in ips_to_delete:
57
+ del rate_limit_store[ip]
58
+ logger.debug(f"Cleaned up rate_limit_store for IP: {ip}")
 
59
  await asyncio.sleep(CLEANUP_INTERVAL)
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  async def rate_limiter_per_ip(request: Request):
62
+ """
63
+ Rate limiter that enforces a limit based on the client's IP address.
64
+ """
65
  client_ip = request.client.host
66
  current_time = time.time()
67
 
68
+ # Initialize or update the count and timestamp
69
+ if current_time - rate_limit_store[client_ip]["timestamp"] > RATE_LIMIT_WINDOW:
70
+ rate_limit_store[client_ip] = {"count": 1, "timestamp": current_time}
 
71
  else:
72
+ if rate_limit_store[client_ip]["count"] >= RATE_LIMIT:
73
  logger.warning(f"Rate limit exceeded for IP address: {client_ip}")
74
  raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address')
75
+ rate_limit_store[client_ip]["count"] += 1
76
 
77
  async def get_api_key(request: Request, authorization: str = Header(None)) -> str:
78
+ """
79
+ Dependency to extract and validate the API key from the Authorization header.
80
+ """
81
  client_ip = request.client.host
82
  if authorization is None or not authorization.startswith('Bearer '):
83
  logger.warning(f"Invalid or missing authorization header from IP: {client_ip}")
 
385
  # FastAPI app setup
386
  app = FastAPI()
387
 
388
+ # Add the cleanup task when the app starts
389
+ @app.on_event("startup")
390
+ async def startup_event():
391
+ asyncio.create_task(cleanup_rate_limit_stores())
392
+ logger.info("Started rate limit store cleanup task.")
393
+
394
+ # Middleware to enhance security and enforce Content-Type for specific endpoints
395
  @app.middleware("http")
396
  async def security_middleware(request: Request, call_next):
397
  client_ip = request.client.host
398
+ # Enforce that POST requests to /v1/chat/completions must have Content-Type: application/json
399
  if request.method == "POST" and request.url.path == "/v1/chat/completions":
400
  content_type = request.headers.get("Content-Type")
401
  if content_type != "application/json":
 
414
  response = await call_next(request)
415
  return response
416
 
417
+ # Request Models
418
  class Message(BaseModel):
419
  role: str
420
  content: str
 
450
  "usage": None,
451
  }
452
 
453
+ @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
454
  async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
455
  client_ip = req.client.host
456
  # Redact user messages only for logging purposes
 
601
  },
602
  )
603
 
 
 
 
 
 
 
604
  if __name__ == "__main__":
605
  import uvicorn
606
  uvicorn.run(app, host="0.0.0.0", port=8000)