Update main.py
Browse files
main.py
CHANGED
@@ -10,8 +10,8 @@ from aiohttp import ClientSession, ClientTimeout, ClientError
|
|
10 |
from fastapi import FastAPI, HTTPException, Request, Depends, Header, status
|
11 |
from fastapi.responses import StreamingResponse, JSONResponse
|
12 |
from fastapi.middleware.cors import CORSMiddleware
|
13 |
-
from pydantic import BaseModel, Field
|
14 |
-
from typing import List, Dict, Any, Optional, AsyncGenerator
|
15 |
from datetime import datetime
|
16 |
from slowapi import Limiter, _rate_limit_exceeded_handler
|
17 |
from slowapi.util import get_remote_address
|
@@ -100,7 +100,7 @@ def to_data_uri(image: Any) -> str:
|
|
100 |
return "data:image/png;base64,..." # Replace with actual base64 data if needed
|
101 |
|
102 |
# Token Counting using tiktoken
|
103 |
-
def count_tokens(messages: List[Dict[str,
|
104 |
"""
|
105 |
Counts the number of tokens in the messages using tiktoken.
|
106 |
Adjust the encoding based on the model.
|
@@ -111,7 +111,14 @@ def count_tokens(messages: List[Dict[str, str]], model: str) -> int:
|
|
111 |
encoding = tiktoken.get_encoding("cl100k_base") # Default encoding
|
112 |
tokens = 0
|
113 |
for message in messages:
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
return tokens
|
116 |
|
117 |
# Blackbox Class: Handles interaction with the external AI service
|
@@ -235,7 +242,7 @@ class Blackbox:
|
|
235 |
async def create_async_generator(
|
236 |
cls,
|
237 |
model: str,
|
238 |
-
messages: List[Dict[str,
|
239 |
proxy: Optional[str] = None,
|
240 |
image: Any = None,
|
241 |
image_name: Optional[str] = None,
|
@@ -269,22 +276,33 @@ class Blackbox:
|
|
269 |
|
270 |
if model in cls.model_prefixes:
|
271 |
prefix = cls.model_prefixes[model]
|
272 |
-
if
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
messages[0]['content'] = f"{prefix} {messages[0]['content']}"
|
275 |
-
|
276 |
random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
|
277 |
-
|
278 |
-
messages
|
279 |
-
|
280 |
-
|
281 |
-
'
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
|
|
|
|
|
|
|
|
|
|
|
288 |
data = {
|
289 |
"messages": messages,
|
290 |
"id": random_id,
|
@@ -337,36 +355,13 @@ class Blackbox:
|
|
337 |
logger.error("Image URL not found in the response.")
|
338 |
raise Exception("Image URL not found in the response")
|
339 |
else:
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
if
|
345 |
-
decoded_chunk
|
346 |
-
|
347 |
-
if decoded_chunk.strip():
|
348 |
-
if '$~~~$' in decoded_chunk:
|
349 |
-
search_results_json += decoded_chunk
|
350 |
-
else:
|
351 |
-
full_response += decoded_chunk
|
352 |
-
yield decoded_chunk
|
353 |
-
logger.info("Finished streaming response chunks.")
|
354 |
-
except Exception as e:
|
355 |
-
logger.exception("Error while iterating over response chunks.")
|
356 |
-
raise e
|
357 |
-
if data["webSearchMode"] and search_results_json:
|
358 |
-
match = re.search(r'\$~~~\$(.*?)\$~~~\$', search_results_json, re.DOTALL)
|
359 |
-
if match:
|
360 |
-
try:
|
361 |
-
search_results = json.loads(match.group(1))
|
362 |
-
formatted_results = "\n\n**Sources:**\n"
|
363 |
-
for i, result in enumerate(search_results[:5], 1):
|
364 |
-
formatted_results += f"{i}. [{result['title']}]({result['link']})\n"
|
365 |
-
logger.info("Formatted search results.")
|
366 |
-
yield formatted_results
|
367 |
-
except json.JSONDecodeError as je:
|
368 |
-
logger.error("Failed to parse search results JSON.")
|
369 |
-
raise je
|
370 |
except ClientError as ce:
|
371 |
logger.error(f"Client error occurred: {ce}. Retrying attempt {attempt + 1}/{retry_attempts}")
|
372 |
if attempt == retry_attempts - 1:
|
@@ -381,9 +376,28 @@ class Blackbox:
|
|
381 |
raise HTTPException(status_code=500, detail=str(e))
|
382 |
|
383 |
# Pydantic Models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
class Message(BaseModel):
|
385 |
role: str = Field(..., description="The role of the message author.")
|
386 |
-
content: str = Field(..., description="The content of the message.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
|
388 |
class ChatRequest(BaseModel):
|
389 |
model: str = Field(..., description="ID of the model to use.")
|
@@ -431,12 +445,26 @@ async def chat_completions(
|
|
431 |
):
|
432 |
logger.info(f"Received chat completions request: {chat_request}")
|
433 |
try:
|
434 |
-
|
435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
|
437 |
async_generator = Blackbox.create_async_generator(
|
438 |
model=chat_request.model,
|
439 |
-
messages=
|
440 |
image=None, # Adjust if image handling is required
|
441 |
image_name=None,
|
442 |
webSearchMode=chat_request.webSearchMode
|
|
|
10 |
from fastapi import FastAPI, HTTPException, Request, Depends, Header, status
|
11 |
from fastapi.responses import StreamingResponse, JSONResponse
|
12 |
from fastapi.middleware.cors import CORSMiddleware
|
13 |
+
from pydantic import BaseModel, Field, validator
|
14 |
+
from typing import List, Dict, Any, Optional, Union, AsyncGenerator
|
15 |
from datetime import datetime
|
16 |
from slowapi import Limiter, _rate_limit_exceeded_handler
|
17 |
from slowapi.util import get_remote_address
|
|
|
100 |
return "data:image/png;base64,..." # Replace with actual base64 data if needed
|
101 |
|
102 |
# Token Counting using tiktoken
|
103 |
+
def count_tokens(messages: List[Dict[str, Any]], model: str) -> int:
|
104 |
"""
|
105 |
Counts the number of tokens in the messages using tiktoken.
|
106 |
Adjust the encoding based on the model.
|
|
|
111 |
encoding = tiktoken.get_encoding("cl100k_base") # Default encoding
|
112 |
tokens = 0
|
113 |
for message in messages:
|
114 |
+
if isinstance(message['content'], list):
|
115 |
+
for content_part in message['content']:
|
116 |
+
if content_part.get('type') == 'text':
|
117 |
+
tokens += len(encoding.encode(content_part['text']))
|
118 |
+
elif content_part.get('type') == 'image_url':
|
119 |
+
tokens += len(encoding.encode(content_part['image_url']['url']))
|
120 |
+
else:
|
121 |
+
tokens += len(encoding.encode(message['content']))
|
122 |
return tokens
|
123 |
|
124 |
# Blackbox Class: Handles interaction with the external AI service
|
|
|
242 |
async def create_async_generator(
|
243 |
cls,
|
244 |
model: str,
|
245 |
+
messages: List[Dict[str, Any]],
|
246 |
proxy: Optional[str] = None,
|
247 |
image: Any = None,
|
248 |
image_name: Optional[str] = None,
|
|
|
276 |
|
277 |
if model in cls.model_prefixes:
|
278 |
prefix = cls.model_prefixes[model]
|
279 |
+
if messages and isinstance(messages[0]['content'], list):
|
280 |
+
# Prepend prefix to the first text message
|
281 |
+
for content_part in messages[0]['content']:
|
282 |
+
if content_part.get('type') == 'text' and not content_part['text'].startswith(prefix):
|
283 |
+
logger.debug(f"Adding prefix '{prefix}' to the first text message.")
|
284 |
+
content_part['text'] = f"{prefix} {content_part['text']}"
|
285 |
+
break
|
286 |
+
elif messages and isinstance(messages[0]['content'], str) and not messages[0]['content'].startswith(prefix):
|
287 |
messages[0]['content'] = f"{prefix} {messages[0]['content']}"
|
288 |
+
|
289 |
random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
|
290 |
+
# Assuming the last message is from the user
|
291 |
+
if messages:
|
292 |
+
last_message = messages[-1]
|
293 |
+
if isinstance(last_message['content'], list):
|
294 |
+
for content_part in last_message['content']:
|
295 |
+
if content_part.get('type') == 'text':
|
296 |
+
content_part['role'] = 'user'
|
297 |
+
else:
|
298 |
+
last_message['id'] = random_id
|
299 |
+
last_message['role'] = 'user'
|
300 |
|
301 |
+
if image is not None:
|
302 |
+
# Process image if required
|
303 |
+
# This implementation assumes that image URLs are handled by the external service
|
304 |
+
pass # Implement as needed
|
305 |
+
|
306 |
data = {
|
307 |
"messages": messages,
|
308 |
"id": random_id,
|
|
|
355 |
logger.error("Image URL not found in the response.")
|
356 |
raise Exception("Image URL not found in the response")
|
357 |
else:
|
358 |
+
async for chunk in response.content.iter_chunks():
|
359 |
+
if chunk:
|
360 |
+
decoded_chunk = chunk.decode(errors='ignore')
|
361 |
+
decoded_chunk = re.sub(r'\$@\$v=[^$]+\$@\$', '', decoded_chunk)
|
362 |
+
if decoded_chunk.strip():
|
363 |
+
yield decoded_chunk
|
364 |
+
break # Exit the retry loop if successful
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
except ClientError as ce:
|
366 |
logger.error(f"Client error occurred: {ce}. Retrying attempt {attempt + 1}/{retry_attempts}")
|
367 |
if attempt == retry_attempts - 1:
|
|
|
376 |
raise HTTPException(status_code=500, detail=str(e))
|
377 |
|
378 |
# Pydantic Models
|
379 |
+
class TextContent(BaseModel):
|
380 |
+
type: str = Field(..., description="Type of content, e.g., 'text'.")
|
381 |
+
text: str = Field(..., description="The text content.")
|
382 |
+
|
383 |
+
class ImageURLContent(BaseModel):
|
384 |
+
type: str = Field(..., description="Type of content, e.g., 'image_url'.")
|
385 |
+
image_url: Dict[str, str] = Field(..., description="Dictionary containing the image URL.")
|
386 |
+
|
387 |
+
Content = Union[TextContent, ImageURLContent]
|
388 |
+
|
389 |
class Message(BaseModel):
|
390 |
role: str = Field(..., description="The role of the message author.")
|
391 |
+
content: Union[str, List[Content]] = Field(..., description="The content of the message. Can be a string or a list of content parts.")
|
392 |
+
|
393 |
+
@validator('content', pre=True)
|
394 |
+
def validate_content(cls, v):
|
395 |
+
if isinstance(v, list):
|
396 |
+
return [Content(**item) for item in v]
|
397 |
+
elif isinstance(v, str):
|
398 |
+
return v
|
399 |
+
else:
|
400 |
+
raise ValueError("Content must be either a string or a list of content parts.")
|
401 |
|
402 |
class ChatRequest(BaseModel):
|
403 |
model: str = Field(..., description="ID of the model to use.")
|
|
|
445 |
):
|
446 |
logger.info(f"Received chat completions request: {chat_request}")
|
447 |
try:
|
448 |
+
# Process messages for token counting and sending to Blackbox
|
449 |
+
processed_messages = []
|
450 |
+
for msg in chat_request.messages:
|
451 |
+
if isinstance(msg.content, list):
|
452 |
+
# Convert list of content parts to a structured format
|
453 |
+
combined_content = []
|
454 |
+
for part in msg.content:
|
455 |
+
if isinstance(part, TextContent):
|
456 |
+
combined_content.append({"type": part.type, "text": part.text})
|
457 |
+
elif isinstance(part, ImageURLContent):
|
458 |
+
combined_content.append({"type": part.type, "image_url": part.image_url})
|
459 |
+
processed_messages.append({"role": msg.role, "content": combined_content})
|
460 |
+
else:
|
461 |
+
processed_messages.append({"role": msg.role, "content": msg.content})
|
462 |
+
|
463 |
+
prompt_tokens = count_tokens(processed_messages, chat_request.model)
|
464 |
|
465 |
async_generator = Blackbox.create_async_generator(
|
466 |
model=chat_request.model,
|
467 |
+
messages=processed_messages,
|
468 |
image=None, # Adjust if image handling is required
|
469 |
image_name=None,
|
470 |
webSearchMode=chat_request.webSearchMode
|