Niansuh commited on
Commit
e97905c
·
verified ·
1 Parent(s): 42492be

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +97 -88
main.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import re
2
  import random
3
  import string
@@ -5,13 +6,14 @@ import uuid
5
  import json
6
  import logging
7
  import asyncio
8
- import base64
9
- from aiohttp import ClientSession, ClientTimeout, ClientError
10
- from fastapi import FastAPI, HTTPException, Request
11
- from pydantic import BaseModel
12
  from typing import List, Dict, Any, Optional, AsyncGenerator
13
- from datetime import datetime
 
 
14
  from fastapi.responses import StreamingResponse
 
15
 
16
  # Configure logging
17
  logging.basicConfig(
@@ -23,6 +25,47 @@ logging.basicConfig(
23
  )
24
  logger = logging.getLogger(__name__)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Custom exception for model not working
27
  class ModelNotWorkingException(Exception):
28
  def __init__(self, model: str):
@@ -30,23 +73,14 @@ class ModelNotWorkingException(Exception):
30
  self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed."
31
  super().__init__(self.message)
32
 
33
- # Implementation for ImageResponse and to_data_uri
34
  class ImageResponse:
35
- def __init__(self, data_uri: str, alt: str):
36
- self.data_uri = data_uri
37
  self.alt = alt
38
 
39
- def to_data_uri(image: bytes, mime_type: str = "image/png") -> str:
40
- encoded = base64.b64encode(image).decode('utf-8')
41
- return f"data:{mime_type};base64,{encoded}"
42
-
43
- def decode_base64_image(data_uri: str) -> bytes:
44
- try:
45
- header, encoded = data_uri.split(",", 1)
46
- return base64.b64decode(encoded)
47
- except Exception as e:
48
- logger.error(f"Error decoding base64 image: {e}")
49
- raise e
50
 
51
  class Blackbox:
52
  url = "https://www.blackbox.ai"
@@ -158,7 +192,7 @@ class Blackbox:
158
  if model in cls.models:
159
  return model
160
  elif model in cls.userSelectedModel:
161
- return cls.userSelectedModel[model]
162
  elif model in cls.model_aliases:
163
  return cls.model_aliases[model]
164
  else:
@@ -168,9 +202,9 @@ class Blackbox:
168
  async def create_async_generator(
169
  cls,
170
  model: str,
171
- messages: List[Dict[str, Any]],
172
  proxy: Optional[str] = None,
173
- image: Optional[str] = None, # Expecting a base64 string
174
  image_name: Optional[str] = None,
175
  webSearchMode: bool = False,
176
  **kwargs
@@ -200,39 +234,24 @@ class Blackbox:
200
  "user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36",
201
  }
202
 
203
- if model in cls.model_prefixes and messages:
204
  prefix = cls.model_prefixes[model]
205
  if not messages[0]['content'].startswith(prefix):
206
  logger.debug(f"Adding prefix '{prefix}' to the first message.")
207
  messages[0]['content'] = f"{prefix} {messages[0]['content']}"
208
 
209
  random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
210
- user_message = {
211
- "id": random_id,
212
- "role": 'user',
213
- "content": 'Hi' # This should be dynamically set based on input
214
- }
215
  if image is not None:
216
- try:
217
- image_bytes = decode_base64_image(image)
218
- data_uri = to_data_uri(image_bytes)
219
- user_message['data'] = {
220
- 'fileText': '',
221
- 'imageBase64': data_uri,
222
- 'title': image_name or "Uploaded Image"
223
- }
224
- user_message['content'] = 'FILE:BB\n$#$\n\n$#$\n' + user_message['content']
225
- logger.debug("Image data added to the message.")
226
- except Exception as e:
227
- logger.error(f"Failed to decode base64 image: {e}")
228
- raise HTTPException(status_code=400, detail="Invalid image data provided.")
229
 
230
- # Update the last message with user_message
231
- if messages:
232
- messages[-1] = user_message
233
- else:
234
- messages.append(user_message)
235
-
236
  data = {
237
  "messages": messages,
238
  "id": random_id,
@@ -280,15 +299,7 @@ class Blackbox:
280
  if url_match:
281
  image_url = url_match.group(0)
282
  logger.info(f"Image URL found: {image_url}")
283
-
284
- # Fetch the image data
285
- async with session.get(image_url) as img_response:
286
- img_response.raise_for_status()
287
- image_bytes = await img_response.read()
288
- data_uri = to_data_uri(image_bytes)
289
- logger.info("Image converted to base64 data URI.")
290
-
291
- yield ImageResponse(data_uri, alt=messages[-1]['content'])
292
  else:
293
  logger.error("Image URL not found in the response.")
294
  raise Exception("Image URL not found in the response")
@@ -349,7 +360,6 @@ class ChatRequest(BaseModel):
349
  messages: List[Message]
350
  stream: Optional[bool] = False
351
  webSearchMode: Optional[bool] = False
352
- image: Optional[str] = None # Add image field for base64 data
353
 
354
  def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
355
  return {
@@ -367,32 +377,25 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
367
  "usage": None,
368
  }
369
 
370
- @app.post("/niansuhai/v1/chat/completions")
371
- async def chat_completions(request: ChatRequest, req: Request):
372
- logger.info(f"Received chat completions request: model='{request.model}' messages={request.messages} stream={request.stream} webSearchMode={request.webSearchMode} image={request.image}")
 
 
 
 
373
  try:
374
- # Validate that all messages have string content
375
- for idx, msg in enumerate(request.messages):
376
- if not isinstance(msg.content, str):
377
- logger.error(f"Message at index {idx} has invalid content type: {type(msg.content)}")
378
- raise HTTPException(
379
- status_code=422,
380
- detail=[{
381
- "type": "string_type",
382
- "loc": ["body", "messages", idx, "content"],
383
- "msg": "Input should be a valid string",
384
- "input": msg.content
385
- }]
386
- )
387
-
388
- # Convert Pydantic messages to dicts
389
  messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
390
 
391
  async_generator = Blackbox.create_async_generator(
392
  model=request.model,
393
  messages=messages,
394
- proxy=None, # Pass proxy if needed
395
- image=request.image, # Pass the base64 image
396
  image_name=None,
397
  webSearchMode=request.webSearchMode
398
  )
@@ -402,8 +405,7 @@ async def chat_completions(request: ChatRequest, req: Request):
402
  try:
403
  async for chunk in async_generator:
404
  if isinstance(chunk, ImageResponse):
405
- # Use the base64 data URI directly
406
- image_markdown = f"![{chunk.alt}]({chunk.data_uri})"
407
  response_chunk = create_response(image_markdown, request.model)
408
  else:
409
  response_chunk = create_response(chunk, request.model)
@@ -426,11 +428,11 @@ async def chat_completions(request: ChatRequest, req: Request):
426
  response_content = ""
427
  async for chunk in async_generator:
428
  if isinstance(chunk, ImageResponse):
429
- response_content += f"![{chunk.alt}]({chunk.data_uri})\n"
430
  else:
431
  response_content += chunk
432
 
433
- logger.info("Completed non-streaming response generation.")
434
  return {
435
  "id": f"chatcmpl-{uuid.uuid4()}",
436
  "object": "chat.completion",
@@ -462,26 +464,33 @@ async def chat_completions(request: ChatRequest, req: Request):
462
  logger.exception("An unexpected error occurred while processing the chat completions request.")
463
  raise HTTPException(status_code=500, detail=str(e))
464
 
465
- @app.get("/niansuhai/v1/models")
466
- async def get_models():
467
- logger.info("Fetching available models.")
 
 
 
 
468
  return {"data": [{"id": model} for model in Blackbox.models]}
469
 
470
  # Additional endpoints for better functionality
471
- @app.get("/niansuhai/v1/health")
472
- async def health_check():
473
  """Health check endpoint to verify the service is running."""
 
474
  return {"status": "ok"}
475
 
476
- @app.get("/niansuhai/v1/models/{model}/status")
477
- async def model_status(model: str):
478
  """Check if a specific model is available."""
 
479
  if model in Blackbox.models:
480
  return {"model": model, "status": "available"}
481
  elif model in Blackbox.model_aliases:
482
  actual_model = Blackbox.model_aliases[model]
483
  return {"model": actual_model, "status": "available via alias"}
484
  else:
 
485
  raise HTTPException(status_code=404, detail="Model not found")
486
 
487
  if __name__ == "__main__":
 
1
+ import os
2
  import re
3
  import random
4
  import string
 
6
  import json
7
  import logging
8
  import asyncio
9
+ import time
10
+ from collections import defaultdict
 
 
11
  from typing import List, Dict, Any, Optional, AsyncGenerator
12
+
13
+ from aiohttp import ClientSession, ClientTimeout, ClientError
14
+ from fastapi import FastAPI, HTTPException, Request, Depends, Header
15
  from fastapi.responses import StreamingResponse
16
+ from pydantic import BaseModel
17
 
18
  # Configure logging
19
  logging.basicConfig(
 
25
  )
26
  logger = logging.getLogger(__name__)
27
 
28
+ # Load environment variables
29
+ API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
30
+ RATE_LIMIT = int(os.getenv('RATE_LIMIT', '60')) # Requests per minute
31
+
32
+ if not API_KEYS or API_KEYS == ['']:
33
+ logger.error("No API keys found. Please set the API_KEYS environment variable.")
34
+ raise Exception("API_KEYS environment variable not set.")
35
+
36
+ # Simple in-memory rate limiter
37
+ rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
38
+
39
+ async def get_api_key(authorization: str = Header(...)) -> str:
40
+ """
41
+ Dependency to extract and validate the API key from the Authorization header.
42
+ Expects the header in the format: Authorization: Bearer <API_KEY>
43
+ """
44
+ if not authorization.startswith('Bearer '):
45
+ logger.warning("Invalid authorization header format.")
46
+ raise HTTPException(status_code=401, detail='Invalid authorization header format')
47
+ api_key = authorization[7:]
48
+ if api_key not in API_KEYS:
49
+ logger.warning(f"Invalid API key attempted: {api_key}")
50
+ raise HTTPException(status_code=401, detail='Invalid API key')
51
+ return api_key
52
+
53
+ async def rate_limiter(api_key: str = Depends(get_api_key)):
54
+ """
55
+ Dependency to enforce rate limiting per API key.
56
+ Raises HTTP 429 if the rate limit is exceeded.
57
+ """
58
+ current_time = time.time()
59
+ window_start = rate_limit_store[api_key]["timestamp"]
60
+ if current_time - window_start > 60:
61
+ # Reset the count and timestamp after the time window
62
+ rate_limit_store[api_key] = {"count": 1, "timestamp": current_time}
63
+ else:
64
+ if rate_limit_store[api_key]["count"] >= RATE_LIMIT:
65
+ logger.warning(f"Rate limit exceeded for API key: {api_key}")
66
+ raise HTTPException(status_code=429, detail='Rate limit exceeded')
67
+ rate_limit_store[api_key]["count"] += 1
68
+
69
  # Custom exception for model not working
70
  class ModelNotWorkingException(Exception):
71
  def __init__(self, model: str):
 
73
  self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed."
74
  super().__init__(self.message)
75
 
76
+ # Mock implementations for ImageResponse and to_data_uri
77
  class ImageResponse:
78
+ def __init__(self, url: str, alt: str):
79
+ self.url = url
80
  self.alt = alt
81
 
82
+ def to_data_uri(image: Any) -> str:
83
+ return "data:image/png;base64,..." # Replace with actual base64 data
 
 
 
 
 
 
 
 
 
84
 
85
  class Blackbox:
86
  url = "https://www.blackbox.ai"
 
192
  if model in cls.models:
193
  return model
194
  elif model in cls.userSelectedModel:
195
+ return model
196
  elif model in cls.model_aliases:
197
  return cls.model_aliases[model]
198
  else:
 
202
  async def create_async_generator(
203
  cls,
204
  model: str,
205
+ messages: List[Dict[str, str]],
206
  proxy: Optional[str] = None,
207
+ image: Any = None,
208
  image_name: Optional[str] = None,
209
  webSearchMode: bool = False,
210
  **kwargs
 
234
  "user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36",
235
  }
236
 
237
+ if model in cls.model_prefixes:
238
  prefix = cls.model_prefixes[model]
239
  if not messages[0]['content'].startswith(prefix):
240
  logger.debug(f"Adding prefix '{prefix}' to the first message.")
241
  messages[0]['content'] = f"{prefix} {messages[0]['content']}"
242
 
243
  random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
244
+ messages[-1]['id'] = random_id
245
+ messages[-1]['role'] = 'user'
 
 
 
246
  if image is not None:
247
+ messages[-1]['data'] = {
248
+ 'fileText': '',
249
+ 'imageBase64': to_data_uri(image),
250
+ 'title': image_name
251
+ }
252
+ messages[-1]['content'] = 'FILE:BB\n$#$\n\n$#$\n' + messages[-1]['content']
253
+ logger.debug("Image data added to the message.")
 
 
 
 
 
 
254
 
 
 
 
 
 
 
255
  data = {
256
  "messages": messages,
257
  "id": random_id,
 
299
  if url_match:
300
  image_url = url_match.group(0)
301
  logger.info(f"Image URL found: {image_url}")
302
+ yield ImageResponse(image_url, alt=messages[-1]['content'])
 
 
 
 
 
 
 
 
303
  else:
304
  logger.error("Image URL not found in the response.")
305
  raise Exception("Image URL not found in the response")
 
360
  messages: List[Message]
361
  stream: Optional[bool] = False
362
  webSearchMode: Optional[bool] = False
 
363
 
364
  def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
365
  return {
 
377
  "usage": None,
378
  }
379
 
380
+ @app.post("/niansuhai/v1/chat/completions", dependencies=[Depends(rate_limiter)])
381
+ async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
382
+ """
383
+ Endpoint to handle chat completions.
384
+ Protected by API key and rate limiter.
385
+ """
386
+ logger.info(f"Received chat completions request from API key: {api_key} | Request: {request}")
387
  try:
388
+ # Validate that the requested model is available
389
+ if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
390
+ logger.warning(f"Attempt to use unavailable model: {request.model}")
391
+ raise HTTPException(status_code=400, detail="Requested model is not available.")
392
+
 
 
 
 
 
 
 
 
 
 
393
  messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
394
 
395
  async_generator = Blackbox.create_async_generator(
396
  model=request.model,
397
  messages=messages,
398
+ image=None,
 
399
  image_name=None,
400
  webSearchMode=request.webSearchMode
401
  )
 
405
  try:
406
  async for chunk in async_generator:
407
  if isinstance(chunk, ImageResponse):
408
+ image_markdown = f"![image]({chunk.url})"
 
409
  response_chunk = create_response(image_markdown, request.model)
410
  else:
411
  response_chunk = create_response(chunk, request.model)
 
428
  response_content = ""
429
  async for chunk in async_generator:
430
  if isinstance(chunk, ImageResponse):
431
+ response_content += f"![image]({chunk.url})\n"
432
  else:
433
  response_content += chunk
434
 
435
+ logger.info(f"Completed non-streaming response generation for API key: {api_key}")
436
  return {
437
  "id": f"chatcmpl-{uuid.uuid4()}",
438
  "object": "chat.completion",
 
464
  logger.exception("An unexpected error occurred while processing the chat completions request.")
465
  raise HTTPException(status_code=500, detail=str(e))
466
 
467
+ @app.get("/niansuhai/v1/models", dependencies=[Depends(rate_limiter)])
468
+ async def get_models(api_key: str = Depends(get_api_key)):
469
+ """
470
+ Endpoint to fetch available models.
471
+ Protected by API key and rate limiter.
472
+ """
473
+ logger.info(f"Fetching available models for API key: {api_key}")
474
  return {"data": [{"id": model} for model in Blackbox.models]}
475
 
476
  # Additional endpoints for better functionality
477
+ @app.get("/niansuhai/v1/health", dependencies=[Depends(rate_limiter)])
478
+ async def health_check(api_key: str = Depends(get_api_key)):
479
  """Health check endpoint to verify the service is running."""
480
+ logger.info(f"Health check requested by API key: {api_key}")
481
  return {"status": "ok"}
482
 
483
+ @app.get("/niansuhai/v1/models/{model}/status", dependencies=[Depends(rate_limiter)])
484
+ async def model_status(model: str, api_key: str = Depends(get_api_key)):
485
  """Check if a specific model is available."""
486
+ logger.info(f"Model status requested for '{model}' by API key: {api_key}")
487
  if model in Blackbox.models:
488
  return {"model": model, "status": "available"}
489
  elif model in Blackbox.model_aliases:
490
  actual_model = Blackbox.model_aliases[model]
491
  return {"model": actual_model, "status": "available via alias"}
492
  else:
493
+ logger.warning(f"Model not found: {model}")
494
  raise HTTPException(status_code=404, detail="Model not found")
495
 
496
  if __name__ == "__main__":