File size: 16,428 Bytes
e7b1094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import os
import time

import httpx
import json
from fastapi import FastAPI, Request, HTTPException, Response, Depends
from fastapi.security import APIKeyHeader
from fastapi.responses import StreamingResponse, JSONResponse
import logging
from contextlib import asynccontextmanager
import typing
import itertools # For key rotation

# --- Configuration ---

# --- Client Authentication (Proxy Access) ---
# Load Allowed Client API Keys (for clients talking to this proxy)
ALLOWED_API_KEYS_STR = os.getenv("ALLOWED_API_KEYS")
if not ALLOWED_API_KEYS_STR:
    raise ValueError("REQUIRED: ALLOWED_API_KEYS environment variable (comma-separated keys for clients) not set.")
ALLOWED_KEYS = set(key.strip() for key in ALLOWED_API_KEYS_STR.split(',') if key.strip())
if not ALLOWED_KEYS:
     raise ValueError("ALLOWED_API_KEYS must contain at least one non-empty key.")
logging.info(f"Loaded {len(ALLOWED_KEYS)} allowed client API keys.")

# --- Upstream API Configuration ---
# URL to fetch upstream API keys from (one key per line)
UPSTREAM_KEYS_URL = os.getenv("UPSTREAM_KEYS_URL")

# Optional: A single fallback/default upstream key (used if URL fetch fails or isn't provided)
# Or required if the upstream target needs a key in a different way sometimes.
# Let's make it optional now.
DEFAULT_OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# Upstream API Base URL
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE", "https://models.aixplain.com/api/v1")
OPENAI_CHAT_ENDPOINT = f"{OPENAI_API_BASE.rstrip('/')}/chat/completions"

if not UPSTREAM_KEYS_URL and not DEFAULT_OPENAI_API_KEY:
     raise ValueError("REQUIRED: Either UPSTREAM_KEYS_URL or OPENAI_API_KEY environment variable must be set for upstream authentication.")

# --- Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Authentication Dependency (Client -> Proxy) ---
api_key_header_auth = APIKeyHeader(name="Authorization", auto_error=False)


# OpenAI API请求模型
class ChatMessage:
    role: str
    content: str

async def verify_api_key(api_key_header: typing.Optional[str] = Depends(api_key_header_auth)):
    """Dependency to verify the client's API key provided to this proxy."""
    if not api_key_header:
        logger.warning("Missing Authorization header from client")
        raise HTTPException(status_code=401, detail="Missing Authorization header")
    parts = api_key_header.split()
    if len(parts) != 2 or parts[0].lower() != "bearer":
        logger.warning(f"Invalid Authorization header format from client.")
        raise HTTPException(status_code=401, detail="Invalid Authorization header format. Use 'Bearer YOUR_KEY'.")
    client_api_key = parts[1]
    if client_api_key not in ALLOWED_KEYS:
        truncated_key = client_api_key[:4] + "..." + client_api_key[-4:] if len(client_api_key) > 8 else client_api_key
        logger.warning(f"Invalid Client API Key received: {truncated_key}")
        raise HTTPException(status_code=403, detail="Invalid API Key provided")
    logger.info(f"Client authenticated successfully (Key ending: ...{client_api_key[-4:]})")
    return client_api_key

# --- Key Fetching and Rotation Logic ---
async def fetch_upstream_keys(url: str) -> list[str]:
    """Fetches keys from the given URL, one key per line."""
    keys = []
    try:
        async with httpx.AsyncClient(timeout=15.0) as client: # Use a temporary client
            logger.info(f"Fetching upstream API keys from: {url}")
            response = await client.get(url)
            response.raise_for_status() # Raise exception for 4xx/5xx status codes
            content = response.text
            keys = [line.strip() for line in content.splitlines() if line.strip()]
            logger.info(f"Successfully fetched {len(keys)} upstream API keys.")
            if not keys:
                logger.warning(f"No valid keys found at {url}. The response was empty or contained only whitespace.")
            return keys
    except httpx.RequestError as e:
        logger.error(f"Error fetching upstream keys from {url}: {e}")
        return [] # Return empty list on fetch error
    except httpx.HTTPStatusError as e:
        logger.error(f"Error fetching upstream keys from {url}: Status {e.response.status_code}")
        logger.error(f"Response body: {e.response.text}")
        return [] # Return empty list on bad status

# --- HTTP Client and Key Iterator Management (Lifespan) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    # --- Initialize Upstream Key Iterator ---
    upstream_keys = []
    if UPSTREAM_KEYS_URL:
        upstream_keys = await fetch_upstream_keys(UPSTREAM_KEYS_URL)

    if not upstream_keys:
        logger.warning("No upstream keys fetched from URL or URL not provided.")
        if DEFAULT_OPENAI_API_KEY:
            logger.info("Using fallback OPENAI_API_KEY for upstream authentication.")
            upstream_keys = [DEFAULT_OPENAI_API_KEY]
        else:
            # Critical failure - no keys available
            logger.critical("FATAL: No upstream API keys available (URL fetch failed/empty and no fallback OPENAI_API_KEY). Exiting.")
            # In a real scenario, you might want a more graceful shutdown or retry mechanism
            # For simplicity here, we'll let it proceed but log critically. The requests will likely fail later.
            # Or raise an exception here to prevent startup:
            raise RuntimeError("Failed to load any upstream API keys. Cannot start service.")

    # Store keys and create the cycling iterator in app.state
    app.state.upstream_api_keys = upstream_keys
    app.state.key_iterator = itertools.cycle(upstream_keys)
    logger.info(f"Initialized key rotation with {len(upstream_keys)} keys.")

    # --- Initialize HTTPX Client ---
    logger.info("Initializing main HTTPX client...")
    timeout = httpx.Timeout(5.0, read=180.0, write=5.0, connect=5.0)
    client = httpx.AsyncClient(timeout=timeout) # No base_url needed if using full URLs
    app.state.http_client = client # Store client in app.state
    logger.info("HTTPX client initialized.")

    yield # Application runs here

    # --- Cleanup ---
    logger.info("Closing HTTPX client...")
    await app.state.http_client.aclose()
    logger.info("HTTPX client closed.")
    app.state.upstream_api_keys = [] # Clear keys
    app.state.key_iterator = None
    logger.info("Upstream keys cleared.")

# --- FastAPI App ---
app = FastAPI(lifespan=lifespan)

# --- Streaming Helper ---
async def yield_openai_chunks(response_body):
    """Asynchronously yields chunks from the upstream response stream."""
    # (Content remains the same as before)
    logger.info("Starting to stream chunks from upstream...")
    try:
        resp_json = json.loads(response_body)

        for choices in resp_json["choices"]:
            if "message" in choices:
                choices["delta"] = choices["message"]
                del choices["message"]

        yield "data:" + json.dumps(resp_json) + "\n\n"
        yield "data: [DONE]"
    except Exception as e:
        logger.error(f"Error during streaming upstream response: {e}")
    finally:
        logger.info("Upstream streaming response closed.")

def format_messages_for_xplain(messages) -> str:
    """格式化消息列表为DeepSider API所需的提示格式"""
    prompt = ""
    for msg in messages:
        role = msg["role"]
        # 将OpenAI的角色映射到DeepSider能理解的格式
        if role == "system":
            # 系统消息放在开头 作为指导
            prompt = f"{msg['content']}\n\n" + prompt
        elif role == "user":
            prompt += f"Human: {msg['content']}\n\n"
        elif role == "assistant":
            content = msg['content']
            prompt += f"Assistant: {msg['content']}\n\n"

        else:
            # 其他角色按用户处理
            prompt += f"Human ({role}): {msg['content']}\n\n"

    # 如果最后一个消息不是用户的 添加一个Human前缀引导模型回答
    if messages and messages[-1]["role"] != "user":
        prompt += "Human: "

    return prompt.strip()


# --- Proxy Endpoint ---
@app.post("/v1/chat/completions")
async def proxy_openai_chat(request: Request, _client_key: str = Depends(verify_api_key)): # Use Depends for auth
    """
    Proxies requests to the configured Chat Completions endpoint AFTER verifying client API key.
    Uses rotated keys for upstream authentication.
    """
    client: httpx.AsyncClient = request.app.state.http_client
    key_iterator = request.app.state.key_iterator

    if not client or not key_iterator:
        logger.error("HTTPX client or Key Iterator not available (app state issue).")
        raise HTTPException(status_code=503, detail="Service temporarily unavailable")

    # --- Get Next Upstream API Key ---
    try:
        current_upstream_key = next(key_iterator)
        # Log rotation (optional, consider security of logging key info)
        # logger.info(f"Using upstream key ending: ...{current_upstream_key[-4:]}")
    except StopIteration:
        # This should not happen if lifespan logic is correct and keys were loaded
        logger.error("Upstream key iterator exhausted unexpectedly.")
        raise HTTPException(status_code=500, detail="Internal Server Error: Key rotation failed")
    except Exception as e:
         logger.error(f"Unexpected error getting next key: {e}")
         raise HTTPException(status_code=500, detail="Internal Server Error: Key rotation failed")

    # --- Get Request Data ---
    try:
        request_body = await request.body()
        payload = json.loads(request_body)
    except json.JSONDecodeError:
        raise HTTPException(status_code=400, detail="Invalid JSON body")
    type = payload.get("type", "claude")

    is_streaming = payload.get("stream", False)

    # --- Prepare Upstream Request ---
    upstream_headers = {
        "Content-Type": request.headers.get("Content-Type", "application/json"),
        "Accept": request.headers.get("Accept", "application/json"),
    }

    # --- Upstream Authentication (Using Rotated Key) ---
    # Decide based on the target API (e.g., freeaichatplayground vs standard OpenAI)
    if "freeaichatplayground.com" in OPENAI_API_BASE:
        logger.debug("Using payload apiKey for upstream authentication (freeaichatplayground specific).")
        payload["apiKey"] = current_upstream_key # Inject ROTATED key into payload
    else:
        # Default to standard Bearer token authentication for upstream
        logger.debug("Using Authorization header for upstream authentication.")
        upstream_headers["Authorization"] = f"Bearer {current_upstream_key}" # Use ROTATED key
        upstream_headers["x-api-key"] = f"{current_upstream_key}" # Use ROTATED key

    if is_streaming and "text/event-stream" not in upstream_headers["Accept"]:
         logger.info("Adding 'Accept: text/event-stream' for streaming request")
         upstream_headers["Accept"] = "text/event-stream, application/json"

    logger.info(f"Forwarding request to {OPENAI_CHAT_ENDPOINT} (Streaming: {is_streaming})")



    # --- Make Request to Upstream ---
    response = None # Define response here to ensure it's available in finally block
    response_body = None
    try:
        if type == "claude":
            newPayload = {
                "text": format_messages_for_xplain(payload["messages"]),
            }
            req = client.build_request("POST","https://models.aixplain.com/api/v1/execute/" + payload["model"],json=newPayload,headers=upstream_headers)
            response = await client.send(req)
            xreqId = response.json()["data"]

            req = client.build_request("GET",xreqId)
            tryCount = 0
            while tryCount < 100:
                response = await client.send(req)
                tryCount += 1

                rj = response.json()
                if rj["completed"] == True:
                    data = {"id": "123456-456789-123456", "object": "chat.completion.chunk",
                            "choices": [{"delta": {"content": rj["data"]}, "index": 0, "finish_reason": None}]}
                    response_body = json.dumps(data)
                    break
                time.sleep(1)
        else:
            req = client.build_request(
                "POST",
                OPENAI_CHAT_ENDPOINT, # Use the full URL
                json=payload,
                headers=upstream_headers,
            )
            response = await client.send(req, stream=True)
            if response.status_code >= 400:
                error_body = await response.aread() # Read error fully
                await response.aclose()
                logger.error(f"Upstream API returned error: {response.status_code} Key ending: ...{current_upstream_key[-4:]} Body: {error_body.decode()}")
                try: detail = json.loads(error_body)
                except json.JSONDecodeError: detail = error_body.decode()
                raise HTTPException(status_code=response.status_code, detail=detail)
            response_body = await response.aread()
            response_body = response_body.decode()

        # --- Handle Streaming Response ---
        if is_streaming:
            logger.info(f"Received OK streaming response from upstream (Status: {response.status_code}). Piping to client.")
            status_code = response.status_code
            if status_code == 201:
                status_code = 200
            return StreamingResponse(
                yield_openai_chunks(response_body),  # Generator handles closing response
                status_code=status_code,
                media_type=response.headers.get("content-type", "text/event-stream"),
            )
        # --- Handle Non-Streaming Response ---
        else:
            logger.info(f"Received OK non-streaming response from upstream (Status: {response.status_code}). Reading full body.")
            response_body = await response.aread()
            await response.aclose() # Ensure closed
            content_type = response.headers.get("content-type", "application/json")
            return Response( # Return raw response, FastAPI handles JSON content type
                content=response_body,
                status_code=response.status_code,
                media_type=content_type,
            )

    except httpx.TimeoutException as e:
        logger.error(f"Request to upstream timed out: {e}")
        if response: await response.aclose()
        raise HTTPException(status_code=504, detail="Request to upstream API timed out.")
    except httpx.RequestError as e:
        logger.error(f"Error requesting upstream API: {e}")
        if response: await response.aclose()
        raise HTTPException(status_code=502, detail=f"Error contacting upstream API: {e}")
    except HTTPException as e:
        # Re-raise FastAPI HTTPExceptions (like the 4xx check above)
        if response and not response.is_closed: await response.aclose()
        raise e
    except Exception as e:
        logger.exception("An unexpected error occurred during response processing.")
        if response and not response.is_closed: await response.aclose()
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")


# --- Health Check Endpoint ---
@app.get("/health")
async def health_check():
    """Simple health check endpoint."""
    # Could add checks here, e.g., if keys were loaded
    key_count = len(app.state.upstream_api_keys) if hasattr(app.state, 'upstream_api_keys') else 0
    return {"status": "ok", "upstream_keys_loaded": key_count > 0, "key_count": key_count}

# --- Main Execution Guard ---
if __name__ == "__main__":
    import uvicorn
    # Startup checks are implicitly handled by config loading at the top
    print("--- Starting FastAPI OpenAI Proxy with Custom Auth & Key Rotation ---")
    print(f"Proxying requests to: {OPENAI_CHAT_ENDPOINT}")
    if UPSTREAM_KEYS_URL:
        print(f"Fetching upstream keys from: {UPSTREAM_KEYS_URL}")
    elif DEFAULT_OPENAI_API_KEY:
        print("Using single OPENAI_API_KEY for upstream.")
    else:
         print("ERROR: No upstream key source configured!") # Should have failed earlier
    print(f"Clients must provide a valid API key in 'Authorization: Bearer <key>' header.")
    print(f"Number of allowed client keys configured: {len(ALLOWED_KEYS)}")
    print("---")

    uvicorn.run(app, host="0.0.0.0", port=7860)