lokiai / main.py
ParthSadaria's picture
Update main.py
2c1c62a verified
raw
history blame
3.76 kB
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from pydantic import BaseModel
import httpx
load_dotenv()
app = FastAPI()
# Get API keys and secret endpoint from environment variables
api_keys_str = os.getenv('API_KEYS')
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
# Validate if the main secret API endpoint is set
if not secret_api_endpoint or not secret_api_endpoint_2:
raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
# Define models that should use the secondary endpoint
alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
class Payload(BaseModel):
model: str
messages: list
stream: bool
@app.get("/", response_class=HTMLResponse)
async def root():
# Open and read the content of index.html (in the same folder as the app)
file_path = "index.html"
try:
with open(file_path, "r") as file:
html_content = file.read()
return HTMLResponse(content=html_content)
except FileNotFoundError:
return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
async def get_models():
async with httpx.AsyncClient() as client:
try:
response = await client.get(f"{secret_api_endpoint}/v1/models", timeout=3)
response.raise_for_status()
return response.json()
except httpx.RequestError as e:
raise HTTPException(status_code=500, detail=f"Request failed: {e}")
@app.get("/models")
async def fetch_models():
return await get_models()
@app.post("/chat/completions")
async def get_completion(payload: Payload, request: Request):
api_key = request.headers.get("Authorization")
print(f"API Key Used: {api_key}")
print(f"Raw Payload: {payload.dict()}")
# Validate API key
if api_key not in valid_api_keys:
raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
# Determine which endpoint to use based on the model
endpoint = secret_api_endpoint_2 if payload.model in alternate_models else secret_api_endpoint
# Use the payload directly as it includes stream and other user data
payload_dict = payload.dict()
async def stream_generator(payload_dict):
async with httpx.AsyncClient() as client:
try:
async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line:
yield f"{line}\n"
except httpx.HTTPStatusError as status_err:
raise HTTPException(status_code=status_err.response.status_code, detail=f"HTTP error: {status_err}")
except httpx.RequestError as req_err:
raise HTTPException(status_code=500, detail=f"Streaming failed: {req_err}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
@app.on_event("startup")
async def startup_event():
print("API endpoints:")
print("GET /")
print("GET /models")
print("POST /v1/chat/completions")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)