Niansuh commited on
Commit
6b5328d
·
verified ·
1 Parent(s): 2edde86

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +80 -52
main.py CHANGED
@@ -10,8 +10,8 @@ from aiohttp import ClientSession, ClientTimeout, ClientError
10
  from fastapi import FastAPI, HTTPException, Request, Depends, Header, status
11
  from fastapi.responses import StreamingResponse, JSONResponse
12
  from fastapi.middleware.cors import CORSMiddleware
13
- from pydantic import BaseModel, Field
14
- from typing import List, Dict, Any, Optional, AsyncGenerator
15
  from datetime import datetime
16
  from slowapi import Limiter, _rate_limit_exceeded_handler
17
  from slowapi.util import get_remote_address
@@ -100,7 +100,7 @@ def to_data_uri(image: Any) -> str:
100
  return "data:image/png;base64,..." # Replace with actual base64 data if needed
101
 
102
  # Token Counting using tiktoken
103
- def count_tokens(messages: List[Dict[str, str]], model: str) -> int:
104
  """
105
  Counts the number of tokens in the messages using tiktoken.
106
  Adjust the encoding based on the model.
@@ -111,7 +111,14 @@ def count_tokens(messages: List[Dict[str, str]], model: str) -> int:
111
  encoding = tiktoken.get_encoding("cl100k_base") # Default encoding
112
  tokens = 0
113
  for message in messages:
114
- tokens += len(encoding.encode(message['content']))
 
 
 
 
 
 
 
115
  return tokens
116
 
117
  # Blackbox Class: Handles interaction with the external AI service
@@ -235,7 +242,7 @@ class Blackbox:
235
  async def create_async_generator(
236
  cls,
237
  model: str,
238
- messages: List[Dict[str, str]],
239
  proxy: Optional[str] = None,
240
  image: Any = None,
241
  image_name: Optional[str] = None,
@@ -269,22 +276,33 @@ class Blackbox:
269
 
270
  if model in cls.model_prefixes:
271
  prefix = cls.model_prefixes[model]
272
- if not messages[0]['content'].startswith(prefix):
273
- logger.debug(f"Adding prefix '{prefix}' to the first message.")
 
 
 
 
 
 
274
  messages[0]['content'] = f"{prefix} {messages[0]['content']}"
275
-
276
  random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
277
- messages[-1]['id'] = random_id
278
- messages[-1]['role'] = 'user'
279
- if image is not None:
280
- messages[-1]['data'] = {
281
- 'fileText': '',
282
- 'imageBase64': to_data_uri(image),
283
- 'title': image_name
284
- }
285
- messages[-1]['content'] = 'FILE:BB\n$#$\n\n$#$\n' + messages[-1]['content']
286
- logger.debug("Image data added to the message.")
287
 
 
 
 
 
 
288
  data = {
289
  "messages": messages,
290
  "id": random_id,
@@ -337,36 +355,13 @@ class Blackbox:
337
  logger.error("Image URL not found in the response.")
338
  raise Exception("Image URL not found in the response")
339
  else:
340
- full_response = ""
341
- search_results_json = ""
342
- try:
343
- async for chunk, _ in response.content.iter_chunks():
344
- if chunk:
345
- decoded_chunk = chunk.decode(errors='ignore')
346
- decoded_chunk = re.sub(r'\$@\$v=[^$]+\$@\$', '', decoded_chunk)
347
- if decoded_chunk.strip():
348
- if '$~~~$' in decoded_chunk:
349
- search_results_json += decoded_chunk
350
- else:
351
- full_response += decoded_chunk
352
- yield decoded_chunk
353
- logger.info("Finished streaming response chunks.")
354
- except Exception as e:
355
- logger.exception("Error while iterating over response chunks.")
356
- raise e
357
- if data["webSearchMode"] and search_results_json:
358
- match = re.search(r'\$~~~\$(.*?)\$~~~\$', search_results_json, re.DOTALL)
359
- if match:
360
- try:
361
- search_results = json.loads(match.group(1))
362
- formatted_results = "\n\n**Sources:**\n"
363
- for i, result in enumerate(search_results[:5], 1):
364
- formatted_results += f"{i}. [{result['title']}]({result['link']})\n"
365
- logger.info("Formatted search results.")
366
- yield formatted_results
367
- except json.JSONDecodeError as je:
368
- logger.error("Failed to parse search results JSON.")
369
- raise je
370
  except ClientError as ce:
371
  logger.error(f"Client error occurred: {ce}. Retrying attempt {attempt + 1}/{retry_attempts}")
372
  if attempt == retry_attempts - 1:
@@ -381,9 +376,28 @@ class Blackbox:
381
  raise HTTPException(status_code=500, detail=str(e))
382
 
383
  # Pydantic Models
 
 
 
 
 
 
 
 
 
 
384
  class Message(BaseModel):
385
  role: str = Field(..., description="The role of the message author.")
386
- content: str = Field(..., description="The content of the message.")
 
 
 
 
 
 
 
 
 
387
 
388
  class ChatRequest(BaseModel):
389
  model: str = Field(..., description="ID of the model to use.")
@@ -431,12 +445,26 @@ async def chat_completions(
431
  ):
432
  logger.info(f"Received chat completions request: {chat_request}")
433
  try:
434
- messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]
435
- prompt_tokens = count_tokens(messages, chat_request.model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
  async_generator = Blackbox.create_async_generator(
438
  model=chat_request.model,
439
- messages=messages,
440
  image=None, # Adjust if image handling is required
441
  image_name=None,
442
  webSearchMode=chat_request.webSearchMode
 
10
  from fastapi import FastAPI, HTTPException, Request, Depends, Header, status
11
  from fastapi.responses import StreamingResponse, JSONResponse
12
  from fastapi.middleware.cors import CORSMiddleware
13
+ from pydantic import BaseModel, Field, validator
14
+ from typing import List, Dict, Any, Optional, Union, AsyncGenerator
15
  from datetime import datetime
16
  from slowapi import Limiter, _rate_limit_exceeded_handler
17
  from slowapi.util import get_remote_address
 
100
  return "data:image/png;base64,..." # Replace with actual base64 data if needed
101
 
102
  # Token Counting using tiktoken
103
+ def count_tokens(messages: List[Dict[str, Any]], model: str) -> int:
104
  """
105
  Counts the number of tokens in the messages using tiktoken.
106
  Adjust the encoding based on the model.
 
111
  encoding = tiktoken.get_encoding("cl100k_base") # Default encoding
112
  tokens = 0
113
  for message in messages:
114
+ if isinstance(message['content'], list):
115
+ for content_part in message['content']:
116
+ if content_part.get('type') == 'text':
117
+ tokens += len(encoding.encode(content_part['text']))
118
+ elif content_part.get('type') == 'image_url':
119
+ tokens += len(encoding.encode(content_part['image_url']['url']))
120
+ else:
121
+ tokens += len(encoding.encode(message['content']))
122
  return tokens
123
 
124
  # Blackbox Class: Handles interaction with the external AI service
 
242
  async def create_async_generator(
243
  cls,
244
  model: str,
245
+ messages: List[Dict[str, Any]],
246
  proxy: Optional[str] = None,
247
  image: Any = None,
248
  image_name: Optional[str] = None,
 
276
 
277
  if model in cls.model_prefixes:
278
  prefix = cls.model_prefixes[model]
279
+ if messages and isinstance(messages[0]['content'], list):
280
+ # Prepend prefix to the first text message
281
+ for content_part in messages[0]['content']:
282
+ if content_part.get('type') == 'text' and not content_part['text'].startswith(prefix):
283
+ logger.debug(f"Adding prefix '{prefix}' to the first text message.")
284
+ content_part['text'] = f"{prefix} {content_part['text']}"
285
+ break
286
+ elif messages and isinstance(messages[0]['content'], str) and not messages[0]['content'].startswith(prefix):
287
  messages[0]['content'] = f"{prefix} {messages[0]['content']}"
288
+
289
  random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
290
+ # Assuming the last message is from the user
291
+ if messages:
292
+ last_message = messages[-1]
293
+ if isinstance(last_message['content'], list):
294
+ for content_part in last_message['content']:
295
+ if content_part.get('type') == 'text':
296
+ content_part['role'] = 'user'
297
+ else:
298
+ last_message['id'] = random_id
299
+ last_message['role'] = 'user'
300
 
301
+ if image is not None:
302
+ # Process image if required
303
+ # This implementation assumes that image URLs are handled by the external service
304
+ pass # Implement as needed
305
+
306
  data = {
307
  "messages": messages,
308
  "id": random_id,
 
355
  logger.error("Image URL not found in the response.")
356
  raise Exception("Image URL not found in the response")
357
  else:
358
+ async for chunk in response.content.iter_chunks():
359
+ if chunk:
360
+ decoded_chunk = chunk.decode(errors='ignore')
361
+ decoded_chunk = re.sub(r'\$@\$v=[^$]+\$@\$', '', decoded_chunk)
362
+ if decoded_chunk.strip():
363
+ yield decoded_chunk
364
+ break # Exit the retry loop if successful
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  except ClientError as ce:
366
  logger.error(f"Client error occurred: {ce}. Retrying attempt {attempt + 1}/{retry_attempts}")
367
  if attempt == retry_attempts - 1:
 
376
  raise HTTPException(status_code=500, detail=str(e))
377
 
378
  # Pydantic Models
379
+ class TextContent(BaseModel):
380
+ type: str = Field(..., description="Type of content, e.g., 'text'.")
381
+ text: str = Field(..., description="The text content.")
382
+
383
+ class ImageURLContent(BaseModel):
384
+ type: str = Field(..., description="Type of content, e.g., 'image_url'.")
385
+ image_url: Dict[str, str] = Field(..., description="Dictionary containing the image URL.")
386
+
387
+ Content = Union[TextContent, ImageURLContent]
388
+
389
  class Message(BaseModel):
390
  role: str = Field(..., description="The role of the message author.")
391
+ content: Union[str, List[Content]] = Field(..., description="The content of the message. Can be a string or a list of content parts.")
392
+
393
+ @validator('content', pre=True)
394
+ def validate_content(cls, v):
395
+ if isinstance(v, list):
396
+ return [Content(**item) for item in v]
397
+ elif isinstance(v, str):
398
+ return v
399
+ else:
400
+ raise ValueError("Content must be either a string or a list of content parts.")
401
 
402
  class ChatRequest(BaseModel):
403
  model: str = Field(..., description="ID of the model to use.")
 
445
  ):
446
  logger.info(f"Received chat completions request: {chat_request}")
447
  try:
448
+ # Process messages for token counting and sending to Blackbox
449
+ processed_messages = []
450
+ for msg in chat_request.messages:
451
+ if isinstance(msg.content, list):
452
+ # Convert list of content parts to a structured format
453
+ combined_content = []
454
+ for part in msg.content:
455
+ if isinstance(part, TextContent):
456
+ combined_content.append({"type": part.type, "text": part.text})
457
+ elif isinstance(part, ImageURLContent):
458
+ combined_content.append({"type": part.type, "image_url": part.image_url})
459
+ processed_messages.append({"role": msg.role, "content": combined_content})
460
+ else:
461
+ processed_messages.append({"role": msg.role, "content": msg.content})
462
+
463
+ prompt_tokens = count_tokens(processed_messages, chat_request.model)
464
 
465
  async_generator = Blackbox.create_async_generator(
466
  model=chat_request.model,
467
+ messages=processed_messages,
468
  image=None, # Adjust if image handling is required
469
  image_name=None,
470
  webSearchMode=chat_request.webSearchMode