Niansuh commited on
Commit
4efca8f
·
verified ·
1 Parent(s): 36aebd6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +325 -199
main.py CHANGED
@@ -1,7 +1,5 @@
1
  import os
2
  import re
3
- import random
4
- import string
5
  import uuid
6
  import json
7
  import logging
@@ -9,10 +7,11 @@ import asyncio
9
  import time
10
  from collections import defaultdict
11
  from typing import List, Dict, Any, Optional, AsyncGenerator
 
12
  from datetime import datetime
13
 
14
  from aiohttp import ClientSession, ClientTimeout, ClientError
15
- from fastapi import FastAPI, HTTPException, Request, Depends, Header, UploadFile, File
16
  from fastapi.responses import StreamingResponse
17
  from pydantic import BaseModel
18
 
@@ -26,35 +25,26 @@ logger = logging.getLogger(__name__)
26
 
27
  # Load environment variables
28
  API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
29
- RATE_LIMIT = int(os.getenv('RATE_LIMIT', '60')) # Requests per minute
30
 
31
  if not API_KEYS or API_KEYS == ['']:
32
- logger.error("No API keys found. Please set the API_KEYS environment variable. | NiansuhAI")
33
- raise Exception("API_KEYS environment variable not set. | NiansuhAI")
34
-
35
- # Simple in-memory rate limiter
36
- rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
37
 
38
- async def get_api_key(authorization: str = Header(...)) -> str:
39
- if not authorization.startswith('Bearer '):
40
- logger.warning("Invalid authorization header format.")
41
- raise HTTPException(status_code=401, detail='Invalid authorization header format | NiansuhAI')
42
- api_key = authorization[7:]
43
- if api_key not in API_KEYS:
44
- logger.warning(f"Invalid API key attempted: {api_key}")
45
- raise HTTPException(status_code=401, detail='Invalid API key | NiansuhAI')
46
- return api_key
47
 
48
- async def rate_limiter(api_key: str = Depends(get_api_key)):
 
49
  current_time = time.time()
50
- window_start = rate_limit_store[api_key]["timestamp"]
51
  if current_time - window_start > 60:
52
- rate_limit_store[api_key] = {"count": 1, "timestamp": current_time}
53
  else:
54
- if rate_limit_store[api_key]["count"] >= RATE_LIMIT:
55
- logger.warning(f"Rate limit exceeded for API key: {api_key}")
56
- raise HTTPException(status_code=429, detail='Rate limit exceeded | NiansuhAI')
57
- rate_limit_store[api_key]["count"] += 1
58
 
59
  # Custom exception for model not working
60
  class ModelNotWorkingException(Exception):
@@ -132,7 +122,7 @@ class Blackbox:
132
  'PyTorchAgent': {'mode': True, 'id': "PyTorch Agent"},
133
  'ReactAgent': {'mode': True, 'id': "React Agent"},
134
  'XcodeAgent': {'mode': True, 'id': "Xcode Agent"},
135
- 'AngularJSAgent': {'mode': True, 'id': "AngularJS Agent"},
136
  }
137
 
138
  userSelectedModel = {
@@ -188,225 +178,361 @@ class Blackbox:
188
  else:
189
  return cls.default_model
190
 
191
- # (Rest of the Blackbox class remains unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  # FastAPI app setup
194
  app = FastAPI()
195
 
 
 
 
 
 
 
 
 
196
  class Message(BaseModel):
197
  role: str
198
  content: str
199
 
200
- class ChatRequest(BaseModel):
201
  model: str
202
  messages: List[Message]
 
 
 
203
  stream: Optional[bool] = False
204
- webSearchMode: Optional[bool] = False
205
-
206
- def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
 
 
 
 
 
207
  return {
208
  "id": f"chatcmpl-{uuid.uuid4()}",
209
- "object": "chat.completion.chunk",
210
  "created": int(datetime.now().timestamp()),
211
  "model": model,
212
  "choices": [
213
  {
214
  "index": 0,
215
- "delta": {"content": content, "role": "assistant"},
216
- "finish_reason": finish_reason,
 
 
 
217
  }
218
  ],
219
- "usage": None,
220
  }
221
 
222
- @app.post("/niansuhai/v1/chat/completions", dependencies=[Depends(rate_limiter)])
223
- async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
224
- # Redact user messages only for logging purposes
225
- redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- logger.info(f"Received chat completions request from API key: {api_key} | Model: {request.model} | Messages: {redacted_messages}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
 
229
  try:
230
- # Validate that the requested model is available
231
- if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
232
- logger.warning(f"Attempt to use unavailable model: {request.model}")
233
- raise HTTPException(status_code=400, detail="Requested model is not available. | NiansuhAI")
234
-
235
- # Process the request with actual message content, but don't log it
236
- async_generator = Blackbox.create_async_generator(
237
- model=request.model,
238
- messages=[{"role": msg.role, "content": msg.content} for msg in request.messages], # Actual message content used here
239
- image=None,
240
- image_name=None,
241
- webSearchMode=request.webSearchMode
242
  )
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  if request.stream:
245
  async def generate():
 
246
  try:
 
 
 
 
247
  async for chunk in async_generator:
248
- if isinstance(chunk, ImageResponse):
249
- image_markdown = f"![image]({chunk.url})"
250
- response_chunk = create_response(image_markdown, request.model)
 
 
 
 
 
 
251
  else:
252
- response_chunk = create_response(chunk, request.model)
253
-
254
- yield f"data: {json.dumps(response_chunk)}\n\n"
255
-
256
  yield "data: [DONE]\n\n"
257
- except HTTPException as he:
258
- error_response = {"error": he.detail}
259
- yield f"data: {json.dumps(error_response)}\n\n"
260
  except Exception as e:
261
- logger.exception("Error during streaming response generation. | NiansuhAI")
262
- error_response = {"error": str(e)}
263
- yield f"data: {json.dumps(error_response)}\n\n"
264
-
265
  return StreamingResponse(generate(), media_type="text/event-stream")
266
  else:
267
  response_content = ""
268
  async for chunk in async_generator:
269
- if isinstance(chunk, ImageResponse):
270
- response_content += f"![image]({chunk.url})\n"
271
- else:
272
  response_content += chunk
273
-
274
- logger.info(f"Completed non-streaming response generation for API key: {api_key}")
275
- return {
276
- "id": f"chatcmpl-{uuid.uuid4()}",
277
- "object": "chat.completion",
278
- "created": int(datetime.now().timestamp()),
279
- "model": request.model,
280
- "choices": [
281
- {
282
- "message": {
283
- "role": "assistant",
284
- "content": response_content
285
- },
286
- "finish_reason": "stop",
287
- "index": 0
288
- }
289
- ],
290
- "usage": {
291
- "prompt_tokens": sum(len(msg.content.split()) for msg in request.messages),
292
- "completion_tokens": len(response_content.split()),
293
- "total_tokens": sum(len(msg.content.split()) for msg in request.messages) + len(response_content.split())
294
- },
295
  }
 
296
  except ModelNotWorkingException as e:
297
  logger.warning(f"Model not working: {e}")
298
  raise HTTPException(status_code=503, detail=str(e))
299
- except HTTPException as he:
300
- logger.warning(f"HTTPException: {he.detail}")
301
- raise he
302
  except Exception as e:
303
- logger.exception("An unexpected error occurred while processing the chat completions request. | NiansuhAI")
304
  raise HTTPException(status_code=500, detail=str(e))
305
 
306
- @app.get("/niansuhai/v1/models", dependencies=[Depends(rate_limiter)])
307
- async def get_models(api_key: str = Depends(get_api_key)):
308
- logger.info(f"Fetching available models for API key: {api_key}")
309
- return {"data": [{"id": model} for model in Blackbox.models]}
310
-
311
- # Additional endpoints for better functionality
312
-
313
- @app.get("/niansuhai/v1/health", dependencies=[Depends(rate_limiter)])
314
- async def health_check(api_key: str = Depends(get_api_key)):
315
- logger.info(f"Health check requested by API key: {api_key}")
316
- return {"status": "ok"}
317
-
318
- @app.get("/niansuhai/v1/models/{model}/status", dependencies=[Depends(rate_limiter)])
319
- async def model_status(model: str, api_key: str = Depends(get_api_key)):
320
- logger.info(f"Model status requested for '{model}' by API key: {api_key}")
321
- if model in Blackbox.models:
322
- return {"model": model, "status": "available | NiansuhAI"}
323
- elif model in Blackbox.model_aliases:
324
- actual_model = Blackbox.model_aliases[model]
325
- return {"model": actual_model, "status": "available via alias | NiansuhAI"}
326
- else:
327
- logger.warning(f"Model not found: {model}")
328
- raise HTTPException(status_code=404, detail="Model not found | NiansuhAI")
329
-
330
- # New endpoint to get model details
331
- @app.get("/niansuhai/v1/models/{model}/details", dependencies=[Depends(rate_limiter)])
332
- async def get_model_details(model: str, api_key: str = Depends(get_api_key)):
333
- logger.info(f"Model details requested for '{model}' by API key: {api_key}")
334
- actual_model = Blackbox.get_model(model)
335
- if actual_model not in Blackbox.models:
336
- logger.warning(f"Model not found: {model}")
337
- raise HTTPException(status_code=404, detail="Model not found | NiansuhAI")
338
- # For demonstration, we'll return mock details
339
- model_details = {
340
- "id": actual_model,
341
- "description": f"Details about model {actual_model}",
342
- "capabilities": ["chat", "completion", "image generation"] if actual_model in Blackbox.image_models else ["chat", "completion"],
343
- "status": "available",
344
- }
345
- return {"data": model_details}
346
-
347
- # Session history endpoints
348
- session_histories = defaultdict(list) # In-memory storage for session histories
349
-
350
- @app.post("/niansuhai/v1/sessions/{session_id}/messages", dependencies=[Depends(rate_limiter)])
351
- async def add_message_to_session(session_id: str, message: Message, api_key: str = Depends(get_api_key)):
352
- logger.info(f"Adding message to session '{session_id}' by API key: {api_key}")
353
- session_histories[session_id].append({"role": message.role, "content": message.content})
354
- return {"status": "message added"}
355
-
356
- @app.get("/niansuhai/v1/sessions/{session_id}/messages", dependencies=[Depends(rate_limiter)])
357
- async def get_session_messages(session_id: str, api_key: str = Depends(get_api_key)):
358
- logger.info(f"Fetching messages for session '{session_id}' by API key: {api_key}")
359
- messages = session_histories.get(session_id)
360
- if messages is None:
361
- raise HTTPException(status_code=404, detail="Session not found | NiansuhAI")
362
- return {"data": messages}
363
-
364
- # User preferences endpoints
365
- user_preferences = defaultdict(dict) # In-memory storage for user preferences
366
-
367
- class UserPreferences(BaseModel):
368
- theme: Optional[str] = "light"
369
- notifications_enabled: Optional[bool] = True
370
-
371
- @app.post("/niansuhai/v1/users/{user_id}/preferences", dependencies=[Depends(rate_limiter)])
372
- async def update_user_preferences(user_id: str, preferences: UserPreferences, api_key: str = Depends(get_api_key)):
373
- logger.info(f"Updating preferences for user '{user_id}' by API key: {api_key}")
374
- user_preferences[user_id] = preferences.dict()
375
- return {"status": "preferences updated"}
376
-
377
- @app.get("/niansuhai/v1/users/{user_id}/preferences", dependencies=[Depends(rate_limiter)])
378
- async def get_user_preferences(user_id: str, api_key: str = Depends(get_api_key)):
379
- logger.info(f"Fetching preferences for user '{user_id}' by API key: {api_key}")
380
- preferences = user_preferences.get(user_id)
381
- if preferences is None:
382
- raise HTTPException(status_code=404, detail="User not found | NiansuhAI")
383
- return {"data": preferences}
384
-
385
- # Image upload endpoint
386
- @app.post("/niansuhai/v1/images/upload", dependencies=[Depends(rate_limiter)])
387
- async def upload_image(image: UploadFile = File(...), api_key: str = Depends(get_api_key)):
388
- logger.info(f"Image upload requested by API key: {api_key}")
389
- if not image.content_type.startswith('image/'):
390
- logger.warning("Uploaded file is not an image.")
391
- raise HTTPException(status_code=400, detail="Uploaded file is not an image | NiansuhAI")
392
- # For demonstration, we'll just return the filename
393
- return {"filename": image.filename, "status": "image uploaded"}
394
-
395
- # Component health check endpoint
396
- @app.get("/niansuhai/v1/health/{component}", dependencies=[Depends(rate_limiter)])
397
- async def component_health_check(component: str, api_key: str = Depends(get_api_key)):
398
- logger.info(f"Health check for component '{component}' requested by API key: {api_key}")
399
- # Mock health status for components
400
- components_status = {
401
- "database": "healthy",
402
- "message_queue": "healthy",
403
- "cache": "healthy",
404
- }
405
- status = components_status.get(component)
406
- if status is None:
407
- logger.warning(f"Component not found: {component}")
408
- raise HTTPException(status_code=404, detail="Component not found | NiansuhAI")
409
- return {"component": component, "status": status}
410
 
411
  if __name__ == "__main__":
412
  import uvicorn
 
1
  import os
2
  import re
 
 
3
  import uuid
4
  import json
5
  import logging
 
7
  import time
8
  from collections import defaultdict
9
  from typing import List, Dict, Any, Optional, AsyncGenerator
10
+
11
  from datetime import datetime
12
 
13
  from aiohttp import ClientSession, ClientTimeout, ClientError
14
+ from fastapi import FastAPI, HTTPException, Request, Depends, Header
15
  from fastapi.responses import StreamingResponse
16
  from pydantic import BaseModel
17
 
 
25
 
26
  # Load environment variables
27
  API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
28
+ RATE_LIMIT_PER_MINUTE = int(os.getenv('RATE_LIMIT_PER_MINUTE', '60')) # Requests per minute per IP
29
 
30
  if not API_KEYS or API_KEYS == ['']:
31
+ logger.error("No API keys found. Please set the API_KEYS environment variable.")
32
+ raise Exception("API_KEYS environment variable not set.")
 
 
 
33
 
34
+ # Simple in-memory rate limiter per IP
35
+ rate_limit_store_ip = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
 
 
 
 
 
 
 
36
 
37
+ async def rate_limiter(request: Request):
38
+ client_host = request.client.host
39
  current_time = time.time()
40
+ window_start = rate_limit_store_ip[client_host]["timestamp"]
41
  if current_time - window_start > 60:
42
+ rate_limit_store_ip[client_host] = {"count": 1, "timestamp": current_time}
43
  else:
44
+ if rate_limit_store_ip[client_host]["count"] >= RATE_LIMIT_PER_MINUTE:
45
+ logger.warning(f"Rate limit exceeded for IP: {client_host}")
46
+ raise HTTPException(status_code=429, detail='Rate limit exceeded.')
47
+ rate_limit_store_ip[client_host]["count"] += 1
48
 
49
  # Custom exception for model not working
50
  class ModelNotWorkingException(Exception):
 
122
  'PyTorchAgent': {'mode': True, 'id': "PyTorch Agent"},
123
  'ReactAgent': {'mode': True, 'id': "React Agent"},
124
  'XcodeAgent': {'mode': True, 'id': "Xcode Agent"},
125
+ 'AngularJSAgent': {'mode': True, 'id": "AngularJS Agent"},
126
  }
127
 
128
  userSelectedModel = {
 
178
  else:
179
  return cls.default_model
180
 
181
+ @classmethod
182
+ async def create_async_generator(
183
+ cls,
184
+ model: str,
185
+ messages: List[Dict[str, str]],
186
+ proxy: Optional[str] = None,
187
+ image: Any = None,
188
+ image_name: Optional[str] = None,
189
+ webSearchMode: bool = False,
190
+ **kwargs
191
+ ) -> AsyncGenerator[Any, None]:
192
+ model = cls.get_model(model)
193
+ logger.info(f"Selected model: {model}")
194
+
195
+ if not cls.working or model not in cls.models:
196
+ logger.error(f"Model {model} is not working or not supported.")
197
+ raise ModelNotWorkingException(model)
198
+
199
+ headers = {
200
+ "accept": "*/*",
201
+ "accept-language": "en-US,en;q=0.9",
202
+ "cache-control": "no-cache",
203
+ "content-type": "application/json",
204
+ "origin": cls.url,
205
+ "pragma": "no-cache",
206
+ "priority": "u=1, i",
207
+ "referer": cls.model_referers.get(model, cls.url),
208
+ "sec-ch-ua": '"Chromium";v="129", "Not=A?Brand";v="8"',
209
+ "sec-ch-ua-mobile": "?0",
210
+ "sec-ch-ua-platform": '"Linux"',
211
+ "sec-fetch-dest": "empty",
212
+ "sec-fetch-mode": "cors",
213
+ "sec-fetch-site": "same-origin",
214
+ "user-agent": "Mozilla/5.0 (X11; Linux x86_64)",
215
+ }
216
+
217
+ if model in cls.model_prefixes:
218
+ prefix = cls.model_prefixes[model]
219
+ if not messages[0]['content'].startswith(prefix):
220
+ logger.debug(f"Adding prefix '{prefix}' to the first message.")
221
+ messages[0]['content'] = f"{prefix} {messages[0]['content']}"
222
+
223
+ random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
224
+ messages[-1]['id'] = random_id
225
+ messages[-1]['role'] = 'user'
226
+
227
+ # Don't log the full message content for privacy
228
+ logger.debug(f"Generated message ID: {random_id} for model: {model}")
229
+
230
+ if image is not None:
231
+ messages[-1]['data'] = {
232
+ 'fileText': '',
233
+ 'imageBase64': to_data_uri(image),
234
+ 'title': image_name
235
+ }
236
+ messages[-1]['content'] = 'FILE:BB\n$#$\n\n$#$\n' + messages[-1]['content']
237
+ logger.debug("Image data added to the message.")
238
+
239
+ data = {
240
+ "messages": messages,
241
+ "id": random_id,
242
+ "previewToken": None,
243
+ "userId": None,
244
+ "codeModelMode": True,
245
+ "agentMode": {},
246
+ "trendingAgentMode": {},
247
+ "isMicMode": False,
248
+ "userSystemPrompt": None,
249
+ "maxTokens": 99999999,
250
+ "playgroundTopP": 0.9,
251
+ "playgroundTemperature": 0.5,
252
+ "isChromeExt": False,
253
+ "githubToken": None,
254
+ "clickedAnswer2": False,
255
+ "clickedAnswer3": False,
256
+ "clickedForceWebSearch": False,
257
+ "visitFromDelta": False,
258
+ "mobileClient": False,
259
+ "userSelectedModel": None,
260
+ "webSearchMode": webSearchMode,
261
+ }
262
+
263
+ if model in cls.agentMode:
264
+ data["agentMode"] = cls.agentMode[model]
265
+ elif model in cls.trendingAgentMode:
266
+ data["trendingAgentMode"] = cls.trendingAgentMode[model]
267
+ elif model in cls.userSelectedModel:
268
+ data["userSelectedModel"] = cls.userSelectedModel[model]
269
+ logger.info(f"Sending request to {cls.api_endpoint} with data (excluding messages).")
270
+
271
+ timeout = ClientTimeout(total=30) # Reduced timeout for faster response
272
+ retry_attempts = 3 # Reduced retry attempts for faster failure handling
273
+
274
+ for attempt in range(retry_attempts):
275
+ try:
276
+ async with ClientSession(headers=headers, timeout=timeout) as session:
277
+ async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
278
+ response.raise_for_status()
279
+ logger.info(f"Received response with status {response.status}")
280
+ if model == 'ImageGeneration':
281
+ response_text = await response.text()
282
+ url_match = re.search(r'https://storage\.googleapis\.com/[^\s\)]+', response_text)
283
+ if url_match:
284
+ image_url = url_match.group(0)
285
+ logger.info(f"Image URL found.")
286
+ yield ImageResponse(image_url, alt=messages[-1]['content'])
287
+ else:
288
+ logger.error("Image URL not found in the response.")
289
+ raise Exception("Image URL not found in the response")
290
+ else:
291
+ full_response = ""
292
+ search_results_json = ""
293
+ try:
294
+ async for chunk, _ in response.content.iter_chunks():
295
+ if chunk:
296
+ decoded_chunk = chunk.decode(errors='ignore')
297
+ decoded_chunk = re.sub(r'\$@\$v=[^$]+\$@\$', '', decoded_chunk)
298
+ if decoded_chunk.strip():
299
+ if '$~~~$' in decoded_chunk:
300
+ search_results_json += decoded_chunk
301
+ else:
302
+ full_response += decoded_chunk
303
+ yield decoded_chunk
304
+ logger.info("Finished streaming response chunks.")
305
+ except Exception as e:
306
+ logger.exception("Error while iterating over response chunks.")
307
+ raise e
308
+ if data["webSearchMode"] and search_results_json:
309
+ match = re.search(r'\$~~~\$(.*?)\$~~~\$', search_results_json, re.DOTALL)
310
+ if match:
311
+ try:
312
+ search_results = json.loads(match.group(1))
313
+ formatted_results = "\n\n**Sources:**\n"
314
+ for i, result in enumerate(search_results[:5], 1):
315
+ formatted_results += f"{i}. [{result['title']}]({result['link']})\n"
316
+ logger.info("Formatted search results.")
317
+ yield formatted_results
318
+ except json.JSONDecodeError as je:
319
+ logger.error("Failed to parse search results JSON.")
320
+ raise je
321
+ break # Exit the retry loop if successful
322
+ except ClientError as ce:
323
+ logger.error(f"Client error occurred: {ce}. Retrying attempt {attempt + 1}/{retry_attempts}")
324
+ if attempt == retry_attempts - 1:
325
+ raise HTTPException(status_code=502, detail="Error communicating with the external API.")
326
+ except asyncio.TimeoutError:
327
+ logger.error(f"Request timed out. Retrying attempt {attempt + 1}/{retry_attempts}")
328
+ if attempt == retry_attempts - 1:
329
+ raise HTTPException(status_code=504, detail="External API request timed out.")
330
+ except Exception as e:
331
+ logger.error(f"Unexpected error: {e}. Retrying attempt {attempt + 1}/{retry_attempts}")
332
+ if attempt == retry_attempts - 1:
333
+ raise HTTPException(status_code=500, detail=str(e))
334
 
335
  # FastAPI app setup
336
  app = FastAPI()
337
 
338
+ # Implement per-IP rate limiting middleware
339
+ @app.middleware("http")
340
+ async def rate_limit_middleware(request: Request, call_next):
341
+ await rate_limiter(request)
342
+ response = await call_next(request)
343
+ return response
344
+
345
+ # Pydantic models for OpenAI API
346
  class Message(BaseModel):
347
  role: str
348
  content: str
349
 
350
+ class ChatCompletionRequest(BaseModel):
351
  model: str
352
  messages: List[Message]
353
+ temperature: Optional[float] = 1.0
354
+ top_p: Optional[float] = 1.0
355
+ n: Optional[int] = 1
356
  stream: Optional[bool] = False
357
+ stop: Optional[Any] = None # Can be a string or list of strings
358
+ max_tokens: Optional[int] = None
359
+ presence_penalty: Optional[float] = 0.0
360
+ frequency_penalty: Optional[float] = 0.0
361
+ logit_bias: Optional[Dict[str, float]] = None
362
+ user: Optional[str] = None
363
+
364
+ def create_chat_completion_response(content: str, model: str, usage: Dict[str, int]) -> Dict[str, Any]:
365
  return {
366
  "id": f"chatcmpl-{uuid.uuid4()}",
367
+ "object": "chat.completion",
368
  "created": int(datetime.now().timestamp()),
369
  "model": model,
370
  "choices": [
371
  {
372
  "index": 0,
373
+ "message": {
374
+ "role": "assistant",
375
+ "content": content
376
+ },
377
+ "finish_reason": "stop"
378
  }
379
  ],
380
+ "usage": usage
381
  }
382
 
383
+ def create_stream_response_chunk(content: str, role: Optional[str] = None, finish_reason: Optional[str] = None):
384
+ delta = {}
385
+ if role:
386
+ delta['role'] = role
387
+ if content:
388
+ delta['content'] = content
389
+ return {
390
+ "object": "chat.completion.chunk",
391
+ "created": int(datetime.now().timestamp()),
392
+ "model": "", # Model name can be added if necessary
393
+ "choices": [
394
+ {
395
+ "delta": delta,
396
+ "index": 0,
397
+ "finish_reason": finish_reason
398
+ }
399
+ ]
400
+ }
401
 
402
+ @app.post("/v1/chat/completions")
403
+ async def chat_completions(request: ChatCompletionRequest, authorization: str = Header(None)):
404
+ # Verify API key
405
+ if not authorization or not authorization.startswith('Bearer '):
406
+ logger.warning("Invalid authorization header format.")
407
+ raise HTTPException(status_code=401, detail='Invalid authorization header format.')
408
+ api_key = authorization[7:]
409
+ if api_key not in API_KEYS:
410
+ logger.warning(f"Invalid API key attempted: {api_key}")
411
+ raise HTTPException(status_code=401, detail='Invalid API key.')
412
+
413
+ logger.info(f"Received chat completion request for model: {request.model}")
414
+
415
+ # Validate model
416
+ if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
417
+ logger.warning(f"Attempt to use unavailable model: {request.model}")
418
+ raise HTTPException(status_code=400, detail="The model is not available.")
419
 
420
+ # Process the request
421
  try:
422
+ # Convert messages to dicts
423
+ messages = [msg.dict() for msg in request.messages]
424
+
425
+ # Check if the user is requesting image generation
426
+ image_generation_requested = any(
427
+ re.search(r'\b(generate|create|draw)\b.*\b(image|picture|art)\b', msg['content'], re.IGNORECASE)
428
+ for msg in messages if msg['role'] == 'user'
 
 
 
 
 
429
  )
430
 
431
+ if image_generation_requested:
432
+ model = 'ImageGeneration'
433
+ # For image generation, use the last message as prompt
434
+ prompt = messages[-1]['content']
435
+ # Build messages for the Blackbox.create_async_generator
436
+ messages = [{"role": "user", "content": prompt}]
437
+ async_generator = Blackbox.create_async_generator(
438
+ model=model,
439
+ messages=messages,
440
+ image=None,
441
+ image_name=None,
442
+ webSearchMode=False
443
+ )
444
+
445
+ # Collect images
446
+ images = []
447
+ count = 0
448
+ async for response in async_generator:
449
+ if isinstance(response, ImageResponse):
450
+ images.append(response.url)
451
+ count += 1
452
+ if count >= request.n:
453
+ break
454
+
455
+ # Build response content with image URLs
456
+ response_content = "\n".join(f"![Generated Image]({url})" for url in images)
457
+ completion_tokens = len(response_content.split())
458
+ else:
459
+ # Use the requested model
460
+ async_generator = Blackbox.create_async_generator(
461
+ model=request.model,
462
+ messages=messages,
463
+ image=None,
464
+ image_name=None,
465
+ webSearchMode=False
466
+ )
467
+ # Usage tracking
468
+ completion_tokens = 0 # Will be updated as we process the response
469
+
470
+ prompt_tokens = sum(len(msg['content'].split()) for msg in messages)
471
+
472
  if request.stream:
473
  async def generate():
474
+ nonlocal completion_tokens
475
  try:
476
+ # Initial delta with role
477
+ initial_chunk = create_stream_response_chunk(content=None, role="assistant")
478
+ yield f"data: {json.dumps(initial_chunk)}\n\n"
479
+
480
  async for chunk in async_generator:
481
+ if isinstance(chunk, str):
482
+ completion_tokens += len(chunk.split())
483
+ response_chunk = create_stream_response_chunk(content=chunk)
484
+ yield f"data: {json.dumps(response_chunk)}\n\n"
485
+ elif isinstance(chunk, ImageResponse):
486
+ content = f"![Generated Image]({chunk.url})"
487
+ completion_tokens += len(content.split())
488
+ response_chunk = create_stream_response_chunk(content=content)
489
+ yield f"data: {json.dumps(response_chunk)}\n\n"
490
  else:
491
+ pass # Handle other types if necessary
492
+ # Finish reason
493
+ final_chunk = create_stream_response_chunk(content=None, finish_reason="stop")
494
+ yield f"data: {json.dumps(final_chunk)}\n\n"
495
  yield "data: [DONE]\n\n"
 
 
 
496
  except Exception as e:
497
+ logger.exception("Error during streaming response generation.")
498
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
 
 
499
  return StreamingResponse(generate(), media_type="text/event-stream")
500
  else:
501
  response_content = ""
502
  async for chunk in async_generator:
503
+ if isinstance(chunk, str):
 
 
504
  response_content += chunk
505
+ elif isinstance(chunk, ImageResponse):
506
+ response_content += f"![Generated Image]({chunk.url})\n"
507
+ completion_tokens = len(response_content.split())
508
+ usage = {
509
+ "prompt_tokens": prompt_tokens,
510
+ "completion_tokens": completion_tokens,
511
+ "total_tokens": prompt_tokens + completion_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  }
513
+ return create_chat_completion_response(response_content, request.model, usage)
514
  except ModelNotWorkingException as e:
515
  logger.warning(f"Model not working: {e}")
516
  raise HTTPException(status_code=503, detail=str(e))
 
 
 
517
  except Exception as e:
518
+ logger.exception("An unexpected error occurred while processing the chat completions request.")
519
  raise HTTPException(status_code=500, detail=str(e))
520
 
521
+ @app.get("/v1/models")
522
+ async def get_models(authorization: str = Header(None)):
523
+ # Verify API key
524
+ if not authorization or not authorization.startswith('Bearer '):
525
+ logger.warning("Invalid authorization header format.")
526
+ raise HTTPException(status_code=401, detail='Invalid authorization header format.')
527
+ api_key = authorization[7:]
528
+ if api_key not in API_KEYS:
529
+ logger.warning(f"Invalid API key attempted: {api_key}")
530
+ raise HTTPException(status_code=401, detail='Invalid API key.')
531
+
532
+ logger.info("Fetching available models.")
533
+ # Return models in OpenAI format
534
+ models_data = [{"id": model, "object": "model", "owned_by": "organization-owner", "permission": []} for model in Blackbox.models]
535
+ return {"data": models_data, "object": "list"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
  if __name__ == "__main__":
538
  import uvicorn