import os import time import httpx import json from fastapi import FastAPI, Request, HTTPException, Response, Depends from fastapi.security import APIKeyHeader from fastapi.responses import StreamingResponse, JSONResponse import logging from contextlib import asynccontextmanager import typing import itertools # For key rotation # --- Configuration --- # --- Client Authentication (Proxy Access) --- # Load Allowed Client API Keys (for clients talking to this proxy) ALLOWED_API_KEYS_STR = os.getenv("ALLOWED_API_KEYS") if not ALLOWED_API_KEYS_STR: raise ValueError("REQUIRED: ALLOWED_API_KEYS environment variable (comma-separated keys for clients) not set.") ALLOWED_KEYS = set(key.strip() for key in ALLOWED_API_KEYS_STR.split(',') if key.strip()) if not ALLOWED_KEYS: raise ValueError("ALLOWED_API_KEYS must contain at least one non-empty key.") logging.info(f"Loaded {len(ALLOWED_KEYS)} allowed client API keys.") # --- Upstream API Configuration --- # URL to fetch upstream API keys from (one key per line) UPSTREAM_KEYS_URL = os.getenv("UPSTREAM_KEYS_URL") # Optional: A single fallback/default upstream key (used if URL fetch fails or isn't provided) # Or required if the upstream target needs a key in a different way sometimes. # Let's make it optional now. DEFAULT_OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Upstream API Base URL OPENAI_API_BASE = os.getenv("OPENAI_API_BASE", "https://models.aixplain.com/api/v1") OPENAI_CHAT_ENDPOINT = f"{OPENAI_API_BASE.rstrip('/')}/chat/completions" if not UPSTREAM_KEYS_URL and not DEFAULT_OPENAI_API_KEY: raise ValueError("REQUIRED: Either UPSTREAM_KEYS_URL or OPENAI_API_KEY environment variable must be set for upstream authentication.") # --- Logging --- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- Authentication Dependency (Client -> Proxy) --- api_key_header_auth = APIKeyHeader(name="Authorization", auto_error=False) # OpenAI API请求模型 class ChatMessage: role: str content: str async def verify_api_key(api_key_header: typing.Optional[str] = Depends(api_key_header_auth)): """Dependency to verify the client's API key provided to this proxy.""" if not api_key_header: logger.warning("Missing Authorization header from client") raise HTTPException(status_code=401, detail="Missing Authorization header") parts = api_key_header.split() if len(parts) != 2 or parts[0].lower() != "bearer": logger.warning(f"Invalid Authorization header format from client.") raise HTTPException(status_code=401, detail="Invalid Authorization header format. Use 'Bearer YOUR_KEY'.") client_api_key = parts[1] if client_api_key not in ALLOWED_KEYS: truncated_key = client_api_key[:4] + "..." + client_api_key[-4:] if len(client_api_key) > 8 else client_api_key logger.warning(f"Invalid Client API Key received: {truncated_key}") raise HTTPException(status_code=403, detail="Invalid API Key provided") logger.info(f"Client authenticated successfully (Key ending: ...{client_api_key[-4:]})") return client_api_key # --- Key Fetching and Rotation Logic --- async def fetch_upstream_keys(url: str) -> list[str]: """Fetches keys from the given URL, one key per line.""" keys = [] try: async with httpx.AsyncClient(timeout=15.0) as client: # Use a temporary client logger.info(f"Fetching upstream API keys from: {url}") response = await client.get(url) response.raise_for_status() # Raise exception for 4xx/5xx status codes content = response.text keys = [line.strip() for line in content.splitlines() if line.strip()] logger.info(f"Successfully fetched {len(keys)} upstream API keys.") if not keys: logger.warning(f"No valid keys found at {url}. The response was empty or contained only whitespace.") return keys except httpx.RequestError as e: logger.error(f"Error fetching upstream keys from {url}: {e}") return [] # Return empty list on fetch error except httpx.HTTPStatusError as e: logger.error(f"Error fetching upstream keys from {url}: Status {e.response.status_code}") logger.error(f"Response body: {e.response.text}") return [] # Return empty list on bad status # --- HTTP Client and Key Iterator Management (Lifespan) --- @asynccontextmanager async def lifespan(app: FastAPI): # --- Initialize Upstream Key Iterator --- upstream_keys = [] if UPSTREAM_KEYS_URL: upstream_keys = await fetch_upstream_keys(UPSTREAM_KEYS_URL) if not upstream_keys: logger.warning("No upstream keys fetched from URL or URL not provided.") if DEFAULT_OPENAI_API_KEY: logger.info("Using fallback OPENAI_API_KEY for upstream authentication.") upstream_keys = [DEFAULT_OPENAI_API_KEY] else: # Critical failure - no keys available logger.critical("FATAL: No upstream API keys available (URL fetch failed/empty and no fallback OPENAI_API_KEY). Exiting.") # In a real scenario, you might want a more graceful shutdown or retry mechanism # For simplicity here, we'll let it proceed but log critically. The requests will likely fail later. # Or raise an exception here to prevent startup: raise RuntimeError("Failed to load any upstream API keys. Cannot start service.") # Store keys and create the cycling iterator in app.state app.state.upstream_api_keys = upstream_keys app.state.key_iterator = itertools.cycle(upstream_keys) logger.info(f"Initialized key rotation with {len(upstream_keys)} keys.") # --- Initialize HTTPX Client --- logger.info("Initializing main HTTPX client...") timeout = httpx.Timeout(5.0, read=180.0, write=5.0, connect=5.0) client = httpx.AsyncClient(timeout=timeout) # No base_url needed if using full URLs app.state.http_client = client # Store client in app.state logger.info("HTTPX client initialized.") yield # Application runs here # --- Cleanup --- logger.info("Closing HTTPX client...") await app.state.http_client.aclose() logger.info("HTTPX client closed.") app.state.upstream_api_keys = [] # Clear keys app.state.key_iterator = None logger.info("Upstream keys cleared.") # --- FastAPI App --- app = FastAPI(lifespan=lifespan) # --- Streaming Helper --- async def yield_openai_chunks(response_body): """Asynchronously yields chunks from the upstream response stream.""" # (Content remains the same as before) logger.info("Starting to stream chunks from upstream...") try: resp_json = json.loads(response_body) for choices in resp_json["choices"]: if "message" in choices: choices["delta"] = choices["message"] del choices["message"] yield "data:" + json.dumps(resp_json) + "\n\n" yield "data: [DONE]" except Exception as e: logger.error(f"Error during streaming upstream response: {e}") finally: logger.info("Upstream streaming response closed.") def format_messages_for_xplain(messages) -> str: """格式化消息列表为DeepSider API所需的提示格式""" prompt = "" for msg in messages: role = msg["role"] # 将OpenAI的角色映射到DeepSider能理解的格式 if role == "system": # 系统消息放在开头 作为指导 prompt = f"{msg['content']}\n\n" + prompt elif role == "user": prompt += f"Human: {msg['content']}\n\n" elif role == "assistant": content = msg['content'] prompt += f"Assistant: {msg['content']}\n\n" else: # 其他角色按用户处理 prompt += f"Human ({role}): {msg['content']}\n\n" # 如果最后一个消息不是用户的 添加一个Human前缀引导模型回答 if messages and messages[-1]["role"] != "user": prompt += "Human: " return prompt.strip() # --- Proxy Endpoint --- @app.post("/v1/chat/completions") async def proxy_openai_chat(request: Request, _client_key: str = Depends(verify_api_key)): # Use Depends for auth """ Proxies requests to the configured Chat Completions endpoint AFTER verifying client API key. Uses rotated keys for upstream authentication. """ client: httpx.AsyncClient = request.app.state.http_client key_iterator = request.app.state.key_iterator if not client or not key_iterator: logger.error("HTTPX client or Key Iterator not available (app state issue).") raise HTTPException(status_code=503, detail="Service temporarily unavailable") # --- Get Next Upstream API Key --- try: current_upstream_key = next(key_iterator) # Log rotation (optional, consider security of logging key info) # logger.info(f"Using upstream key ending: ...{current_upstream_key[-4:]}") except StopIteration: # This should not happen if lifespan logic is correct and keys were loaded logger.error("Upstream key iterator exhausted unexpectedly.") raise HTTPException(status_code=500, detail="Internal Server Error: Key rotation failed") except Exception as e: logger.error(f"Unexpected error getting next key: {e}") raise HTTPException(status_code=500, detail="Internal Server Error: Key rotation failed") # --- Get Request Data --- try: request_body = await request.body() payload = json.loads(request_body) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid JSON body") type = payload.get("type", "claude") is_streaming = payload.get("stream", False) # --- Prepare Upstream Request --- upstream_headers = { "Content-Type": request.headers.get("Content-Type", "application/json"), "Accept": request.headers.get("Accept", "application/json"), } # --- Upstream Authentication (Using Rotated Key) --- # Decide based on the target API (e.g., freeaichatplayground vs standard OpenAI) if "freeaichatplayground.com" in OPENAI_API_BASE: logger.debug("Using payload apiKey for upstream authentication (freeaichatplayground specific).") payload["apiKey"] = current_upstream_key # Inject ROTATED key into payload else: # Default to standard Bearer token authentication for upstream logger.debug("Using Authorization header for upstream authentication.") upstream_headers["Authorization"] = f"Bearer {current_upstream_key}" # Use ROTATED key upstream_headers["x-api-key"] = f"{current_upstream_key}" # Use ROTATED key if is_streaming and "text/event-stream" not in upstream_headers["Accept"]: logger.info("Adding 'Accept: text/event-stream' for streaming request") upstream_headers["Accept"] = "text/event-stream, application/json" logger.info(f"Forwarding request to {OPENAI_CHAT_ENDPOINT} (Streaming: {is_streaming})") # --- Make Request to Upstream --- response = None # Define response here to ensure it's available in finally block response_body = None try: if type == "claude": newPayload = { "text": format_messages_for_xplain(payload["messages"]), } req = client.build_request("POST","https://models.aixplain.com/api/v1/execute/" + payload["model"],json=newPayload,headers=upstream_headers) response = await client.send(req) xreqId = response.json()["data"] req = client.build_request("GET",xreqId) tryCount = 0 while tryCount < 100: response = await client.send(req) tryCount += 1 rj = response.json() if rj["completed"] == True: data = {"id": "123456-456789-123456", "object": "chat.completion.chunk", "choices": [{"delta": {"content": rj["data"]}, "index": 0, "finish_reason": None}]} response_body = json.dumps(data) break time.sleep(1) else: req = client.build_request( "POST", OPENAI_CHAT_ENDPOINT, # Use the full URL json=payload, headers=upstream_headers, ) response = await client.send(req, stream=True) if response.status_code >= 400: error_body = await response.aread() # Read error fully await response.aclose() logger.error(f"Upstream API returned error: {response.status_code} Key ending: ...{current_upstream_key[-4:]} Body: {error_body.decode()}") try: detail = json.loads(error_body) except json.JSONDecodeError: detail = error_body.decode() raise HTTPException(status_code=response.status_code, detail=detail) response_body = await response.aread() response_body = response_body.decode() # --- Handle Streaming Response --- if is_streaming: logger.info(f"Received OK streaming response from upstream (Status: {response.status_code}). Piping to client.") status_code = response.status_code if status_code == 201: status_code = 200 return StreamingResponse( yield_openai_chunks(response_body), # Generator handles closing response status_code=status_code, media_type=response.headers.get("content-type", "text/event-stream"), ) # --- Handle Non-Streaming Response --- else: logger.info(f"Received OK non-streaming response from upstream (Status: {response.status_code}). Reading full body.") response_body = await response.aread() await response.aclose() # Ensure closed content_type = response.headers.get("content-type", "application/json") return Response( # Return raw response, FastAPI handles JSON content type content=response_body, status_code=response.status_code, media_type=content_type, ) except httpx.TimeoutException as e: logger.error(f"Request to upstream timed out: {e}") if response: await response.aclose() raise HTTPException(status_code=504, detail="Request to upstream API timed out.") except httpx.RequestError as e: logger.error(f"Error requesting upstream API: {e}") if response: await response.aclose() raise HTTPException(status_code=502, detail=f"Error contacting upstream API: {e}") except HTTPException as e: # Re-raise FastAPI HTTPExceptions (like the 4xx check above) if response and not response.is_closed: await response.aclose() raise e except Exception as e: logger.exception("An unexpected error occurred during response processing.") if response and not response.is_closed: await response.aclose() raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") # --- Health Check Endpoint --- @app.get("/health") async def health_check(): """Simple health check endpoint.""" # Could add checks here, e.g., if keys were loaded key_count = len(app.state.upstream_api_keys) if hasattr(app.state, 'upstream_api_keys') else 0 return {"status": "ok", "upstream_keys_loaded": key_count > 0, "key_count": key_count} # --- Main Execution Guard --- if __name__ == "__main__": import uvicorn # Startup checks are implicitly handled by config loading at the top print("--- Starting FastAPI OpenAI Proxy with Custom Auth & Key Rotation ---") print(f"Proxying requests to: {OPENAI_CHAT_ENDPOINT}") if UPSTREAM_KEYS_URL: print(f"Fetching upstream keys from: {UPSTREAM_KEYS_URL}") elif DEFAULT_OPENAI_API_KEY: print("Using single OPENAI_API_KEY for upstream.") else: print("ERROR: No upstream key source configured!") # Should have failed earlier print(f"Clients must provide a valid API key in 'Authorization: Bearer ' header.") print(f"Number of allowed client keys configured: {len(ALLOWED_KEYS)}") print("---") uvicorn.run(app, host="0.0.0.0", port=7860)