File size: 13,425 Bytes
dba6b4c 23f04f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 |
import json
import os
import re
import time
import uuid
import asyncio
import requests
import logging
from typing import Optional, List
from pydantic import BaseModel
from fastapi import FastAPI, Request, Response, Depends, HTTPException, status
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# 设置API密钥
API_KEY = os.getenv("API_KEY", "linux.do")
# 创建FastAPI应用
app = FastAPI()
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 定义请求模型
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 0.6
top_p: Optional[float] = 0.95
stream: Optional[bool] = True
system_message: Optional[str] = "You are a helpful assistant."
# 获取会话cookie的函数
async def get_session_cookie():
try:
response = requests.get(
"https://chat.akash.network/api/auth/session",
headers={"Content-Type": "application/json"}
)
if response.cookies:
return '; '.join([f"{cookie.name}={cookie.value}" for cookie in response.cookies])
return ""
except Exception as e:
logger.error(f"获取会话cookie失败: {e}")
return ""
# API密钥验证依赖项
async def verify_api_key(request: Request):
auth_header = request.headers.get("Authorization")
if auth_header != "Bearer " + API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={"success": False, "message": "Unauthorized: Invalid API key"}
)
return True
# 获取模型列表的端点
@app.get("/v1/models")
async def get_models(authorized: bool = Depends(verify_api_key)):
# 获取最新的会话cookie
cookie = await get_session_cookie()
# 构建请求头
headers = {"Content-Type": "application/json", "Cookie": cookie}
# 发送请求到Akash API
response = requests.get("https://chat.akash.network/api/models", headers=headers)
models_data = response.json()
current_timestamp = int(time.time())
converted_data = {
"object": "list",
"data": [
{
"id": model["id"],
"object": "model",
"created": current_timestamp,
"owned_by": "openai" if "Meta" in model["id"] else "third_party",
"permissions": [],
"root": model["id"],
"parent": None,
"capabilities": {
"temperature": model["temperature"],
"top_p": model["top_p"],
},
"name": model["name"],
"description": model["description"],
"available": model["available"],
}
for model in models_data.get("models", [])
],
}
return converted_data
# 聊天完成端点
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, authorized: bool = Depends(verify_api_key)):
logger.info("开始处理聊天完成请求")
try:
# 获取请求体
body = await request.json()
# 获取最新的会话cookie
cookie = await get_session_cookie()
# 检查模型参数
if not body.get("model"):
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"error": "Missing model parameter"}
)
# 获取可用模型列表
models_response = requests.get(
"https://chat.akash.network/api/models",
headers={"Content-Type": "application/json", "Cookie": cookie}
)
models_data = models_response.json()
available_models = [model["id"] for model in models_data.get("models", [])]
# 模型名称修正
requested_model = body.get("model")
matched_model = next((model for model in available_models if model.lower() == requested_model.lower()), None)
if not matched_model:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"error": f"Model '{requested_model}' not found"}
)
# 生成唯一ID
chat_id = str(uuid.uuid4()).replace("-", "")[:16]
# 构建Akash格式的请求数据
akash_data = {
"id": chat_id,
"messages": body.get("messages", []),
"model": matched_model, # 使用修正后的模型名称
"system": body.get("system_message", "You are a helpful assistant."),
"temperature": body.get("temperature", 0.6),
"topP": body.get("top_p", 0.95),
}
# 构建请求头
headers = {"Content-Type": "application/json", "Cookie": cookie}
_stream = body.get("stream", True)
# AkashGen模型特殊处理
if body.get("model", "DeepSeek-R1") == "AkashGen":
_stream = False
# 发送请求到Akash API
response = requests.post(
"https://chat.akash.network/api/chat",
json=akash_data,
headers=headers,
stream=_stream,
)
logger.debug(f"Akash API响应: {response.text}")
# 处理流式响应
if _stream is True:
async def generate():
content_buffer = ""
for line in response.iter_lines():
if not line:
continue
try:
# 解析行数据,格式为 "type:json_data"
line_str = line.decode("utf-8")
msg_type, msg_data = line_str.split(":", 1)
# 处理内容类型的消息
if msg_type == "0":
# 只去掉两边的双引号
if msg_data.startswith('"') and msg_data.endswith('"'):
msg_data = msg_data.replace('\\"', '"')
msg_data = msg_data[1:-1]
msg_data = msg_data.replace("\\n", "\n")
content_buffer += msg_data
# 构建 OpenAI 格式的响应块
chunk = {
"id": f"chatcmpl-{chat_id}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": body.get("model", "DeepSeek-R1"),
"choices": [
{
"delta": {"content": msg_data},
"index": 0,
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
# 处理结束消息
elif msg_type in ["e", "d"]:
chunk = {
"id": f"chatcmpl-{chat_id}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": body.get("model", "DeepSeek-R1"),
"choices": [
{"delta": {}, "index": 0, "finish_reason": "stop"}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
break
except Exception as e:
logger.error(f"处理响应行时出错: {e}")
continue
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
},
)
else:
# 处理非流式响应
if body.get("model", "DeepSeek-R1") != "AkashGen":
text_matches = re.findall(r'0:"(.*?)"', response.text)
parsed_text = "".join(text_matches)
response_data = {
"object": "chat.completion",
"created": int(time.time() * 1000),
"model": body.get("model", "DeepSeek-R1"),
"choices": [
{
"index": 0,
"message": {"role": "user", "content": parsed_text},
"finish_reason": "stop",
}
],
}
logger.debug(json.dumps(response_data, ensure_ascii=False))
return Response(
content=json.dumps(response_data, ensure_ascii=False),
status_code=response.status_code,
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/json",
},
)
else:
# 处理AkashGen模型(图像生成)
match = re.search(r"jobId='([^']+)'", response.text)
job_id = None
if match:
job_id = match.group(1)
logger.info(f"获取到图像生成任务ID: {job_id}")
# 轮询图像生成状态
async def check_image_status():
while True:
try:
_img_response = requests.get(
"https://chat.akash.network/api/image-status?ids=" + job_id,
headers=headers,
)
_data = _img_response.json()
# 检查图像是否生成完成
if _data[0]["status"] == "completed":
logger.info(f"图片生成完成: {job_id}")
return _data[0]["result"]
else:
logger.debug(f"图片生成中: {job_id}")
except Exception as e:
logger.error(f"请求图像状态时出现异常: {e}")
# 每隔5秒请求一次
await asyncio.sleep(5)
# 等待图像生成完成
image_url = await check_image_status()
# 返回包含图像URL的响应
response_data = {
"object": "chat.completion",
"created": int(time.time() * 1000),
"model": body.get("model", "DeepSeek-R1"),
"choices": [
{
"index": 0,
"message": {
"role": "user",
"content": f"根据您的描述,这里是一张生成的图片:\n\n",
},
"finish_reason": "stop",
}
],
}
return Response(
content=json.dumps(response_data, ensure_ascii=False),
status_code=response.status_code,
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/json",
},
)
else:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"error": "当前官方服务异常"}
)
except Exception as e:
logger.error(f"处理聊天完成请求时发生错误: {e}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"error": str(e)}
)
# 启动服务器
if __name__ == "__main__":
import uvicorn
logger.info("启动FastAPI服务器,监听端口8080")
uvicorn.run("app:app", host="0.0.0.0", port=8080, reload=True) |