Update main.py
Browse files
main.py
CHANGED
@@ -11,7 +11,7 @@ from fastapi import FastAPI, HTTPException, Request, Depends, Header, status
|
|
11 |
from fastapi.responses import StreamingResponse, JSONResponse
|
12 |
from fastapi.middleware.cors import CORSMiddleware
|
13 |
from pydantic import BaseModel, Field, validator
|
14 |
-
from typing import List, Dict, Any, Optional, Union, AsyncGenerator
|
15 |
from datetime import datetime
|
16 |
from slowapi import Limiter, _rate_limit_exceeded_handler
|
17 |
from slowapi.util import get_remote_address
|
@@ -113,10 +113,11 @@ def count_tokens(messages: List[Dict[str, Any]], model: str) -> int:
|
|
113 |
for message in messages:
|
114 |
if isinstance(message['content'], list):
|
115 |
for content_part in message['content']:
|
116 |
-
if content_part
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
120 |
else:
|
121 |
tokens += len(encoding.encode(message['content']))
|
122 |
return tokens
|
@@ -377,11 +378,11 @@ class Blackbox:
|
|
377 |
|
378 |
# Pydantic Models
|
379 |
class TextContent(BaseModel):
|
380 |
-
type:
|
381 |
text: str = Field(..., description="The text content.")
|
382 |
|
383 |
class ImageURLContent(BaseModel):
|
384 |
-
type:
|
385 |
image_url: Dict[str, str] = Field(..., description="Dictionary containing the image URL.")
|
386 |
|
387 |
Content = Union[TextContent, ImageURLContent]
|
@@ -393,7 +394,17 @@ class Message(BaseModel):
|
|
393 |
@validator('content', pre=True)
|
394 |
def validate_content(cls, v):
|
395 |
if isinstance(v, list):
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
elif isinstance(v, str):
|
398 |
return v
|
399 |
else:
|
@@ -452,10 +463,10 @@ async def chat_completions(
|
|
452 |
# Convert list of content parts to a structured format
|
453 |
combined_content = []
|
454 |
for part in msg.content:
|
455 |
-
if isinstance(part,
|
456 |
-
combined_content.append({"type": part
|
457 |
-
elif isinstance(part,
|
458 |
-
combined_content.append({"type": part
|
459 |
processed_messages.append({"role": msg.role, "content": combined_content})
|
460 |
else:
|
461 |
processed_messages.append({"role": msg.role, "content": msg.content})
|
|
|
11 |
from fastapi.responses import StreamingResponse, JSONResponse
|
12 |
from fastapi.middleware.cors import CORSMiddleware
|
13 |
from pydantic import BaseModel, Field, validator
|
14 |
+
from typing import List, Dict, Any, Optional, Union, AsyncGenerator, Literal
|
15 |
from datetime import datetime
|
16 |
from slowapi import Limiter, _rate_limit_exceeded_handler
|
17 |
from slowapi.util import get_remote_address
|
|
|
113 |
for message in messages:
|
114 |
if isinstance(message['content'], list):
|
115 |
for content_part in message['content']:
|
116 |
+
if isinstance(content_part, dict):
|
117 |
+
if content_part.get('type') == 'text':
|
118 |
+
tokens += len(encoding.encode(content_part['text']))
|
119 |
+
elif content_part.get('type') == 'image_url':
|
120 |
+
tokens += len(encoding.encode(content_part['image_url']['url']))
|
121 |
else:
|
122 |
tokens += len(encoding.encode(message['content']))
|
123 |
return tokens
|
|
|
378 |
|
379 |
# Pydantic Models
|
380 |
class TextContent(BaseModel):
|
381 |
+
type: Literal["text"] = Field(..., description="Type of content, e.g., 'text'.")
|
382 |
text: str = Field(..., description="The text content.")
|
383 |
|
384 |
class ImageURLContent(BaseModel):
|
385 |
+
type: Literal["image_url"] = Field(..., description="Type of content, e.g., 'image_url'.")
|
386 |
image_url: Dict[str, str] = Field(..., description="Dictionary containing the image URL.")
|
387 |
|
388 |
Content = Union[TextContent, ImageURLContent]
|
|
|
394 |
@validator('content', pre=True)
|
395 |
def validate_content(cls, v):
|
396 |
if isinstance(v, list):
|
397 |
+
processed_content = []
|
398 |
+
for item in v:
|
399 |
+
if 'type' not in item:
|
400 |
+
raise ValueError("Each content part must have a 'type' field.")
|
401 |
+
if item['type'] == 'text':
|
402 |
+
processed_content.append(TextContent(**item))
|
403 |
+
elif item['type'] == 'image_url':
|
404 |
+
processed_content.append(ImageURLContent(**item))
|
405 |
+
else:
|
406 |
+
raise ValueError(f"Unsupported content type: {item['type']}")
|
407 |
+
return processed_content
|
408 |
elif isinstance(v, str):
|
409 |
return v
|
410 |
else:
|
|
|
463 |
# Convert list of content parts to a structured format
|
464 |
combined_content = []
|
465 |
for part in msg.content:
|
466 |
+
if isinstance(part, TextContent):
|
467 |
+
combined_content.append({"type": part.type, "text": part.text})
|
468 |
+
elif isinstance(part, ImageURLContent):
|
469 |
+
combined_content.append({"type": part.type, "image_url": part.image_url})
|
470 |
processed_messages.append({"role": msg.role, "content": combined_content})
|
471 |
else:
|
472 |
processed_messages.append({"role": msg.role, "content": msg.content})
|