|
import os |
|
import time |
|
|
|
import httpx |
|
import json |
|
from fastapi import FastAPI, Request, HTTPException, Response, Depends |
|
from fastapi.security import APIKeyHeader |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
import logging |
|
from contextlib import asynccontextmanager |
|
import typing |
|
import itertools |
|
|
|
|
|
|
|
|
|
|
|
ALLOWED_API_KEYS_STR = os.getenv("ALLOWED_API_KEYS") |
|
if not ALLOWED_API_KEYS_STR: |
|
raise ValueError("REQUIRED: ALLOWED_API_KEYS environment variable (comma-separated keys for clients) not set.") |
|
ALLOWED_KEYS = set(key.strip() for key in ALLOWED_API_KEYS_STR.split(',') if key.strip()) |
|
if not ALLOWED_KEYS: |
|
raise ValueError("ALLOWED_API_KEYS must contain at least one non-empty key.") |
|
logging.info(f"Loaded {len(ALLOWED_KEYS)} allowed client API keys.") |
|
|
|
|
|
|
|
UPSTREAM_KEYS_URL = os.getenv("UPSTREAM_KEYS_URL") |
|
|
|
|
|
|
|
|
|
DEFAULT_OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE", "https://models.aixplain.com/api/v1") |
|
OPENAI_CHAT_ENDPOINT = f"{OPENAI_API_BASE.rstrip('/')}/chat/completions" |
|
|
|
if not UPSTREAM_KEYS_URL and not DEFAULT_OPENAI_API_KEY: |
|
raise ValueError("REQUIRED: Either UPSTREAM_KEYS_URL or OPENAI_API_KEY environment variable must be set for upstream authentication.") |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
api_key_header_auth = APIKeyHeader(name="Authorization", auto_error=False) |
|
|
|
|
|
|
|
class ChatMessage: |
|
role: str |
|
content: str |
|
|
|
async def verify_api_key(api_key_header: typing.Optional[str] = Depends(api_key_header_auth)): |
|
"""Dependency to verify the client's API key provided to this proxy.""" |
|
if not api_key_header: |
|
logger.warning("Missing Authorization header from client") |
|
raise HTTPException(status_code=401, detail="Missing Authorization header") |
|
parts = api_key_header.split() |
|
if len(parts) != 2 or parts[0].lower() != "bearer": |
|
logger.warning(f"Invalid Authorization header format from client.") |
|
raise HTTPException(status_code=401, detail="Invalid Authorization header format. Use 'Bearer YOUR_KEY'.") |
|
client_api_key = parts[1] |
|
if client_api_key not in ALLOWED_KEYS: |
|
truncated_key = client_api_key[:4] + "..." + client_api_key[-4:] if len(client_api_key) > 8 else client_api_key |
|
logger.warning(f"Invalid Client API Key received: {truncated_key}") |
|
raise HTTPException(status_code=403, detail="Invalid API Key provided") |
|
logger.info(f"Client authenticated successfully (Key ending: ...{client_api_key[-4:]})") |
|
return client_api_key |
|
|
|
|
|
async def fetch_upstream_keys(url: str) -> list[str]: |
|
"""Fetches keys from the given URL, one key per line.""" |
|
keys = [] |
|
try: |
|
async with httpx.AsyncClient(timeout=15.0) as client: |
|
logger.info(f"Fetching upstream API keys from: {url}") |
|
response = await client.get(url) |
|
response.raise_for_status() |
|
content = response.text |
|
keys = [line.strip() for line in content.splitlines() if line.strip()] |
|
logger.info(f"Successfully fetched {len(keys)} upstream API keys.") |
|
if not keys: |
|
logger.warning(f"No valid keys found at {url}. The response was empty or contained only whitespace.") |
|
return keys |
|
except httpx.RequestError as e: |
|
logger.error(f"Error fetching upstream keys from {url}: {e}") |
|
return [] |
|
except httpx.HTTPStatusError as e: |
|
logger.error(f"Error fetching upstream keys from {url}: Status {e.response.status_code}") |
|
logger.error(f"Response body: {e.response.text}") |
|
return [] |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
upstream_keys = [] |
|
if UPSTREAM_KEYS_URL: |
|
upstream_keys = await fetch_upstream_keys(UPSTREAM_KEYS_URL) |
|
|
|
if not upstream_keys: |
|
logger.warning("No upstream keys fetched from URL or URL not provided.") |
|
if DEFAULT_OPENAI_API_KEY: |
|
logger.info("Using fallback OPENAI_API_KEY for upstream authentication.") |
|
upstream_keys = [DEFAULT_OPENAI_API_KEY] |
|
else: |
|
|
|
logger.critical("FATAL: No upstream API keys available (URL fetch failed/empty and no fallback OPENAI_API_KEY). Exiting.") |
|
|
|
|
|
|
|
raise RuntimeError("Failed to load any upstream API keys. Cannot start service.") |
|
|
|
|
|
app.state.upstream_api_keys = upstream_keys |
|
app.state.key_iterator = itertools.cycle(upstream_keys) |
|
logger.info(f"Initialized key rotation with {len(upstream_keys)} keys.") |
|
|
|
|
|
logger.info("Initializing main HTTPX client...") |
|
timeout = httpx.Timeout(5.0, read=180.0, write=5.0, connect=5.0) |
|
client = httpx.AsyncClient(timeout=timeout) |
|
app.state.http_client = client |
|
logger.info("HTTPX client initialized.") |
|
|
|
yield |
|
|
|
|
|
logger.info("Closing HTTPX client...") |
|
await app.state.http_client.aclose() |
|
logger.info("HTTPX client closed.") |
|
app.state.upstream_api_keys = [] |
|
app.state.key_iterator = None |
|
logger.info("Upstream keys cleared.") |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
async def yield_openai_chunks(response_body): |
|
"""Asynchronously yields chunks from the upstream response stream.""" |
|
|
|
logger.info("Starting to stream chunks from upstream...") |
|
try: |
|
resp_json = json.loads(response_body) |
|
|
|
for choices in resp_json["choices"]: |
|
if "message" in choices: |
|
choices["delta"] = choices["message"] |
|
del choices["message"] |
|
|
|
yield "data:" + json.dumps(resp_json) + "\n\n" |
|
yield "data: [DONE]" |
|
except Exception as e: |
|
logger.error(f"Error during streaming upstream response: {e}") |
|
finally: |
|
logger.info("Upstream streaming response closed.") |
|
|
|
def format_messages_for_xplain(messages) -> str: |
|
"""格式化消息列表为DeepSider API所需的提示格式""" |
|
prompt = "" |
|
for msg in messages: |
|
role = msg["role"] |
|
|
|
if role == "system": |
|
|
|
prompt = f"{msg['content']}\n\n" + prompt |
|
elif role == "user": |
|
prompt += f"Human: {msg['content']}\n\n" |
|
elif role == "assistant": |
|
content = msg['content'] |
|
prompt += f"Assistant: {msg['content']}\n\n" |
|
|
|
else: |
|
|
|
prompt += f"Human ({role}): {msg['content']}\n\n" |
|
|
|
|
|
if messages and messages[-1]["role"] != "user": |
|
prompt += "Human: " |
|
|
|
return prompt.strip() |
|
|
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
async def proxy_openai_chat(request: Request, _client_key: str = Depends(verify_api_key)): |
|
""" |
|
Proxies requests to the configured Chat Completions endpoint AFTER verifying client API key. |
|
Uses rotated keys for upstream authentication. |
|
""" |
|
client: httpx.AsyncClient = request.app.state.http_client |
|
key_iterator = request.app.state.key_iterator |
|
|
|
if not client or not key_iterator: |
|
logger.error("HTTPX client or Key Iterator not available (app state issue).") |
|
raise HTTPException(status_code=503, detail="Service temporarily unavailable") |
|
|
|
|
|
try: |
|
current_upstream_key = next(key_iterator) |
|
|
|
|
|
except StopIteration: |
|
|
|
logger.error("Upstream key iterator exhausted unexpectedly.") |
|
raise HTTPException(status_code=500, detail="Internal Server Error: Key rotation failed") |
|
except Exception as e: |
|
logger.error(f"Unexpected error getting next key: {e}") |
|
raise HTTPException(status_code=500, detail="Internal Server Error: Key rotation failed") |
|
|
|
|
|
try: |
|
request_body = await request.body() |
|
payload = json.loads(request_body) |
|
except json.JSONDecodeError: |
|
raise HTTPException(status_code=400, detail="Invalid JSON body") |
|
type = payload.get("type", "claude") |
|
|
|
is_streaming = payload.get("stream", False) |
|
|
|
|
|
upstream_headers = { |
|
"Content-Type": request.headers.get("Content-Type", "application/json"), |
|
"Accept": request.headers.get("Accept", "application/json"), |
|
} |
|
|
|
|
|
|
|
if "freeaichatplayground.com" in OPENAI_API_BASE: |
|
logger.debug("Using payload apiKey for upstream authentication (freeaichatplayground specific).") |
|
payload["apiKey"] = current_upstream_key |
|
else: |
|
|
|
logger.debug("Using Authorization header for upstream authentication.") |
|
upstream_headers["Authorization"] = f"Bearer {current_upstream_key}" |
|
upstream_headers["x-api-key"] = f"{current_upstream_key}" |
|
|
|
if is_streaming and "text/event-stream" not in upstream_headers["Accept"]: |
|
logger.info("Adding 'Accept: text/event-stream' for streaming request") |
|
upstream_headers["Accept"] = "text/event-stream, application/json" |
|
|
|
logger.info(f"Forwarding request to {OPENAI_CHAT_ENDPOINT} (Streaming: {is_streaming})") |
|
|
|
|
|
|
|
|
|
response = None |
|
response_body = None |
|
try: |
|
if type == "claude": |
|
newPayload = { |
|
"text": format_messages_for_xplain(payload["messages"]), |
|
} |
|
req = client.build_request("POST","https://models.aixplain.com/api/v1/execute/" + payload["model"],json=newPayload,headers=upstream_headers) |
|
response = await client.send(req) |
|
xreqId = response.json()["data"] |
|
|
|
req = client.build_request("GET",xreqId) |
|
tryCount = 0 |
|
while tryCount < 100: |
|
response = await client.send(req) |
|
tryCount += 1 |
|
|
|
rj = response.json() |
|
if rj["completed"] == True: |
|
data = {"id": "123456-456789-123456", "object": "chat.completion.chunk", |
|
"choices": [{"delta": {"content": rj["data"]}, "index": 0, "finish_reason": None}]} |
|
response_body = json.dumps(data) |
|
break |
|
time.sleep(1) |
|
else: |
|
req = client.build_request( |
|
"POST", |
|
OPENAI_CHAT_ENDPOINT, |
|
json=payload, |
|
headers=upstream_headers, |
|
) |
|
response = await client.send(req, stream=True) |
|
if response.status_code >= 400: |
|
error_body = await response.aread() |
|
await response.aclose() |
|
logger.error(f"Upstream API returned error: {response.status_code} Key ending: ...{current_upstream_key[-4:]} Body: {error_body.decode()}") |
|
try: detail = json.loads(error_body) |
|
except json.JSONDecodeError: detail = error_body.decode() |
|
raise HTTPException(status_code=response.status_code, detail=detail) |
|
response_body = await response.aread() |
|
response_body = response_body.decode() |
|
|
|
|
|
if is_streaming: |
|
logger.info(f"Received OK streaming response from upstream (Status: {response.status_code}). Piping to client.") |
|
status_code = response.status_code |
|
if status_code == 201: |
|
status_code = 200 |
|
return StreamingResponse( |
|
yield_openai_chunks(response_body), |
|
status_code=status_code, |
|
media_type=response.headers.get("content-type", "text/event-stream"), |
|
) |
|
|
|
else: |
|
logger.info(f"Received OK non-streaming response from upstream (Status: {response.status_code}). Reading full body.") |
|
response_body = await response.aread() |
|
await response.aclose() |
|
content_type = response.headers.get("content-type", "application/json") |
|
return Response( |
|
content=response_body, |
|
status_code=response.status_code, |
|
media_type=content_type, |
|
) |
|
|
|
except httpx.TimeoutException as e: |
|
logger.error(f"Request to upstream timed out: {e}") |
|
if response: await response.aclose() |
|
raise HTTPException(status_code=504, detail="Request to upstream API timed out.") |
|
except httpx.RequestError as e: |
|
logger.error(f"Error requesting upstream API: {e}") |
|
if response: await response.aclose() |
|
raise HTTPException(status_code=502, detail=f"Error contacting upstream API: {e}") |
|
except HTTPException as e: |
|
|
|
if response and not response.is_closed: await response.aclose() |
|
raise e |
|
except Exception as e: |
|
logger.exception("An unexpected error occurred during response processing.") |
|
if response and not response.is_closed: await response.aclose() |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Simple health check endpoint.""" |
|
|
|
key_count = len(app.state.upstream_api_keys) if hasattr(app.state, 'upstream_api_keys') else 0 |
|
return {"status": "ok", "upstream_keys_loaded": key_count > 0, "key_count": key_count} |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
print("--- Starting FastAPI OpenAI Proxy with Custom Auth & Key Rotation ---") |
|
print(f"Proxying requests to: {OPENAI_CHAT_ENDPOINT}") |
|
if UPSTREAM_KEYS_URL: |
|
print(f"Fetching upstream keys from: {UPSTREAM_KEYS_URL}") |
|
elif DEFAULT_OPENAI_API_KEY: |
|
print("Using single OPENAI_API_KEY for upstream.") |
|
else: |
|
print("ERROR: No upstream key source configured!") |
|
print(f"Clients must provide a valid API key in 'Authorization: Bearer <key>' header.") |
|
print(f"Number of allowed client keys configured: {len(ALLOWED_KEYS)}") |
|
print("---") |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|