ParthSadaria commited on
Commit
b955cc1
·
verified ·
1 Parent(s): ac4bad0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +16 -5
main.py CHANGED
@@ -4,22 +4,28 @@ from fastapi import FastAPI, HTTPException, Request
4
  from fastapi.responses import StreamingResponse, HTMLResponse
5
  from pydantic import BaseModel
6
  import httpx
7
- from functools import lru_cache
8
 
9
  load_dotenv()
10
 
11
  app = FastAPI()
12
 
 
13
  api_keys_str = os.getenv('API_KEYS')
14
  valid_api_keys = api_keys_str.split(',') if api_keys_str else []
15
  secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
 
16
 
17
- if not secret_api_endpoint:
18
- raise HTTPException(status_code=500, detail="API endpoint is not configured in environment variables.")
 
 
 
 
19
 
20
  class Payload(BaseModel):
21
  model: str
22
  messages: list
 
23
 
24
  @app.get("/", response_class=HTMLResponse)
25
  async def root():
@@ -65,15 +71,20 @@ async def fetch_models():
65
  async def get_completion(payload: Payload, request: Request):
66
  api_key = request.headers.get("Authorization")
67
 
 
68
  if api_key not in valid_api_keys:
69
  raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
70
 
71
- payload_dict = {**payload.dict(), "stream": True}
 
 
 
 
72
 
73
  async def stream_generator(payload_dict):
74
  async with httpx.AsyncClient() as client:
75
  try:
76
- async with client.stream("POST", f"{secret_api_endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response:
77
  response.raise_for_status()
78
  async for line in response.aiter_lines():
79
  if line:
 
4
  from fastapi.responses import StreamingResponse, HTMLResponse
5
  from pydantic import BaseModel
6
  import httpx
 
7
 
8
  load_dotenv()
9
 
10
  app = FastAPI()
11
 
12
+ # Get API keys and secret endpoint from environment variables
13
  api_keys_str = os.getenv('API_KEYS')
14
  valid_api_keys = api_keys_str.split(',') if api_keys_str else []
15
  secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
16
+ secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
17
 
18
+ # Validate if the main secret API endpoint is set
19
+ if not secret_api_endpoint or not secret_api_endpoint_2:
20
+ raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
21
+
22
+ # Define models that should use the secondary endpoint
23
+ alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
24
 
25
  class Payload(BaseModel):
26
  model: str
27
  messages: list
28
+ stream: bool
29
 
30
  @app.get("/", response_class=HTMLResponse)
31
  async def root():
 
71
  async def get_completion(payload: Payload, request: Request):
72
  api_key = request.headers.get("Authorization")
73
 
74
+ # Validate API key
75
  if api_key not in valid_api_keys:
76
  raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
77
 
78
+ # Determine which endpoint to use based on the model
79
+ endpoint = secret_api_endpoint_2 if payload.model in alternate_models else secret_api_endpoint
80
+
81
+ # Use the payload directly as it includes stream and other user data
82
+ payload_dict = payload.dict()
83
 
84
  async def stream_generator(payload_dict):
85
  async with httpx.AsyncClient() as client:
86
  try:
87
+ async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response:
88
  response.raise_for_status()
89
  async for line in response.aiter_lines():
90
  if line: