Update main.py
Browse files
main.py
CHANGED
@@ -532,41 +532,21 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
|
|
532 |
"usage": None, # To be filled in non-streaming responses
|
533 |
}
|
534 |
|
535 |
-
def
|
536 |
"""
|
537 |
-
Extracts
|
538 |
-
Returns a tuple of (alt_text, image_data_uri) if found, else None.
|
539 |
-
"""
|
540 |
-
# Regex to match markdown image syntax: 
|
541 |
-
match = re.search(r'!\[([^\]]*)\]\((data:image/\w+;base64,[^\)]+)\)', content)
|
542 |
-
if match:
|
543 |
-
alt_text = match.group(1)
|
544 |
-
image_data_uri = match.group(2)
|
545 |
-
return alt_text, image_data_uri
|
546 |
-
return None
|
547 |
-
|
548 |
-
def extract_all_images_from_content(content: str) -> List[Tuple[str, str]]:
|
549 |
-
"""
|
550 |
-
Extracts all images from the content string.
|
551 |
Returns a list of tuples containing (alt_text, image_data_uri).
|
552 |
"""
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
image_data = image_data_uri.split(",")[1]
|
564 |
-
# Decode and process the image as needed
|
565 |
-
# For example, send it to an external API
|
566 |
-
# Here, we'll return a dummy response
|
567 |
-
await asyncio.sleep(1) # Simulate processing delay
|
568 |
-
return "Image analysis result: The image depicts a beautiful sunset over the mountains."
|
569 |
-
|
570 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
|
571 |
async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
|
572 |
client_ip = req.client.host
|
@@ -588,7 +568,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
588 |
for msg in request.messages:
|
589 |
if msg.role == "user":
|
590 |
# Extract all images from the message content
|
591 |
-
images = extract_all_images_from_content(
|
592 |
for alt_text, image_data_uri in images:
|
593 |
# Analyze the image
|
594 |
analysis_result = await analyze_image(image_data_uri)
|
@@ -598,35 +578,94 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
598 |
assistant_content += "Based on the image you provided, here are the insights..."
|
599 |
|
600 |
# Calculate token usage (simple approximation)
|
601 |
-
prompt_tokens = sum(len(" ".join([item.text if item
|
602 |
completion_tokens = len(assistant_content.split())
|
603 |
total_tokens = prompt_tokens + completion_tokens
|
604 |
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
605 |
|
606 |
logger.info(f"Completed response generation for API key: {api_key} | IP: {client_ip}")
|
607 |
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
630 |
except ModelNotWorkingException as e:
|
631 |
logger.warning(f"Model not working: {e} | IP: {client_ip}")
|
632 |
raise HTTPException(status_code=503, detail=str(e))
|
|
|
532 |
"usage": None, # To be filled in non-streaming responses
|
533 |
}
|
534 |
|
535 |
+
def extract_all_images_from_content(content_items: List[ContentItem]) -> List[Tuple[str, str]]:
|
536 |
"""
|
537 |
+
Extracts all images from the content list.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
Returns a list of tuples containing (alt_text, image_data_uri).
|
539 |
"""
|
540 |
+
images = []
|
541 |
+
for item in content_items:
|
542 |
+
if isinstance(item, ImageContent):
|
543 |
+
alt_text = item.image_url.get('alt', '') # Optional alt text
|
544 |
+
image_data_uri = item.image_url.get('url', '')
|
545 |
+
if image_data_uri:
|
546 |
+
images.append((alt_text, image_data_uri))
|
547 |
+
return images
|
548 |
+
|
549 |
+
# Endpoint: POST /v1/chat/completions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
|
551 |
async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
|
552 |
client_ip = req.client.host
|
|
|
568 |
for msg in request.messages:
|
569 |
if msg.role == "user":
|
570 |
# Extract all images from the message content
|
571 |
+
images = extract_all_images_from_content(msg.content)
|
572 |
for alt_text, image_data_uri in images:
|
573 |
# Analyze the image
|
574 |
analysis_result = await analyze_image(image_data_uri)
|
|
|
578 |
assistant_content += "Based on the image you provided, here are the insights..."
|
579 |
|
580 |
# Calculate token usage (simple approximation)
|
581 |
+
prompt_tokens = sum(len(" ".join([item.text if isinstance(item, TextContent) else item.image_url['url'] for item in msg.content]).split()) for msg in request.messages)
|
582 |
completion_tokens = len(assistant_content.split())
|
583 |
total_tokens = prompt_tokens + completion_tokens
|
584 |
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
585 |
|
586 |
logger.info(f"Completed response generation for API key: {api_key} | IP: {client_ip}")
|
587 |
|
588 |
+
if request.stream:
|
589 |
+
async def generate():
|
590 |
+
try:
|
591 |
+
for msg in request.messages:
|
592 |
+
if msg.role == "user":
|
593 |
+
images = extract_all_images_from_content(msg.content)
|
594 |
+
for alt_text, image_data_uri in images:
|
595 |
+
analysis_result = await analyze_image(image_data_uri)
|
596 |
+
response_chunk = {
|
597 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
598 |
+
"object": "chat.completion.chunk",
|
599 |
+
"created": int(datetime.now().timestamp()),
|
600 |
+
"model": request.model,
|
601 |
+
"choices": [
|
602 |
+
{
|
603 |
+
"index": 0,
|
604 |
+
"delta": {"content": analysis_result + "\n", "role": "assistant"},
|
605 |
+
"finish_reason": None,
|
606 |
+
}
|
607 |
+
],
|
608 |
+
"usage": None,
|
609 |
+
}
|
610 |
+
yield f"data: {json.dumps(response_chunk)}\n\n"
|
611 |
+
|
612 |
+
# Final message
|
613 |
+
final_response = {
|
614 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
615 |
+
"object": "chat.completion",
|
616 |
+
"created": int(datetime.now().timestamp()),
|
617 |
+
"model": request.model,
|
618 |
+
"choices": [
|
619 |
+
{
|
620 |
+
"message": {
|
621 |
+
"role": "assistant",
|
622 |
+
"content": assistant_content.strip()
|
623 |
+
},
|
624 |
+
"finish_reason": "stop",
|
625 |
+
"index": 0
|
626 |
+
}
|
627 |
+
],
|
628 |
+
"usage": {
|
629 |
+
"prompt_tokens": prompt_tokens,
|
630 |
+
"completion_tokens": completion_tokens,
|
631 |
+
"total_tokens": total_tokens,
|
632 |
+
"estimated_cost": estimated_cost
|
633 |
+
},
|
634 |
+
}
|
635 |
+
yield f"data: {json.dumps(final_response)}\n\n"
|
636 |
+
yield "data: [DONE]\n\n"
|
637 |
+
except HTTPException as he:
|
638 |
+
error_response = {"error": he.detail}
|
639 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
640 |
+
except Exception as e:
|
641 |
+
logger.exception(f"Error during streaming response generation from IP: {client_ip}.")
|
642 |
+
error_response = {"error": str(e)}
|
643 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
644 |
+
|
645 |
+
return StreamingResponse(generate(), media_type="text/event-stream")
|
646 |
+
else:
|
647 |
+
return {
|
648 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
649 |
+
"object": "chat.completion",
|
650 |
+
"created": int(datetime.now().timestamp()),
|
651 |
+
"model": request.model,
|
652 |
+
"choices": [
|
653 |
+
{
|
654 |
+
"message": {
|
655 |
+
"role": "assistant",
|
656 |
+
"content": assistant_content.strip()
|
657 |
+
},
|
658 |
+
"finish_reason": "stop",
|
659 |
+
"index": 0
|
660 |
+
}
|
661 |
+
],
|
662 |
+
"usage": {
|
663 |
+
"prompt_tokens": prompt_tokens,
|
664 |
+
"completion_tokens": completion_tokens,
|
665 |
+
"total_tokens": total_tokens,
|
666 |
+
"estimated_cost": estimated_cost
|
667 |
+
},
|
668 |
+
}
|
669 |
except ModelNotWorkingException as e:
|
670 |
logger.warning(f"Model not working: {e} | IP: {client_ip}")
|
671 |
raise HTTPException(status_code=503, detail=str(e))
|