Niansuh commited on
Commit
432ff41
·
verified ·
1 Parent(s): 4a435c4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +29 -14
main.py CHANGED
@@ -10,19 +10,32 @@ import time
10
  from collections import defaultdict
11
  from typing import List, Dict, Any, Optional, AsyncGenerator, Union
12
  from datetime import datetime
 
13
 
14
  from aiohttp import ClientSession, ClientTimeout, ClientError
15
  from fastapi import FastAPI, HTTPException, Request, Depends, Header
16
  from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
17
  from pydantic import BaseModel
18
 
 
 
 
 
 
 
 
 
 
19
  # Configure logging
20
  logging.basicConfig(
21
  level=logging.INFO,
22
  format="%(asctime)s [%(levelname)s] %(name)s [IP: %(client_ip)s]: %(message)s",
23
  handlers=[logging.StreamHandler()]
24
  )
 
 
25
  logger = logging.getLogger(__name__)
 
26
 
27
  # Load environment variables
28
  API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
@@ -377,9 +390,8 @@ app = FastAPI()
377
  @app.middleware("http")
378
  async def security_middleware(request: Request, call_next):
379
  client_ip = request.client.host
380
- # Enrich the logger with client_ip
381
- extra = {"client_ip": client_ip}
382
- logger = logging.LoggerAdapter(logging.getLogger(__name__), extra)
383
 
384
  # Enforce that POST requests to sensitive endpoints must have a valid Content-Type
385
  if request.method == "POST" and request.url.path == "/v1/chat/completions":
@@ -397,11 +409,15 @@ async def security_middleware(request: Request, call_next):
397
  }
398
  },
399
  )
 
400
  # Log the incoming request
401
  logger.info(f"Incoming request: {request.method} {request.url.path}")
 
402
  response = await call_next(request)
 
403
  # Log the response status
404
  logger.info(f"Response status: {response.status_code}")
 
405
  return response
406
 
407
  class Message(BaseModel):
@@ -441,16 +457,15 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
441
 
442
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter)])
443
  async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
444
- client_ip = req.client.host
445
  # Redact user messages only for logging purposes
446
  redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
447
 
448
- logger.info(f"Received chat completions request from API key: {api_key} | Client IP: {client_ip} | Model: {request.model} | Messages: {redacted_messages}")
449
 
450
  try:
451
  # Validate that the requested model is available
452
  if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
453
- logger.warning(f"Attempt to use unavailable model: {request.model} | Client IP: {client_ip}")
454
  raise HTTPException(status_code=400, detail="Requested model is not available. | NiansuhAI")
455
 
456
  # Process the request with actual message content, but don't log it
@@ -492,7 +507,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
492
  else:
493
  response_content += chunk
494
 
495
- logger.info(f"Completed non-streaming response generation for API key: {api_key} | Client IP: {client_ip}")
496
  return {
497
  "id": f"chatcmpl-{uuid.uuid4()}",
498
  "object": "chat.completion",
@@ -515,10 +530,10 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
515
  },
516
  }
517
  except ModelNotWorkingException as e:
518
- logger.warning(f"Model not working: {e} | Client IP: {client_ip}")
519
  raise HTTPException(status_code=503, detail=str(e))
520
  except HTTPException as he:
521
- logger.warning(f"HTTPException: {he.detail} | Client IP: {client_ip}")
522
  raise he
523
  except Exception as e:
524
  logger.exception("An unexpected error occurred while processing the chat completions request.")
@@ -535,27 +550,27 @@ async def chat_completions_get():
535
  @app.get("/v1/models")
536
  async def get_models(req: Request):
537
  client_ip = req.client.host
538
- logger.info(f"Fetching available models | Client IP: {client_ip}")
539
  return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
540
 
541
  # Additional endpoints for better functionality
542
  @app.get("/v1/health")
543
  async def health_check(req: Request):
544
  client_ip = req.client.host
545
- logger.info(f"Health check requested | Client IP: {client_ip}")
546
  return {"status": "ok"}
547
 
548
  @app.get("/v1/models/{model}/status")
549
  async def model_status(model: str, req: Request):
550
  client_ip = req.client.host
551
- logger.info(f"Model status requested for '{model}' | Client IP: {client_ip}")
552
  if model in Blackbox.models:
553
  return {"model": model, "status": "available"}
554
  elif model in Blackbox.model_aliases and Blackbox.model_aliases[model] in Blackbox.models:
555
  actual_model = Blackbox.model_aliases[model]
556
  return {"model": actual_model, "status": "available via alias"}
557
  else:
558
- logger.warning(f"Model not found: {model} | Client IP: {client_ip}")
559
  raise HTTPException(status_code=404, detail="Model not found")
560
 
561
  # Custom exception handler to match OpenAI's error format
@@ -584,7 +599,7 @@ async def tokenizer(request: TokenizerRequest, req: Request):
584
  client_ip = req.client.host
585
  text = request.text
586
  token_count = len(text.split())
587
- logger.info(f"Tokenizer called | Client IP: {client_ip} | Tokens: {token_count}")
588
  return {"text": text, "tokens": token_count}
589
 
590
  if __name__ == "__main__":
 
10
  from collections import defaultdict
11
  from typing import List, Dict, Any, Optional, AsyncGenerator, Union
12
  from datetime import datetime
13
+ import contextvars # Added for context variables
14
 
15
  from aiohttp import ClientSession, ClientTimeout, ClientError
16
  from fastapi import FastAPI, HTTPException, Request, Depends, Header
17
  from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
18
  from pydantic import BaseModel
19
 
20
+ # Define a context variable for client_ip
21
+ client_ip_var = contextvars.ContextVar("client_ip", default="N/A")
22
+
23
+ # Custom logging filter to inject client_ip from context variable
24
+ class ContextFilter(logging.Filter):
25
+ def filter(self, record):
26
+ record.client_ip = client_ip_var.get()
27
+ return True
28
+
29
  # Configure logging
30
  logging.basicConfig(
31
  level=logging.INFO,
32
  format="%(asctime)s [%(levelname)s] %(name)s [IP: %(client_ip)s]: %(message)s",
33
  handlers=[logging.StreamHandler()]
34
  )
35
+
36
+ # Get the root logger and add the ContextFilter
37
  logger = logging.getLogger(__name__)
38
+ logger.addFilter(ContextFilter())
39
 
40
  # Load environment variables
41
  API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
 
390
  @app.middleware("http")
391
  async def security_middleware(request: Request, call_next):
392
  client_ip = request.client.host
393
+ # Set the client_ip in the context variable
394
+ client_ip_var.set(client_ip)
 
395
 
396
  # Enforce that POST requests to sensitive endpoints must have a valid Content-Type
397
  if request.method == "POST" and request.url.path == "/v1/chat/completions":
 
409
  }
410
  },
411
  )
412
+
413
  # Log the incoming request
414
  logger.info(f"Incoming request: {request.method} {request.url.path}")
415
+
416
  response = await call_next(request)
417
+
418
  # Log the response status
419
  logger.info(f"Response status: {response.status_code}")
420
+
421
  return response
422
 
423
  class Message(BaseModel):
 
457
 
458
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter)])
459
  async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
 
460
  # Redact user messages only for logging purposes
461
  redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
462
 
463
+ logger.info(f"Received chat completions request from API key: {api_key} | Model: {request.model} | Messages: {redacted_messages}")
464
 
465
  try:
466
  # Validate that the requested model is available
467
  if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
468
+ logger.warning(f"Attempt to use unavailable model: {request.model}")
469
  raise HTTPException(status_code=400, detail="Requested model is not available. | NiansuhAI")
470
 
471
  # Process the request with actual message content, but don't log it
 
507
  else:
508
  response_content += chunk
509
 
510
+ logger.info(f"Completed non-streaming response generation for API key: {api_key} | Model: {request.model}")
511
  return {
512
  "id": f"chatcmpl-{uuid.uuid4()}",
513
  "object": "chat.completion",
 
530
  },
531
  }
532
  except ModelNotWorkingException as e:
533
+ logger.warning(f"Model not working: {e}")
534
  raise HTTPException(status_code=503, detail=str(e))
535
  except HTTPException as he:
536
+ logger.warning(f"HTTPException: {he.detail}")
537
  raise he
538
  except Exception as e:
539
  logger.exception("An unexpected error occurred while processing the chat completions request.")
 
550
  @app.get("/v1/models")
551
  async def get_models(req: Request):
552
  client_ip = req.client.host
553
+ logger.info(f"Fetching available models")
554
  return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
555
 
556
  # Additional endpoints for better functionality
557
  @app.get("/v1/health")
558
  async def health_check(req: Request):
559
  client_ip = req.client.host
560
+ logger.info(f"Health check requested")
561
  return {"status": "ok"}
562
 
563
  @app.get("/v1/models/{model}/status")
564
  async def model_status(model: str, req: Request):
565
  client_ip = req.client.host
566
+ logger.info(f"Model status requested for '{model}'")
567
  if model in Blackbox.models:
568
  return {"model": model, "status": "available"}
569
  elif model in Blackbox.model_aliases and Blackbox.model_aliases[model] in Blackbox.models:
570
  actual_model = Blackbox.model_aliases[model]
571
  return {"model": actual_model, "status": "available via alias"}
572
  else:
573
+ logger.warning(f"Model not found: {model}")
574
  raise HTTPException(status_code=404, detail="Model not found")
575
 
576
  # Custom exception handler to match OpenAI's error format
 
599
  client_ip = req.client.host
600
  text = request.text
601
  token_count = len(text.split())
602
+ logger.info(f"Tokenizer called | Tokens: {token_count}")
603
  return {"text": text, "tokens": token_count}
604
 
605
  if __name__ == "__main__":