Niansuh commited on
Commit
2722c48
·
verified ·
1 Parent(s): 80a3863

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +51 -91
main.py CHANGED
@@ -1,34 +1,14 @@
1
- import os
2
- import uuid
3
- import logging
4
- import json
5
  import re
6
  import random
7
  import string
8
- from datetime import datetime
9
- from typing import Any, Dict, List, Optional
10
-
11
- import httpx
12
- from fastapi import FastAPI, HTTPException, Depends
13
  from pydantic import BaseModel
14
- from starlette.middleware.cors import CORSMiddleware
15
- from starlette.responses import StreamingResponse
16
-
17
- # Setup logging
18
- logging.basicConfig(level=logging.INFO)
19
- logger = logging.getLogger(__name__)
20
-
21
- # FastAPI app setup
22
- app = FastAPI()
23
-
24
- # CORS middleware setup
25
- app.add_middleware(
26
- CORSMiddleware,
27
- allow_origins=["*"],
28
- allow_credentials=True,
29
- allow_methods=["*"],
30
- allow_headers=["*"],
31
- )
32
 
33
  # Mock implementations for ImageResponse and to_data_uri
34
  class ImageResponse:
@@ -40,7 +20,6 @@ def to_data_uri(image: Any) -> str:
40
  # Placeholder for actual image encoding
41
  return "data:image/png;base64,..." # Replace with actual base64 data
42
 
43
- # Define models and providers
44
  class AsyncGeneratorProvider:
45
  pass
46
 
@@ -54,7 +33,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
54
  supports_stream = True
55
  supports_system_message = True
56
  supports_message_history = True
57
-
58
  default_model = 'blackbox'
59
  models = [
60
  'blackbox',
@@ -79,13 +58,13 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
79
  'llama-3.1-70b': {'mode': True, 'id': "llama-3.1-70b"},
80
  'llama-3.1-405b': {'mode': True, 'id': "llama-3.1-405b"},
81
  }
82
-
83
  userSelectedModel = {
84
  "gpt-4o": "gpt-4o",
85
  "gemini-pro": "gemini-pro",
86
  'claude-sonnet-3.5': "claude-sonnet-3.5",
87
  }
88
-
89
  model_aliases = {
90
  "gemini-flash": "gemini-1.5-flash",
91
  "flux": "ImageGenerationLV45LJp",
@@ -136,7 +115,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
136
  if not messages[0]['content'].startswith(prefix):
137
  messages[0]['content'] = f"{prefix} {messages[0]['content']}"
138
 
139
- async with httpx.AsyncClient(headers=headers) as session:
140
  if image is not None:
141
  messages[-1]["data"] = {
142
  "fileText": image_name,
@@ -187,14 +166,16 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
187
  else:
188
  raise Exception("Image URL not found in the response")
189
  else:
190
- async for chunk in response.aiter_bytes():
191
  if chunk:
192
  decoded_chunk = chunk.decode()
193
  decoded_chunk = re.sub(r'\$@\$v=[^$]+\$@\$', '', decoded_chunk)
194
  if decoded_chunk.strip():
195
  yield decoded_chunk
196
 
197
- # Message and chat request models
 
 
198
  class Message(BaseModel):
199
  role: str
200
  content: str
@@ -202,68 +183,44 @@ class Message(BaseModel):
202
  class ChatRequest(BaseModel):
203
  model: str
204
  messages: List[Message]
205
- stream: Optional[bool] = False
206
-
207
- # Verify app secret (placeholder)
208
- async def verify_app_secret(app_secret: str):
209
- if app_secret != os.getenv("APP_SECRET"):
210
- raise HTTPException(status_code=403, detail="Forbidden")
211
-
212
- @app.post("/v1/chat/completions")
213
- async def chat_completions(
214
- request: ChatRequest,
215
- app_secret: str = Depends(lambda: os.getenv("APP_SECRET"))
216
- ):
217
- # Validate model
218
- if request.model not in Blackbox.models:
219
- raise HTTPException(
220
- status_code=400,
221
- detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(Blackbox.models)}",
222
- )
223
-
224
- # Generate a UUID for the conversation
225
- conversation_id = str(uuid.uuid4()).replace("-", "")
226
-
227
- # Define headers
228
- headers = {
229
- "Authorization": f"Bearer {app_secret}",
230
- "Content-Type": "application/json",
231
- "uniqueid": conversation_id,
232
  }
233
 
234
- json_data = {
235
- "attachments": [],
236
- "conversationId": conversation_id,
237
- "prompt": "\n".join(
238
- [f"{msg.role}: {msg.content}" for msg in request.messages]
239
- ),
240
- }
241
 
242
- async def generate():
243
- async with httpx.AsyncClient() as client:
244
- try:
245
- async with client.stream('POST', Blackbox.api_endpoint, headers=headers, json=json_data) as response:
246
- response.raise_for_status()
247
- async for line in response.aiter_lines():
248
- if line and line != "[DONE]":
249
- content = json.loads(line)["data"]
250
- yield f"data: {json.dumps(content)}\n\n"
251
- yield "data: [DONE]\n\n"
252
- except httpx.HTTPStatusError as e:
253
- logger.error(f"HTTP error occurred: {e}")
254
- raise HTTPException(status_code=e.response.status_code, detail=str(e))
255
- except httpx.RequestError as e:
256
- logger.error(f"An error occurred while requesting: {e}")
257
- raise HTTPException(status_code=500, detail=str(e))
258
 
259
  if request.stream:
 
 
 
 
 
260
  return StreamingResponse(generate(), media_type="text/event-stream")
261
  else:
262
- full_response = ""
263
- async for chunk in generate():
264
- if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"):
265
- data = json.loads(chunk[6:])
266
- full_response += data.get("choices", [{}])[0].get("delta", {}).get("content", "")
267
 
268
  return {
269
  "id": f"chatcmpl-{uuid.uuid4()}",
@@ -272,9 +229,12 @@ async def chat_completions(
272
  "model": request.model,
273
  "choices": [
274
  {
275
- "index": 0,
276
- "message": {"role": "assistant", "content": full_response},
 
 
277
  "finish_reason": "stop",
 
278
  }
279
  ],
280
  "usage": None,
 
1
+ from __future__ import annotations
2
+
 
 
3
  import re
4
  import random
5
  import string
6
+ import uuid
7
+ from aiohttp import ClientSession
8
+ from fastapi import FastAPI, HTTPException
 
 
9
  from pydantic import BaseModel
10
+ from typing import List, Dict, Any, Optional
11
+ from datetime import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Mock implementations for ImageResponse and to_data_uri
14
  class ImageResponse:
 
20
  # Placeholder for actual image encoding
21
  return "data:image/png;base64,..." # Replace with actual base64 data
22
 
 
23
  class AsyncGeneratorProvider:
24
  pass
25
 
 
33
  supports_stream = True
34
  supports_system_message = True
35
  supports_message_history = True
36
+
37
  default_model = 'blackbox'
38
  models = [
39
  'blackbox',
 
58
  'llama-3.1-70b': {'mode': True, 'id': "llama-3.1-70b"},
59
  'llama-3.1-405b': {'mode': True, 'id': "llama-3.1-405b"},
60
  }
61
+
62
  userSelectedModel = {
63
  "gpt-4o": "gpt-4o",
64
  "gemini-pro": "gemini-pro",
65
  'claude-sonnet-3.5': "claude-sonnet-3.5",
66
  }
67
+
68
  model_aliases = {
69
  "gemini-flash": "gemini-1.5-flash",
70
  "flux": "ImageGenerationLV45LJp",
 
115
  if not messages[0]['content'].startswith(prefix):
116
  messages[0]['content'] = f"{prefix} {messages[0]['content']}"
117
 
118
+ async with ClientSession(headers=headers) as session:
119
  if image is not None:
120
  messages[-1]["data"] = {
121
  "fileText": image_name,
 
166
  else:
167
  raise Exception("Image URL not found in the response")
168
  else:
169
+ async for chunk in response.content.iter_any():
170
  if chunk:
171
  decoded_chunk = chunk.decode()
172
  decoded_chunk = re.sub(r'\$@\$v=[^$]+\$@\$', '', decoded_chunk)
173
  if decoded_chunk.strip():
174
  yield decoded_chunk
175
 
176
+ # FastAPI app setup
177
+ app = FastAPI()
178
+
179
  class Message(BaseModel):
180
  role: str
181
  content: str
 
183
  class ChatRequest(BaseModel):
184
  model: str
185
  messages: List[Message]
186
+ stream: Optional[bool] = False # Add this for streaming
187
+
188
+ def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
189
+ return {
190
+ "id": f"chatcmpl-{uuid.uuid4()}",
191
+ "object": "chat.completion.chunk",
192
+ "created": int(datetime.now().timestamp()),
193
+ "model": model,
194
+ "choices": [
195
+ {
196
+ "index": 0,
197
+ "delta": {"content": content, "role": "assistant"},
198
+ "finish_reason": finish_reason,
199
+ }
200
+ ],
201
+ "usage": None,
 
 
 
 
 
 
 
 
 
 
 
202
  }
203
 
204
+ @app.post("/v1/chat/completions")
205
+ async def chat_completions(request: ChatRequest):
206
+ messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
 
 
 
 
207
 
208
+ async_generator = Blackbox.create_async_generator(
209
+ model=request.model,
210
+ messages=messages
211
+ )
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  if request.stream:
214
+ async def generate():
215
+ async for chunk in async_generator:
216
+ yield f"data: {json.dumps(create_response(chunk, request.model))}\n\n"
217
+ yield "data: [DONE]\n\n"
218
+
219
  return StreamingResponse(generate(), media_type="text/event-stream")
220
  else:
221
+ response_content = ""
222
+ async for chunk in async_generator:
223
+ response_content += chunk if isinstance(chunk, str) else chunk.content # Concatenate response
 
 
224
 
225
  return {
226
  "id": f"chatcmpl-{uuid.uuid4()}",
 
229
  "model": request.model,
230
  "choices": [
231
  {
232
+ "message": {
233
+ "role": "assistant",
234
+ "content": response_content
235
+ },
236
  "finish_reason": "stop",
237
+ "index": 0
238
  }
239
  ],
240
  "usage": None,