ParthSadaria commited on
Commit
8e4491b
·
verified ·
1 Parent(s): 3109050

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +13 -4
main.py CHANGED
@@ -10,12 +10,18 @@ load_dotenv()
10
 
11
  app = FastAPI()
12
 
 
13
  api_keys_str = os.getenv('API_KEYS')
14
  valid_api_keys = api_keys_str.split(',') if api_keys_str else []
15
  secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
 
16
 
17
- if not secret_api_endpoint:
18
- raise HTTPException(status_code=500, detail="API endpoint is not configured in environment variables.")
 
 
 
 
19
 
20
  class Payload(BaseModel):
21
  model: str
@@ -65,15 +71,18 @@ async def fetch_models():
65
  async def get_completion(payload: Payload, request: Request):
66
  api_key = request.headers.get("Authorization")
67
 
 
68
  if api_key not in valid_api_keys:
69
  raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
70
 
 
 
71
  payload_dict = {**payload.dict(), "stream": True}
72
 
73
  async def stream_generator(payload_dict):
74
  async with httpx.AsyncClient() as client:
75
  try:
76
- async with client.stream("POST", f"{secret_api_endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response:
77
  response.raise_for_status()
78
  async for line in response.aiter_lines():
79
  if line:
@@ -96,4 +105,4 @@ async def startup_event():
96
 
97
  if __name__ == "__main__":
98
  import uvicorn
99
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
10
 
11
  app = FastAPI()
12
 
13
+ # Get API keys and secret endpoint from environment variables
14
  api_keys_str = os.getenv('API_KEYS')
15
  valid_api_keys = api_keys_str.split(',') if api_keys_str else []
16
  secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
17
+ secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
18
 
19
+ # Validate if the main secret API endpoint is set
20
+ if not secret_api_endpoint or not secret_api_endpoint_2:
21
+ raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
22
+
23
+ # Define models that should use the secondary endpoint
24
+ alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
25
 
26
  class Payload(BaseModel):
27
  model: str
 
71
  async def get_completion(payload: Payload, request: Request):
72
  api_key = request.headers.get("Authorization")
73
 
74
+ # Validate API key
75
  if api_key not in valid_api_keys:
76
  raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
77
 
78
+ # Determine which endpoint to use based on the model
79
+ endpoint = secret_api_endpoint_2 if payload.model in alternate_models else secret_api_endpoint
80
  payload_dict = {**payload.dict(), "stream": True}
81
 
82
  async def stream_generator(payload_dict):
83
  async with httpx.AsyncClient() as client:
84
  try:
85
+ async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response:
86
  response.raise_for_status()
87
  async for line in response.aiter_lines():
88
  if line:
 
105
 
106
  if __name__ == "__main__":
107
  import uvicorn
108
+ uvicorn.run(app, host="0.0.0.0", port=8000)