Update main.py
Browse files
main.py
CHANGED
@@ -7,8 +7,6 @@ import json
|
|
7 |
import logging
|
8 |
import asyncio
|
9 |
import time
|
10 |
-
import base64
|
11 |
-
from io import BytesIO
|
12 |
from collections import defaultdict
|
13 |
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
14 |
|
@@ -18,7 +16,10 @@ from aiohttp import ClientSession, ClientTimeout, ClientError
|
|
18 |
from fastapi import FastAPI, HTTPException, Request, Depends, Header
|
19 |
from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
|
20 |
from pydantic import BaseModel
|
|
|
21 |
from PIL import Image
|
|
|
|
|
22 |
|
23 |
# Configure logging
|
24 |
logging.basicConfig(
|
@@ -108,25 +109,6 @@ class ImageResponse:
|
|
108 |
def to_data_uri(image: Any) -> str:
|
109 |
return "data:image/png;base64,..." # Replace with actual base64 data
|
110 |
|
111 |
-
# Utility functions for image processing
|
112 |
-
def decode_base64_image(base64_str: str) -> Image.Image:
|
113 |
-
try:
|
114 |
-
image_data = base64.b64decode(base64_str)
|
115 |
-
image = Image.open(BytesIO(image_data))
|
116 |
-
return image
|
117 |
-
except Exception as e:
|
118 |
-
logger.error("Failed to decode base64 image.")
|
119 |
-
raise HTTPException(status_code=400, detail="Invalid base64 image data.") from e
|
120 |
-
|
121 |
-
def analyze_image(image: Image.Image) -> str:
|
122 |
-
"""
|
123 |
-
Placeholder for image analysis.
|
124 |
-
Replace this with actual image analysis logic.
|
125 |
-
"""
|
126 |
-
# Example: Return image size as analysis
|
127 |
-
width, height = image.size
|
128 |
-
return f"Image analyzed successfully. Width: {width}px, Height: {height}px."
|
129 |
-
|
130 |
class Blackbox:
|
131 |
url = "https://www.blackbox.ai"
|
132 |
api_endpoint = "https://www.blackbox.ai/api/chat"
|
@@ -440,7 +422,7 @@ async def security_middleware(request: Request, call_next):
|
|
440 |
# Request Models
|
441 |
class Message(BaseModel):
|
442 |
role: str
|
443 |
-
content: Union[str, List[Any]] #
|
444 |
|
445 |
class ChatRequest(BaseModel):
|
446 |
model: str
|
@@ -510,31 +492,59 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
510 |
logger.exception("Unexpected error during image analysis.")
|
511 |
raise HTTPException(status_code=500, detail="Image analysis failed.") from e
|
512 |
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
|
|
|
528 |
if request.stream:
|
529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
try:
|
531 |
-
|
532 |
-
async for chunk in async_generator:
|
533 |
if isinstance(chunk, ImageResponse):
|
534 |
# Handle image responses if necessary
|
535 |
image_markdown = f"\n"
|
536 |
assistant_content += image_markdown
|
537 |
-
response_chunk = create_response(image_markdown,
|
538 |
else:
|
539 |
assistant_content += chunk
|
540 |
# Yield the chunk as a partial choice
|
@@ -542,7 +552,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
542 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
543 |
"object": "chat.completion.chunk",
|
544 |
"created": int(datetime.now().timestamp()),
|
545 |
-
"model":
|
546 |
"choices": [
|
547 |
{
|
548 |
"index": 0,
|
@@ -555,7 +565,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
555 |
yield f"data: {json.dumps(response_chunk)}\n\n"
|
556 |
|
557 |
# After all chunks are sent, send the final message with finish_reason
|
558 |
-
prompt_tokens = sum(len(msg
|
559 |
completion_tokens = len(assistant_content.split())
|
560 |
total_tokens = prompt_tokens + completion_tokens
|
561 |
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
@@ -564,7 +574,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
564 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
565 |
"object": "chat.completion",
|
566 |
"created": int(datetime.now().timestamp()),
|
567 |
-
"model":
|
568 |
"choices": [
|
569 |
{
|
570 |
"message": {
|
@@ -596,16 +606,26 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
596 |
error_response = {"error": str(e)}
|
597 |
yield f"data: {json.dumps(error_response)}\n\n"
|
598 |
|
599 |
-
return StreamingResponse(
|
600 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
601 |
response_content = ""
|
602 |
-
async for chunk in
|
603 |
if isinstance(chunk, ImageResponse):
|
604 |
response_content += f"\n"
|
605 |
else:
|
606 |
response_content += chunk
|
607 |
|
608 |
-
prompt_tokens = sum(len(msg
|
609 |
completion_tokens = len(response_content.split())
|
610 |
total_tokens = prompt_tokens + completion_tokens
|
611 |
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
@@ -616,7 +636,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
616 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
617 |
"object": "chat.completion",
|
618 |
"created": int(datetime.now().timestamp()),
|
619 |
-
"model":
|
620 |
"choices": [
|
621 |
{
|
622 |
"message": {
|
@@ -710,6 +730,25 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
|
710 |
},
|
711 |
)
|
712 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
713 |
# Run the application
|
714 |
if __name__ == "__main__":
|
715 |
import uvicorn
|
|
|
7 |
import logging
|
8 |
import asyncio
|
9 |
import time
|
|
|
|
|
10 |
from collections import defaultdict
|
11 |
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
12 |
|
|
|
16 |
from fastapi import FastAPI, HTTPException, Request, Depends, Header
|
17 |
from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
|
18 |
from pydantic import BaseModel
|
19 |
+
|
20 |
from PIL import Image
|
21 |
+
import base64
|
22 |
+
from io import BytesIO
|
23 |
|
24 |
# Configure logging
|
25 |
logging.basicConfig(
|
|
|
109 |
def to_data_uri(image: Any) -> str:
|
110 |
return "data:image/png;base64,..." # Replace with actual base64 data
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
class Blackbox:
|
113 |
url = "https://www.blackbox.ai"
|
114 |
api_endpoint = "https://www.blackbox.ai/api/chat"
|
|
|
422 |
# Request Models
|
423 |
class Message(BaseModel):
|
424 |
role: str
|
425 |
+
content: Union[str, List[Any]] # content can be a string or a list (for images)
|
426 |
|
427 |
class ChatRequest(BaseModel):
|
428 |
model: str
|
|
|
492 |
logger.exception("Unexpected error during image analysis.")
|
493 |
raise HTTPException(status_code=500, detail="Image analysis failed.") from e
|
494 |
|
495 |
+
# Prepare messages to send to the external API, excluding image data
|
496 |
+
processed_messages = []
|
497 |
+
for msg in request.messages:
|
498 |
+
if isinstance(msg.content, list) and len(msg.content) == 2:
|
499 |
+
# Assume the second item is image data, skip it
|
500 |
+
processed_messages.append({
|
501 |
+
"role": msg.role,
|
502 |
+
"content": msg.content[0]["text"] # Only include the text part
|
503 |
+
})
|
504 |
+
else:
|
505 |
+
processed_messages.append({
|
506 |
+
"role": msg.role,
|
507 |
+
"content": msg.content
|
508 |
+
})
|
509 |
+
|
510 |
+
# Create a modified ChatRequest without the image
|
511 |
+
modified_request = ChatRequest(
|
512 |
+
model=request.model,
|
513 |
+
messages=[msg for msg in processed_messages],
|
514 |
+
stream=request.stream,
|
515 |
+
temperature=request.temperature,
|
516 |
+
top_p=request.top_p,
|
517 |
+
max_tokens=request.max_tokens,
|
518 |
+
presence_penalty=request.presence_penalty,
|
519 |
+
frequency_penalty=request.frequency_penalty,
|
520 |
+
logit_bias=request.logit_bias,
|
521 |
+
user=request.user,
|
522 |
+
webSearchMode=request.webSearchMode,
|
523 |
+
image=None # Exclude image from external API
|
524 |
+
)
|
525 |
|
526 |
+
try:
|
527 |
if request.stream:
|
528 |
+
logger.info("Streaming response")
|
529 |
+
streaming_response = await Blackbox.create_async_generator(
|
530 |
+
model=modified_request.model,
|
531 |
+
messages=[{"role": msg["role"], "content": msg["content"]} for msg in modified_request.messages],
|
532 |
+
proxy=None,
|
533 |
+
image=None,
|
534 |
+
image_name=None,
|
535 |
+
webSearchMode=modified_request.webSearchMode
|
536 |
+
)
|
537 |
+
|
538 |
+
# Wrap the streaming generator to include image analysis at the end
|
539 |
+
async def generate_with_analysis():
|
540 |
+
assistant_content = ""
|
541 |
try:
|
542 |
+
async for chunk in streaming_response:
|
|
|
543 |
if isinstance(chunk, ImageResponse):
|
544 |
# Handle image responses if necessary
|
545 |
image_markdown = f"\n"
|
546 |
assistant_content += image_markdown
|
547 |
+
response_chunk = create_response(image_markdown, modified_request.model, finish_reason=None)
|
548 |
else:
|
549 |
assistant_content += chunk
|
550 |
# Yield the chunk as a partial choice
|
|
|
552 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
553 |
"object": "chat.completion.chunk",
|
554 |
"created": int(datetime.now().timestamp()),
|
555 |
+
"model": modified_request.model,
|
556 |
"choices": [
|
557 |
{
|
558 |
"index": 0,
|
|
|
565 |
yield f"data: {json.dumps(response_chunk)}\n\n"
|
566 |
|
567 |
# After all chunks are sent, send the final message with finish_reason
|
568 |
+
prompt_tokens = sum(len(msg["content"].split()) for msg in modified_request.messages)
|
569 |
completion_tokens = len(assistant_content.split())
|
570 |
total_tokens = prompt_tokens + completion_tokens
|
571 |
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
|
|
574 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
575 |
"object": "chat.completion",
|
576 |
"created": int(datetime.now().timestamp()),
|
577 |
+
"model": modified_request.model,
|
578 |
"choices": [
|
579 |
{
|
580 |
"message": {
|
|
|
606 |
error_response = {"error": str(e)}
|
607 |
yield f"data: {json.dumps(error_response)}\n\n"
|
608 |
|
609 |
+
return StreamingResponse(generate_with_analysis(), media_type="text/event-stream")
|
610 |
else:
|
611 |
+
logger.info("Non-streaming response")
|
612 |
+
streaming_response = await Blackbox.create_async_generator(
|
613 |
+
model=modified_request.model,
|
614 |
+
messages=[{"role": msg["role"], "content": msg["content"]} for msg in modified_request.messages],
|
615 |
+
proxy=None,
|
616 |
+
image=None,
|
617 |
+
image_name=None,
|
618 |
+
webSearchMode=modified_request.webSearchMode
|
619 |
+
)
|
620 |
+
|
621 |
response_content = ""
|
622 |
+
async for chunk in streaming_response:
|
623 |
if isinstance(chunk, ImageResponse):
|
624 |
response_content += f"\n"
|
625 |
else:
|
626 |
response_content += chunk
|
627 |
|
628 |
+
prompt_tokens = sum(len(msg["content"].split()) for msg in modified_request.messages)
|
629 |
completion_tokens = len(response_content.split())
|
630 |
total_tokens = prompt_tokens + completion_tokens
|
631 |
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
|
|
636 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
637 |
"object": "chat.completion",
|
638 |
"created": int(datetime.now().timestamp()),
|
639 |
+
"model": modified_request.model,
|
640 |
"choices": [
|
641 |
{
|
642 |
"message": {
|
|
|
730 |
},
|
731 |
)
|
732 |
|
733 |
+
# Image Processing Utilities
|
734 |
+
def decode_base64_image(base64_str: str) -> Image.Image:
|
735 |
+
try:
|
736 |
+
image_data = base64.b64decode(base64_str)
|
737 |
+
image = Image.open(BytesIO(image_data))
|
738 |
+
return image
|
739 |
+
except Exception as e:
|
740 |
+
logger.error("Failed to decode base64 image.")
|
741 |
+
raise HTTPException(status_code=400, detail="Invalid base64 image data.") from e
|
742 |
+
|
743 |
+
def analyze_image(image: Image.Image) -> str:
|
744 |
+
"""
|
745 |
+
Placeholder for image analysis.
|
746 |
+
Replace this with actual image analysis logic.
|
747 |
+
"""
|
748 |
+
# Example: Return image size as analysis
|
749 |
+
width, height = image.size
|
750 |
+
return f"Image analyzed successfully. Width: {width}px, Height: {height}px."
|
751 |
+
|
752 |
# Run the application
|
753 |
if __name__ == "__main__":
|
754 |
import uvicorn
|