Niansuh commited on
Commit
b2e7248
·
verified ·
1 Parent(s): 09b8364

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +133 -215
main.py CHANGED
@@ -1,5 +1,3 @@
1
- # main.py
2
-
3
  import os
4
  import re
5
  import random
@@ -9,22 +7,18 @@ import json
9
  import logging
10
  import asyncio
11
  import time
 
 
12
  from collections import defaultdict
13
- from typing import List, Dict, Any, Optional, Union, Tuple, AsyncGenerator
14
 
15
  from datetime import datetime
16
 
17
  from aiohttp import ClientSession, ClientTimeout, ClientError
18
  from fastapi import FastAPI, HTTPException, Request, Depends, Header
19
  from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
20
- from pydantic import BaseModel, validator
21
- from io import BytesIO
22
- import base64
23
-
24
- from dotenv import load_dotenv
25
-
26
- # Load environment variables from .env file
27
- load_dotenv()
28
 
29
  # Configure logging
30
  logging.basicConfig(
@@ -105,50 +99,37 @@ class ModelNotWorkingException(Exception):
105
  self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed."
106
  super().__init__(self.message)
107
 
108
- # Image Handling Functions
109
- ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}
 
 
 
110
 
111
- def is_allowed_extension(filename: str) -> bool:
112
- """
113
- Checks if the given filename has an allowed extension.
114
- """
115
- return '.' in filename and \
116
- filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
117
 
118
- def is_data_uri_an_image(data_uri: str) -> bool:
119
- """
120
- Checks if the given data URI represents an image.
121
- """
122
- match = re.match(r'data:image/(\w+);base64,', data_uri)
123
- if not match:
124
- raise ValueError("Invalid data URI image.")
125
- image_format = match.group(1).lower()
126
- if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml":
127
- raise ValueError("Invalid image format (from MIME type).")
128
- return True
129
-
130
- def extract_data_uri(data_uri: str) -> bytes:
131
- """
132
- Extracts the binary data from the given data URI.
133
- """
134
- return base64.b64decode(data_uri.split(",")[1])
135
 
136
- def to_data_uri(image: str) -> str:
137
  """
138
- Validates and returns the data URI for an image.
 
139
  """
140
- is_data_uri_an_image(image)
141
- return image
142
-
143
- class ImageResponseCustom:
144
- def __init__(self, url: str, alt: str):
145
- self.url = url
146
- self.alt = alt
147
 
148
- # Placeholder for Blackbox AI Integration
149
  class Blackbox:
150
  url = "https://www.blackbox.ai"
151
- api_endpoint = "https://www.blackbox.ai/api/chat" # Placeholder endpoint
152
  working = True
153
  supports_stream = True
154
  supports_system_message = True
@@ -159,7 +140,6 @@ class Blackbox:
159
  models = [
160
  default_model,
161
  'blackboxai-pro',
162
- *image_models,
163
  "llama-3.1-8b",
164
  'llama-3.1-70b',
165
  'llama-3.1-405b',
@@ -180,13 +160,18 @@ class Blackbox:
180
  'ReactAgent',
181
  'XcodeAgent',
182
  'AngularJSAgent',
 
 
183
  ]
184
 
 
 
 
 
185
  agentMode = {
186
  'ImageGeneration': {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
187
  'Niansuh': {'mode': True, 'id': "NiansuhAIk1HgESy", 'name': "Niansuh"},
188
  }
189
-
190
  trendingAgentMode = {
191
  "blackboxai": {},
192
  "gemini-1.5-flash": {'mode': True, 'id': 'Gemini'},
@@ -266,13 +251,13 @@ class Blackbox:
266
  async def create_async_generator(
267
  cls,
268
  model: str,
269
- messages: List[Dict[str, Any]],
270
  proxy: Optional[str] = None,
271
- image: Optional[str] = None,
272
  image_name: Optional[str] = None,
273
  webSearchMode: bool = False,
274
  **kwargs
275
- ) -> AsyncGenerator[Union[str, ImageResponseCustom], None]:
276
  model = cls.get_model(model)
277
  if model is None:
278
  logger.error(f"Model {model} is not available.")
@@ -283,7 +268,7 @@ class Blackbox:
283
  if not cls.working or model not in cls.models:
284
  logger.error(f"Model {model} is not working or not supported.")
285
  raise ModelNotWorkingException(model)
286
-
287
  headers = {
288
  "accept": "*/*",
289
  "accept-language": "en-US,en;q=0.9",
@@ -307,7 +292,7 @@ class Blackbox:
307
  if not messages[0]['content'].startswith(prefix):
308
  logger.debug(f"Adding prefix '{prefix}' to the first message.")
309
  messages[0]['content'] = f"{prefix} {messages[0]['content']}"
310
-
311
  random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
312
  messages[-1]['id'] = random_id
313
  messages[-1]['role'] = 'user'
@@ -318,12 +303,12 @@ class Blackbox:
318
  if image is not None:
319
  messages[-1]['data'] = {
320
  'fileText': '',
321
- 'imageBase64': image,
322
  'title': image_name
323
  }
324
  messages[-1]['content'] = 'FILE:BB\n$#$\n\n$#$\n' + messages[-1]['content']
325
  logger.debug("Image data added to the message.")
326
-
327
  data = {
328
  "messages": messages,
329
  "id": random_id,
@@ -365,14 +350,13 @@ class Blackbox:
365
  async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
366
  response.raise_for_status()
367
  logger.info(f"Received response with status {response.status}")
368
- if model in cls.image_models:
369
  response_text = await response.text()
370
- # Extract image URL from the response
371
  url_match = re.search(r'https://storage\.googleapis\.com/[^\s\)]+', response_text)
372
  if url_match:
373
  image_url = url_match.group(0)
374
- logger.info(f"Image URL found: {image_url}")
375
- yield ImageResponseCustom(url=image_url, alt=messages[-1]['content'])
376
  else:
377
  logger.error("Image URL not found in the response.")
378
  raise Exception("Image URL not found in the response")
@@ -421,7 +405,7 @@ class Blackbox:
421
  if attempt == retry_attempts - 1:
422
  raise HTTPException(status_code=500, detail=str(e))
423
 
424
- # Initialize FastAPI app
425
  app = FastAPI()
426
 
427
  # Add the cleanup task when the app starts
@@ -453,39 +437,10 @@ async def security_middleware(request: Request, call_next):
453
  response = await call_next(request)
454
  return response
455
 
456
- # Pydantic Models
457
-
458
- class TextContent(BaseModel):
459
- type: str = "text"
460
- text: str
461
-
462
- @validator('type')
463
- def type_must_be_text(cls, v):
464
- if v != "text":
465
- raise ValueError("Type must be 'text'")
466
- return v
467
-
468
- class ImageContent(BaseModel):
469
- type: str = "image_url"
470
- image_url: Dict[str, str]
471
-
472
- @validator('type')
473
- def type_must_be_image_url(cls, v):
474
- if v != "image_url":
475
- raise ValueError("Type must be 'image_url'")
476
- return v
477
-
478
- ContentItem = Union[TextContent, ImageContent]
479
-
480
  class Message(BaseModel):
481
  role: str
482
- content: Union[str, List[ContentItem]]
483
-
484
- @validator('role')
485
- def role_must_be_valid(cls, v):
486
- if v not in {"system", "user", "assistant"}:
487
- raise ValueError("Role must be 'system', 'user', or 'assistant'")
488
- return v
489
 
490
  class ChatRequest(BaseModel):
491
  model: str
@@ -501,12 +456,11 @@ class ChatRequest(BaseModel):
501
  logit_bias: Optional[Dict[str, float]] = None
502
  user: Optional[str] = None
503
  webSearchMode: Optional[bool] = False # Custom parameter
 
504
 
505
  class TokenizerRequest(BaseModel):
506
  text: str
507
 
508
- # Utility Functions
509
-
510
  def calculate_estimated_cost(prompt_tokens: int, completion_tokens: int) -> float:
511
  """
512
  Calculate the estimated cost based on the number of tokens.
@@ -516,18 +470,6 @@ def calculate_estimated_cost(prompt_tokens: int, completion_tokens: int) -> floa
516
  cost_per_token = 0.00000268
517
  return round((prompt_tokens + completion_tokens) * cost_per_token, 8)
518
 
519
- def count_tokens(text: str) -> int:
520
- """
521
- Counts the number of tokens in a given text using tiktoken.
522
- """
523
- try:
524
- import tiktoken
525
- encoding = tiktoken.get_encoding("cl100k_base")
526
- return len(encoding.encode(text))
527
- except ImportError:
528
- # Fallback if tiktoken is not installed
529
- return len(text.split())
530
-
531
  def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
532
  return {
533
  "id": f"chatcmpl-{uuid.uuid4()}",
@@ -547,60 +489,6 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
547
  "usage": None, # To be filled in non-streaming responses
548
  }
549
 
550
- def extract_all_images_from_content(content: Union[str, List[ContentItem]]) -> List[Tuple[str, str]]:
551
- """
552
- Extracts all images from the content.
553
- Returns a list of tuples containing (alt_text, image_data_uri).
554
- """
555
- images = []
556
- if isinstance(content, list):
557
- for item in content:
558
- if isinstance(item, ImageContent):
559
- alt_text = item.image_url.get('alt', '') # Optional alt text
560
- image_data_uri = item.image_url.get('url', '')
561
- if image_data_uri:
562
- images.append((alt_text, image_data_uri))
563
- return images
564
-
565
- # Image Analysis Function (Placeholder)
566
- async def analyze_image(image_data_uri: str) -> str:
567
- """
568
- Placeholder function to analyze the image.
569
- Replace this with actual image analysis logic or API calls.
570
- """
571
- try:
572
- # Extract base64 data
573
- image_data = image_data_uri.split(",")[1]
574
- # Decode the image
575
- image_bytes = base64.b64decode(image_data)
576
-
577
- # Here, integrate with an image analysis API or implement your own logic
578
- # For demonstration, we'll simulate analysis with a dummy response.
579
- await asyncio.sleep(1) # Simulate processing delay
580
- return "Image analysis result: The image depicts a beautiful sunset over the mountains."
581
- except Exception as e:
582
- logger.error(f"Failed to analyze image: {e}")
583
- raise HTTPException(status_code=400, detail="Failed to process the provided image.")
584
-
585
- # Helper Function for Token Counting
586
- def count_prompt_tokens(request: ChatRequest) -> int:
587
- """
588
- Counts the number of tokens in the prompt (input messages).
589
- Handles both string and list types for the 'content' field.
590
- """
591
- total = 0
592
- for msg in request.messages:
593
- if isinstance(msg.content, str):
594
- total += count_tokens(msg.content)
595
- elif isinstance(msg.content, list):
596
- for item in msg.content:
597
- if isinstance(item, TextContent):
598
- total += count_tokens(item.text)
599
- elif isinstance(item, ImageContent):
600
- total += count_tokens(item.image_url['url'])
601
- return total
602
-
603
- # Endpoint: POST /v1/chat/completions
604
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
605
  async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
606
  client_ip = req.client.host
@@ -609,61 +497,69 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
609
 
610
  logger.info(f"Received chat completions request from API key: {api_key} | IP: {client_ip} | Model: {request.model} | Messages: {redacted_messages}")
611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  try:
613
  # Validate that the requested model is available
614
  if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
615
  logger.warning(f"Attempt to use unavailable model: {request.model} from IP: {client_ip}")
616
  raise HTTPException(status_code=400, detail="Requested model is not available.")
617
 
618
- # Initialize response content
619
- assistant_content = ""
620
-
621
- # Iterate through messages to find and process images
622
- for msg in request.messages:
623
- if msg.role == "user":
624
- # Extract all images from the message content
625
- images = extract_all_images_from_content(msg.content)
626
- for alt_text, image_data_uri in images:
627
- # Analyze the image
628
- analysis_result = await analyze_image(image_data_uri)
629
- assistant_content += analysis_result + "\n"
630
-
631
- # Example response content
632
- assistant_content += "Based on the image you provided, here are the insights..."
633
-
634
- # Calculate token usage using the helper function
635
- prompt_tokens = count_prompt_tokens(request)
636
- completion_tokens = count_tokens(assistant_content)
637
- total_tokens = prompt_tokens + completion_tokens
638
- estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
639
-
640
- logger.info(f"Completed response generation for API key: {api_key} | IP: {client_ip}")
641
 
642
  if request.stream:
643
  async def generate():
644
  try:
645
- for msg in request.messages:
646
- if msg.role == "user":
647
- images = extract_all_images_from_content(msg.content)
648
- for alt_text, image_data_uri in images:
649
- analysis_result = await analyze_image(image_data_uri)
650
- response_chunk = {
651
- "id": f"chatcmpl-{uuid.uuid4()}",
652
- "object": "chat.completion.chunk",
653
- "created": int(datetime.now().timestamp()),
654
- "model": request.model,
655
- "choices": [
656
- {
657
- "index": 0,
658
- "delta": {"content": analysis_result + "\n", "role": "assistant"},
659
- "finish_reason": None,
660
- }
661
- ],
662
- "usage": None,
663
- }
664
- yield f"data: {json.dumps(response_chunk)}\n\n"
665
-
666
- # Final message
 
 
 
 
 
 
 
 
 
 
667
  final_response = {
668
  "id": f"chatcmpl-{uuid.uuid4()}",
669
  "object": "chat.completion",
@@ -673,7 +569,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
673
  {
674
  "message": {
675
  "role": "assistant",
676
- "content": assistant_content.strip()
677
  },
678
  "finish_reason": "stop",
679
  "index": 0
@@ -686,6 +582,9 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
686
  "estimated_cost": estimated_cost
687
  },
688
  }
 
 
 
689
  yield f"data: {json.dumps(final_response)}\n\n"
690
  yield "data: [DONE]\n\n"
691
  except HTTPException as he:
@@ -698,7 +597,21 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
698
 
699
  return StreamingResponse(generate(), media_type="text/event-stream")
700
  else:
701
- return {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
  "id": f"chatcmpl-{uuid.uuid4()}",
703
  "object": "chat.completion",
704
  "created": int(datetime.now().timestamp()),
@@ -707,7 +620,7 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
707
  {
708
  "message": {
709
  "role": "assistant",
710
- "content": assistant_content.strip()
711
  },
712
  "finish_reason": "stop",
713
  "index": 0
@@ -720,6 +633,11 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
720
  "estimated_cost": estimated_cost
721
  },
722
  }
 
 
 
 
 
723
  except ModelNotWorkingException as e:
724
  logger.warning(f"Model not working: {e} | IP: {client_ip}")
725
  raise HTTPException(status_code=503, detail=str(e))
@@ -732,23 +650,23 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
732
 
733
  # Endpoint: POST /v1/tokenizer
734
  @app.post("/v1/tokenizer", dependencies=[Depends(rate_limiter_per_ip)])
735
- async def tokenizer(request: TokenizerRequest, req: Request, api_key: str = Depends(get_api_key)):
736
  client_ip = req.client.host
737
  text = request.text
738
- token_count = count_tokens(text)
739
  logger.info(f"Tokenizer requested from IP: {client_ip} | Text length: {len(text)}")
740
  return {"text": text, "tokens": token_count}
741
 
742
  # Endpoint: GET /v1/models
743
  @app.get("/v1/models", dependencies=[Depends(rate_limiter_per_ip)])
744
- async def get_models(req: Request, api_key: str = Depends(get_api_key)):
745
  client_ip = req.client.host
746
  logger.info(f"Fetching available models from IP: {client_ip}")
747
  return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
748
 
749
  # Endpoint: GET /v1/models/{model}/status
750
  @app.get("/v1/models/{model}/status", dependencies=[Depends(rate_limiter_per_ip)])
751
- async def model_status(model: str, req: Request, api_key: str = Depends(get_api_key)):
752
  client_ip = req.client.host
753
  logger.info(f"Model status requested for '{model}' from IP: {client_ip}")
754
  if model in Blackbox.models:
@@ -762,14 +680,14 @@ async def model_status(model: str, req: Request, api_key: str = Depends(get_api_
762
 
763
  # Endpoint: GET /v1/health
764
  @app.get("/v1/health", dependencies=[Depends(rate_limiter_per_ip)])
765
- async def health_check(req: Request, api_key: str = Depends(get_api_key)):
766
  client_ip = req.client.host
767
  logger.info(f"Health check requested from IP: {client_ip}")
768
  return {"status": "ok"}
769
 
770
  # Endpoint: GET /v1/chat/completions (GET method)
771
  @app.get("/v1/chat/completions")
772
- async def chat_completions_get(req: Request, api_key: str = Depends(get_api_key)):
773
  client_ip = req.client.host
774
  logger.info(f"GET request made to /v1/chat/completions from IP: {client_ip}, redirecting to 'about:blank'")
775
  return RedirectResponse(url='about:blank')
@@ -794,4 +712,4 @@ async def http_exception_handler(request: Request, exc: HTTPException):
794
  # Run the application
795
  if __name__ == "__main__":
796
  import uvicorn
797
- uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
 
 
 
1
  import os
2
  import re
3
  import random
 
7
  import logging
8
  import asyncio
9
  import time
10
+ import base64
11
+ from io import BytesIO
12
  from collections import defaultdict
13
+ from typing import List, Dict, Any, Optional, AsyncGenerator, Union
14
 
15
  from datetime import datetime
16
 
17
  from aiohttp import ClientSession, ClientTimeout, ClientError
18
  from fastapi import FastAPI, HTTPException, Request, Depends, Header
19
  from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
20
+ from pydantic import BaseModel
21
+ from PIL import Image
 
 
 
 
 
 
22
 
23
  # Configure logging
24
  logging.basicConfig(
 
99
  self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed."
100
  super().__init__(self.message)
101
 
102
+ # Mock implementations for ImageResponse and to_data_uri
103
+ class ImageResponse:
104
+ def __init__(self, url: str, alt: str):
105
+ self.url = url
106
+ self.alt = alt
107
 
108
+ def to_data_uri(image: Any) -> str:
109
+ return "data:image/png;base64,..." # Replace with actual base64 data
 
 
 
 
110
 
111
+ # Utility functions for image processing
112
+ def decode_base64_image(base64_str: str) -> Image.Image:
113
+ try:
114
+ image_data = base64.b64decode(base64_str)
115
+ image = Image.open(BytesIO(image_data))
116
+ return image
117
+ except Exception as e:
118
+ logger.error("Failed to decode base64 image.")
119
+ raise HTTPException(status_code=400, detail="Invalid base64 image data.") from e
 
 
 
 
 
 
 
 
120
 
121
+ def analyze_image(image: Image.Image) -> str:
122
  """
123
+ Placeholder for image analysis.
124
+ Replace this with actual image analysis logic.
125
  """
126
+ # Example: Return image size as analysis
127
+ width, height = image.size
128
+ return f"Image analyzed successfully. Width: {width}px, Height: {height}px."
 
 
 
 
129
 
 
130
  class Blackbox:
131
  url = "https://www.blackbox.ai"
132
+ api_endpoint = "https://www.blackbox.ai/api/chat"
133
  working = True
134
  supports_stream = True
135
  supports_system_message = True
 
140
  models = [
141
  default_model,
142
  'blackboxai-pro',
 
143
  "llama-3.1-8b",
144
  'llama-3.1-70b',
145
  'llama-3.1-405b',
 
160
  'ReactAgent',
161
  'XcodeAgent',
162
  'AngularJSAgent',
163
+ *image_models,
164
+ 'Niansuh',
165
  ]
166
 
167
+ # Filter models based on AVAILABLE_MODELS
168
+ if AVAILABLE_MODELS:
169
+ models = [model for model in models if model in AVAILABLE_MODELS]
170
+
171
  agentMode = {
172
  'ImageGeneration': {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
173
  'Niansuh': {'mode': True, 'id': "NiansuhAIk1HgESy", 'name': "Niansuh"},
174
  }
 
175
  trendingAgentMode = {
176
  "blackboxai": {},
177
  "gemini-1.5-flash": {'mode': True, 'id': 'Gemini'},
 
251
  async def create_async_generator(
252
  cls,
253
  model: str,
254
+ messages: List[Dict[str, str]],
255
  proxy: Optional[str] = None,
256
+ image: Any = None,
257
  image_name: Optional[str] = None,
258
  webSearchMode: bool = False,
259
  **kwargs
260
+ ) -> AsyncGenerator[Any, None]:
261
  model = cls.get_model(model)
262
  if model is None:
263
  logger.error(f"Model {model} is not available.")
 
268
  if not cls.working or model not in cls.models:
269
  logger.error(f"Model {model} is not working or not supported.")
270
  raise ModelNotWorkingException(model)
271
+
272
  headers = {
273
  "accept": "*/*",
274
  "accept-language": "en-US,en;q=0.9",
 
292
  if not messages[0]['content'].startswith(prefix):
293
  logger.debug(f"Adding prefix '{prefix}' to the first message.")
294
  messages[0]['content'] = f"{prefix} {messages[0]['content']}"
295
+
296
  random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
297
  messages[-1]['id'] = random_id
298
  messages[-1]['role'] = 'user'
 
303
  if image is not None:
304
  messages[-1]['data'] = {
305
  'fileText': '',
306
+ 'imageBase64': to_data_uri(image),
307
  'title': image_name
308
  }
309
  messages[-1]['content'] = 'FILE:BB\n$#$\n\n$#$\n' + messages[-1]['content']
310
  logger.debug("Image data added to the message.")
311
+
312
  data = {
313
  "messages": messages,
314
  "id": random_id,
 
350
  async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
351
  response.raise_for_status()
352
  logger.info(f"Received response with status {response.status}")
353
+ if model == 'ImageGeneration':
354
  response_text = await response.text()
 
355
  url_match = re.search(r'https://storage\.googleapis\.com/[^\s\)]+', response_text)
356
  if url_match:
357
  image_url = url_match.group(0)
358
+ logger.info(f"Image URL found.")
359
+ yield ImageResponse(image_url, alt=messages[-1]['content'])
360
  else:
361
  logger.error("Image URL not found in the response.")
362
  raise Exception("Image URL not found in the response")
 
405
  if attempt == retry_attempts - 1:
406
  raise HTTPException(status_code=500, detail=str(e))
407
 
408
+ # FastAPI app setup
409
  app = FastAPI()
410
 
411
  # Add the cleanup task when the app starts
 
437
  response = await call_next(request)
438
  return response
439
 
440
+ # Request Models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  class Message(BaseModel):
442
  role: str
443
+ content: Union[str, List[Any]] # Adjusted to accept list if needed
 
 
 
 
 
 
444
 
445
  class ChatRequest(BaseModel):
446
  model: str
 
456
  logit_bias: Optional[Dict[str, float]] = None
457
  user: Optional[str] = None
458
  webSearchMode: Optional[bool] = False # Custom parameter
459
+ image: Optional[str] = None # Base64-encoded image
460
 
461
  class TokenizerRequest(BaseModel):
462
  text: str
463
 
 
 
464
  def calculate_estimated_cost(prompt_tokens: int, completion_tokens: int) -> float:
465
  """
466
  Calculate the estimated cost based on the number of tokens.
 
470
  cost_per_token = 0.00000268
471
  return round((prompt_tokens + completion_tokens) * cost_per_token, 8)
472
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
474
  return {
475
  "id": f"chatcmpl-{uuid.uuid4()}",
 
489
  "usage": None, # To be filled in non-streaming responses
490
  }
491
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
493
  async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
494
  client_ip = req.client.host
 
497
 
498
  logger.info(f"Received chat completions request from API key: {api_key} | IP: {client_ip} | Model: {request.model} | Messages: {redacted_messages}")
499
 
500
+ analysis_result = None
501
+ if request.image:
502
+ try:
503
+ image = decode_base64_image(request.image)
504
+ analysis_result = analyze_image(image)
505
+ logger.info("Image analysis completed successfully.")
506
+ except HTTPException as he:
507
+ logger.error(f"Image analysis failed: {he.detail}")
508
+ raise he
509
+ except Exception as e:
510
+ logger.exception("Unexpected error during image analysis.")
511
+ raise HTTPException(status_code=500, detail="Image analysis failed.") from e
512
+
513
  try:
514
  # Validate that the requested model is available
515
  if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
516
  logger.warning(f"Attempt to use unavailable model: {request.model} from IP: {client_ip}")
517
  raise HTTPException(status_code=400, detail="Requested model is not available.")
518
 
519
+ # Process the request with actual message content and image data
520
+ async_generator = Blackbox.create_async_generator(
521
+ model=request.model,
522
+ messages=[{"role": msg.role, "content": msg.content} for msg in request.messages],
523
+ image=request.image,
524
+ image_name="uploaded_image", # You can modify this as needed
525
+ webSearchMode=request.webSearchMode
526
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
  if request.stream:
529
  async def generate():
530
  try:
531
+ assistant_content = ""
532
+ async for chunk in async_generator:
533
+ if isinstance(chunk, ImageResponse):
534
+ # Handle image responses if necessary
535
+ image_markdown = f"![image]({chunk.url})\n"
536
+ assistant_content += image_markdown
537
+ response_chunk = create_response(image_markdown, request.model, finish_reason=None)
538
+ else:
539
+ assistant_content += chunk
540
+ # Yield the chunk as a partial choice
541
+ response_chunk = {
542
+ "id": f"chatcmpl-{uuid.uuid4()}",
543
+ "object": "chat.completion.chunk",
544
+ "created": int(datetime.now().timestamp()),
545
+ "model": request.model,
546
+ "choices": [
547
+ {
548
+ "index": 0,
549
+ "delta": {"content": chunk, "role": "assistant"},
550
+ "finish_reason": None,
551
+ }
552
+ ],
553
+ "usage": None, # Usage can be updated if you track tokens in real-time
554
+ }
555
+ yield f"data: {json.dumps(response_chunk)}\n\n"
556
+
557
+ # After all chunks are sent, send the final message with finish_reason
558
+ prompt_tokens = sum(len(msg.content.split()) for msg in request.messages)
559
+ completion_tokens = len(assistant_content.split())
560
+ total_tokens = prompt_tokens + completion_tokens
561
+ estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
562
+
563
  final_response = {
564
  "id": f"chatcmpl-{uuid.uuid4()}",
565
  "object": "chat.completion",
 
569
  {
570
  "message": {
571
  "role": "assistant",
572
+ "content": assistant_content
573
  },
574
  "finish_reason": "stop",
575
  "index": 0
 
582
  "estimated_cost": estimated_cost
583
  },
584
  }
585
+ if analysis_result:
586
+ final_response["choices"][0]["message"]["content"] += f"\n\n**Image Analysis:** {analysis_result}"
587
+
588
  yield f"data: {json.dumps(final_response)}\n\n"
589
  yield "data: [DONE]\n\n"
590
  except HTTPException as he:
 
597
 
598
  return StreamingResponse(generate(), media_type="text/event-stream")
599
  else:
600
+ response_content = ""
601
+ async for chunk in async_generator:
602
+ if isinstance(chunk, ImageResponse):
603
+ response_content += f"![image]({chunk.url})\n"
604
+ else:
605
+ response_content += chunk
606
+
607
+ prompt_tokens = sum(len(msg.content.split()) for msg in request.messages)
608
+ completion_tokens = len(response_content.split())
609
+ total_tokens = prompt_tokens + completion_tokens
610
+ estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
611
+
612
+ logger.info(f"Completed non-streaming response generation for API key: {api_key} | IP: {client_ip}")
613
+
614
+ response = {
615
  "id": f"chatcmpl-{uuid.uuid4()}",
616
  "object": "chat.completion",
617
  "created": int(datetime.now().timestamp()),
 
620
  {
621
  "message": {
622
  "role": "assistant",
623
+ "content": response_content
624
  },
625
  "finish_reason": "stop",
626
  "index": 0
 
633
  "estimated_cost": estimated_cost
634
  },
635
  }
636
+
637
+ if analysis_result:
638
+ response["choices"][0]["message"]["content"] += f"\n\n**Image Analysis:** {analysis_result}"
639
+
640
+ return response
641
  except ModelNotWorkingException as e:
642
  logger.warning(f"Model not working: {e} | IP: {client_ip}")
643
  raise HTTPException(status_code=503, detail=str(e))
 
650
 
651
  # Endpoint: POST /v1/tokenizer
652
  @app.post("/v1/tokenizer", dependencies=[Depends(rate_limiter_per_ip)])
653
+ async def tokenizer(request: TokenizerRequest, req: Request):
654
  client_ip = req.client.host
655
  text = request.text
656
+ token_count = len(text.split())
657
  logger.info(f"Tokenizer requested from IP: {client_ip} | Text length: {len(text)}")
658
  return {"text": text, "tokens": token_count}
659
 
660
  # Endpoint: GET /v1/models
661
  @app.get("/v1/models", dependencies=[Depends(rate_limiter_per_ip)])
662
+ async def get_models(req: Request):
663
  client_ip = req.client.host
664
  logger.info(f"Fetching available models from IP: {client_ip}")
665
  return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
666
 
667
  # Endpoint: GET /v1/models/{model}/status
668
  @app.get("/v1/models/{model}/status", dependencies=[Depends(rate_limiter_per_ip)])
669
+ async def model_status(model: str, req: Request):
670
  client_ip = req.client.host
671
  logger.info(f"Model status requested for '{model}' from IP: {client_ip}")
672
  if model in Blackbox.models:
 
680
 
681
  # Endpoint: GET /v1/health
682
  @app.get("/v1/health", dependencies=[Depends(rate_limiter_per_ip)])
683
+ async def health_check(req: Request):
684
  client_ip = req.client.host
685
  logger.info(f"Health check requested from IP: {client_ip}")
686
  return {"status": "ok"}
687
 
688
  # Endpoint: GET /v1/chat/completions (GET method)
689
  @app.get("/v1/chat/completions")
690
+ async def chat_completions_get(req: Request):
691
  client_ip = req.client.host
692
  logger.info(f"GET request made to /v1/chat/completions from IP: {client_ip}, redirecting to 'about:blank'")
693
  return RedirectResponse(url='about:blank')
 
712
  # Run the application
713
  if __name__ == "__main__":
714
  import uvicorn
715
+ uvicorn.run(app, host="0.0.0.0", port=8000)