|
import os
|
|
import httpx
|
|
import json
|
|
from fastapi import FastAPI, Request, HTTPException, Response, Depends
|
|
from fastapi.security import APIKeyHeader
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
import typing
|
|
import itertools
|
|
import asyncio
|
|
|
|
|
|
|
|
|
|
|
|
ALLOWED_API_KEYS_STR = os.getenv("ALLOWED_API_KEYS")
|
|
if not ALLOWED_API_KEYS_STR:
|
|
raise ValueError("REQUIRED: ALLOWED_API_KEYS environment variable (comma-separated keys for clients) not set.")
|
|
ALLOWED_KEYS = set(key.strip() for key in ALLOWED_API_KEYS_STR.split(',') if key.strip())
|
|
if not ALLOWED_KEYS:
|
|
raise ValueError("ALLOWED_API_KEYS must contain at least one non-empty key.")
|
|
logging.info(f"Loaded {len(ALLOWED_KEYS)} allowed client API keys.")
|
|
|
|
|
|
|
|
UPSTREAM_KEYS_URL = os.getenv("UPSTREAM_KEYS_URL")
|
|
|
|
|
|
|
|
|
|
DEFAULT_OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
|
|
|
|
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE", "https://models.aixplain.com/api/v1")
|
|
OPENAI_CHAT_ENDPOINT = f"{OPENAI_API_BASE.rstrip('/')}/chat/completions"
|
|
|
|
if not UPSTREAM_KEYS_URL and not DEFAULT_OPENAI_API_KEY:
|
|
raise ValueError("REQUIRED: Either UPSTREAM_KEYS_URL or OPENAI_API_KEY environment variable must be set for upstream authentication.")
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
api_key_header_auth = APIKeyHeader(name="Authorization", auto_error=False)
|
|
async def verify_api_key(api_key_header: typing.Optional[str] = Depends(api_key_header_auth)):
|
|
"""Dependency to verify the client's API key provided to this proxy."""
|
|
if not api_key_header:
|
|
logger.warning("Missing Authorization header from client")
|
|
raise HTTPException(status_code=401, detail="Missing Authorization header")
|
|
parts = api_key_header.split()
|
|
if len(parts) != 2 or parts[0].lower() != "bearer":
|
|
logger.warning(f"Invalid Authorization header format from client.")
|
|
raise HTTPException(status_code=401, detail="Invalid Authorization header format. Use 'Bearer YOUR_KEY'.")
|
|
client_api_key = parts[1]
|
|
if client_api_key not in ALLOWED_KEYS:
|
|
truncated_key = client_api_key[:4] + "..." + client_api_key[-4:] if len(client_api_key) > 8 else client_api_key
|
|
logger.warning(f"Invalid Client API Key received: {truncated_key}")
|
|
raise HTTPException(status_code=403, detail="Invalid API Key provided")
|
|
logger.info(f"Client authenticated successfully (Key ending: ...{client_api_key[-4:]})")
|
|
return client_api_key
|
|
|
|
|
|
async def fetch_upstream_keys(url: str) -> list[str]:
|
|
"""Fetches keys from the given URL, one key per line."""
|
|
keys = []
|
|
try:
|
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
logger.info(f"Fetching upstream API keys from: {url}")
|
|
response = await client.get(url)
|
|
response.raise_for_status()
|
|
content = response.text
|
|
keys = [line.strip() for line in content.splitlines() if line.strip()]
|
|
logger.info(f"Successfully fetched {len(keys)} upstream API keys.")
|
|
if not keys:
|
|
logger.warning(f"No valid keys found at {url}. The response was empty or contained only whitespace.")
|
|
return keys
|
|
except httpx.RequestError as e:
|
|
logger.error(f"Error fetching upstream keys from {url}: {e}")
|
|
return []
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"Error fetching upstream keys from {url}: Status {e.response.status_code}")
|
|
logger.error(f"Response body: {e.response.text}")
|
|
return []
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
|
|
upstream_keys = []
|
|
if UPSTREAM_KEYS_URL:
|
|
upstream_keys = await fetch_upstream_keys(UPSTREAM_KEYS_URL)
|
|
|
|
if not upstream_keys:
|
|
logger.warning("No upstream keys fetched from URL or URL not provided.")
|
|
if DEFAULT_OPENAI_API_KEY:
|
|
logger.info("Using fallback OPENAI_API_KEY for upstream authentication.")
|
|
upstream_keys = [DEFAULT_OPENAI_API_KEY]
|
|
else:
|
|
|
|
logger.critical("FATAL: No upstream API keys available (URL fetch failed/empty and no fallback OPENAI_API_KEY). Exiting.")
|
|
|
|
|
|
|
|
raise RuntimeError("Failed to load any upstream API keys. Cannot start service.")
|
|
|
|
|
|
app.state.upstream_api_keys = upstream_keys
|
|
app.state.key_iterator = itertools.cycle(upstream_keys)
|
|
logger.info(f"Initialized key rotation with {len(upstream_keys)} keys.")
|
|
|
|
|
|
logger.info("Initializing main HTTPX client...")
|
|
timeout = httpx.Timeout(5.0, read=180.0, write=5.0, connect=5.0)
|
|
client = httpx.AsyncClient(timeout=timeout)
|
|
app.state.http_client = client
|
|
logger.info("HTTPX client initialized.")
|
|
|
|
yield
|
|
|
|
|
|
logger.info("Closing HTTPX client...")
|
|
await app.state.http_client.aclose()
|
|
logger.info("HTTPX client closed.")
|
|
app.state.upstream_api_keys = []
|
|
app.state.key_iterator = None
|
|
logger.info("Upstream keys cleared.")
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
async def yield_openai_chunks(response_body):
|
|
"""Asynchronously yields chunks from the upstream response stream."""
|
|
|
|
logger.info("Starting to stream chunks from upstream...")
|
|
try:
|
|
resp_json = json.loads(response_body.decode())
|
|
|
|
for choices in resp_json["choices"]:
|
|
choices["delta"] = choices["message"]
|
|
del choices["message"]
|
|
|
|
yield "data:" + json.dumps(resp_json) + "\n\n"
|
|
yield "data: [DONE]"
|
|
except Exception as e:
|
|
logger.error(f"Error during streaming upstream response: {e}")
|
|
finally:
|
|
logger.info("Upstream streaming response closed.")
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def proxy_openai_chat(request: Request, _client_key: str = Depends(verify_api_key)):
|
|
"""
|
|
Proxies requests to the configured Chat Completions endpoint AFTER verifying client API key.
|
|
Uses rotated keys for upstream authentication.
|
|
"""
|
|
client: httpx.AsyncClient = request.app.state.http_client
|
|
key_iterator = request.app.state.key_iterator
|
|
|
|
if not client or not key_iterator:
|
|
logger.error("HTTPX client or Key Iterator not available (app state issue).")
|
|
raise HTTPException(status_code=503, detail="Service temporarily unavailable")
|
|
|
|
|
|
try:
|
|
current_upstream_key = next(key_iterator)
|
|
|
|
|
|
except StopIteration:
|
|
|
|
logger.error("Upstream key iterator exhausted unexpectedly.")
|
|
raise HTTPException(status_code=500, detail="Internal Server Error: Key rotation failed")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error getting next key: {e}")
|
|
raise HTTPException(status_code=500, detail="Internal Server Error: Key rotation failed")
|
|
|
|
|
|
try:
|
|
request_body = await request.body()
|
|
payload = json.loads(request_body)
|
|
except json.JSONDecodeError:
|
|
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
|
|
|
is_streaming = payload.get("stream", False)
|
|
|
|
|
|
upstream_headers = {
|
|
"Content-Type": request.headers.get("Content-Type", "application/json"),
|
|
"Accept": request.headers.get("Accept", "application/json"),
|
|
}
|
|
|
|
|
|
|
|
if "freeaichatplayground.com" in OPENAI_API_BASE:
|
|
logger.debug("Using payload apiKey for upstream authentication (freeaichatplayground specific).")
|
|
payload["apiKey"] = current_upstream_key
|
|
else:
|
|
|
|
logger.debug("Using Authorization header for upstream authentication.")
|
|
upstream_headers["Authorization"] = f"Bearer {current_upstream_key}"
|
|
|
|
if is_streaming and "text/event-stream" not in upstream_headers["Accept"]:
|
|
logger.info("Adding 'Accept: text/event-stream' for streaming request")
|
|
upstream_headers["Accept"] = "text/event-stream, application/json"
|
|
|
|
logger.info(f"Forwarding request to {OPENAI_CHAT_ENDPOINT} (Streaming: {is_streaming})")
|
|
|
|
|
|
response = None
|
|
try:
|
|
req = client.build_request(
|
|
"POST",
|
|
OPENAI_CHAT_ENDPOINT,
|
|
json=payload,
|
|
headers=upstream_headers,
|
|
)
|
|
response = await client.send(req, stream=True)
|
|
|
|
|
|
if response.status_code >= 400:
|
|
error_body = await response.aread()
|
|
await response.aclose()
|
|
logger.error(f"Upstream API returned error: {response.status_code} Key ending: ...{current_upstream_key[-4:]} Body: {error_body.decode()}")
|
|
try: detail = json.loads(error_body)
|
|
except json.JSONDecodeError: detail = error_body.decode()
|
|
raise HTTPException(status_code=response.status_code, detail=detail)
|
|
response_body = await response.aread()
|
|
|
|
|
|
if is_streaming:
|
|
logger.info(f"Received OK streaming response from upstream (Status: {response.status_code}). Piping to client.")
|
|
status_code = response.status_code
|
|
if status_code == 201:
|
|
status_code = 200
|
|
return StreamingResponse(
|
|
yield_openai_chunks(response_body),
|
|
status_code=status_code,
|
|
media_type=response.headers.get("content-type", "text/event-stream"),
|
|
)
|
|
|
|
else:
|
|
logger.info(f"Received OK non-streaming response from upstream (Status: {response.status_code}). Reading full body.")
|
|
response_body = await response.aread()
|
|
await response.aclose()
|
|
content_type = response.headers.get("content-type", "application/json")
|
|
return Response(
|
|
content=response_body,
|
|
status_code=response.status_code,
|
|
media_type=content_type,
|
|
)
|
|
|
|
except httpx.TimeoutException as e:
|
|
logger.error(f"Request to upstream timed out: {e}")
|
|
if response: await response.aclose()
|
|
raise HTTPException(status_code=504, detail="Request to upstream API timed out.")
|
|
except httpx.RequestError as e:
|
|
logger.error(f"Error requesting upstream API: {e}")
|
|
if response: await response.aclose()
|
|
raise HTTPException(status_code=502, detail=f"Error contacting upstream API: {e}")
|
|
except HTTPException as e:
|
|
|
|
if response and not response.is_closed: await response.aclose()
|
|
raise e
|
|
except Exception as e:
|
|
logger.exception("An unexpected error occurred during response processing.")
|
|
if response and not response.is_closed: await response.aclose()
|
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
|
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Simple health check endpoint."""
|
|
|
|
key_count = len(app.state.upstream_api_keys) if hasattr(app.state, 'upstream_api_keys') else 0
|
|
return {"status": "ok", "upstream_keys_loaded": key_count > 0, "key_count": key_count}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
print("--- Starting FastAPI OpenAI Proxy with Custom Auth & Key Rotation ---")
|
|
print(f"Proxying requests to: {OPENAI_CHAT_ENDPOINT}")
|
|
if UPSTREAM_KEYS_URL:
|
|
print(f"Fetching upstream keys from: {UPSTREAM_KEYS_URL}")
|
|
elif DEFAULT_OPENAI_API_KEY:
|
|
print("Using single OPENAI_API_KEY for upstream.")
|
|
else:
|
|
print("ERROR: No upstream key source configured!")
|
|
print(f"Clients must provide a valid API key in 'Authorization: Bearer <key>' header.")
|
|
print(f"Number of allowed client keys configured: {len(ALLOWED_KEYS)}")
|
|
print("---")
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
|