Niansuh commited on
Commit
a2672b1
·
verified ·
1 Parent(s): f626b99

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +130 -97
main.py CHANGED
@@ -9,16 +9,19 @@ import asyncio
9
  import time
10
  from collections import defaultdict
11
  from typing import List, Dict, Any, Optional, Union, AsyncGenerator
12
- from datetime import datetime # <-- Added import
 
13
 
14
  from aiohttp import ClientSession, ClientResponseError
15
  from fastapi import FastAPI, HTTPException, Request, Depends, Header
16
- from fastapi.responses import JSONResponse
17
  from pydantic import BaseModel
18
 
 
 
19
  # Configure logging
20
  logging.basicConfig(
21
- level=logging.DEBUG, # Changed to DEBUG for detailed logs
22
  format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
23
  handlers=[logging.StreamHandler()]
24
  )
@@ -39,19 +42,43 @@ rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
39
  CLEANUP_INTERVAL = 60 # seconds
40
  RATE_LIMIT_WINDOW = 60 # seconds
41
 
42
- # Define the ImageResponse model (as used in the new Blackbox class)
 
43
  class ImageResponseModel(BaseModel):
44
  images: str # URL of the generated image
45
  alt: str
46
 
47
- # Custom exception for model not working
48
- class ModelNotWorkingException(Exception):
49
- def __init__(self, model: str):
50
- self.model = model
51
- self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed."
52
- super().__init__(self.message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Updated Blackbox Class with New Models and Functionality
55
  class Blackbox:
56
  label = "Blackbox AI"
57
  url = "https://www.blackbox.ai"
@@ -364,6 +391,8 @@ class Blackbox:
364
  logger.exception(f"Unexpected error during /api/chat request: {str(e)}") # Added logging
365
  return f"Unexpected error during /api/chat request: {str(e)}"
366
 
 
 
367
  chat_url = f'{cls.url}/chat/{chat_id}?model={model}'
368
 
369
  try:
@@ -562,6 +591,8 @@ class Blackbox:
562
  except Exception as e:
563
  yield f"Unexpected error during /api/chat request: {str(e)}"
564
 
 
 
565
  chat_url = f'{cls.url}/chat/{chat_id}?model={model}'
566
 
567
  try:
@@ -585,10 +616,12 @@ class Blackbox:
585
  except Exception as e:
586
  yield f"Unexpected error during /chat/{chat_id} request: {str(e)}"
587
 
588
- # FastAPI app setup
 
589
  app = FastAPI()
590
 
591
- # Rate Limiter Cleanup Task
 
592
  async def cleanup_rate_limit_stores():
593
  """
594
  Periodically cleans up stale entries in the rate_limit_store to prevent memory bloat.
@@ -607,7 +640,8 @@ async def startup_event():
607
  asyncio.create_task(cleanup_rate_limit_stores())
608
  logger.info("Started rate limit store cleanup task.")
609
 
610
- # Middleware to enhance security and enforce Content-Type for specific endpoints
 
611
  @app.middleware("http")
612
  async def security_middleware(request: Request, call_next):
613
  client_ip = request.client.host
@@ -630,24 +664,24 @@ async def security_middleware(request: Request, call_next):
630
  response = await call_next(request)
631
  return response
632
 
633
- # Request Models
634
- class Message(BaseModel):
635
- role: str
636
- content: str
637
 
638
- class ChatRequest(BaseModel):
639
- model: str
640
- messages: List[Message]
641
- temperature: Optional[float] = 1.0
642
- top_p: Optional[float] = 1.0
643
- n: Optional[int] = 1
644
- max_tokens: Optional[int] = None
645
- presence_penalty: Optional[float] = 0.0
646
- frequency_penalty: Optional[float] = 0.0
647
- logit_bias: Optional[Dict[str, float]] = None
648
- user: Optional[str] = None
 
 
 
 
649
 
650
- # Rate Limiter Dependency
651
  async def rate_limiter_per_ip(request: Request):
652
  """
653
  Rate limiter that enforces a limit based on the client's IP address.
@@ -664,22 +698,8 @@ async def rate_limiter_per_ip(request: Request):
664
  raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address | NiansuhAI')
665
  rate_limit_store[client_ip]["count"] += 1
666
 
667
- # API Key Dependency
668
- async def get_api_key(request: Request, authorization: str = Header(None)) -> str:
669
- """
670
- Dependency to extract and validate the API key from the Authorization header.
671
- """
672
- client_ip = request.client.host
673
- if authorization is None or not authorization.startswith('Bearer '):
674
- logger.warning(f"Invalid or missing authorization header from IP: {client_ip}")
675
- raise HTTPException(status_code=401, detail='Invalid authorization header format')
676
- api_key = authorization[7:]
677
- if api_key not in API_KEYS:
678
- logger.warning(f"Invalid API key attempted: {api_key} from IP: {client_ip}")
679
- raise HTTPException(status_code=401, detail='Invalid API key')
680
- return api_key
681
 
682
- # Endpoint: POST /v1/chat/completions
683
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
684
  async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
685
  client_ip = req.client.host
@@ -697,61 +717,72 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
697
  # Check if the model is an image generation model
698
  is_image_model = request.model in Blackbox.image_models
699
 
700
- # Generate response
701
- response_content = await Blackbox.generate_response(
702
- model=request.model,
703
- messages=[{"role": msg.role, "content": msg.content} for msg in request.messages],
704
- temperature=request.temperature,
705
- max_tokens=request.max_tokens
706
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
 
708
- # If the model is for image generation, handle accordingly
709
- if is_image_model and isinstance(response_content, ImageResponseModel):
710
- logger.info(f"Completed image generation for API key: {api_key} | IP: {client_ip}")
711
  return {
712
  "id": f"chatcmpl-{uuid.uuid4()}",
713
  "object": "chat.completion",
714
- "created": int(datetime.now().timestamp()), # Fixed: datetime is now imported
715
  "model": request.model,
716
  "choices": [
717
  {
718
- "index": 0,
719
  "message": {
720
  "role": "assistant",
721
- "content": response_content.images # Return the image URL
722
  },
723
- "finish_reason": "stop"
 
724
  }
725
  ],
726
  "usage": {
727
  "prompt_tokens": sum(len(msg.content.split()) for msg in request.messages),
728
- "completion_tokens": len(response_content.images.split()),
729
- "total_tokens": sum(len(msg.content.split()) for msg in request.messages) + len(response_content.images.split())
730
  },
731
  }
732
-
733
- logger.info(f"Completed response generation for API key: {api_key} | IP: {client_ip}")
734
- return {
735
- "id": f"chatcmpl-{uuid.uuid4()}",
736
- "object": "chat.completion",
737
- "created": int(datetime.now().timestamp()), # Fixed: datetime is now imported
738
- "model": request.model,
739
- "choices": [
740
- {
741
- "index": 0,
742
- "message": {
743
- "role": "assistant",
744
- "content": response_content
745
- },
746
- "finish_reason": "stop"
747
- }
748
- ],
749
- "usage": {
750
- "prompt_tokens": sum(len(msg.content.split()) for msg in request.messages),
751
- "completion_tokens": len(response_content.split()),
752
- "total_tokens": sum(len(msg.content.split()) for msg in request.messages) + len(response_content.split())
753
- },
754
- }
755
  except ModelNotWorkingException as e:
756
  logger.warning(f"Model not working: {e} | IP: {client_ip}")
757
  raise HTTPException(status_code=503, detail=str(e))
@@ -762,12 +793,13 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
762
  logger.exception(f"An unexpected error occurred while processing the chat completions request from IP: {client_ip}.")
763
  raise HTTPException(status_code=500, detail=str(e))
764
 
765
- # Optional: Endpoint for Streaming Responses (Requires Client Support)
766
- # If you wish to support streaming, you can implement an endpoint that leverages the asynchronous generator.
767
- # This requires clients to handle streaming responses appropriately.
768
 
769
  @app.post("/v1/chat/completions/stream", dependencies=[Depends(rate_limiter_per_ip)])
770
  async def chat_completions_stream(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
 
 
 
771
  client_ip = req.client.host
772
  redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
773
 
@@ -786,14 +818,13 @@ async def chat_completions_stream(request: ChatRequest, req: Request, api_key: s
786
  async_gen = Blackbox.create_async_generator(
787
  model=request.model,
788
  messages=[{"role": msg.role, "content": msg.content} for msg in request.messages],
789
- temperature=request.temperature,
790
- max_tokens=request.max_tokens
791
  )
792
 
793
  async def stream_response() -> AsyncGenerator[bytes, None]:
794
  async for chunk in async_gen:
795
  if isinstance(chunk, ImageResponseModel):
796
- # For image responses, you might want to send the URL directly
797
  yield json.dumps({
798
  "role": "assistant",
799
  "content": chunk.images
@@ -805,10 +836,9 @@ async def chat_completions_stream(request: ChatRequest, req: Request, api_key: s
805
  }).encode('utf-8') + b'\n'
806
 
807
  logger.info(f"Streaming response started for API key: {api_key} | IP: {client_ip}")
808
- return JSONResponse(
809
- content=None, # The actual streaming is handled by the generator
810
- media_type='text/event-stream',
811
- background=stream_response()
812
  )
813
  except ModelNotWorkingException as e:
814
  logger.warning(f"Model not working: {e} | IP: {client_ip}")
@@ -820,21 +850,22 @@ async def chat_completions_stream(request: ChatRequest, req: Request, api_key: s
820
  logger.exception(f"An unexpected error occurred while processing the streaming chat completions request from IP: {client_ip}.")
821
  raise HTTPException(status_code=500, detail=str(e))
822
 
823
- # Endpoint: GET /v1/models
 
824
  @app.get("/v1/models", dependencies=[Depends(rate_limiter_per_ip)])
825
  async def get_models(req: Request):
826
  client_ip = req.client.host
827
  logger.info(f"Fetching available models from IP: {client_ip}")
828
  return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
829
 
830
- # Endpoint: GET /v1/health
831
  @app.get("/v1/health", dependencies=[Depends(rate_limiter_per_ip)])
832
  async def health_check(req: Request):
833
  client_ip = req.client.host
834
  logger.info(f"Health check requested from IP: {client_ip}")
835
  return {"status": "ok"}
836
 
837
- # Custom exception handler to match OpenAI's error format
 
838
  @app.exception_handler(HTTPException)
839
  async def http_exception_handler(request: Request, exc: HTTPException):
840
  client_ip = request.client.host
@@ -851,6 +882,8 @@ async def http_exception_handler(request: Request, exc: HTTPException):
851
  },
852
  )
853
 
 
 
854
  if __name__ == "__main__":
855
  import uvicorn
856
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
9
  import time
10
  from collections import defaultdict
11
  from typing import List, Dict, Any, Optional, Union, AsyncGenerator
12
+
13
+ from datetime import datetime # Required for timestamping
14
 
15
  from aiohttp import ClientSession, ClientResponseError
16
  from fastapi import FastAPI, HTTPException, Request, Depends, Header
17
+ from fastapi.responses import JSONResponse, StreamingResponse
18
  from pydantic import BaseModel
19
 
20
+ # ----------------------------- Configuration -----------------------------
21
+
22
  # Configure logging
23
  logging.basicConfig(
24
+ level=logging.INFO, # Set to DEBUG for more detailed logs
25
  format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
26
  handlers=[logging.StreamHandler()]
27
  )
 
42
  CLEANUP_INTERVAL = 60 # seconds
43
  RATE_LIMIT_WINDOW = 60 # seconds
44
 
45
+ # ----------------------------- Pydantic Models -----------------------------
46
+
47
  class ImageResponseModel(BaseModel):
48
  images: str # URL of the generated image
49
  alt: str
50
 
51
+ class Message(BaseModel):
52
+ role: str
53
+ content: str
54
+
55
+ class ChatRequest(BaseModel):
56
+ model: str
57
+ messages: List[Message]
58
+ temperature: Optional[float] = 1.0
59
+ top_p: Optional[float] = 1.0
60
+ n: Optional[int] = 1
61
+ max_tokens: Optional[int] = None
62
+ presence_penalty: Optional[float] = 0.0
63
+ frequency_penalty: Optional[float] = 0.0
64
+ logit_bias: Optional[Dict[str, float]] = None
65
+ user: Optional[str] = None
66
+ stream: Optional[bool] = False
67
+ webSearchMode: Optional[bool] = False # Added based on old code
68
+
69
+ # ----------------------------- Helper Functions -----------------------------
70
+
71
+ def create_response(content: str, model: str) -> Dict[str, Any]:
72
+ """
73
+ Formats the response chunk.
74
+ """
75
+ return {
76
+ "model": model,
77
+ "content": content
78
+ }
79
+
80
+ # ----------------------------- Blackbox Class -----------------------------
81
 
 
82
  class Blackbox:
83
  label = "Blackbox AI"
84
  url = "https://www.blackbox.ai"
 
391
  logger.exception(f"Unexpected error during /api/chat request: {str(e)}") # Added logging
392
  return f"Unexpected error during /api/chat request: {str(e)}"
393
 
394
+ # The following block is unreachable due to the return statements above
395
+ # It can be removed unless it's intended for future use
396
  chat_url = f'{cls.url}/chat/{chat_id}?model={model}'
397
 
398
  try:
 
591
  except Exception as e:
592
  yield f"Unexpected error during /api/chat request: {str(e)}"
593
 
594
+ # The following block is unreachable due to the yield statements above
595
+ # It can be removed unless it's intended for future use
596
  chat_url = f'{cls.url}/chat/{chat_id}?model={model}'
597
 
598
  try:
 
616
  except Exception as e:
617
  yield f"Unexpected error during /chat/{chat_id} request: {str(e)}"
618
 
619
+ # ----------------------------- FastAPI App Setup -----------------------------
620
+
621
  app = FastAPI()
622
 
623
+ # ----------------------------- Rate Limiter Cleanup Task -----------------------------
624
+
625
  async def cleanup_rate_limit_stores():
626
  """
627
  Periodically cleans up stale entries in the rate_limit_store to prevent memory bloat.
 
640
  asyncio.create_task(cleanup_rate_limit_stores())
641
  logger.info("Started rate limit store cleanup task.")
642
 
643
+ # ----------------------------- Security Middleware -----------------------------
644
+
645
  @app.middleware("http")
646
  async def security_middleware(request: Request, call_next):
647
  client_ip = request.client.host
 
664
  response = await call_next(request)
665
  return response
666
 
667
+ # ----------------------------- API Key Dependency -----------------------------
 
 
 
668
 
669
+ async def get_api_key(request: Request, authorization: str = Header(None)) -> str:
670
+ """
671
+ Dependency to extract and validate the API key from the Authorization header.
672
+ """
673
+ client_ip = request.client.host
674
+ if authorization is None or not authorization.startswith('Bearer '):
675
+ logger.warning(f"Invalid or missing authorization header from IP: {client_ip}")
676
+ raise HTTPException(status_code=401, detail='Invalid authorization header format')
677
+ api_key = authorization[7:]
678
+ if api_key not in API_KEYS:
679
+ logger.warning(f"Invalid API key attempted: {api_key} from IP: {client_ip}")
680
+ raise HTTPException(status_code=401, detail='Invalid API key')
681
+ return api_key
682
+
683
+ # ----------------------------- Rate Limiter Dependency -----------------------------
684
 
 
685
  async def rate_limiter_per_ip(request: Request):
686
  """
687
  Rate limiter that enforces a limit based on the client's IP address.
 
698
  raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address | NiansuhAI')
699
  rate_limit_store[client_ip]["count"] += 1
700
 
701
+ # ----------------------------- Chat Completions Endpoint -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
702
 
 
703
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
704
  async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
705
  client_ip = req.client.host
 
717
  # Check if the model is an image generation model
718
  is_image_model = request.model in Blackbox.image_models
719
 
720
+ if request.stream:
721
+ # Streaming response
722
+ async_generator = Blackbox.create_async_generator(
723
+ model=request.model,
724
+ messages=[{"role": msg.role, "content": msg.content} for msg in request.messages],
725
+ websearch=request.webSearchMode
726
+ )
727
+
728
+ async def generate():
729
+ try:
730
+ async for chunk in async_generator:
731
+ if isinstance(chunk, ImageResponseModel):
732
+ image_markdown = f"![image]({chunk.images})"
733
+ response_chunk = create_response(image_markdown, request.model)
734
+ else:
735
+ response_chunk = create_response(chunk, request.model)
736
+
737
+ yield f"data: {json.dumps(response_chunk)}\n\n"
738
+
739
+ yield "data: [DONE]\n\n"
740
+ except HTTPException as he:
741
+ error_response = {"error": he.detail}
742
+ yield f"data: {json.dumps(error_response)}\n\n"
743
+ except Exception as e:
744
+ logger.exception(f"Error during streaming response generation from IP: {client_ip}.")
745
+ error_response = {"error": str(e)}
746
+ yield f"data: {json.dumps(error_response)}\n\n"
747
+
748
+ return StreamingResponse(generate(), media_type="text/event-stream")
749
+ else:
750
+ # Non-streaming response
751
+ async_generator = Blackbox.create_async_generator(
752
+ model=request.model,
753
+ messages=[{"role": msg.role, "content": msg.content} for msg in request.messages],
754
+ websearch=request.webSearchMode
755
+ )
756
+
757
+ response_content = ""
758
+ async for chunk in async_generator:
759
+ if isinstance(chunk, ImageResponseModel):
760
+ response_content += f"![image]({chunk.images})\n"
761
+ else:
762
+ response_content += chunk
763
 
764
+ logger.info(f"Completed non-streaming response generation for API key: {api_key} | IP: {client_ip}")
 
 
765
  return {
766
  "id": f"chatcmpl-{uuid.uuid4()}",
767
  "object": "chat.completion",
768
+ "created": int(datetime.now().timestamp()),
769
  "model": request.model,
770
  "choices": [
771
  {
 
772
  "message": {
773
  "role": "assistant",
774
+ "content": response_content
775
  },
776
+ "finish_reason": "stop",
777
+ "index": 0
778
  }
779
  ],
780
  "usage": {
781
  "prompt_tokens": sum(len(msg.content.split()) for msg in request.messages),
782
+ "completion_tokens": len(response_content.split()),
783
+ "total_tokens": sum(len(msg.content.split()) for msg in request.messages) + len(response_content.split())
784
  },
785
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786
  except ModelNotWorkingException as e:
787
  logger.warning(f"Model not working: {e} | IP: {client_ip}")
788
  raise HTTPException(status_code=503, detail=str(e))
 
793
  logger.exception(f"An unexpected error occurred while processing the chat completions request from IP: {client_ip}.")
794
  raise HTTPException(status_code=500, detail=str(e))
795
 
796
+ # ----------------------------- Streaming Endpoint (Optional) -----------------------------
 
 
797
 
798
  @app.post("/v1/chat/completions/stream", dependencies=[Depends(rate_limiter_per_ip)])
799
  async def chat_completions_stream(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
800
+ """
801
+ Optional endpoint for streaming responses. Can be removed if not needed.
802
+ """
803
  client_ip = req.client.host
804
  redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
805
 
 
818
  async_gen = Blackbox.create_async_generator(
819
  model=request.model,
820
  messages=[{"role": msg.role, "content": msg.content} for msg in request.messages],
821
+ websearch=request.webSearchMode
 
822
  )
823
 
824
  async def stream_response() -> AsyncGenerator[bytes, None]:
825
  async for chunk in async_gen:
826
  if isinstance(chunk, ImageResponseModel):
827
+ # For image responses, send the URL directly
828
  yield json.dumps({
829
  "role": "assistant",
830
  "content": chunk.images
 
836
  }).encode('utf-8') + b'\n'
837
 
838
  logger.info(f"Streaming response started for API key: {api_key} | IP: {client_ip}")
839
+ return StreamingResponse(
840
+ stream_response(),
841
+ media_type='text/event-stream'
 
842
  )
843
  except ModelNotWorkingException as e:
844
  logger.warning(f"Model not working: {e} | IP: {client_ip}")
 
850
  logger.exception(f"An unexpected error occurred while processing the streaming chat completions request from IP: {client_ip}.")
851
  raise HTTPException(status_code=500, detail=str(e))
852
 
853
+ # ----------------------------- Additional Endpoints -----------------------------
854
+
855
  @app.get("/v1/models", dependencies=[Depends(rate_limiter_per_ip)])
856
  async def get_models(req: Request):
857
  client_ip = req.client.host
858
  logger.info(f"Fetching available models from IP: {client_ip}")
859
  return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
860
 
 
861
  @app.get("/v1/health", dependencies=[Depends(rate_limiter_per_ip)])
862
  async def health_check(req: Request):
863
  client_ip = req.client.host
864
  logger.info(f"Health check requested from IP: {client_ip}")
865
  return {"status": "ok"}
866
 
867
+ # ----------------------------- Exception Handlers -----------------------------
868
+
869
  @app.exception_handler(HTTPException)
870
  async def http_exception_handler(request: Request, exc: HTTPException):
871
  client_ip = request.client.host
 
882
  },
883
  )
884
 
885
+ # ----------------------------- Main Entry Point -----------------------------
886
+
887
  if __name__ == "__main__":
888
  import uvicorn
889
  uvicorn.run(app, host="0.0.0.0", port=8000)