File size: 13,433 Bytes
dba6b4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import json
import os
import re
import time
import uuid
import asyncio
import requests
import logging

from typing import Optional, List
from pydantic import BaseModel
from fastapi import FastAPI, Request, Response, Depends, HTTPException, status
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# 设置API密钥
API_KEY = os.getenv("API_KEY", "linux.do")

# 创建FastAPI应用
app = FastAPI()

# 添加CORS中间件
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 定义请求模型
class ChatMessage(BaseModel):
    role: str
    content: str

class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: Optional[float] = 0.6
    top_p: Optional[float] = 0.95
    stream: Optional[bool] = True
    system_message: Optional[str] = "You are a helpful assistant."

# 获取会话cookie的函数
async def get_session_cookie():
    try:
        response = requests.get(
            "https://chat.akash.network/api/auth/session",
            headers={"Content-Type": "application/json"}
        )
        if response.cookies:
            return '; '.join([f"{cookie.name}={cookie.value}" for cookie in response.cookies])
        return ""
    except Exception as e:
        logger.error(f"获取会话cookie失败: {e}")
        return ""

# API密钥验证依赖项
async def verify_api_key(request: Request):
    auth_header = request.headers.get("Authorization")
    if auth_header != "Bearer " + API_KEY:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail={"success": False, "message": "Unauthorized: Invalid API key"}
        )
    return True

# 获取模型列表的端点
@app.get("/v1/models")
async def get_models(authorized: bool = Depends(verify_api_key)):
    # 获取最新的会话cookie
    cookie = await get_session_cookie()
    # 构建请求头
    headers = {"Content-Type": "application/json", "Cookie": cookie}
    # 发送请求到Akash API
    response = requests.get("https://chat.akash.network/api/models", headers=headers)
    models_data = response.json()
    current_timestamp = int(time.time())
    converted_data = {
        "object": "list",
        "data": [
            {
                "id": model["id"],
                "object": "model",
                "created": current_timestamp,
                "owned_by": "openai" if "Meta" in model["id"] else "third_party",
                "permissions": [],
                "root": model["id"],
                "parent": None,
                "capabilities": {
                    "temperature": model["temperature"],
                    "top_p": model["top_p"],
                },
                "name": model["name"],
                "description": model["description"],
                "available": model["available"],
            }
            for model in models_data.get("models", [])
        ],
    }
    return converted_data

# 聊天完成端点
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, authorized: bool = Depends(verify_api_key)):
    logger.info("开始处理聊天完成请求")
    try:
        # 获取请求体
        body = await request.json()
        
        # 获取最新的会话cookie
        cookie = await get_session_cookie()
        
        # 检查模型参数
        if not body.get("model"):
            return JSONResponse(
                status_code=status.HTTP_400_BAD_REQUEST,
                content={"error": "Missing model parameter"}
            )

        # 获取可用模型列表
        models_response = requests.get(
            "https://chat.akash.network/api/models", 
            headers={"Content-Type": "application/json", "Cookie": cookie}
        )
        models_data = models_response.json()
        available_models = [model["id"] for model in models_data.get("models", [])]

        # 模型名称修正
        requested_model = body.get("model")
        matched_model = next((model for model in available_models if model.lower() == requested_model.lower()), None)

        if not matched_model:
            return JSONResponse(
                status_code=status.HTTP_400_BAD_REQUEST,
                content={"error": f"Model '{requested_model}' not found"}
            )

        # 生成唯一ID
        chat_id = str(uuid.uuid4()).replace("-", "")[:16]

        # 构建Akash格式的请求数据
        akash_data = {
            "id": chat_id,
            "messages": body.get("messages", []),
            "model": matched_model,  # 使用修正后的模型名称
            "system": body.get("system_message", "You are a helpful assistant."),
            "temperature": body.get("temperature", 0.6),
            "topP": body.get("top_p", 0.95),
        }
        
        # 构建请求头
        headers = {"Content-Type": "application/json", "Cookie": cookie}
        _stream = body.get("stream", True)
        
        # AkashGen模型特殊处理
        if body.get("model", "DeepSeek-R1") == "AkashGen":
            _stream = False
            
        # 发送请求到Akash API
        response = requests.post(
            "https://chat.akash.network/api/chat",
            json=akash_data,
            headers=headers,
            stream=_stream,
        )

        logger.debug(f"Akash API响应: {response.text}")

        # 处理流式响应
        if _stream is True:
            async def generate():
                content_buffer = ""
                for line in response.iter_lines():
                    if not line:
                        continue

                    try:
                        # 解析行数据,格式为 "type:json_data"
                        line_str = line.decode("utf-8")
                        msg_type, msg_data = line_str.split(":", 1)

                        # 处理内容类型的消息
                        if msg_type == "0":
                            # 只去掉两边的双引号
                            if msg_data.startswith('"') and msg_data.endswith('"'):
                                msg_data = msg_data.replace('\\"', '"')
                                msg_data = msg_data[1:-1]
                            msg_data = msg_data.replace("\\n", "\n")
                            content_buffer += msg_data

                            # 构建 OpenAI 格式的响应块
                            chunk = {
                                "id": f"chatcmpl-{chat_id}",
                                "object": "chat.completion.chunk",
                                "created": int(time.time()),
                                "model": body.get("model", "DeepSeek-R1"),
                                "choices": [
                                    {
                                        "delta": {"content": msg_data},
                                        "index": 0,
                                        "finish_reason": None,
                                    }
                                ],
                            }
                            yield f"data: {json.dumps(chunk)}\n\n"

                        # 处理结束消息
                        elif msg_type in ["e", "d"]:
                            chunk = {
                                "id": f"chatcmpl-{chat_id}",
                                "object": "chat.completion.chunk",
                                "created": int(time.time()),
                                "model": body.get("model", "DeepSeek-R1"),
                                "choices": [
                                    {"delta": {}, "index": 0, "finish_reason": "stop"}
                                ],
                            }
                            yield f"data: {json.dumps(chunk)}\n\n"
                            yield "data: [DONE]\n\n"
                            break

                    except Exception as e:
                        logger.error(f"处理响应行时出错: {e}")
                        continue

            return StreamingResponse(
                generate(),
                media_type="text/event-stream",
                headers={
                    "Cache-Control": "no-cache",
                    "Connection": "keep-alive",
                    "Content-Type": "text/event-stream",
                },
            )
        else:
            # 处理非流式响应
            if body.get("model", "DeepSeek-R1") != "AkashGen":
                text_matches = re.findall(r'0:"(.*?)"', response.text)
                parsed_text = "".join(text_matches)
                response_data = {
                    "object": "chat.completion",
                    "created": int(time.time() * 1000),
                    "model": body.get("model", "DeepSeek-R1"),
                    "choices": [
                        {
                            "index": 0,
                            "message": {"role": "user", "content": parsed_text},
                            "finish_reason": "stop",
                        }
                    ],
                }
                logger.debug(json.dumps(response_data, ensure_ascii=False))
                return Response(
                    content=json.dumps(response_data, ensure_ascii=False),
                    status_code=response.status_code,
                    headers={
                        "Cache-Control": "no-cache",
                        "Connection": "keep-alive",
                        "Content-Type": "application/json",
                    },
                )
            else:
                # 处理AkashGen模型(图像生成)
                match = re.search(r"jobId='([^']+)'", response.text)
                job_id = None
                if match:
                    job_id = match.group(1)
                    logger.info(f"获取到图像生成任务ID: {job_id}")
                    
                    # 轮询图像生成状态
                    async def check_image_status():
                        while True:
                            try:
                                _img_response = requests.get(
                                    "https://chat.akash.network/api/image-status?ids=" + job_id,
                                    headers=headers,
                                )
                                _data = _img_response.json()
                                
                                # 检查图像是否生成完成
                                if _data[0]["status"] == "completed":
                                    logger.info(f"图片生成完成: {job_id}")
                                    return _data[0]["result"]
                                else:
                                    logger.debug(f"图片生成中: {job_id}")
                            except Exception as e:
                                logger.error(f"请求图像状态时出现异常: {e}")

                            # 每隔5秒请求一次
                            await asyncio.sleep(5)
                    
                    # 等待图像生成完成
                    image_url = await check_image_status()
                    
                    # 返回包含图像URL的响应
                    response_data = {
                        "object": "chat.completion",
                        "created": int(time.time() * 1000),
                        "model": body.get("model", "DeepSeek-R1"),
                        "choices": [
                            {
                                "index": 0,
                                "message": {
                                    "role": "user",
                                    "content": f"根据您的描述,这里是一张生成的图片:\n\n![生成的图片]({image_url})",
                                },
                                "finish_reason": "stop",
                            }
                        ],
                    }
                    
                    return Response(
                        content=json.dumps(response_data, ensure_ascii=False),
                        status_code=response.status_code,
                        headers={
                            "Cache-Control": "no-cache",
                            "Connection": "keep-alive",
                            "Content-Type": "application/json",
                        },
                    )
                else:
                    return JSONResponse(
                        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                        content={"error": "当前官方服务异常"}
                    )

    except Exception as e:
        logger.error(f"处理聊天完成请求时发生错误: {e}")
        return JSONResponse(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            content={"error": str(e)}
        )

# 启动服务器
if __name__ == "__main__":
    import uvicorn
    logger.info("启动FastAPI服务器,监听端口5200")
    uvicorn.run("app_fastapi:app", host="0.0.0.0", port=5200, reload=True)