|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import HTMLResponse, FileResponse |
|
import uvicorn |
|
import json |
|
import asyncio |
|
import os |
|
from pathlib import Path |
|
from datetime import datetime |
|
from bw_utils import is_image, load_json_file |
|
from BookWorld import BookWorld |
|
|
|
app = FastAPI() |
|
default_icon_path = './frontend/assets/images/default-icon.jpg' |
|
config = load_json_file('config.json') |
|
for key in config: |
|
if "API_KEY" in key: |
|
os.environ[key] = config[key] |
|
|
|
static_file_abspath = os.path.abspath(os.path.join(os.path.dirname(__file__), 'frontend')) |
|
app.mount("/frontend", StaticFiles(directory=static_file_abspath), name="frontend") |
|
|
|
class ConnectionManager: |
|
def __init__(self): |
|
self.active_connections: dict[str, WebSocket] = {} |
|
self.story_tasks: dict[str, asyncio.Task] = {} |
|
if True: |
|
if "preset_path" in config and config["preset_path"] and os.path.exists(config["preset_path"]): |
|
preset_path = config["preset_path"] |
|
elif "genre" in config and config["genre"]: |
|
genre = config["genre"] |
|
preset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),f"./config/experiment_{genre}.json") |
|
else: |
|
raise ValueError("Please set the preset_path in `config.json`.") |
|
self.bw = BookWorld(preset_path = preset_path, |
|
world_llm_name = config["world_llm_name"], |
|
role_llm_name = config["role_llm_name"], |
|
embedding_name = config["embedding_model_name"]) |
|
self.bw.set_generator(rounds = config["rounds"], |
|
save_dir = config["save_dir"], |
|
if_save = config["if_save"], |
|
mode = config["mode"], |
|
scene_mode = config["scene_mode"],) |
|
else: |
|
from BookWorld_test import BookWorld_test |
|
self.bw = BookWorld_test() |
|
|
|
async def connect(self, websocket: WebSocket, client_id: str): |
|
await websocket.accept() |
|
self.active_connections[client_id] = websocket |
|
|
|
def disconnect(self, client_id: str): |
|
if client_id in self.active_connections: |
|
del self.active_connections[client_id] |
|
self.stop_story(client_id) |
|
|
|
def stop_story(self, client_id: str): |
|
if client_id in self.story_tasks: |
|
self.story_tasks[client_id].cancel() |
|
del self.story_tasks[client_id] |
|
|
|
async def start_story(self, client_id: str): |
|
if client_id in self.story_tasks: |
|
|
|
self.stop_story(client_id) |
|
|
|
|
|
self.story_tasks[client_id] = asyncio.create_task( |
|
self.generate_story(client_id) |
|
) |
|
|
|
async def generate_story(self, client_id: str): |
|
"""持续生成故事的协程""" |
|
try: |
|
while True: |
|
if client_id in self.active_connections: |
|
message,status = await self.get_next_message() |
|
await self.active_connections[client_id].send_json({ |
|
'type': 'message', |
|
'data': message |
|
}) |
|
await self.active_connections[client_id].send_json({ |
|
'type': 'status_update', |
|
'data': status |
|
}) |
|
|
|
await asyncio.sleep(1) |
|
else: |
|
break |
|
except asyncio.CancelledError: |
|
|
|
print(f"Story generation cancelled for client {client_id}") |
|
except Exception as e: |
|
print(f"Error in generate_story: {e}") |
|
|
|
async def get_initial_data(self): |
|
"""获取初始化数据""" |
|
return { |
|
'characters': self.bw.get_characters_info(), |
|
'map': self.bw.get_map_info(), |
|
'settings': self.bw.get_settings_info(), |
|
'status': self.bw.get_current_status(), |
|
'history_messages':self.bw.get_history_messages(save_dir = config["save_dir"]), |
|
} |
|
|
|
async def get_next_message(self): |
|
"""从BookWorld获取下一条消息""" |
|
message = self.bw.generate_next_message() |
|
if not is_image(message["icon"]): |
|
message["icon"] = default_icon_path |
|
status = self.bw.get_current_status() |
|
return message,status |
|
|
|
manager = ConnectionManager() |
|
|
|
@app.get("/") |
|
async def get(): |
|
html_file = Path("index.html") |
|
return HTMLResponse(html_file.read_text()) |
|
|
|
@app.get("/data/{full_path:path}") |
|
async def get_file(full_path: str): |
|
|
|
base_paths = [ |
|
Path("/data/") |
|
] |
|
|
|
for base_path in base_paths: |
|
file_path = base_path / full_path |
|
if file_path.exists() and file_path.is_file(): |
|
return FileResponse(file_path) |
|
else: |
|
return FileResponse(default_icon_path) |
|
|
|
raise HTTPException(status_code=404, detail="File not found") |
|
|
|
@app.websocket("/ws/{client_id}") |
|
async def websocket_endpoint(websocket: WebSocket, client_id: str): |
|
await manager.connect(websocket, client_id) |
|
try: |
|
initial_data = await manager.get_initial_data() |
|
await websocket.send_json({ |
|
'type': 'initial_data', |
|
'data': initial_data |
|
}) |
|
|
|
while True: |
|
data = await websocket.receive_text() |
|
message = json.loads(data) |
|
|
|
if message['type'] == 'user_message': |
|
|
|
await websocket.send_json({ |
|
'type': 'message', |
|
'data': { |
|
'username': 'User', |
|
'timestamp': message['timestamp'], |
|
'text': message['text'], |
|
'icon': default_icon_path, |
|
} |
|
}) |
|
|
|
elif message['type'] == 'control': |
|
|
|
if message['action'] == 'start': |
|
await manager.start_story(client_id) |
|
elif message['action'] == 'pause': |
|
manager.stop_story(client_id) |
|
elif message['action'] == 'stop': |
|
manager.stop_story(client_id) |
|
|
|
|
|
elif message['type'] == 'edit_message': |
|
|
|
edit_data = message['data'] |
|
|
|
manager.bw.handle_message_edit( |
|
record_id=edit_data['uuid'], |
|
new_text=edit_data['text'] |
|
) |
|
|
|
elif message['type'] == 'request_scene_characters': |
|
manager.bw.select_scene(message['scene']) |
|
scene_characters = manager.bw.get_characters_info() |
|
await websocket.send_json({ |
|
'type': 'scene_characters', |
|
'data': scene_characters |
|
}) |
|
|
|
elif message['type'] == 'generate_story': |
|
|
|
story_text = manager.bw.generate_story() |
|
|
|
await websocket.send_json({ |
|
'type': 'message', |
|
'data': { |
|
'username': 'System', |
|
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
'text': story_text, |
|
'icon': default_icon_path, |
|
'type': 'story' |
|
} |
|
}) |
|
|
|
elif message['type'] == 'request_api_configs': |
|
await websocket.send_json({ |
|
'type': 'api_configs', |
|
'data': API_CONFIGS |
|
}) |
|
|
|
elif message['type'] == 'api_settings': |
|
|
|
settings = message['data'] |
|
|
|
os.environ[settings['envKey']] = settings['apiKey'] |
|
|
|
|
|
manager.bw.update_api_settings( |
|
provider=settings['provider'], |
|
model=settings['model'] |
|
) |
|
|
|
|
|
await websocket.send_json({ |
|
'type': 'message', |
|
'data': { |
|
'username': 'System', |
|
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
'text': f'已更新 {settings["provider"]} API设置', |
|
'icon': default_icon_path, |
|
'type': 'system' |
|
} |
|
}) |
|
except Exception as e: |
|
print(f"WebSocket error: {e}") |
|
finally: |
|
manager.disconnect(client_id) |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True) |
|
|