Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -4,22 +4,28 @@ from fastapi import FastAPI, HTTPException, Request
|
|
4 |
from fastapi.responses import StreamingResponse, HTMLResponse
|
5 |
from pydantic import BaseModel
|
6 |
import httpx
|
7 |
-
from functools import lru_cache
|
8 |
|
9 |
load_dotenv()
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
|
|
13 |
api_keys_str = os.getenv('API_KEYS')
|
14 |
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
|
15 |
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
|
|
|
16 |
|
17 |
-
if
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
class Payload(BaseModel):
|
21 |
model: str
|
22 |
messages: list
|
|
|
23 |
|
24 |
@app.get("/", response_class=HTMLResponse)
|
25 |
async def root():
|
@@ -65,15 +71,20 @@ async def fetch_models():
|
|
65 |
async def get_completion(payload: Payload, request: Request):
|
66 |
api_key = request.headers.get("Authorization")
|
67 |
|
|
|
68 |
if api_key not in valid_api_keys:
|
69 |
raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
|
73 |
async def stream_generator(payload_dict):
|
74 |
async with httpx.AsyncClient() as client:
|
75 |
try:
|
76 |
-
async with client.stream("POST", f"{
|
77 |
response.raise_for_status()
|
78 |
async for line in response.aiter_lines():
|
79 |
if line:
|
|
|
4 |
from fastapi.responses import StreamingResponse, HTMLResponse
|
5 |
from pydantic import BaseModel
|
6 |
import httpx
|
|
|
7 |
|
8 |
load_dotenv()
|
9 |
|
10 |
app = FastAPI()
|
11 |
|
12 |
+
# Get API keys and secret endpoint from environment variables
|
13 |
api_keys_str = os.getenv('API_KEYS')
|
14 |
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
|
15 |
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
|
16 |
+
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
|
17 |
|
18 |
+
# Validate if the main secret API endpoint is set
|
19 |
+
if not secret_api_endpoint or not secret_api_endpoint_2:
|
20 |
+
raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
|
21 |
+
|
22 |
+
# Define models that should use the secondary endpoint
|
23 |
+
alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
|
24 |
|
25 |
class Payload(BaseModel):
|
26 |
model: str
|
27 |
messages: list
|
28 |
+
stream: bool
|
29 |
|
30 |
@app.get("/", response_class=HTMLResponse)
|
31 |
async def root():
|
|
|
71 |
async def get_completion(payload: Payload, request: Request):
|
72 |
api_key = request.headers.get("Authorization")
|
73 |
|
74 |
+
# Validate API key
|
75 |
if api_key not in valid_api_keys:
|
76 |
raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
|
77 |
|
78 |
+
# Determine which endpoint to use based on the model
|
79 |
+
endpoint = secret_api_endpoint_2 if payload.model in alternate_models else secret_api_endpoint
|
80 |
+
|
81 |
+
# Use the payload directly as it includes stream and other user data
|
82 |
+
payload_dict = payload.dict()
|
83 |
|
84 |
async def stream_generator(payload_dict):
|
85 |
async with httpx.AsyncClient() as client:
|
86 |
try:
|
87 |
+
async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response:
|
88 |
response.raise_for_status()
|
89 |
async for line in response.aiter_lines():
|
90 |
if line:
|