Niansuh commited on
Commit
c815e1f
·
verified ·
1 Parent(s): 5c13b7b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +44 -314
main.py CHANGED
@@ -1,53 +1,22 @@
1
  from __future__ import annotations
2
 
3
- import os
4
- import re
5
  import random
6
  import string
7
- import uuid
8
  import json
9
- import logging
10
- import asyncio
11
- import time
12
- from collections import defaultdict
13
- from typing import List, Dict, Any, Optional, Union, AsyncGenerator
14
 
15
  from aiohttp import ClientSession, ClientResponseError
16
- from fastapi import FastAPI, HTTPException, Request, Depends, Header
17
- from fastapi.responses import JSONResponse, StreamingResponse
18
- from pydantic import BaseModel
19
- from datetime import datetime
20
-
21
- # Configure logging
22
- logging.basicConfig(
23
- level=logging.INFO,
24
- format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
25
- handlers=[logging.StreamHandler()]
26
- )
27
- logger = logging.getLogger(__name__)
28
-
29
- # Load environment variables
30
- API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
31
- RATE_LIMIT = int(os.getenv('RATE_LIMIT', '60')) # Requests per minute
32
-
33
- if not API_KEYS or API_KEYS == ['']:
34
- logger.error("No API keys found. Please set the API_KEYS environment variable.")
35
- raise Exception("API_KEYS environment variable not set.")
36
-
37
- # Simple in-memory rate limiter based solely on IP addresses
38
- rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
39
-
40
- # Define cleanup interval and window
41
- CLEANUP_INTERVAL = 60 # seconds
42
- RATE_LIMIT_WINDOW = 60 # seconds
43
-
44
- # Define the ImageResponse model
45
- class ImageResponseModel(BaseModel):
46
- images: str
47
- alt: str
48
-
49
- # Updated Blackbox Provider
50
- class Blackbox:
51
  label = "Blackbox AI"
52
  url = "https://www.blackbox.ai"
53
  api_endpoint = "https://www.blackbox.ai/api/chat"
@@ -200,23 +169,23 @@ class Blackbox:
200
  async def create_async_generator(
201
  cls,
202
  model: str,
203
- messages: List[Dict[str, str]],
204
  proxy: Optional[str] = None,
205
  websearch: bool = False,
206
  **kwargs
207
- ) -> AsyncGenerator[Union[str, ImageResponseModel], None]:
208
  """
209
  Creates an asynchronous generator for streaming responses from Blackbox AI.
210
 
211
  Parameters:
212
  model (str): Model to use for generating responses.
213
- messages (List[Dict[str, str]]): Message history.
214
  proxy (Optional[str]): Proxy URL, if needed.
215
  websearch (bool): Enables or disables web search mode.
216
  **kwargs: Additional keyword arguments.
217
 
218
  Yields:
219
- Union[str, ImageResponseModel]: Segments of the generated response or ImageResponseModel objects.
220
  """
221
  model = cls.get_model(model)
222
 
@@ -317,26 +286,43 @@ class Blackbox:
317
  proxy=proxy
318
  ) as response_api_chat:
319
  response_api_chat.raise_for_status()
320
- # Simulate streaming by breaking the response into chunks
321
- # Since the external API may not support streaming, we'll simulate it
322
  text = await response_api_chat.text()
323
  cleaned_response = cls.clean_response(text)
324
 
325
- # For demonstration, we'll split the response into words and yield them with delays
326
- words = cleaned_response.split()
327
- for word in words:
328
- await asyncio.sleep(0.1) # Simulate delay
329
- yield word + ' '
330
-
331
- # If the model is an image model, handle accordingly
332
  if model in cls.image_models:
333
  match = re.search(r'!\[.*?\]\((https?://[^\)]+)\)', cleaned_response)
334
  if match:
335
  image_url = match.group(1)
336
- image_response = ImageResponseModel(images=image_url, alt="Generated Image")
337
  yield image_response
338
  else:
339
  yield cleaned_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  except ClientResponseError as e:
341
  error_text = f"Error {e.status}: {e.message}"
342
  try:
@@ -349,7 +335,6 @@ class Blackbox:
349
  except Exception as e:
350
  yield f"Unexpected error during /api/chat request: {str(e)}"
351
 
352
- # Simulate the second API call to /chat/{chat_id}
353
  chat_url = f'{cls.url}/chat/{chat_id}?model={model}'
354
 
355
  try:
@@ -360,7 +345,7 @@ class Blackbox:
360
  proxy=proxy
361
  ) as response_chat:
362
  response_chat.raise_for_status()
363
- # Assuming no streaming from this endpoint
364
  except ClientResponseError as e:
365
  error_text = f"Error {e.status}: {e.message}"
366
  try:
@@ -372,258 +357,3 @@ class Blackbox:
372
  yield error_text
373
  except Exception as e:
374
  yield f"Unexpected error during /chat/{chat_id} request: {str(e)}"
375
-
376
- # Custom exception for model not working
377
- class ModelNotWorkingException(Exception):
378
- def __init__(self, model: str):
379
- self.model = model
380
- self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed."
381
- super().__init__(self.message)
382
-
383
- async def cleanup_rate_limit_stores():
384
- """
385
- Periodically cleans up stale entries in the rate_limit_store to prevent memory bloat.
386
- """
387
- while True:
388
- current_time = time.time()
389
- ips_to_delete = [ip for ip, value in rate_limit_store.items() if current_time - value["timestamp"] > RATE_LIMIT_WINDOW * 2]
390
- for ip in ips_to_delete:
391
- del rate_limit_store[ip]
392
- logger.debug(f"Cleaned up rate_limit_store for IP: {ip}")
393
- await asyncio.sleep(CLEANUP_INTERVAL)
394
-
395
- async def rate_limiter_per_ip(request: Request):
396
- """
397
- Rate limiter that enforces a limit based on the client's IP address.
398
- """
399
- client_ip = request.client.host
400
- current_time = time.time()
401
-
402
- # Initialize or update the count and timestamp
403
- if current_time - rate_limit_store[client_ip]["timestamp"] > RATE_LIMIT_WINDOW:
404
- rate_limit_store[client_ip] = {"count": 1, "timestamp": current_time}
405
- else:
406
- if rate_limit_store[client_ip]["count"] >= RATE_LIMIT:
407
- logger.warning(f"Rate limit exceeded for IP address: {client_ip}")
408
- raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address | NiansuhAI')
409
- rate_limit_store[client_ip]["count"] += 1
410
-
411
- async def get_api_key(request: Request, authorization: str = Header(None)) -> str:
412
- """
413
- Dependency to extract and validate the API key from the Authorization header.
414
- """
415
- client_ip = request.client.host
416
- if authorization is None or not authorization.startswith('Bearer '):
417
- logger.warning(f"Invalid or missing authorization header from IP: {client_ip}")
418
- raise HTTPException(status_code=401, detail='Invalid authorization header format')
419
- api_key = authorization[7:]
420
- if api_key not in API_KEYS:
421
- logger.warning(f"Invalid API key attempted: {api_key} from IP: {client_ip}")
422
- raise HTTPException(status_code=401, detail='Invalid API key')
423
- return api_key
424
-
425
- # FastAPI app setup
426
- app = FastAPI()
427
-
428
- # Add the cleanup task when the app starts
429
- @app.on_event("startup")
430
- async def startup_event():
431
- asyncio.create_task(cleanup_rate_limit_stores())
432
- logger.info("Started rate limit store cleanup task.")
433
-
434
- # Middleware to enhance security and enforce Content-Type for specific endpoints
435
- @app.middleware("http")
436
- async def security_middleware(request: Request, call_next):
437
- client_ip = request.client.host
438
- # Enforce that POST requests to /v1/chat/completions must have Content-Type: application/json
439
- if request.method == "POST" and request.url.path == "/v1/chat/completions":
440
- content_type = request.headers.get("Content-Type")
441
- if content_type != "application/json":
442
- logger.warning(f"Invalid Content-Type from IP: {client_ip} for path: {request.url.path}")
443
- return JSONResponse(
444
- status_code=400,
445
- content={
446
- "error": {
447
- "message": "Content-Type must be application/json",
448
- "type": "invalid_request_error",
449
- "param": None,
450
- "code": None
451
- }
452
- },
453
- )
454
- response = await call_next(request)
455
- return response
456
-
457
- # Request Models
458
- class Message(BaseModel):
459
- role: str
460
- content: str
461
-
462
- class ChatRequest(BaseModel):
463
- model: str
464
- messages: List[Message]
465
- temperature: Optional[float] = 1.0
466
- top_p: Optional[float] = 1.0
467
- n: Optional[int] = 1
468
- max_tokens: Optional[int] = None
469
- presence_penalty: Optional[float] = 0.0
470
- frequency_penalty: Optional[float] = 0.0
471
- logit_bias: Optional[Dict[str, float]] = None
472
- user: Optional[str] = None
473
- stream: Optional[bool] = False # Added stream parameter
474
-
475
- @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
476
- async def chat_completions(
477
- request: ChatRequest,
478
- req: Request,
479
- api_key: str = Depends(get_api_key)
480
- ):
481
- client_ip = req.client.host
482
- # Redact user messages only for logging purposes
483
- redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
484
-
485
- logger.info(f"Received chat completions request from API key: {api_key} | IP: {client_ip} | Model: {request.model} | Messages: {redacted_messages} | Stream: {request.stream}")
486
-
487
- try:
488
- # Validate that the requested model is available
489
- if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
490
- logger.warning(f"Attempt to use unavailable model: {request.model} from IP: {client_ip}")
491
- raise HTTPException(status_code=400, detail="Requested model is not available.")
492
-
493
- if request.stream:
494
- # Create the asynchronous generator for streaming responses
495
- async_generator = Blackbox.create_async_generator(
496
- model=request.model,
497
- messages=[{"role": msg.role, "content": msg.content} for msg in request.messages],
498
- temperature=request.temperature,
499
- max_tokens=request.max_tokens
500
- )
501
-
502
- logger.info(f"Started streaming response for API key: {api_key} | IP: {client_ip}")
503
-
504
- # Define a generator function to yield the streamed data in the OpenAI format
505
- async def stream_response():
506
- try:
507
- async for chunk in async_generator:
508
- if isinstance(chunk, ImageResponseModel):
509
- # Handle image responses by sending markdown image syntax
510
- data = json.dumps({
511
- "choices": [{
512
- "delta": {
513
- "content": f"![Image]({chunk.images})"
514
- }
515
- }]
516
- })
517
- else:
518
- # Assuming chunk is a string
519
- data = json.dumps({
520
- "choices": [{
521
- "delta": {
522
- "content": chunk
523
- }
524
- }]
525
- })
526
- yield f"data: {data}\n\n"
527
- except Exception as e:
528
- # If an error occurs during streaming, send it as an SSE error event
529
- error_data = json.dumps({
530
- "choices": [{
531
- "delta": {
532
- "content": f"Error: {str(e)}"
533
- }
534
- }]
535
- })
536
- yield f"data: {error_data}\n\n"
537
-
538
- return StreamingResponse(stream_response(), media_type="text/event-stream")
539
- else:
540
- # Non-streaming: Collect all chunks and assemble the final response
541
- all_chunks = []
542
- async for chunk in Blackbox.create_async_generator(
543
- model=request.model,
544
- messages=[{"role": msg.role, "content": msg.content} for msg in request.messages],
545
- temperature=request.temperature,
546
- max_tokens=request.max_tokens
547
- ):
548
- if isinstance(chunk, ImageResponseModel):
549
- # Convert ImageResponseModel to markdown image syntax
550
- content = f"![Image]({chunk.images})"
551
- else:
552
- content = chunk
553
- all_chunks.append(content)
554
-
555
- # Assemble the full response content
556
- full_content = "\n".join(all_chunks)
557
-
558
- # Calculate token usage (approximation)
559
- prompt_tokens = sum(len(msg.content.split()) for msg in request.messages)
560
- completion_tokens = len(full_content.split())
561
- total_tokens = prompt_tokens + completion_tokens
562
-
563
- logger.info(f"Completed non-stream response generation for API key: {api_key} | IP: {client_ip}")
564
-
565
- return {
566
- "id": f"chatcmpl-{uuid.uuid4()}",
567
- "object": "chat.completion",
568
- "created": int(datetime.now().timestamp()),
569
- "model": request.model,
570
- "choices": [
571
- {
572
- "index": 0,
573
- "message": {
574
- "role": "assistant",
575
- "content": full_content
576
- },
577
- "finish_reason": "stop"
578
- }
579
- ],
580
- "usage": {
581
- "prompt_tokens": prompt_tokens,
582
- "completion_tokens": completion_tokens,
583
- "total_tokens": total_tokens
584
- },
585
- }
586
- except ModelNotWorkingException as e:
587
- logger.warning(f"Model not working: {e} | IP: {client_ip}")
588
- raise HTTPException(status_code=503, detail=str(e))
589
- except HTTPException as he:
590
- logger.warning(f"HTTPException: {he.detail} | IP: {client_ip}")
591
- raise he
592
- except Exception as e:
593
- logger.exception(f"An unexpected error occurred while processing the chat completions request from IP: {client_ip}.")
594
- raise HTTPException(status_code=500, detail=str(e))
595
-
596
- # Endpoint: GET /v1/models
597
- @app.get("/v1/models", dependencies=[Depends(rate_limiter_per_ip)])
598
- async def get_models(req: Request):
599
- client_ip = req.client.host
600
- logger.info(f"Fetching available models from IP: {client_ip}")
601
- return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
602
-
603
- # Endpoint: GET /v1/health
604
- @app.get("/v1/health", dependencies=[Depends(rate_limiter_per_ip)])
605
- async def health_check(req: Request):
606
- client_ip = req.client.host
607
- logger.info(f"Health check requested from IP: {client_ip}")
608
- return {"status": "ok"}
609
-
610
- # Custom exception handler to match OpenAI's error format
611
- @app.exception_handler(HTTPException)
612
- async def http_exception_handler(request: Request, exc: HTTPException):
613
- client_ip = request.client.host
614
- logger.error(f"HTTPException: {exc.detail} | Path: {request.url.path} | IP: {client_ip}")
615
- return JSONResponse(
616
- status_code=exc.status_code,
617
- content={
618
- "error": {
619
- "message": exc.detail,
620
- "type": "invalid_request_error",
621
- "param": None,
622
- "code": None
623
- }
624
- },
625
- )
626
-
627
- if __name__ == "__main__":
628
- import uvicorn
629
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  from __future__ import annotations
2
 
3
+ import asyncio
4
+ import aiohttp
5
  import random
6
  import string
 
7
  import json
8
+ import uuid
9
+ import re
10
+ from typing import Optional, AsyncGenerator, Union
 
 
11
 
12
  from aiohttp import ClientSession, ClientResponseError
13
+
14
+ from ..typing import AsyncResult, Messages
15
+ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
16
+ from ..image import ImageResponse
17
+
18
+
19
+ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  label = "Blackbox AI"
21
  url = "https://www.blackbox.ai"
22
  api_endpoint = "https://www.blackbox.ai/api/chat"
 
169
  async def create_async_generator(
170
  cls,
171
  model: str,
172
+ messages: Messages,
173
  proxy: Optional[str] = None,
174
  websearch: bool = False,
175
  **kwargs
176
+ ) -> AsyncGenerator[Union[str, ImageResponse], None]:
177
  """
178
  Creates an asynchronous generator for streaming responses from Blackbox AI.
179
 
180
  Parameters:
181
  model (str): Model to use for generating responses.
182
+ messages (Messages): Message history.
183
  proxy (Optional[str]): Proxy URL, if needed.
184
  websearch (bool): Enables or disables web search mode.
185
  **kwargs: Additional keyword arguments.
186
 
187
  Yields:
188
+ Union[str, ImageResponse]: Segments of the generated response or ImageResponse objects.
189
  """
190
  model = cls.get_model(model)
191
 
 
286
  proxy=proxy
287
  ) as response_api_chat:
288
  response_api_chat.raise_for_status()
 
 
289
  text = await response_api_chat.text()
290
  cleaned_response = cls.clean_response(text)
291
 
 
 
 
 
 
 
 
292
  if model in cls.image_models:
293
  match = re.search(r'!\[.*?\]\((https?://[^\)]+)\)', cleaned_response)
294
  if match:
295
  image_url = match.group(1)
296
+ image_response = ImageResponse(images=image_url, alt="Generated Image")
297
  yield image_response
298
  else:
299
  yield cleaned_response
300
+ else:
301
+ if websearch:
302
+ match = re.search(r'\$~~~\$(.*?)\$~~~\$', cleaned_response, re.DOTALL)
303
+ if match:
304
+ source_part = match.group(1).strip()
305
+ answer_part = cleaned_response[match.end():].strip()
306
+ try:
307
+ sources = json.loads(source_part)
308
+ source_formatted = "**Source:**\n"
309
+ for item in sources:
310
+ title = item.get('title', 'No Title')
311
+ link = item.get('link', '#')
312
+ position = item.get('position', '')
313
+ source_formatted += f"{position}. [{title}]({link})\n"
314
+ final_response = f"{answer_part}\n\n{source_formatted}"
315
+ except json.JSONDecodeError:
316
+ final_response = f"{answer_part}\n\nSource information is unavailable."
317
+ else:
318
+ final_response = cleaned_response
319
+ else:
320
+ if '$~~~$' in cleaned_response:
321
+ final_response = cleaned_response.split('$~~~$')[0].strip()
322
+ else:
323
+ final_response = cleaned_response
324
+
325
+ yield final_response
326
  except ClientResponseError as e:
327
  error_text = f"Error {e.status}: {e.message}"
328
  try:
 
335
  except Exception as e:
336
  yield f"Unexpected error during /api/chat request: {str(e)}"
337
 
 
338
  chat_url = f'{cls.url}/chat/{chat_id}?model={model}'
339
 
340
  try:
 
345
  proxy=proxy
346
  ) as response_chat:
347
  response_chat.raise_for_status()
348
+ pass
349
  except ClientResponseError as e:
350
  error_text = f"Error {e.status}: {e.message}"
351
  try:
 
357
  yield error_text
358
  except Exception as e:
359
  yield f"Unexpected error during /chat/{chat_id} request: {str(e)}"