Niansuh commited on
Commit
53aa71f
·
verified ·
1 Parent(s): bb184cb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +86 -47
main.py CHANGED
@@ -7,8 +7,6 @@ import json
7
  import logging
8
  import asyncio
9
  import time
10
- import base64
11
- from io import BytesIO
12
  from collections import defaultdict
13
  from typing import List, Dict, Any, Optional, AsyncGenerator, Union
14
 
@@ -18,7 +16,10 @@ from aiohttp import ClientSession, ClientTimeout, ClientError
18
  from fastapi import FastAPI, HTTPException, Request, Depends, Header
19
  from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
20
  from pydantic import BaseModel
 
21
  from PIL import Image
 
 
22
 
23
  # Configure logging
24
  logging.basicConfig(
@@ -108,25 +109,6 @@ class ImageResponse:
108
  def to_data_uri(image: Any) -> str:
109
  return "data:image/png;base64,..." # Replace with actual base64 data
110
 
111
- # Utility functions for image processing
112
- def decode_base64_image(base64_str: str) -> Image.Image:
113
- try:
114
- image_data = base64.b64decode(base64_str)
115
- image = Image.open(BytesIO(image_data))
116
- return image
117
- except Exception as e:
118
- logger.error("Failed to decode base64 image.")
119
- raise HTTPException(status_code=400, detail="Invalid base64 image data.") from e
120
-
121
- def analyze_image(image: Image.Image) -> str:
122
- """
123
- Placeholder for image analysis.
124
- Replace this with actual image analysis logic.
125
- """
126
- # Example: Return image size as analysis
127
- width, height = image.size
128
- return f"Image analyzed successfully. Width: {width}px, Height: {height}px."
129
-
130
  class Blackbox:
131
  url = "https://www.blackbox.ai"
132
  api_endpoint = "https://www.blackbox.ai/api/chat"
@@ -440,7 +422,7 @@ async def security_middleware(request: Request, call_next):
440
  # Request Models
441
  class Message(BaseModel):
442
  role: str
443
- content: Union[str, List[Any]] # Adjusted to accept list if needed
444
 
445
  class ChatRequest(BaseModel):
446
  model: str
@@ -510,31 +492,59 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
510
  logger.exception("Unexpected error during image analysis.")
511
  raise HTTPException(status_code=500, detail="Image analysis failed.") from e
512
 
513
- try:
514
- # Validate that the requested model is available
515
- if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
516
- logger.warning(f"Attempt to use unavailable model: {request.model} from IP: {client_ip}")
517
- raise HTTPException(status_code=400, detail="Requested model is not available.")
518
-
519
- # Process the request with actual message content and image data
520
- async_generator = Blackbox.create_async_generator(
521
- model=request.model,
522
- messages=[{"role": msg.role, "content": msg.content} for msg in request.messages],
523
- image=request.image,
524
- image_name="uploaded_image", # You can modify this as needed
525
- webSearchMode=request.webSearchMode
526
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
 
528
  if request.stream:
529
- async def generate():
 
 
 
 
 
 
 
 
 
 
 
 
530
  try:
531
- assistant_content = ""
532
- async for chunk in async_generator:
533
  if isinstance(chunk, ImageResponse):
534
  # Handle image responses if necessary
535
  image_markdown = f"![image]({chunk.url})\n"
536
  assistant_content += image_markdown
537
- response_chunk = create_response(image_markdown, request.model, finish_reason=None)
538
  else:
539
  assistant_content += chunk
540
  # Yield the chunk as a partial choice
@@ -542,7 +552,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
542
  "id": f"chatcmpl-{uuid.uuid4()}",
543
  "object": "chat.completion.chunk",
544
  "created": int(datetime.now().timestamp()),
545
- "model": request.model,
546
  "choices": [
547
  {
548
  "index": 0,
@@ -555,7 +565,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
555
  yield f"data: {json.dumps(response_chunk)}\n\n"
556
 
557
  # After all chunks are sent, send the final message with finish_reason
558
- prompt_tokens = sum(len(msg.content.split()) for msg in request.messages)
559
  completion_tokens = len(assistant_content.split())
560
  total_tokens = prompt_tokens + completion_tokens
561
  estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
@@ -564,7 +574,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
564
  "id": f"chatcmpl-{uuid.uuid4()}",
565
  "object": "chat.completion",
566
  "created": int(datetime.now().timestamp()),
567
- "model": request.model,
568
  "choices": [
569
  {
570
  "message": {
@@ -596,16 +606,26 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
596
  error_response = {"error": str(e)}
597
  yield f"data: {json.dumps(error_response)}\n\n"
598
 
599
- return StreamingResponse(generate(), media_type="text/event-stream")
600
  else:
 
 
 
 
 
 
 
 
 
 
601
  response_content = ""
602
- async for chunk in async_generator:
603
  if isinstance(chunk, ImageResponse):
604
  response_content += f"![image]({chunk.url})\n"
605
  else:
606
  response_content += chunk
607
 
608
- prompt_tokens = sum(len(msg.content.split()) for msg in request.messages)
609
  completion_tokens = len(response_content.split())
610
  total_tokens = prompt_tokens + completion_tokens
611
  estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
@@ -616,7 +636,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
616
  "id": f"chatcmpl-{uuid.uuid4()}",
617
  "object": "chat.completion",
618
  "created": int(datetime.now().timestamp()),
619
- "model": request.model,
620
  "choices": [
621
  {
622
  "message": {
@@ -710,6 +730,25 @@ async def http_exception_handler(request: Request, exc: HTTPException):
710
  },
711
  )
712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  # Run the application
714
  if __name__ == "__main__":
715
  import uvicorn
 
7
  import logging
8
  import asyncio
9
  import time
 
 
10
  from collections import defaultdict
11
  from typing import List, Dict, Any, Optional, AsyncGenerator, Union
12
 
 
16
  from fastapi import FastAPI, HTTPException, Request, Depends, Header
17
  from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
18
  from pydantic import BaseModel
19
+
20
  from PIL import Image
21
+ import base64
22
+ from io import BytesIO
23
 
24
  # Configure logging
25
  logging.basicConfig(
 
109
  def to_data_uri(image: Any) -> str:
110
  return "data:image/png;base64,..." # Replace with actual base64 data
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  class Blackbox:
113
  url = "https://www.blackbox.ai"
114
  api_endpoint = "https://www.blackbox.ai/api/chat"
 
422
  # Request Models
423
  class Message(BaseModel):
424
  role: str
425
+ content: Union[str, List[Any]] # content can be a string or a list (for images)
426
 
427
  class ChatRequest(BaseModel):
428
  model: str
 
492
  logger.exception("Unexpected error during image analysis.")
493
  raise HTTPException(status_code=500, detail="Image analysis failed.") from e
494
 
495
+ # Prepare messages to send to the external API, excluding image data
496
+ processed_messages = []
497
+ for msg in request.messages:
498
+ if isinstance(msg.content, list) and len(msg.content) == 2:
499
+ # Assume the second item is image data, skip it
500
+ processed_messages.append({
501
+ "role": msg.role,
502
+ "content": msg.content[0]["text"] # Only include the text part
503
+ })
504
+ else:
505
+ processed_messages.append({
506
+ "role": msg.role,
507
+ "content": msg.content
508
+ })
509
+
510
+ # Create a modified ChatRequest without the image
511
+ modified_request = ChatRequest(
512
+ model=request.model,
513
+ messages=[msg for msg in processed_messages],
514
+ stream=request.stream,
515
+ temperature=request.temperature,
516
+ top_p=request.top_p,
517
+ max_tokens=request.max_tokens,
518
+ presence_penalty=request.presence_penalty,
519
+ frequency_penalty=request.frequency_penalty,
520
+ logit_bias=request.logit_bias,
521
+ user=request.user,
522
+ webSearchMode=request.webSearchMode,
523
+ image=None # Exclude image from external API
524
+ )
525
 
526
+ try:
527
  if request.stream:
528
+ logger.info("Streaming response")
529
+ streaming_response = await Blackbox.create_async_generator(
530
+ model=modified_request.model,
531
+ messages=[{"role": msg["role"], "content": msg["content"]} for msg in modified_request.messages],
532
+ proxy=None,
533
+ image=None,
534
+ image_name=None,
535
+ webSearchMode=modified_request.webSearchMode
536
+ )
537
+
538
+ # Wrap the streaming generator to include image analysis at the end
539
+ async def generate_with_analysis():
540
+ assistant_content = ""
541
  try:
542
+ async for chunk in streaming_response:
 
543
  if isinstance(chunk, ImageResponse):
544
  # Handle image responses if necessary
545
  image_markdown = f"![image]({chunk.url})\n"
546
  assistant_content += image_markdown
547
+ response_chunk = create_response(image_markdown, modified_request.model, finish_reason=None)
548
  else:
549
  assistant_content += chunk
550
  # Yield the chunk as a partial choice
 
552
  "id": f"chatcmpl-{uuid.uuid4()}",
553
  "object": "chat.completion.chunk",
554
  "created": int(datetime.now().timestamp()),
555
+ "model": modified_request.model,
556
  "choices": [
557
  {
558
  "index": 0,
 
565
  yield f"data: {json.dumps(response_chunk)}\n\n"
566
 
567
  # After all chunks are sent, send the final message with finish_reason
568
+ prompt_tokens = sum(len(msg["content"].split()) for msg in modified_request.messages)
569
  completion_tokens = len(assistant_content.split())
570
  total_tokens = prompt_tokens + completion_tokens
571
  estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
 
574
  "id": f"chatcmpl-{uuid.uuid4()}",
575
  "object": "chat.completion",
576
  "created": int(datetime.now().timestamp()),
577
+ "model": modified_request.model,
578
  "choices": [
579
  {
580
  "message": {
 
606
  error_response = {"error": str(e)}
607
  yield f"data: {json.dumps(error_response)}\n\n"
608
 
609
+ return StreamingResponse(generate_with_analysis(), media_type="text/event-stream")
610
  else:
611
+ logger.info("Non-streaming response")
612
+ streaming_response = await Blackbox.create_async_generator(
613
+ model=modified_request.model,
614
+ messages=[{"role": msg["role"], "content": msg["content"]} for msg in modified_request.messages],
615
+ proxy=None,
616
+ image=None,
617
+ image_name=None,
618
+ webSearchMode=modified_request.webSearchMode
619
+ )
620
+
621
  response_content = ""
622
+ async for chunk in streaming_response:
623
  if isinstance(chunk, ImageResponse):
624
  response_content += f"![image]({chunk.url})\n"
625
  else:
626
  response_content += chunk
627
 
628
+ prompt_tokens = sum(len(msg["content"].split()) for msg in modified_request.messages)
629
  completion_tokens = len(response_content.split())
630
  total_tokens = prompt_tokens + completion_tokens
631
  estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
 
636
  "id": f"chatcmpl-{uuid.uuid4()}",
637
  "object": "chat.completion",
638
  "created": int(datetime.now().timestamp()),
639
+ "model": modified_request.model,
640
  "choices": [
641
  {
642
  "message": {
 
730
  },
731
  )
732
 
733
+ # Image Processing Utilities
734
+ def decode_base64_image(base64_str: str) -> Image.Image:
735
+ try:
736
+ image_data = base64.b64decode(base64_str)
737
+ image = Image.open(BytesIO(image_data))
738
+ return image
739
+ except Exception as e:
740
+ logger.error("Failed to decode base64 image.")
741
+ raise HTTPException(status_code=400, detail="Invalid base64 image data.") from e
742
+
743
+ def analyze_image(image: Image.Image) -> str:
744
+ """
745
+ Placeholder for image analysis.
746
+ Replace this with actual image analysis logic.
747
+ """
748
+ # Example: Return image size as analysis
749
+ width, height = image.size
750
+ return f"Image analyzed successfully. Width: {width}px, Height: {height}px."
751
+
752
  # Run the application
753
  if __name__ == "__main__":
754
  import uvicorn