Update main.py
Browse files
main.py
CHANGED
@@ -39,72 +39,45 @@ if AVAILABLE_MODELS:
|
|
39 |
else:
|
40 |
AVAILABLE_MODELS = [] # If empty, all models are available
|
41 |
|
42 |
-
# Simple in-memory rate limiter
|
43 |
rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
|
44 |
-
ip_rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
|
45 |
|
46 |
# Define cleanup interval and window
|
47 |
CLEANUP_INTERVAL = 60 # seconds
|
48 |
RATE_LIMIT_WINDOW = 60 # seconds
|
49 |
|
50 |
async def cleanup_rate_limit_stores():
|
|
|
|
|
|
|
51 |
while True:
|
52 |
current_time = time.time()
|
53 |
-
|
54 |
-
keys_to_delete = [key for key, value in rate_limit_store.items() if current_time - value["timestamp"] > RATE_LIMIT_WINDOW * 2]
|
55 |
-
for key in keys_to_delete:
|
56 |
-
del rate_limit_store[key]
|
57 |
-
logger.debug(f"Cleaned up rate_limit_store for API key: {key}")
|
58 |
-
|
59 |
-
# Clean IP rate limit store
|
60 |
-
ips_to_delete = [ip for ip, value in ip_rate_limit_store.items() if current_time - value["timestamp"] > RATE_LIMIT_WINDOW * 2]
|
61 |
for ip in ips_to_delete:
|
62 |
-
del
|
63 |
-
logger.debug(f"Cleaned up
|
64 |
-
|
65 |
await asyncio.sleep(CLEANUP_INTERVAL)
|
66 |
|
67 |
-
# Define rate limiter for endpoints requiring API key
|
68 |
-
async def rate_limiter_with_api_key(request: Request, api_key: str = Depends(lambda request: get_api_key(request, authorization=request.headers.get('Authorization')))):
|
69 |
-
client_ip = request.client.host
|
70 |
-
current_time = time.time()
|
71 |
-
|
72 |
-
# Rate limiting per API key
|
73 |
-
window_start = rate_limit_store[api_key]["timestamp"]
|
74 |
-
if current_time - window_start > RATE_LIMIT_WINDOW:
|
75 |
-
rate_limit_store[api_key] = {"count": 1, "timestamp": current_time}
|
76 |
-
else:
|
77 |
-
if rate_limit_store[api_key]["count"] >= RATE_LIMIT:
|
78 |
-
logger.warning(f"Rate limit exceeded for API key: {api_key} from IP: {client_ip}")
|
79 |
-
raise HTTPException(status_code=429, detail='Rate limit exceeded for API key')
|
80 |
-
rate_limit_store[api_key]["count"] += 1
|
81 |
-
|
82 |
-
# Rate limiting per IP address
|
83 |
-
window_start_ip = ip_rate_limit_store[client_ip]["timestamp"]
|
84 |
-
if current_time - window_start_ip > RATE_LIMIT_WINDOW:
|
85 |
-
ip_rate_limit_store[client_ip] = {"count": 1, "timestamp": current_time}
|
86 |
-
else:
|
87 |
-
if ip_rate_limit_store[client_ip]["count"] >= RATE_LIMIT:
|
88 |
-
logger.warning(f"Rate limit exceeded for IP address: {client_ip}")
|
89 |
-
raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address')
|
90 |
-
ip_rate_limit_store[client_ip]["count"] += 1
|
91 |
-
|
92 |
-
# Define rate limiter for endpoints NOT requiring API key
|
93 |
async def rate_limiter_per_ip(request: Request):
|
|
|
|
|
|
|
94 |
client_ip = request.client.host
|
95 |
current_time = time.time()
|
96 |
|
97 |
-
#
|
98 |
-
|
99 |
-
|
100 |
-
ip_rate_limit_store[client_ip] = {"count": 1, "timestamp": current_time}
|
101 |
else:
|
102 |
-
if
|
103 |
logger.warning(f"Rate limit exceeded for IP address: {client_ip}")
|
104 |
raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address')
|
105 |
-
|
106 |
|
107 |
async def get_api_key(request: Request, authorization: str = Header(None)) -> str:
|
|
|
|
|
|
|
108 |
client_ip = request.client.host
|
109 |
if authorization is None or not authorization.startswith('Bearer '):
|
110 |
logger.warning(f"Invalid or missing authorization header from IP: {client_ip}")
|
@@ -412,11 +385,17 @@ class Blackbox:
|
|
412 |
# FastAPI app setup
|
413 |
app = FastAPI()
|
414 |
|
415 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
@app.middleware("http")
|
417 |
async def security_middleware(request: Request, call_next):
|
418 |
client_ip = request.client.host
|
419 |
-
# Enforce that POST requests to
|
420 |
if request.method == "POST" and request.url.path == "/v1/chat/completions":
|
421 |
content_type = request.headers.get("Content-Type")
|
422 |
if content_type != "application/json":
|
@@ -435,6 +414,7 @@ async def security_middleware(request: Request, call_next):
|
|
435 |
response = await call_next(request)
|
436 |
return response
|
437 |
|
|
|
438 |
class Message(BaseModel):
|
439 |
role: str
|
440 |
content: str
|
@@ -470,7 +450,7 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
|
|
470 |
"usage": None,
|
471 |
}
|
472 |
|
473 |
-
@app.post("/v1/chat/completions", dependencies=[Depends(
|
474 |
async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
|
475 |
client_ip = req.client.host
|
476 |
# Redact user messages only for logging purposes
|
@@ -621,12 +601,6 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
|
621 |
},
|
622 |
)
|
623 |
|
624 |
-
# Add the cleanup task when the app starts
|
625 |
-
@app.on_event("startup")
|
626 |
-
async def startup_event():
|
627 |
-
asyncio.create_task(cleanup_rate_limit_stores())
|
628 |
-
logger.info("Started rate limit store cleanup task.")
|
629 |
-
|
630 |
if __name__ == "__main__":
|
631 |
import uvicorn
|
632 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
39 |
else:
|
40 |
AVAILABLE_MODELS = [] # If empty, all models are available
|
41 |
|
42 |
+
# Simple in-memory rate limiter based solely on IP addresses
|
43 |
rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
|
|
|
44 |
|
45 |
# Define cleanup interval and window
|
46 |
CLEANUP_INTERVAL = 60 # seconds
|
47 |
RATE_LIMIT_WINDOW = 60 # seconds
|
48 |
|
49 |
async def cleanup_rate_limit_stores():
|
50 |
+
"""
|
51 |
+
Periodically cleans up stale entries in the rate_limit_store to prevent memory bloat.
|
52 |
+
"""
|
53 |
while True:
|
54 |
current_time = time.time()
|
55 |
+
ips_to_delete = [ip for ip, value in rate_limit_store.items() if current_time - value["timestamp"] > RATE_LIMIT_WINDOW * 2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
for ip in ips_to_delete:
|
57 |
+
del rate_limit_store[ip]
|
58 |
+
logger.debug(f"Cleaned up rate_limit_store for IP: {ip}")
|
|
|
59 |
await asyncio.sleep(CLEANUP_INTERVAL)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
async def rate_limiter_per_ip(request: Request):
|
62 |
+
"""
|
63 |
+
Rate limiter that enforces a limit based on the client's IP address.
|
64 |
+
"""
|
65 |
client_ip = request.client.host
|
66 |
current_time = time.time()
|
67 |
|
68 |
+
# Initialize or update the count and timestamp
|
69 |
+
if current_time - rate_limit_store[client_ip]["timestamp"] > RATE_LIMIT_WINDOW:
|
70 |
+
rate_limit_store[client_ip] = {"count": 1, "timestamp": current_time}
|
|
|
71 |
else:
|
72 |
+
if rate_limit_store[client_ip]["count"] >= RATE_LIMIT:
|
73 |
logger.warning(f"Rate limit exceeded for IP address: {client_ip}")
|
74 |
raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address')
|
75 |
+
rate_limit_store[client_ip]["count"] += 1
|
76 |
|
77 |
async def get_api_key(request: Request, authorization: str = Header(None)) -> str:
|
78 |
+
"""
|
79 |
+
Dependency to extract and validate the API key from the Authorization header.
|
80 |
+
"""
|
81 |
client_ip = request.client.host
|
82 |
if authorization is None or not authorization.startswith('Bearer '):
|
83 |
logger.warning(f"Invalid or missing authorization header from IP: {client_ip}")
|
|
|
385 |
# FastAPI app setup
|
386 |
app = FastAPI()
|
387 |
|
388 |
+
# Add the cleanup task when the app starts
|
389 |
+
@app.on_event("startup")
|
390 |
+
async def startup_event():
|
391 |
+
asyncio.create_task(cleanup_rate_limit_stores())
|
392 |
+
logger.info("Started rate limit store cleanup task.")
|
393 |
+
|
394 |
+
# Middleware to enhance security and enforce Content-Type for specific endpoints
|
395 |
@app.middleware("http")
|
396 |
async def security_middleware(request: Request, call_next):
|
397 |
client_ip = request.client.host
|
398 |
+
# Enforce that POST requests to /v1/chat/completions must have Content-Type: application/json
|
399 |
if request.method == "POST" and request.url.path == "/v1/chat/completions":
|
400 |
content_type = request.headers.get("Content-Type")
|
401 |
if content_type != "application/json":
|
|
|
414 |
response = await call_next(request)
|
415 |
return response
|
416 |
|
417 |
+
# Request Models
|
418 |
class Message(BaseModel):
|
419 |
role: str
|
420 |
content: str
|
|
|
450 |
"usage": None,
|
451 |
}
|
452 |
|
453 |
+
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
|
454 |
async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
|
455 |
client_ip = req.client.host
|
456 |
# Redact user messages only for logging purposes
|
|
|
601 |
},
|
602 |
)
|
603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
if __name__ == "__main__":
|
605 |
import uvicorn
|
606 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|