Niansuh commited on
Commit
121ef6b
·
verified ·
1 Parent(s): 90a29cf

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +95 -56
main.py CHANGED
@@ -532,41 +532,21 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
532
  "usage": None, # To be filled in non-streaming responses
533
  }
534
 
535
- def extract_image_from_content(content: str) -> Optional[Tuple[str, str]]:
536
  """
537
- Extracts the first image from the content string.
538
- Returns a tuple of (alt_text, image_data_uri) if found, else None.
539
- """
540
- # Regex to match markdown image syntax: ![Alt Text](image_url)
541
- match = re.search(r'!\[([^\]]*)\]\((data:image/\w+;base64,[^\)]+)\)', content)
542
- if match:
543
- alt_text = match.group(1)
544
- image_data_uri = match.group(2)
545
- return alt_text, image_data_uri
546
- return None
547
-
548
- def extract_all_images_from_content(content: str) -> List[Tuple[str, str]]:
549
- """
550
- Extracts all images from the content string.
551
  Returns a list of tuples containing (alt_text, image_data_uri).
552
  """
553
- # Regex to match markdown image syntax: ![Alt Text](image_url)
554
- matches = re.findall(r'!\[([^\]]*)\]\((data:image/\w+;base64,[^\)]+)\)', content)
555
- return matches if matches else []
556
-
557
- async def analyze_image(image_data_uri: str) -> str:
558
- """
559
- Placeholder function to analyze the image.
560
- Replace this with actual image analysis logic or API calls.
561
- """
562
- # Extract base64 data
563
- image_data = image_data_uri.split(",")[1]
564
- # Decode and process the image as needed
565
- # For example, send it to an external API
566
- # Here, we'll return a dummy response
567
- await asyncio.sleep(1) # Simulate processing delay
568
- return "Image analysis result: The image depicts a beautiful sunset over the mountains."
569
-
570
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
571
  async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
572
  client_ip = req.client.host
@@ -588,7 +568,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
588
  for msg in request.messages:
589
  if msg.role == "user":
590
  # Extract all images from the message content
591
- images = extract_all_images_from_content(" ".join([item.text if item.type == "text" else item.image_url['url'] for item in msg.content]))
592
  for alt_text, image_data_uri in images:
593
  # Analyze the image
594
  analysis_result = await analyze_image(image_data_uri)
@@ -598,35 +578,94 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
598
  assistant_content += "Based on the image you provided, here are the insights..."
599
 
600
  # Calculate token usage (simple approximation)
601
- prompt_tokens = sum(len(" ".join([item.text if item.type == "text" else item.image_url['url'] for item in msg.content]).split()) for msg in request.messages)
602
  completion_tokens = len(assistant_content.split())
603
  total_tokens = prompt_tokens + completion_tokens
604
  estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
605
 
606
  logger.info(f"Completed response generation for API key: {api_key} | IP: {client_ip}")
607
 
608
- return {
609
- "id": f"chatcmpl-{uuid.uuid4()}",
610
- "object": "chat.completion",
611
- "created": int(datetime.now().timestamp()),
612
- "model": request.model,
613
- "choices": [
614
- {
615
- "message": {
616
- "role": "assistant",
617
- "content": assistant_content.strip()
618
- },
619
- "finish_reason": "stop",
620
- "index": 0
621
- }
622
- ],
623
- "usage": {
624
- "prompt_tokens": prompt_tokens,
625
- "completion_tokens": completion_tokens,
626
- "total_tokens": total_tokens,
627
- "estimated_cost": estimated_cost
628
- },
629
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  except ModelNotWorkingException as e:
631
  logger.warning(f"Model not working: {e} | IP: {client_ip}")
632
  raise HTTPException(status_code=503, detail=str(e))
 
532
  "usage": None, # To be filled in non-streaming responses
533
  }
534
 
535
+ def extract_all_images_from_content(content_items: List[ContentItem]) -> List[Tuple[str, str]]:
536
  """
537
+ Extracts all images from the content list.
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  Returns a list of tuples containing (alt_text, image_data_uri).
539
  """
540
+ images = []
541
+ for item in content_items:
542
+ if isinstance(item, ImageContent):
543
+ alt_text = item.image_url.get('alt', '') # Optional alt text
544
+ image_data_uri = item.image_url.get('url', '')
545
+ if image_data_uri:
546
+ images.append((alt_text, image_data_uri))
547
+ return images
548
+
549
+ # Endpoint: POST /v1/chat/completions
 
 
 
 
 
 
 
550
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
551
  async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
552
  client_ip = req.client.host
 
568
  for msg in request.messages:
569
  if msg.role == "user":
570
  # Extract all images from the message content
571
+ images = extract_all_images_from_content(msg.content)
572
  for alt_text, image_data_uri in images:
573
  # Analyze the image
574
  analysis_result = await analyze_image(image_data_uri)
 
578
  assistant_content += "Based on the image you provided, here are the insights..."
579
 
580
  # Calculate token usage (simple approximation)
581
+ prompt_tokens = sum(len(" ".join([item.text if isinstance(item, TextContent) else item.image_url['url'] for item in msg.content]).split()) for msg in request.messages)
582
  completion_tokens = len(assistant_content.split())
583
  total_tokens = prompt_tokens + completion_tokens
584
  estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
585
 
586
  logger.info(f"Completed response generation for API key: {api_key} | IP: {client_ip}")
587
 
588
+ if request.stream:
589
+ async def generate():
590
+ try:
591
+ for msg in request.messages:
592
+ if msg.role == "user":
593
+ images = extract_all_images_from_content(msg.content)
594
+ for alt_text, image_data_uri in images:
595
+ analysis_result = await analyze_image(image_data_uri)
596
+ response_chunk = {
597
+ "id": f"chatcmpl-{uuid.uuid4()}",
598
+ "object": "chat.completion.chunk",
599
+ "created": int(datetime.now().timestamp()),
600
+ "model": request.model,
601
+ "choices": [
602
+ {
603
+ "index": 0,
604
+ "delta": {"content": analysis_result + "\n", "role": "assistant"},
605
+ "finish_reason": None,
606
+ }
607
+ ],
608
+ "usage": None,
609
+ }
610
+ yield f"data: {json.dumps(response_chunk)}\n\n"
611
+
612
+ # Final message
613
+ final_response = {
614
+ "id": f"chatcmpl-{uuid.uuid4()}",
615
+ "object": "chat.completion",
616
+ "created": int(datetime.now().timestamp()),
617
+ "model": request.model,
618
+ "choices": [
619
+ {
620
+ "message": {
621
+ "role": "assistant",
622
+ "content": assistant_content.strip()
623
+ },
624
+ "finish_reason": "stop",
625
+ "index": 0
626
+ }
627
+ ],
628
+ "usage": {
629
+ "prompt_tokens": prompt_tokens,
630
+ "completion_tokens": completion_tokens,
631
+ "total_tokens": total_tokens,
632
+ "estimated_cost": estimated_cost
633
+ },
634
+ }
635
+ yield f"data: {json.dumps(final_response)}\n\n"
636
+ yield "data: [DONE]\n\n"
637
+ except HTTPException as he:
638
+ error_response = {"error": he.detail}
639
+ yield f"data: {json.dumps(error_response)}\n\n"
640
+ except Exception as e:
641
+ logger.exception(f"Error during streaming response generation from IP: {client_ip}.")
642
+ error_response = {"error": str(e)}
643
+ yield f"data: {json.dumps(error_response)}\n\n"
644
+
645
+ return StreamingResponse(generate(), media_type="text/event-stream")
646
+ else:
647
+ return {
648
+ "id": f"chatcmpl-{uuid.uuid4()}",
649
+ "object": "chat.completion",
650
+ "created": int(datetime.now().timestamp()),
651
+ "model": request.model,
652
+ "choices": [
653
+ {
654
+ "message": {
655
+ "role": "assistant",
656
+ "content": assistant_content.strip()
657
+ },
658
+ "finish_reason": "stop",
659
+ "index": 0
660
+ }
661
+ ],
662
+ "usage": {
663
+ "prompt_tokens": prompt_tokens,
664
+ "completion_tokens": completion_tokens,
665
+ "total_tokens": total_tokens,
666
+ "estimated_cost": estimated_cost
667
+ },
668
+ }
669
  except ModelNotWorkingException as e:
670
  logger.warning(f"Model not working: {e} | IP: {client_ip}")
671
  raise HTTPException(status_code=503, detail=str(e))