Update main.py
Browse files
main.py
CHANGED
@@ -64,21 +64,11 @@ async def cleanup_rate_limit_stores():
|
|
64 |
|
65 |
await asyncio.sleep(CLEANUP_INTERVAL)
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
if authorization is None or not authorization.startswith('Bearer '):
|
70 |
-
logger.warning(f"Invalid or missing authorization header from IP: {client_ip}")
|
71 |
-
raise HTTPException(status_code=401, detail='Invalid authorization header format')
|
72 |
-
api_key = authorization[7:]
|
73 |
-
if api_key not in API_KEYS:
|
74 |
-
logger.warning(f"Invalid API key attempted: {api_key} from IP: {client_ip}")
|
75 |
-
raise HTTPException(status_code=401, detail='Invalid API key')
|
76 |
-
return api_key
|
77 |
-
|
78 |
-
async def rate_limiter(request: Request, api_key: str = Depends(get_api_key)):
|
79 |
client_ip = request.client.host
|
80 |
current_time = time.time()
|
81 |
-
|
82 |
# Rate limiting per API key
|
83 |
window_start = rate_limit_store[api_key]["timestamp"]
|
84 |
if current_time - window_start > RATE_LIMIT_WINDOW:
|
@@ -99,6 +89,32 @@ async def rate_limiter(request: Request, api_key: str = Depends(get_api_key)):
|
|
99 |
raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address')
|
100 |
ip_rate_limit_store[client_ip]["count"] += 1
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
# Custom exception for model not working
|
103 |
class ModelNotWorkingException(Exception):
|
104 |
def __init__(self, model: str):
|
@@ -399,11 +415,11 @@ app = FastAPI()
|
|
399 |
# Middleware to enhance security
|
400 |
@app.middleware("http")
|
401 |
async def security_middleware(request: Request, call_next):
|
|
|
402 |
# Enforce that POST requests to sensitive endpoints must have a valid Content-Type
|
403 |
-
if request.method == "POST" and request.url.path
|
404 |
content_type = request.headers.get("Content-Type")
|
405 |
if content_type != "application/json":
|
406 |
-
client_ip = request.client.host
|
407 |
logger.warning(f"Invalid Content-Type from IP: {client_ip} for path: {request.url.path}")
|
408 |
return JSONResponse(
|
409 |
status_code=400,
|
@@ -454,7 +470,7 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
|
|
454 |
"usage": None,
|
455 |
}
|
456 |
|
457 |
-
@app.post("/v1/chat/completions", dependencies=[Depends(
|
458 |
async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
|
459 |
client_ip = req.client.host
|
460 |
# Redact user messages only for logging purposes
|
@@ -539,27 +555,29 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
539 |
logger.exception(f"An unexpected error occurred while processing the chat completions request from IP: {client_ip}.")
|
540 |
raise HTTPException(status_code=500, detail=str(e))
|
541 |
|
542 |
-
#
|
543 |
-
|
544 |
-
|
|
|
|
|
|
|
|
|
|
|
545 |
client_ip = req.client.host
|
546 |
-
|
547 |
-
|
|
|
|
|
548 |
|
549 |
-
|
|
|
550 |
async def get_models(req: Request):
|
551 |
client_ip = req.client.host
|
552 |
logger.info(f"Fetching available models from IP: {client_ip}")
|
553 |
return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
|
554 |
|
555 |
-
#
|
556 |
-
@app.get("/v1/
|
557 |
-
async def health_check(req: Request, api_key: str = Depends(get_api_key)):
|
558 |
-
client_ip = req.client.host
|
559 |
-
logger.info(f"Health check requested by API key: {api_key} | IP: {client_ip}")
|
560 |
-
return {"status": "ok"}
|
561 |
-
|
562 |
-
@app.get("/v1/models/{model}/status")
|
563 |
async def model_status(model: str, req: Request):
|
564 |
client_ip = req.client.host
|
565 |
logger.info(f"Model status requested for '{model}' from IP: {client_ip}")
|
@@ -572,6 +590,20 @@ async def model_status(model: str, req: Request):
|
|
572 |
logger.warning(f"Model not found: {model} from IP: {client_ip}")
|
573 |
raise HTTPException(status_code=404, detail="Model not found")
|
574 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
575 |
# Custom exception handler to match OpenAI's error format
|
576 |
@app.exception_handler(HTTPException)
|
577 |
async def http_exception_handler(request: Request, exc: HTTPException):
|
@@ -589,76 +621,6 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
|
589 |
},
|
590 |
)
|
591 |
|
592 |
-
# New endpoint: /v1/tokenizer to calculate token counts
|
593 |
-
class TokenizerRequest(BaseModel):
|
594 |
-
text: str
|
595 |
-
|
596 |
-
@app.post("/v1/tokenizer")
|
597 |
-
async def tokenizer(request: TokenizerRequest, req: Request, api_key: str = Depends(get_api_key)):
|
598 |
-
client_ip = req.client.host
|
599 |
-
text = request.text
|
600 |
-
token_count = len(text.split())
|
601 |
-
logger.info(f"Tokenizer requested by API key: {api_key} | IP: {client_ip} | Text length: {len(text)}")
|
602 |
-
return {"text": text, "tokens": token_count}
|
603 |
-
|
604 |
-
# New endpoint: /v1/completions to support text completions
|
605 |
-
class CompletionRequest(BaseModel):
|
606 |
-
model: str
|
607 |
-
prompt: str
|
608 |
-
max_tokens: Optional[int] = 16
|
609 |
-
temperature: Optional[float] = 1.0
|
610 |
-
top_p: Optional[float] = 1.0
|
611 |
-
n: Optional[int] = 1
|
612 |
-
stream: Optional[bool] = False
|
613 |
-
stop: Optional[Union[str, List[str]]] = None
|
614 |
-
logprobs: Optional[int] = None
|
615 |
-
echo: Optional[bool] = False
|
616 |
-
presence_penalty: Optional[float] = 0.0
|
617 |
-
frequency_penalty: Optional[float] = 0.0
|
618 |
-
best_of: Optional[int] = 1
|
619 |
-
logit_bias: Optional[Dict[str, float]] = None
|
620 |
-
user: Optional[str] = None
|
621 |
-
|
622 |
-
@app.post("/v1/completions", dependencies=[Depends(rate_limiter)])
|
623 |
-
async def completions(request: CompletionRequest, req: Request, api_key: str = Depends(get_api_key)):
|
624 |
-
client_ip = req.client.host
|
625 |
-
logger.info(f"Received completion request from API key: {api_key} | IP: {client_ip} | Model: {request.model}")
|
626 |
-
|
627 |
-
try:
|
628 |
-
# Validate that the requested model is available
|
629 |
-
if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
|
630 |
-
logger.warning(f"Attempt to use unavailable model: {request.model} from IP: {client_ip}")
|
631 |
-
raise HTTPException(status_code=400, detail="Requested model is not available.")
|
632 |
-
|
633 |
-
# Simulate a simple completion by echoing the prompt
|
634 |
-
completion_text = f"{request.prompt} [Completed by {request.model}]"
|
635 |
-
|
636 |
-
return {
|
637 |
-
"id": f"cmpl-{uuid.uuid4()}",
|
638 |
-
"object": "text_completion",
|
639 |
-
"created": int(datetime.now().timestamp()),
|
640 |
-
"model": request.model,
|
641 |
-
"choices": [
|
642 |
-
{
|
643 |
-
"text": completion_text,
|
644 |
-
"index": 0,
|
645 |
-
"logprobs": None,
|
646 |
-
"finish_reason": "length"
|
647 |
-
}
|
648 |
-
],
|
649 |
-
"usage": {
|
650 |
-
"prompt_tokens": len(request.prompt.split()),
|
651 |
-
"completion_tokens": len(completion_text.split()),
|
652 |
-
"total_tokens": len(request.prompt.split()) + len(completion_text.split())
|
653 |
-
}
|
654 |
-
}
|
655 |
-
except HTTPException as he:
|
656 |
-
logger.warning(f"HTTPException: {he.detail} | IP: {client_ip}")
|
657 |
-
raise he
|
658 |
-
except Exception as e:
|
659 |
-
logger.exception(f"An unexpected error occurred while processing the completions request from IP: {client_ip}.")
|
660 |
-
raise HTTPException(status_code=500, detail=str(e))
|
661 |
-
|
662 |
# Add the cleanup task when the app starts
|
663 |
@app.on_event("startup")
|
664 |
async def startup_event():
|
|
|
64 |
|
65 |
await asyncio.sleep(CLEANUP_INTERVAL)
|
66 |
|
67 |
+
# Define rate limiter for endpoints requiring API key
|
68 |
+
async def rate_limiter_with_api_key(request: Request, api_key: str = Depends(lambda request: get_api_key(request, authorization=request.headers.get('Authorization')))):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
client_ip = request.client.host
|
70 |
current_time = time.time()
|
71 |
+
|
72 |
# Rate limiting per API key
|
73 |
window_start = rate_limit_store[api_key]["timestamp"]
|
74 |
if current_time - window_start > RATE_LIMIT_WINDOW:
|
|
|
89 |
raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address')
|
90 |
ip_rate_limit_store[client_ip]["count"] += 1
|
91 |
|
92 |
+
# Define rate limiter for endpoints NOT requiring API key
|
93 |
+
async def rate_limiter_per_ip(request: Request):
|
94 |
+
client_ip = request.client.host
|
95 |
+
current_time = time.time()
|
96 |
+
|
97 |
+
# Rate limiting per IP address
|
98 |
+
window_start_ip = ip_rate_limit_store[client_ip]["timestamp"]
|
99 |
+
if current_time - window_start_ip > RATE_LIMIT_WINDOW:
|
100 |
+
ip_rate_limit_store[client_ip] = {"count": 1, "timestamp": current_time}
|
101 |
+
else:
|
102 |
+
if ip_rate_limit_store[client_ip]["count"] >= RATE_LIMIT:
|
103 |
+
logger.warning(f"Rate limit exceeded for IP address: {client_ip}")
|
104 |
+
raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address')
|
105 |
+
ip_rate_limit_store[client_ip]["count"] += 1
|
106 |
+
|
107 |
+
async def get_api_key(request: Request, authorization: str = Header(None)) -> str:
|
108 |
+
client_ip = request.client.host
|
109 |
+
if authorization is None or not authorization.startswith('Bearer '):
|
110 |
+
logger.warning(f"Invalid or missing authorization header from IP: {client_ip}")
|
111 |
+
raise HTTPException(status_code=401, detail='Invalid authorization header format')
|
112 |
+
api_key = authorization[7:]
|
113 |
+
if api_key not in API_KEYS:
|
114 |
+
logger.warning(f"Invalid API key attempted: {api_key} from IP: {client_ip}")
|
115 |
+
raise HTTPException(status_code=401, detail='Invalid API key')
|
116 |
+
return api_key
|
117 |
+
|
118 |
# Custom exception for model not working
|
119 |
class ModelNotWorkingException(Exception):
|
120 |
def __init__(self, model: str):
|
|
|
415 |
# Middleware to enhance security
|
416 |
@app.middleware("http")
|
417 |
async def security_middleware(request: Request, call_next):
|
418 |
+
client_ip = request.client.host
|
419 |
# Enforce that POST requests to sensitive endpoints must have a valid Content-Type
|
420 |
+
if request.method == "POST" and request.url.path == "/v1/chat/completions":
|
421 |
content_type = request.headers.get("Content-Type")
|
422 |
if content_type != "application/json":
|
|
|
423 |
logger.warning(f"Invalid Content-Type from IP: {client_ip} for path: {request.url.path}")
|
424 |
return JSONResponse(
|
425 |
status_code=400,
|
|
|
470 |
"usage": None,
|
471 |
}
|
472 |
|
473 |
+
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_with_api_key)])
|
474 |
async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
|
475 |
client_ip = req.client.host
|
476 |
# Redact user messages only for logging purposes
|
|
|
555 |
logger.exception(f"An unexpected error occurred while processing the chat completions request from IP: {client_ip}.")
|
556 |
raise HTTPException(status_code=500, detail=str(e))
|
557 |
|
558 |
+
# Re-added endpoints without API key authentication
|
559 |
+
|
560 |
+
# Endpoint: POST /v1/tokenizer
|
561 |
+
class TokenizerRequest(BaseModel):
|
562 |
+
text: str
|
563 |
+
|
564 |
+
@app.post("/v1/tokenizer", dependencies=[Depends(rate_limiter_per_ip)])
|
565 |
+
async def tokenizer(request: TokenizerRequest, req: Request):
|
566 |
client_ip = req.client.host
|
567 |
+
text = request.text
|
568 |
+
token_count = len(text.split())
|
569 |
+
logger.info(f"Tokenizer requested from IP: {client_ip} | Text length: {len(text)}")
|
570 |
+
return {"text": text, "tokens": token_count}
|
571 |
|
572 |
+
# Endpoint: GET /v1/models
|
573 |
+
@app.get("/v1/models", dependencies=[Depends(rate_limiter_per_ip)])
|
574 |
async def get_models(req: Request):
|
575 |
client_ip = req.client.host
|
576 |
logger.info(f"Fetching available models from IP: {client_ip}")
|
577 |
return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
|
578 |
|
579 |
+
# Endpoint: GET /v1/models/{model}/status
|
580 |
+
@app.get("/v1/models/{model}/status", dependencies=[Depends(rate_limiter_per_ip)])
|
|
|
|
|
|
|
|
|
|
|
|
|
581 |
async def model_status(model: str, req: Request):
|
582 |
client_ip = req.client.host
|
583 |
logger.info(f"Model status requested for '{model}' from IP: {client_ip}")
|
|
|
590 |
logger.warning(f"Model not found: {model} from IP: {client_ip}")
|
591 |
raise HTTPException(status_code=404, detail="Model not found")
|
592 |
|
593 |
+
# Endpoint: GET /v1/health
|
594 |
+
@app.get("/v1/health", dependencies=[Depends(rate_limiter_per_ip)])
|
595 |
+
async def health_check(req: Request):
|
596 |
+
client_ip = req.client.host
|
597 |
+
logger.info(f"Health check requested from IP: {client_ip}")
|
598 |
+
return {"status": "ok"}
|
599 |
+
|
600 |
+
# Endpoint: GET /v1/chat/completions (GET method)
|
601 |
+
@app.get("/v1/chat/completions")
|
602 |
+
async def chat_completions_get(req: Request):
|
603 |
+
client_ip = req.client.host
|
604 |
+
logger.info(f"GET request made to /v1/chat/completions from IP: {client_ip}, redirecting to 'about:blank'")
|
605 |
+
return RedirectResponse(url='about:blank')
|
606 |
+
|
607 |
# Custom exception handler to match OpenAI's error format
|
608 |
@app.exception_handler(HTTPException)
|
609 |
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
|
621 |
},
|
622 |
)
|
623 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
624 |
# Add the cleanup task when the app starts
|
625 |
@app.on_event("startup")
|
626 |
async def startup_event():
|