Update main.py
Browse files
main.py
CHANGED
@@ -10,19 +10,32 @@ import time
|
|
10 |
from collections import defaultdict
|
11 |
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
12 |
from datetime import datetime
|
|
|
13 |
|
14 |
from aiohttp import ClientSession, ClientTimeout, ClientError
|
15 |
from fastapi import FastAPI, HTTPException, Request, Depends, Header
|
16 |
from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
|
17 |
from pydantic import BaseModel
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
# Configure logging
|
20 |
logging.basicConfig(
|
21 |
level=logging.INFO,
|
22 |
format="%(asctime)s [%(levelname)s] %(name)s [IP: %(client_ip)s]: %(message)s",
|
23 |
handlers=[logging.StreamHandler()]
|
24 |
)
|
|
|
|
|
25 |
logger = logging.getLogger(__name__)
|
|
|
26 |
|
27 |
# Load environment variables
|
28 |
API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
|
@@ -377,9 +390,8 @@ app = FastAPI()
|
|
377 |
@app.middleware("http")
|
378 |
async def security_middleware(request: Request, call_next):
|
379 |
client_ip = request.client.host
|
380 |
-
#
|
381 |
-
|
382 |
-
logger = logging.LoggerAdapter(logging.getLogger(__name__), extra)
|
383 |
|
384 |
# Enforce that POST requests to sensitive endpoints must have a valid Content-Type
|
385 |
if request.method == "POST" and request.url.path == "/v1/chat/completions":
|
@@ -397,11 +409,15 @@ async def security_middleware(request: Request, call_next):
|
|
397 |
}
|
398 |
},
|
399 |
)
|
|
|
400 |
# Log the incoming request
|
401 |
logger.info(f"Incoming request: {request.method} {request.url.path}")
|
|
|
402 |
response = await call_next(request)
|
|
|
403 |
# Log the response status
|
404 |
logger.info(f"Response status: {response.status_code}")
|
|
|
405 |
return response
|
406 |
|
407 |
class Message(BaseModel):
|
@@ -441,16 +457,15 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
|
|
441 |
|
442 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter)])
|
443 |
async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
|
444 |
-
client_ip = req.client.host
|
445 |
# Redact user messages only for logging purposes
|
446 |
redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
|
447 |
|
448 |
-
logger.info(f"Received chat completions request from API key: {api_key} |
|
449 |
|
450 |
try:
|
451 |
# Validate that the requested model is available
|
452 |
if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
|
453 |
-
logger.warning(f"Attempt to use unavailable model: {request.model}
|
454 |
raise HTTPException(status_code=400, detail="Requested model is not available. | NiansuhAI")
|
455 |
|
456 |
# Process the request with actual message content, but don't log it
|
@@ -492,7 +507,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
492 |
else:
|
493 |
response_content += chunk
|
494 |
|
495 |
-
logger.info(f"Completed non-streaming response generation for API key: {api_key} |
|
496 |
return {
|
497 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
498 |
"object": "chat.completion",
|
@@ -515,10 +530,10 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
515 |
},
|
516 |
}
|
517 |
except ModelNotWorkingException as e:
|
518 |
-
logger.warning(f"Model not working: {e}
|
519 |
raise HTTPException(status_code=503, detail=str(e))
|
520 |
except HTTPException as he:
|
521 |
-
logger.warning(f"HTTPException: {he.detail}
|
522 |
raise he
|
523 |
except Exception as e:
|
524 |
logger.exception("An unexpected error occurred while processing the chat completions request.")
|
@@ -535,27 +550,27 @@ async def chat_completions_get():
|
|
535 |
@app.get("/v1/models")
|
536 |
async def get_models(req: Request):
|
537 |
client_ip = req.client.host
|
538 |
-
logger.info(f"Fetching available models
|
539 |
return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
|
540 |
|
541 |
# Additional endpoints for better functionality
|
542 |
@app.get("/v1/health")
|
543 |
async def health_check(req: Request):
|
544 |
client_ip = req.client.host
|
545 |
-
logger.info(f"Health check requested
|
546 |
return {"status": "ok"}
|
547 |
|
548 |
@app.get("/v1/models/{model}/status")
|
549 |
async def model_status(model: str, req: Request):
|
550 |
client_ip = req.client.host
|
551 |
-
logger.info(f"Model status requested for '{model}'
|
552 |
if model in Blackbox.models:
|
553 |
return {"model": model, "status": "available"}
|
554 |
elif model in Blackbox.model_aliases and Blackbox.model_aliases[model] in Blackbox.models:
|
555 |
actual_model = Blackbox.model_aliases[model]
|
556 |
return {"model": actual_model, "status": "available via alias"}
|
557 |
else:
|
558 |
-
logger.warning(f"Model not found: {model}
|
559 |
raise HTTPException(status_code=404, detail="Model not found")
|
560 |
|
561 |
# Custom exception handler to match OpenAI's error format
|
@@ -584,7 +599,7 @@ async def tokenizer(request: TokenizerRequest, req: Request):
|
|
584 |
client_ip = req.client.host
|
585 |
text = request.text
|
586 |
token_count = len(text.split())
|
587 |
-
logger.info(f"Tokenizer called |
|
588 |
return {"text": text, "tokens": token_count}
|
589 |
|
590 |
if __name__ == "__main__":
|
|
|
10 |
from collections import defaultdict
|
11 |
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
12 |
from datetime import datetime
|
13 |
+
import contextvars # Added for context variables
|
14 |
|
15 |
from aiohttp import ClientSession, ClientTimeout, ClientError
|
16 |
from fastapi import FastAPI, HTTPException, Request, Depends, Header
|
17 |
from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
|
18 |
from pydantic import BaseModel
|
19 |
|
20 |
+
# Define a context variable for client_ip
|
21 |
+
client_ip_var = contextvars.ContextVar("client_ip", default="N/A")
|
22 |
+
|
23 |
+
# Custom logging filter to inject client_ip from context variable
|
24 |
+
class ContextFilter(logging.Filter):
|
25 |
+
def filter(self, record):
|
26 |
+
record.client_ip = client_ip_var.get()
|
27 |
+
return True
|
28 |
+
|
29 |
# Configure logging
|
30 |
logging.basicConfig(
|
31 |
level=logging.INFO,
|
32 |
format="%(asctime)s [%(levelname)s] %(name)s [IP: %(client_ip)s]: %(message)s",
|
33 |
handlers=[logging.StreamHandler()]
|
34 |
)
|
35 |
+
|
36 |
+
# Get the root logger and add the ContextFilter
|
37 |
logger = logging.getLogger(__name__)
|
38 |
+
logger.addFilter(ContextFilter())
|
39 |
|
40 |
# Load environment variables
|
41 |
API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
|
|
|
390 |
@app.middleware("http")
|
391 |
async def security_middleware(request: Request, call_next):
|
392 |
client_ip = request.client.host
|
393 |
+
# Set the client_ip in the context variable
|
394 |
+
client_ip_var.set(client_ip)
|
|
|
395 |
|
396 |
# Enforce that POST requests to sensitive endpoints must have a valid Content-Type
|
397 |
if request.method == "POST" and request.url.path == "/v1/chat/completions":
|
|
|
409 |
}
|
410 |
},
|
411 |
)
|
412 |
+
|
413 |
# Log the incoming request
|
414 |
logger.info(f"Incoming request: {request.method} {request.url.path}")
|
415 |
+
|
416 |
response = await call_next(request)
|
417 |
+
|
418 |
# Log the response status
|
419 |
logger.info(f"Response status: {response.status_code}")
|
420 |
+
|
421 |
return response
|
422 |
|
423 |
class Message(BaseModel):
|
|
|
457 |
|
458 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter)])
|
459 |
async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
|
|
|
460 |
# Redact user messages only for logging purposes
|
461 |
redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
|
462 |
|
463 |
+
logger.info(f"Received chat completions request from API key: {api_key} | Model: {request.model} | Messages: {redacted_messages}")
|
464 |
|
465 |
try:
|
466 |
# Validate that the requested model is available
|
467 |
if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
|
468 |
+
logger.warning(f"Attempt to use unavailable model: {request.model}")
|
469 |
raise HTTPException(status_code=400, detail="Requested model is not available. | NiansuhAI")
|
470 |
|
471 |
# Process the request with actual message content, but don't log it
|
|
|
507 |
else:
|
508 |
response_content += chunk
|
509 |
|
510 |
+
logger.info(f"Completed non-streaming response generation for API key: {api_key} | Model: {request.model}")
|
511 |
return {
|
512 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
513 |
"object": "chat.completion",
|
|
|
530 |
},
|
531 |
}
|
532 |
except ModelNotWorkingException as e:
|
533 |
+
logger.warning(f"Model not working: {e}")
|
534 |
raise HTTPException(status_code=503, detail=str(e))
|
535 |
except HTTPException as he:
|
536 |
+
logger.warning(f"HTTPException: {he.detail}")
|
537 |
raise he
|
538 |
except Exception as e:
|
539 |
logger.exception("An unexpected error occurred while processing the chat completions request.")
|
|
|
550 |
@app.get("/v1/models")
|
551 |
async def get_models(req: Request):
|
552 |
client_ip = req.client.host
|
553 |
+
logger.info(f"Fetching available models")
|
554 |
return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
|
555 |
|
556 |
# Additional endpoints for better functionality
|
557 |
@app.get("/v1/health")
|
558 |
async def health_check(req: Request):
|
559 |
client_ip = req.client.host
|
560 |
+
logger.info(f"Health check requested")
|
561 |
return {"status": "ok"}
|
562 |
|
563 |
@app.get("/v1/models/{model}/status")
|
564 |
async def model_status(model: str, req: Request):
|
565 |
client_ip = req.client.host
|
566 |
+
logger.info(f"Model status requested for '{model}'")
|
567 |
if model in Blackbox.models:
|
568 |
return {"model": model, "status": "available"}
|
569 |
elif model in Blackbox.model_aliases and Blackbox.model_aliases[model] in Blackbox.models:
|
570 |
actual_model = Blackbox.model_aliases[model]
|
571 |
return {"model": actual_model, "status": "available via alias"}
|
572 |
else:
|
573 |
+
logger.warning(f"Model not found: {model}")
|
574 |
raise HTTPException(status_code=404, detail="Model not found")
|
575 |
|
576 |
# Custom exception handler to match OpenAI's error format
|
|
|
599 |
client_ip = req.client.host
|
600 |
text = request.text
|
601 |
token_count = len(text.split())
|
602 |
+
logger.info(f"Tokenizer called | Tokens: {token_count}")
|
603 |
return {"text": text, "tokens": token_count}
|
604 |
|
605 |
if __name__ == "__main__":
|