Niansuh commited on
Commit
05f6d1c
·
verified ·
1 Parent(s): 5b74647

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +23 -12
main.py CHANGED
@@ -11,7 +11,7 @@ from fastapi import FastAPI, HTTPException, Request, Depends, Header, status
11
  from fastapi.responses import StreamingResponse, JSONResponse
12
  from fastapi.middleware.cors import CORSMiddleware
13
  from pydantic import BaseModel, Field, validator
14
- from typing import List, Dict, Any, Optional, Union, AsyncGenerator
15
  from datetime import datetime
16
  from slowapi import Limiter, _rate_limit_exceeded_handler
17
  from slowapi.util import get_remote_address
@@ -113,10 +113,11 @@ def count_tokens(messages: List[Dict[str, Any]], model: str) -> int:
113
  for message in messages:
114
  if isinstance(message['content'], list):
115
  for content_part in message['content']:
116
- if content_part.get('type') == 'text':
117
- tokens += len(encoding.encode(content_part['text']))
118
- elif content_part.get('type') == 'image_url':
119
- tokens += len(encoding.encode(content_part['image_url']['url']))
 
120
  else:
121
  tokens += len(encoding.encode(message['content']))
122
  return tokens
@@ -377,11 +378,11 @@ class Blackbox:
377
 
378
  # Pydantic Models
379
  class TextContent(BaseModel):
380
- type: str = Field(..., description="Type of content, e.g., 'text'.")
381
  text: str = Field(..., description="The text content.")
382
 
383
  class ImageURLContent(BaseModel):
384
- type: str = Field(..., description="Type of content, e.g., 'image_url'.")
385
  image_url: Dict[str, str] = Field(..., description="Dictionary containing the image URL.")
386
 
387
  Content = Union[TextContent, ImageURLContent]
@@ -393,7 +394,17 @@ class Message(BaseModel):
393
  @validator('content', pre=True)
394
  def validate_content(cls, v):
395
  if isinstance(v, list):
396
- return [Content(**item) for item in v]
 
 
 
 
 
 
 
 
 
 
397
  elif isinstance(v, str):
398
  return v
399
  else:
@@ -452,10 +463,10 @@ async def chat_completions(
452
  # Convert list of content parts to a structured format
453
  combined_content = []
454
  for part in msg.content:
455
- if isinstance(part, dict) and part.get('type') == 'text':
456
- combined_content.append({"type": part['type'], "text": part['text']})
457
- elif isinstance(part, dict) and part.get('type') == 'image_url':
458
- combined_content.append({"type": part['type'], "image_url": part['image_url']})
459
  processed_messages.append({"role": msg.role, "content": combined_content})
460
  else:
461
  processed_messages.append({"role": msg.role, "content": msg.content})
 
11
  from fastapi.responses import StreamingResponse, JSONResponse
12
  from fastapi.middleware.cors import CORSMiddleware
13
  from pydantic import BaseModel, Field, validator
14
+ from typing import List, Dict, Any, Optional, Union, AsyncGenerator, Literal
15
  from datetime import datetime
16
  from slowapi import Limiter, _rate_limit_exceeded_handler
17
  from slowapi.util import get_remote_address
 
113
  for message in messages:
114
  if isinstance(message['content'], list):
115
  for content_part in message['content']:
116
+ if isinstance(content_part, dict):
117
+ if content_part.get('type') == 'text':
118
+ tokens += len(encoding.encode(content_part['text']))
119
+ elif content_part.get('type') == 'image_url':
120
+ tokens += len(encoding.encode(content_part['image_url']['url']))
121
  else:
122
  tokens += len(encoding.encode(message['content']))
123
  return tokens
 
378
 
379
  # Pydantic Models
380
  class TextContent(BaseModel):
381
+ type: Literal["text"] = Field(..., description="Type of content, e.g., 'text'.")
382
  text: str = Field(..., description="The text content.")
383
 
384
  class ImageURLContent(BaseModel):
385
+ type: Literal["image_url"] = Field(..., description="Type of content, e.g., 'image_url'.")
386
  image_url: Dict[str, str] = Field(..., description="Dictionary containing the image URL.")
387
 
388
  Content = Union[TextContent, ImageURLContent]
 
394
  @validator('content', pre=True)
395
  def validate_content(cls, v):
396
  if isinstance(v, list):
397
+ processed_content = []
398
+ for item in v:
399
+ if 'type' not in item:
400
+ raise ValueError("Each content part must have a 'type' field.")
401
+ if item['type'] == 'text':
402
+ processed_content.append(TextContent(**item))
403
+ elif item['type'] == 'image_url':
404
+ processed_content.append(ImageURLContent(**item))
405
+ else:
406
+ raise ValueError(f"Unsupported content type: {item['type']}")
407
+ return processed_content
408
  elif isinstance(v, str):
409
  return v
410
  else:
 
463
  # Convert list of content parts to a structured format
464
  combined_content = []
465
  for part in msg.content:
466
+ if isinstance(part, TextContent):
467
+ combined_content.append({"type": part.type, "text": part.text})
468
+ elif isinstance(part, ImageURLContent):
469
+ combined_content.append({"type": part.type, "image_url": part.image_url})
470
  processed_messages.append({"role": msg.role, "content": combined_content})
471
  else:
472
  processed_messages.append({"role": msg.role, "content": msg.content})