Spaces:
Running
Running
import os | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.responses import StreamingResponse, HTMLResponse | |
from pydantic import BaseModel | |
import httpx | |
load_dotenv() | |
app = FastAPI() | |
# Get API keys and secret endpoint from environment variables | |
api_keys_str = os.getenv('API_KEYS') | |
valid_api_keys = api_keys_str.split(',') if api_keys_str else [] | |
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT') | |
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2') | |
# Validate if the main secret API endpoint is set | |
if not secret_api_endpoint or not secret_api_endpoint_2: | |
raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.") | |
# Define models that should use the secondary endpoint | |
alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"} | |
class Payload(BaseModel): | |
model: str | |
messages: list | |
stream: bool | |
async def root(): | |
# Open and read the content of index.html (in the same folder as the app) | |
file_path = "index.html" | |
try: | |
with open(file_path, "r") as file: | |
html_content = file.read() | |
return HTMLResponse(content=html_content) | |
except FileNotFoundError: | |
return HTMLResponse(content="<h1>File not found</h1>", status_code=404) | |
async def get_models(): | |
async with httpx.AsyncClient() as client: | |
try: | |
response = await client.get(f"{secret_api_endpoint}/v1/models", timeout=3) | |
response.raise_for_status() | |
return response.json() | |
except httpx.RequestError as e: | |
raise HTTPException(status_code=500, detail=f"Request failed: {e}") | |
async def fetch_models(): | |
return await get_models() | |
async def get_completion(payload: Payload, request: Request): | |
api_key = request.headers.get("Authorization") | |
print(f"API Key Used: {api_key}") | |
print(f"Raw Payload: {payload.dict()}") | |
# Validate API key | |
if api_key not in valid_api_keys: | |
raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)") | |
# Determine which endpoint to use based on the model | |
endpoint = secret_api_endpoint_2 if payload.model in alternate_models else secret_api_endpoint | |
# Use the payload directly as it includes stream and other user data | |
payload_dict = payload.dict() | |
async def stream_generator(payload_dict): | |
async with httpx.AsyncClient() as client: | |
try: | |
async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response: | |
response.raise_for_status() | |
async for line in response.aiter_lines(): | |
if line: | |
yield f"{line}\n" | |
except httpx.HTTPStatusError as status_err: | |
raise HTTPException(status_code=status_err.response.status_code, detail=f"HTTP error: {status_err}") | |
except httpx.RequestError as req_err: | |
raise HTTPException(status_code=500, detail=f"Streaming failed: {req_err}") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}") | |
return StreamingResponse(stream_generator(payload_dict), media_type="application/json") | |
async def startup_event(): | |
print("API endpoints:") | |
print("GET /") | |
print("GET /models") | |
print("POST /v1/chat/completions") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |