import ollama import gradio as gr import numpy as np import json from tts_api import TTSapi, DEFAULT_TTS_MODEL_NAME from config import * from utils import * from knowledge_base import LocalRAG, CosPlayer def handle_retry(history, thinking_history, config, section_state, retry_data: gr.RetryData): # 获取用户之前的消息 previous_message = history[retry_data.index]['content'] # 清除后续的回复和思考过程 new_history = history[:retry_data.index] section_state['chat_history'] = section_state['chat_history'][:retry_data.index + 1] try: items = thinking_history.split('\n==================\n') if len(items) > 2: new_thinking_history = '\n==================\n'.join(items[:-2]) else: new_thinking_history = '' items = section_state['thinking_history'].split('\n==================\n') if len(items) > 2: section_state['thinking_history'] = '\n==================\n'.join(items[:-2]) else: section_state['thinking_history'] = '' except Exception as e: print('-----------------------------------') print(e) print('-----------------------------------') print('思考过程发生异常,重置为空') section_state['thinking_history'] = '' new_thinking_history = '' # 重新生成回复 return predict(previous_message, new_history, new_thinking_history, config, section_state) def predict(message, chat_history, thinking_history, config, section_state): global local_rag, TTS_LOADED, synthesiser print(config) print(f"当前模式:{config['mode_selected']}") print(f'角色扮演描述:{config["character_description"]}') print(f"写入角色设定方式:{config['character_setting_mode']}") print(f"选中LLM:{config['llm_model']}") print(f"是否使用RAG本地知识库:{config['kb_on']}") print(f"选中知识库:{config['current_knowledge_base']}") print(f"是否联网搜索:{config['net_on']}") print(f"选中TTS模型:{config['tts_model']}") print(f"是否合成语音:{config['tts_on']}") print(f"参考音频路径:{config['ref_audio']}") print(f"参考音频文本:{config['ref_audio_transcribe']}") context = '' net_search_res = [] docs = [] if config['kb_on'] and len(config['current_knowledge_base']) > 0: # 检索相似文档 doc_and_scores = local_rag.vector_db.similarity_search(message, k=local_rag.rag_top_k) # doc_and_scores = list(filter(lambda x: x[1] <= 0.4, doc_and_scores)) if len(doc_and_scores) > 0: docs, scores = list(zip(*doc_and_scores)) docs, scores = list(docs), list(scores) context_local = "【本地知识库】" + "\n".join([concate_metadata(d.metadata) + d.page_content for d in docs]) context = context + context_local if config['net_on']: # 检索相似文档 ret = web_search(message, max_results=MAX_RESULTS) net_search_res = parse_net_search(ret) context_net = "\n【网络搜索结果】" + ''.join(net_search_res) context = context + context_net if config['character_description']: if config['character_setting_mode'] == 'by system': if len(section_state['chat_history']) == 0 or section_state['chat_history'][0]['role'] != 'system': section_state['chat_history'].insert(0, {"role": "system", "content": config["character_description"]}) elif config['character_setting_mode'] == 'by prompt': if len(section_state['chat_history']) > 0 and section_state['chat_history'][0]['role'] == 'system': section_state['chat_history'].pop(0) context = f'【系统核心设定】:{config["character_description"]}\n' if config["character_description"] else '' + context else: raise ValueError(f"未知的角色设定模式:{config['character_setting_mode']}") if len(context) > 0: prompt = f"""请充分理解以下上下文信息,并结合当前及历史对话产生回复':\n 上下文:{context} 用户当前输入:{message} 回复: """ input_message = section_state["chat_history"] + [{"role": "user", "content": prompt}] else: input_message = section_state["chat_history"] + [{"role": "user", "content": message}] # 关闭Qwen3系列默认的思考模式 if config['llm_model'].startswith('qwen3'): input_message[-1]['content'] += '/no_think' # input_message[-1]['content'] += '/no_think' # 添加用户消息到历史 section_state["chat_history"].append({"role": "user", "content": message}) # 计算当前上下文长度,动态调整上下文窗口的长度,规避ollama的限制 try: tokenizer = load_tokenizer(config['llm_model']) except Exception as e: if config['llm_model'] in BASE_MODEL_TABLE: tokenizer = load_tokenizer(BASE_MODEL_TABLE[config['llm_model']]) else: raise e token_cnt = count_tokens_local(input_message, tokenizer) if token_cnt >= MAX_MODEL_CTX: gr.Warning("当前对话已经超出模型上下文长度,请开启新会话...") try: # 调用模型 response = ollama.chat( model=config['llm_model'], messages=input_message, stream=False, options={'num_ctx': min(int(token_cnt * 1.2), MAX_MODEL_CTX)} ) # 解析响应 thinking, response_content = parse_output(response['message']['content']) # 更新对话历史 chat_history.append({'role': 'user', 'content': message}) if len(context) > 0: # 构建带折叠结构的消息 formatted_response = f"""
🔍 检索完成✅(共{len(docs)+len(net_search_res)}条)
{ "
".join( ["
".join(wash_up_content(content if isinstance(content, str) else (content.page_content, scores[idx]))) for idx, content in enumerate(docs + net_search_res)] ) }
{response_content}
""" chat_history.append({'role': 'assistant', 'content': formatted_response}) else: chat_history.append({'role': 'assistant', 'content': response_content}) thinking_history += f"User: {message}\nThinking: {thinking}" + '\n==================\n' # 添加助手响应到历史 section_state["chat_history"].append({"role": "assistant", "content": response_content}) section_state["thinking_history"] += f"User: {message}\nThinking: {thinking}" + '\n==================\n' if (not config['tts_on']) or len(response_content) == 0: audio_output = np.array([0], dtype=np.int16) if len(response_content) == 0: print("LLM 回复为空,无法合成语音") else: if not TTS_LOADED: print('TTS模型首次加载...') gr.Info("初次加载TTS模型,请稍候..", duration=63) synthesiser = TTSapi(model_name=config['tts_model']) TTS_LOADED = True print('加载完毕...') # 检查当前模型是否是所选 if config['tts_model'] != synthesiser.model_name: print(f'当前TTS模型{synthesiser.model_name}非所选,重新加载') synthesiser.reload(model_name=config['tts_model']) # 如果提供了参考音频,则需把参考音频的文本加在response_content前面作为前缀 if config['ref_audio']: prompt_text = config['ref_audio_transcribe'] if prompt_text is None: # prompt_text = ... raise NotImplementedError('暂时必须提供文本') # TODO:考虑后续加入ASR模型 response_content = prompt_text + response_content audio_output = synthesiser.forward(response_content, speech_prompt=config['ref_audio']) except Exception as e: print('!!!!!!!!') print(e) print('!!!!!!!!') error_msg = f"Error: {str(e)}" chat_history.append((message, error_msg)) thinking_history += f"Error occurred: {str(e)}" + '\n' return "", chat_history, thinking_history, (synthesiser.sr if synthesiser else 16000, audio_output) def init_model(init_llm=False, init_rag=False, init_tts=False): if init_llm: print(f'正在加载LLM:{DEFAULT_MODEL_NAME}...') ollama.chat(model=DEFAULT_MODEL_NAME, messages=[]) if init_rag: gr.Info("正在加载知识库,请稍候...") local_rag = LocalRAG(rag_top_k=RAG_TOP_K) else: local_rag =None if init_tts: print(f'正在加载TTS模型:{DEFAULT_TTS_MODEL_NAME}...') synthesiser = TTSapi() TTS_LOADED = True else: synthesiser = None TTS_LOADED = False return local_rag, synthesiser, TTS_LOADED if __name__ == "__main__": import time st = time.time() print('********************模型加载中************************') local_rag, synthesiser, TTS_LOADED = init_model() print('********************模型加载完成************************') print('耗时:',time.time() - st) state = {} resp, state = log_in(0, state) cosplayer = CosPlayer(description_file=DEFAULT_COSPLAY_SETTING) print("===== 初始化开始 =====") with gr.Blocks(css=CSS, title="LLM Chat Demo", theme=gr.themes.Soft(font=["sans-serif", "Arial"])) as demo: gr.Markdown(""" # LLM Chat Demo ## 用法介绍 ### 用户登录 * 输入用户名,点击Log In按钮。首次登录会自动创建用户目录,聊天记录会保存在下面,如不登录,默认为公共目录'0' ### 模型选择 目前支持Qwen、Deepseek-R1蒸馏系列等部分模型,可下拉菜单选择 ### 高级设置 * 模式选择:可以选择角色扮演模式/普通模式 * 角色设定选择:支持加载不同角色设定文件 * 角色配置方式: * by system: 角色设定将作为system prompt存在于输入首部 * by prompt: 角色设定每次被添加到当前上下文中 * 知识库配置: 支持自由选择、组合知识库 """) section_state = gr.State(value=state) # 创建会话状态对象 with gr.Row(): uid_input = gr.Textbox(label="Type Your UID:") response = gr.Textbox(label='', value=resp) login_button = gr.Button("Log In") llm_select = gr.Dropdown(label= "模型选择", choices=AVALIABLE_MODELS, value=DEFAULT_MODEL_NAME, visible=True) gr.Markdown("## 高级设置") with gr.Accordion("点击展开折叠", open=False, visible=True): mode_select = gr.Radio(label='模式选择', choices=SUPPORT_MODES, value=DEFAULT_MODE) coser_select = gr.Dropdown(label= "角色设定选择", choices=cosplayer.get_all_characters(), value=DEFAULT_COSPLAY_SETTING, visible=True) coser_setting = gr.Radio(label='角色配置方式', choices=CHARACTER_SETTING_MODES, value=DEFAULT_C_SETTING_MODE, visible=True) kb_select = gr.Dropdown(label= "知识库配置", choices=AVALIABLE_KNOWLEDGE_BASE, value=None, visible=True, multiselect=True) with gr.Row(): # 页面左侧 with gr.Column(scale=3): chatbot = gr.Chatbot(label="对话记录", height=500, show_copy_button=True, type='messages') with gr.Row(): msg = gr.Textbox(label="输入消息", placeholder="请输入您的问题...", scale=7) with gr.Column(scale=1, min_width=15): with gr.Row(): rag_switch = gr.Checkbox(label="本地RAG", value=False, info="") net_switch = gr.Checkbox(label="联网搜索", value=False, info="") submit_btn = gr.Button("发送", variant="primary", min_width=15)#, , elem_classes=['custom-btn'] with gr.Row(): gr.Examples( examples=[[example] for example in EXAMPLES], inputs=msg, outputs=chatbot, fn=predict, visible=True, cache_examples=False ) with gr.Row(): save_btn = gr.Button("保存对话") clear_btn = gr.Button("清空对话") chat_history_select = gr.Dropdown(label='加载历史对话', choices=state['available_history'], visible=True, interactive=True) # 页面右侧 with gr.Column(scale=2): thinking_display = gr.TextArea(label="思考过程",interactive=False, placeholder="模型思考过程将在此显示..." ) tts_switch = gr.Checkbox(label="TTS开关", value=False, info="Check me to hear voice") with gr.Tabs() as audio_tabs: # 选项卡1:音频播放 with gr.Tab("音频输出", id="audio_output"): audio_player = gr.Audio( label="听听我声音~", type="numpy", interactive=False ) # 选项卡2:TTS配置 with gr.Tab("TTS配置", id="tts_config"): # TTS模型选择 tts_model = gr.Dropdown( label="选择TTS模型", choices=["ChatTTS", "HKUSTAudio/LLasa-1B", "HKUSTAudio/LLasa-3B", "HKUSTAudio/LLasa-8B"], value=DEFAULT_TTS_MODEL_NAME, interactive=True ) # 参考音频上传 ref_audio = gr.Audio( label="上传参考音频", type="filepath", interactive=True ) ref_audio_transcribe = gr.Textbox(label="参考音频对应文本", visible=True) # ================= 状态管理 ================= current_config = gr.State({ "llm_model": DEFAULT_MODEL_NAME, "tts_model": DEFAULT_TTS_MODEL_NAME, "tts_on": False, "kb_on": False, "net_on": False, "ref_audio": None, "ref_audio_transcribe": None, "mode_selected": DEFAULT_MODE, "character_description": cosplayer.get_core_setting(), "character_setting_mode": DEFAULT_C_SETTING_MODE, "current_knowledge_base": AVALIABLE_KNOWLEDGE_BASE[0] }) # 事件处理 login_button.click(log_in, inputs=[uid_input, section_state], outputs=[response, section_state]) gr.on(triggers=[llm_select.change, tts_model.change, ref_audio.change, ref_audio_transcribe.change, tts_switch.select, rag_switch.select, net_switch.select, mode_select.change], fn=lambda model1, model2, audio, text, tts_on, kb_on, net_on, mode, character_setting, kb_select: {"llm_model": model1, "tts_model": model2, "ref_audio": audio, "ref_audio_transcribe": text, "tts_on": tts_on, "kb_on": kb_on, 'net_on': net_on, "mode_selected": mode, "character_description": None if mode == '普通模式' else cosplayer.get_core_setting(), "character_setting_mode": character_setting, "current_knowledge_base": kb_select}, inputs=[llm_select, tts_model, ref_audio, ref_audio_transcribe, tts_switch, rag_switch, net_switch, mode_select, coser_setting, kb_select], outputs=current_config ) msg.submit( predict, [msg, chatbot, thinking_display, current_config, section_state], [msg, chatbot, thinking_display, audio_player], queue=False ) chatbot.retry(fn=handle_retry, inputs=[chatbot, thinking_display, current_config, section_state], outputs=[msg, chatbot, thinking_display, audio_player]) submit_btn.click( predict, [msg, chatbot, thinking_display, current_config, section_state], [msg, chatbot, thinking_display, audio_player], queue=False ) def save_chat(state): from datetime import datetime now = datetime.now().strftime('%Y%m%d_%H%M%S') with open(state['user_dir'] / f'chat_history_{now}.json', 'w', encoding='utf-8') as file: json.dump(state["chat_history"], file, ensure_ascii=False, indent=4) with open(state['user_dir'] / f'thinking_history_{now}.txt', 'w') as file: if isinstance(state["thinking_history"], list): for item in state["thinking_history"]: file.write(item + '\n') else: file.write(state["thinking_history"]) gr.Info("聊天记录已保存!") state['available_history'].append(f'chat_history_{now}') return state def clear_chat(state): state["chat_history"] = [] state["thinking_history"] = [] prologue = cosplayer.get_prologue() if prologue: state['chat_history'].append({'role': 'assistant', 'content': prologue}) chatbot = [{'role': 'assistant', 'content': prologue}] else: chatbot = [] return chatbot, [], state def load_chat(state, chat_file): # NOTE: 加载历史聊天记录。一般在对话开始之前加载,如果本次对话已经开始,本操作会覆盖当前会话内容 if chat_file: think_file = chat_file.replace("chat_", "thinking_") chat_file_path = state['user_dir'] / (chat_file + '.json') think_file_path = state['user_dir'] / (think_file + '.txt') if not chat_file_path.exists(): gr.Warning(f'聊天记录文件:{chat_file}.json不存在, 加载失败') return [], '', state with open(chat_file_path, 'r', encoding='utf-8') as f: content = json.load(f) state['chat_history'] = content think = '' if think_file_path.exists(): with open(think_file_path, 'r') as f: think = f.read() state['thinking_history'] = think # 转换成chatbot可以识别的格式 # bot_content = parse_chat_history(content) # 指定chatbot类型为message后,无需解析 bot_content = content return bot_content, think, state return [], '', state def update_history(state): return gr.update(choices=state['available_history']) def update_visible(mode): if mode != '普通模式': gr.Warning("当前为角色扮演模式,请确认已配置好该角色的知识库...") return gr.update(visible=True), gr.update(visible=True) return gr.update(visible=False), gr.update(visible=False) def update_cosplay(cos_select, config, chatbot, think_display, state): cosplayer.update(cos_select) config['character_description'] = cosplayer.get_core_setting() # 角色设定发生改变后,自动保存当前聊天记录,之后清空历史记录 if len(state['chat_history']) > 1: state = save_chat(state) gr.Warning("我的角色已更换,对话已重置。请检查知识库是否需要更新...") chatbot, think_display, state = clear_chat(state) return gr.update(value=cos_select), config, chatbot, think_display, state def update_character_setting_mode(coser_setting, config): config['character_setting_mode'] = coser_setting return gr.update(value=coser_setting), config def update_knowledge_base(knowledge_base, config): global local_rag config['current_knowledge_base'] = knowledge_base if len(knowledge_base) == 0: gr.Warning("当前未选中任何知识库,本地RAG将失效。请确认...") else: if local_rag is None: gr.Info("初次加载知识库,请稍候...") local_rag = LocalRAG(rag_top_k=RAG_TOP_K, doc_dir=knowledge_base) gr.Info("知识库加载完成!") else: gr.Info("重新加载知识库,请稍候...") local_rag.reload_knowledge_base(knowledge_base) gr.Info("知识库加载完成!") return gr.update(value=knowledge_base), config def init_kb(rag_on, kb_select, config): global local_rag if rag_on: # 初始化本地知识库 if config['mode_selected'] == "角色扮演": gr.Warning("当前为角色扮演模式,请确认已配置好该角色的知识库...") if local_rag is None: gr.Info("初次加载知识库,请稍候...") local_rag = LocalRAG(rag_top_k=RAG_TOP_K, doc_dir=kb_select) gr.Info("知识库加载完成!") return gr.update(value=rag_on) # 选择非普通模式时(角色扮演),会展示可控选择的角色设定列表 mode_select.change(update_visible, inputs=mode_select, outputs=[coser_select, coser_setting]) coser_select.change(update_cosplay, inputs=[coser_select, current_config, chatbot, thinking_display, section_state], outputs=[coser_select, current_config, chatbot, thinking_display, section_state]) # TODO: 根据角色变化动态展示示例 # coser_select.change(update_examples, # inputs=[coser_select], # outputs=[examples_show]) coser_setting.change(update_character_setting_mode, inputs=[coser_setting, current_config], outputs=[coser_setting, current_config]) kb_select.change(update_knowledge_base, inputs=[kb_select, current_config], outputs=[kb_select, current_config]) # 勾选本地知识库时,若为角色扮演模式,提醒用户设置知识库目录 rag_switch.select(init_kb, inputs=[rag_switch, kb_select, current_config], outputs=rag_switch) clear_btn.click( clear_chat, inputs=section_state, outputs=[chatbot, thinking_display, section_state], queue=False ) save_btn.click( save_chat, inputs=section_state, outputs=section_state, queue=False ).then( fn=update_history, inputs=section_state, outputs=chat_history_select ) chat_history_select.change(load_chat, inputs=[section_state, chat_history_select], outputs=[chatbot, thinking_display, section_state]) section_state.change(update_history, inputs=section_state, outputs=chat_history_select) print("===== 初始化完成 =====") demo.launch(share=True)