Niansuh commited on
Commit
b0b613c
·
verified ·
1 Parent(s): ebc0716

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +368 -0
main.py CHANGED
@@ -386,3 +386,371 @@ class Blackbox:
386
  logger.error(f"Unexpected error: {e}. Retrying attempt {attempt + 1}/{retry_attempts}")
387
  if attempt == retry_attempts - 1:
388
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  logger.error(f"Unexpected error: {e}. Retrying attempt {attempt + 1}/{retry_attempts}")
387
  if attempt == retry_attempts - 1:
388
  raise HTTPException(status_code=500, detail=str(e))
389
+
390
+ # FastAPI app setup
391
+ app = FastAPI()
392
+
393
+ # Add the cleanup task when the app starts
394
+ @app.on_event("startup")
395
+ async def startup_event():
396
+ asyncio.create_task(cleanup_rate_limit_stores())
397
+ logger.info("Started rate limit store cleanup task.")
398
+
399
+ # Middleware to enhance security and enforce Content-Type for specific endpoints
400
+ @app.middleware("http")
401
+ async def security_middleware(request: Request, call_next):
402
+ client_ip = request.client.host
403
+ # Enforce that POST requests to /v1/chat/completions must have Content-Type: application/json
404
+ if request.method == "POST" and request.url.path == "/v1/chat/completions":
405
+ content_type = request.headers.get("Content-Type")
406
+ if content_type != "application/json":
407
+ logger.warning(f"Invalid Content-Type from IP: {client_ip} for path: {request.url.path}")
408
+ return JSONResponse(
409
+ status_code=400,
410
+ content={
411
+ "error": {
412
+ "message": "Content-Type must be application/json",
413
+ "type": "invalid_request_error",
414
+ "param": None,
415
+ "code": None
416
+ }
417
+ },
418
+ )
419
+ response = await call_next(request)
420
+ return response
421
+
422
+ # Request Models
423
+ class Message(BaseModel):
424
+ role: str
425
+ content: Union[str, List[Any]] # content can be a string or a list (for images)
426
+
427
+ class ChatRequest(BaseModel):
428
+ model: str
429
+ messages: List[Message]
430
+ temperature: Optional[float] = 1.0
431
+ top_p: Optional[float] = 1.0
432
+ n: Optional[int] = 1
433
+ stream: Optional[bool] = False
434
+ stop: Optional[Union[str, List[str]]] = None
435
+ max_tokens: Optional[int] = None
436
+ presence_penalty: Optional[float] = 0.0
437
+ frequency_penalty: Optional[float] = 0.0
438
+ logit_bias: Optional[Dict[str, float]] = None
439
+ user: Optional[str] = None
440
+ webSearchMode: Optional[bool] = False # Custom parameter
441
+ image: Optional[str] = None # Base64-encoded image
442
+
443
+ class TokenizerRequest(BaseModel):
444
+ text: str
445
+
446
+ def calculate_estimated_cost(prompt_tokens: int, completion_tokens: int) -> float:
447
+ """
448
+ Calculate the estimated cost based on the number of tokens.
449
+ Replace the pricing below with your actual pricing model.
450
+ """
451
+ # Example pricing: $0.00000268 per token
452
+ cost_per_token = 0.00000268
453
+ return round((prompt_tokens + completion_tokens) * cost_per_token, 8)
454
+
455
+ def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
456
+ return {
457
+ "id": f"chatcmpl-{uuid.uuid4()}",
458
+ "object": "chat.completion",
459
+ "created": int(datetime.now().timestamp()),
460
+ "model": model,
461
+ "choices": [
462
+ {
463
+ "index": 0,
464
+ "message": {
465
+ "role": "assistant",
466
+ "content": content
467
+ },
468
+ "finish_reason": finish_reason
469
+ }
470
+ ],
471
+ "usage": None, # To be filled in non-streaming responses
472
+ }
473
+
474
+ @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
475
+ async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
476
+ client_ip = req.client.host
477
+ # Redact user messages only for logging purposes
478
+ redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
479
+
480
+ logger.info(f"Received chat completions request from API key: {api_key} | IP: {client_ip} | Model: {request.model} | Messages: {redacted_messages}")
481
+
482
+ analysis_result = None
483
+ if request.image:
484
+ try:
485
+ image = decode_base64_image(request.image)
486
+ analysis_result = analyze_image(image)
487
+ logger.info("Image analysis completed successfully.")
488
+ except HTTPException as he:
489
+ logger.error(f"Image analysis failed: {he.detail}")
490
+ raise he
491
+ except Exception as e:
492
+ logger.exception("Unexpected error during image analysis.")
493
+ raise HTTPException(status_code=500, detail="Image analysis failed.") from e
494
+
495
+ # Prepare messages to send to the external API, excluding image data
496
+ processed_messages = []
497
+ for msg in request.messages:
498
+ if isinstance(msg.content, list) and len(msg.content) == 2:
499
+ # Assume the second item is image data, skip it
500
+ processed_messages.append({
501
+ "role": msg.role,
502
+ "content": msg.content[0]["text"] # Only include the text part
503
+ })
504
+ else:
505
+ processed_messages.append({
506
+ "role": msg.role,
507
+ "content": msg.content
508
+ })
509
+
510
+ # Create a modified ChatRequest without the image
511
+ modified_request = ChatRequest(
512
+ model=request.model,
513
+ messages=[msg for msg in processed_messages],
514
+ stream=request.stream,
515
+ temperature=request.temperature,
516
+ top_p=request.top_p,
517
+ max_tokens=request.max_tokens,
518
+ presence_penalty=request.presence_penalty,
519
+ frequency_penalty=request.frequency_penalty,
520
+ logit_bias=request.logit_bias,
521
+ user=request.user,
522
+ webSearchMode=request.webSearchMode,
523
+ image=None # Exclude image from external API
524
+ )
525
+
526
+ try:
527
+ if request.stream:
528
+ logger.info("Streaming response")
529
+ streaming_response = await Blackbox.create_async_generator(
530
+ model=modified_request.model,
531
+ messages=[{"role": msg.role, "content": msg.content} for msg in modified_request.messages],
532
+ proxy=None,
533
+ image=None,
534
+ image_name=None,
535
+ webSearchMode=modified_request.webSearchMode
536
+ )
537
+
538
+ # Wrap the streaming generator to include image analysis at the end
539
+ async def generate_with_analysis():
540
+ assistant_content = ""
541
+ try:
542
+ async for chunk in streaming_response:
543
+ if isinstance(chunk, ImageResponse):
544
+ # Handle image responses if necessary
545
+ image_markdown = f"![image]({chunk.url})\n"
546
+ assistant_content += image_markdown
547
+ response_chunk = create_response(image_markdown, modified_request.model, finish_reason=None)
548
+ else:
549
+ assistant_content += chunk
550
+ # Yield the chunk as a partial choice
551
+ response_chunk = {
552
+ "id": f"chatcmpl-{uuid.uuid4()}",
553
+ "object": "chat.completion.chunk",
554
+ "created": int(datetime.now().timestamp()),
555
+ "model": modified_request.model,
556
+ "choices": [
557
+ {
558
+ "index": 0,
559
+ "delta": {"content": chunk, "role": "assistant"},
560
+ "finish_reason": None,
561
+ }
562
+ ],
563
+ "usage": None, # Usage can be updated if you track tokens in real-time
564
+ }
565
+ yield f"data: {json.dumps(response_chunk)}\n\n"
566
+
567
+ # After all chunks are sent, send the final message with finish_reason
568
+ prompt_tokens = sum(len(msg["content"].split()) for msg in modified_request.messages)
569
+ completion_tokens = len(assistant_content.split())
570
+ total_tokens = prompt_tokens + completion_tokens
571
+ estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
572
+
573
+ final_content = assistant_content
574
+ if analysis_result:
575
+ final_content += f"\n\n**Image Analysis:** {analysis_result}"
576
+
577
+ final_response = {
578
+ "id": f"chatcmpl-{uuid.uuid4()}",
579
+ "object": "chat.completion",
580
+ "created": int(datetime.now().timestamp()),
581
+ "model": modified_request.model,
582
+ "choices": [
583
+ {
584
+ "message": {
585
+ "role": "assistant",
586
+ "content": final_content
587
+ },
588
+ "finish_reason": "stop",
589
+ "index": 0
590
+ }
591
+ ],
592
+ "usage": {
593
+ "prompt_tokens": prompt_tokens,
594
+ "completion_tokens": completion_tokens,
595
+ "total_tokens": total_tokens,
596
+ "estimated_cost": estimated_cost
597
+ },
598
+ }
599
+
600
+ yield f"data: {json.dumps(final_response)}\n\n"
601
+ yield "data: [DONE]\n\n"
602
+ except HTTPException as he:
603
+ error_response = {"error": he.detail}
604
+ yield f"data: {json.dumps(error_response)}\n\n"
605
+ except Exception as e:
606
+ logger.exception(f"Error during streaming response generation from IP: {client_ip}.")
607
+ error_response = {"error": str(e)}
608
+ yield f"data: {json.dumps(error_response)}\n\n"
609
+
610
+ return StreamingResponse(generate_with_analysis(), media_type="text/event-stream")
611
+ else:
612
+ logger.info("Non-streaming response")
613
+ streaming_response = await Blackbox.create_async_generator(
614
+ model=modified_request.model,
615
+ messages=[{"role": msg.role, "content": msg.content} for msg in modified_request.messages],
616
+ proxy=None,
617
+ image=None,
618
+ image_name=None,
619
+ webSearchMode=modified_request.webSearchMode
620
+ )
621
+
622
+ response_content = ""
623
+ async for chunk in streaming_response:
624
+ if isinstance(chunk, ImageResponse):
625
+ response_content += f"![image]({chunk.url})\n"
626
+ else:
627
+ response_content += chunk
628
+
629
+ prompt_tokens = sum(len(msg["content"].split()) for msg in modified_request.messages)
630
+ completion_tokens = len(response_content.split())
631
+ total_tokens = prompt_tokens + completion_tokens
632
+ estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
633
+
634
+ if analysis_result:
635
+ response_content += f"\n\n**Image Analysis:** {analysis_result}"
636
+
637
+ logger.info(f"Completed non-streaming response generation for API key: {api_key} | IP: {client_ip}")
638
+
639
+ response = {
640
+ "id": f"chatcmpl-{uuid.uuid4()}",
641
+ "object": "chat.completion",
642
+ "created": int(datetime.now().timestamp()),
643
+ "model": modified_request.model,
644
+ "choices": [
645
+ {
646
+ "message": {
647
+ "role": "assistant",
648
+ "content": response_content
649
+ },
650
+ "finish_reason": "stop",
651
+ "index": 0
652
+ }
653
+ ],
654
+ "usage": {
655
+ "prompt_tokens": prompt_tokens,
656
+ "completion_tokens": completion_tokens,
657
+ "total_tokens": total_tokens,
658
+ "estimated_cost": estimated_cost
659
+ },
660
+ }
661
+
662
+ return response
663
+ except ModelNotWorkingException as e:
664
+ logger.warning(f"Model not working: {e} | IP: {client_ip}")
665
+ raise HTTPException(status_code=503, detail=str(e))
666
+ except HTTPException as he:
667
+ logger.warning(f"HTTPException: {he.detail} | IP: {client_ip}")
668
+ raise he
669
+ except Exception as e:
670
+ logger.exception(f"An unexpected error occurred while processing the chat completions request from IP: {client_ip}.")
671
+ raise HTTPException(status_code=500, detail=str(e))
672
+
673
+ # Endpoint: POST /v1/tokenizer
674
+ @app.post("/v1/tokenizer", dependencies=[Depends(rate_limiter_per_ip)])
675
+ async def tokenizer(request: TokenizerRequest, req: Request):
676
+ client_ip = req.client.host
677
+ text = request.text
678
+ token_count = len(text.split())
679
+ logger.info(f"Tokenizer requested from IP: {client_ip} | Text length: {len(text)}")
680
+ return {"text": text, "tokens": token_count}
681
+
682
+ # Endpoint: GET /v1/models
683
+ @app.get("/v1/models", dependencies=[Depends(rate_limiter_per_ip)])
684
+ async def get_models(req: Request):
685
+ client_ip = req.client.host
686
+ logger.info(f"Fetching available models from IP: {client_ip}")
687
+ return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
688
+
689
+ # Endpoint: GET /v1/models/{model}/status
690
+ @app.get("/v1/models/{model}/status", dependencies=[Depends(rate_limiter_per_ip)])
691
+ async def model_status(model: str, req: Request):
692
+ client_ip = req.client.host
693
+ logger.info(f"Model status requested for '{model}' from IP: {client_ip}")
694
+ if model in Blackbox.models:
695
+ return {"model": model, "status": "available"}
696
+ elif model in Blackbox.model_aliases and Blackbox.model_aliases[model] in Blackbox.models:
697
+ actual_model = Blackbox.model_aliases[model]
698
+ return {"model": actual_model, "status": "available via alias"}
699
+ else:
700
+ logger.warning(f"Model not found: {model} from IP: {client_ip}")
701
+ raise HTTPException(status_code=404, detail="Model not found")
702
+
703
+ # Endpoint: GET /v1/health
704
+ @app.get("/v1/health", dependencies=[Depends(rate_limiter_per_ip)])
705
+ async def health_check(req: Request):
706
+ client_ip = req.client.host
707
+ logger.info(f"Health check requested from IP: {client_ip}")
708
+ return {"status": "ok"}
709
+
710
+ # Endpoint: GET /v1/chat/completions (GET method)
711
+ @app.get("/v1/chat/completions")
712
+ async def chat_completions_get(req: Request):
713
+ client_ip = req.client.host
714
+ logger.info(f"GET request made to /v1/chat/completions from IP: {client_ip}, redirecting to 'about:blank'")
715
+ return RedirectResponse(url='about:blank')
716
+
717
+ # Custom exception handler to match OpenAI's error format
718
+ @app.exception_handler(HTTPException)
719
+ async def http_exception_handler(request: Request, exc: HTTPException):
720
+ client_ip = request.client.host
721
+ logger.error(f"HTTPException: {exc.detail} | Path: {request.url.path} | IP: {client_ip}")
722
+ return JSONResponse(
723
+ status_code=exc.status_code,
724
+ content={
725
+ "error": {
726
+ "message": exc.detail,
727
+ "type": "invalid_request_error",
728
+ "param": None,
729
+ "code": None
730
+ }
731
+ },
732
+ )
733
+
734
+ # Image Processing Utilities
735
+ def decode_base64_image(base64_str: str) -> Image.Image:
736
+ try:
737
+ image_data = base64.b64decode(base64_str)
738
+ image = Image.open(BytesIO(image_data))
739
+ return image
740
+ except Exception as e:
741
+ logger.error("Failed to decode base64 image.")
742
+ raise HTTPException(status_code=400, detail="Invalid base64 image data.") from e
743
+
744
+ def analyze_image(image: Image.Image) -> str:
745
+ """
746
+ Placeholder for image analysis.
747
+ Replace this with actual image analysis logic.
748
+ """
749
+ # Example: Return image size as analysis
750
+ width, height = image.size
751
+ return f"Image analyzed successfully. Width: {width}px, Height: {height}px."
752
+
753
+ # Run the application
754
+ if __name__ == "__main__":
755
+ import uvicorn
756
+ uvicorn.run(app, host="0.0.0.0", port=8000)