Niansuh commited on
Commit
9cf0d3b
·
verified ·
1 Parent(s): 018ab76

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +41 -26
main.py CHANGED
@@ -1,6 +1,10 @@
1
  import os
2
  import uuid
3
  import logging
 
 
 
 
4
  from datetime import datetime
5
  from typing import Any, Dict, List, Optional
6
 
@@ -8,7 +12,23 @@ import httpx
8
  from fastapi import FastAPI, HTTPException, Depends
9
  from pydantic import BaseModel
10
  from starlette.middleware.cors import CORSMiddleware
11
- from starlette.responses import StreamingResponse, Response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Mock implementations for ImageResponse and to_data_uri
14
  class ImageResponse:
@@ -20,6 +40,7 @@ def to_data_uri(image: Any) -> str:
20
  # Placeholder for actual image encoding
21
  return "data:image/png;base64,..." # Replace with actual base64 data
22
 
 
23
  class AsyncGeneratorProvider:
24
  pass
25
 
@@ -33,7 +54,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
33
  supports_stream = True
34
  supports_system_message = True
35
  supports_message_history = True
36
-
37
  default_model = 'blackbox'
38
  models = [
39
  'blackbox',
@@ -58,13 +79,13 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
58
  'llama-3.1-70b': {'mode': True, 'id': "llama-3.1-70b"},
59
  'llama-3.1-405b': {'mode': True, 'id': "llama-3.1-405b"},
60
  }
61
-
62
  userSelectedModel = {
63
  "gpt-4o": "gpt-4o",
64
  "gemini-pro": "gemini-pro",
65
  'claude-sonnet-3.5': "claude-sonnet-3.5",
66
  }
67
-
68
  model_aliases = {
69
  "gemini-flash": "gemini-1.5-flash",
70
  "flux": "ImageGenerationLV45LJp",
@@ -115,7 +136,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
115
  if not messages[0]['content'].startswith(prefix):
116
  messages[0]['content'] = f"{prefix} {messages[0]['content']}"
117
 
118
- async with ClientSession(headers=headers) as session:
119
  if image is not None:
120
  messages[-1]["data"] = {
121
  "fileText": image_name,
@@ -166,16 +187,14 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
166
  else:
167
  raise Exception("Image URL not found in the response")
168
  else:
169
- async for chunk in response.content.iter_any():
170
  if chunk:
171
  decoded_chunk = chunk.decode()
172
  decoded_chunk = re.sub(r'\$@\$v=[^$]+\$@\$', '', decoded_chunk)
173
  if decoded_chunk.strip():
174
  yield decoded_chunk
175
 
176
- # FastAPI app setup
177
- app = FastAPI()
178
-
179
  class Message(BaseModel):
180
  role: str
181
  content: str
@@ -183,19 +202,22 @@ class Message(BaseModel):
183
  class ChatRequest(BaseModel):
184
  model: str
185
  messages: List[Message]
 
186
 
187
- from fastapi.responses import Response
 
 
 
188
 
189
  @app.post("/v1/chat/completions")
190
- async def chat_completions(
191
- request: ChatRequest, app_secret: str = Depends(verify_app_secret)
192
- ):
193
  logger.info(f"Received chat completion request for model: {request.model}")
194
 
195
- if request.model not in [model['id'] for model in ALLOWED_MODELS]:
 
196
  raise HTTPException(
197
  status_code=400,
198
- detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(model['id'] for model in ALLOWED_MODELS)}",
199
  )
200
 
201
  # Generate a UUID for the conversation
@@ -205,10 +227,7 @@ async def chat_completions(
205
  "attachments": [],
206
  "conversationId": conversation_id,
207
  "prompt": "\n".join(
208
- [
209
- f"{'User' if msg.role == 'user' else 'Assistant'}: {msg.content}"
210
- for msg in request.messages
211
- ]
212
  ),
213
  }
214
 
@@ -217,13 +236,12 @@ async def chat_completions(
217
  async def generate():
218
  async with httpx.AsyncClient() as client:
219
  try:
220
- async with client.stream('POST', f'{BASE_URL}/api/chat/gpt4o/chat', headers=headers, json=json_data, timeout=120.0) as response:
221
  response.raise_for_status()
222
  async for line in response.aiter_lines():
223
  if line and line != "[DONE]":
224
  content = json.loads(line)["data"]
225
- yield f"data: {json.dumps(create_chat_completion_data(content['message'], request.model))}\n\n"
226
- yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n"
227
  yield "data: [DONE]\n\n"
228
  except httpx.HTTPStatusError as e:
229
  logger.error(f"HTTP error occurred: {e}")
@@ -233,16 +251,13 @@ async def chat_completions(
233
  raise HTTPException(status_code=500, detail=str(e))
234
 
235
  if request.stream:
236
- logger.info("Streaming response")
237
  return StreamingResponse(generate(), media_type="text/event-stream")
238
  else:
239
- logger.info("Non-streaming response")
240
  full_response = ""
241
  async for chunk in generate():
242
  if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"):
243
  data = json.loads(chunk[6:])
244
- if data["choices"][0]["delta"].get("content"):
245
- full_response += data["choices"][0]["delta"]["content"]
246
 
247
  return {
248
  "id": f"chatcmpl-{uuid.uuid4()}",
 
1
  import os
2
  import uuid
3
  import logging
4
+ import json
5
+ import re
6
+ import random
7
+ import string
8
  from datetime import datetime
9
  from typing import Any, Dict, List, Optional
10
 
 
12
  from fastapi import FastAPI, HTTPException, Depends
13
  from pydantic import BaseModel
14
  from starlette.middleware.cors import CORSMiddleware
15
+ from starlette.responses import StreamingResponse
16
+
17
+ # Setup logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # FastAPI app setup
22
+ app = FastAPI()
23
+
24
+ # CORS middleware setup
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
 
33
  # Mock implementations for ImageResponse and to_data_uri
34
  class ImageResponse:
 
40
  # Placeholder for actual image encoding
41
  return "data:image/png;base64,..." # Replace with actual base64 data
42
 
43
+ # Define models and providers
44
  class AsyncGeneratorProvider:
45
  pass
46
 
 
54
  supports_stream = True
55
  supports_system_message = True
56
  supports_message_history = True
57
+
58
  default_model = 'blackbox'
59
  models = [
60
  'blackbox',
 
79
  'llama-3.1-70b': {'mode': True, 'id': "llama-3.1-70b"},
80
  'llama-3.1-405b': {'mode': True, 'id': "llama-3.1-405b"},
81
  }
82
+
83
  userSelectedModel = {
84
  "gpt-4o": "gpt-4o",
85
  "gemini-pro": "gemini-pro",
86
  'claude-sonnet-3.5': "claude-sonnet-3.5",
87
  }
88
+
89
  model_aliases = {
90
  "gemini-flash": "gemini-1.5-flash",
91
  "flux": "ImageGenerationLV45LJp",
 
136
  if not messages[0]['content'].startswith(prefix):
137
  messages[0]['content'] = f"{prefix} {messages[0]['content']}"
138
 
139
+ async with httpx.AsyncClient(headers=headers) as session:
140
  if image is not None:
141
  messages[-1]["data"] = {
142
  "fileText": image_name,
 
187
  else:
188
  raise Exception("Image URL not found in the response")
189
  else:
190
+ async for chunk in response.aiter_bytes():
191
  if chunk:
192
  decoded_chunk = chunk.decode()
193
  decoded_chunk = re.sub(r'\$@\$v=[^$]+\$@\$', '', decoded_chunk)
194
  if decoded_chunk.strip():
195
  yield decoded_chunk
196
 
197
+ # Message and chat request models
 
 
198
  class Message(BaseModel):
199
  role: str
200
  content: str
 
202
  class ChatRequest(BaseModel):
203
  model: str
204
  messages: List[Message]
205
+ stream: Optional[bool] = False
206
 
207
+ # Verify app secret (placeholder)
208
+ async def verify_app_secret(app_secret: str):
209
+ if app_secret != os.getenv("APP_SECRET"):
210
+ raise HTTPException(status_code=403, detail="Forbidden")
211
 
212
  @app.post("/v1/chat/completions")
213
+ async def chat_completions(request: ChatRequest, app_secret: str = Depends(verify_app_secret)):
 
 
214
  logger.info(f"Received chat completion request for model: {request.model}")
215
 
216
+ # Validate model
217
+ if request.model not in Blackbox.models:
218
  raise HTTPException(
219
  status_code=400,
220
+ detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(Blackbox.models)}",
221
  )
222
 
223
  # Generate a UUID for the conversation
 
227
  "attachments": [],
228
  "conversationId": conversation_id,
229
  "prompt": "\n".join(
230
+ [f"{msg.role}: {msg.content}" for msg in request.messages]
 
 
 
231
  ),
232
  }
233
 
 
236
  async def generate():
237
  async with httpx.AsyncClient() as client:
238
  try:
239
+ async with client.stream('POST', f'{Blackbox.api_endpoint}', headers=headers, json=json_data, timeout=120.0) as response:
240
  response.raise_for_status()
241
  async for line in response.aiter_lines():
242
  if line and line != "[DONE]":
243
  content = json.loads(line)["data"]
244
+ yield f"data: {json.dumps(content)}\n\n"
 
245
  yield "data: [DONE]\n\n"
246
  except httpx.HTTPStatusError as e:
247
  logger.error(f"HTTP error occurred: {e}")
 
251
  raise HTTPException(status_code=500, detail=str(e))
252
 
253
  if request.stream:
 
254
  return StreamingResponse(generate(), media_type="text/event-stream")
255
  else:
 
256
  full_response = ""
257
  async for chunk in generate():
258
  if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"):
259
  data = json.loads(chunk[6:])
260
+ full_response += data.get("choices", [{}])[0].get("delta", {}).get("content", "")
 
261
 
262
  return {
263
  "id": f"chatcmpl-{uuid.uuid4()}",