LLM_demo / app.py
chenjianfei
debug
2c96aac
import ollama
import gradio as gr
import numpy as np
import json
from tts_api import TTSapi, DEFAULT_TTS_MODEL_NAME
from config import *
from utils import *
from knowledge_base import LocalRAG, CosPlayer
def handle_retry(history, thinking_history, config, section_state, retry_data: gr.RetryData):
# 获取用户之前的消息
previous_message = history[retry_data.index]['content']
# 清除后续的回复和思考过程
new_history = history[:retry_data.index]
section_state['chat_history'] = section_state['chat_history'][:retry_data.index + 1]
try:
items = thinking_history.split('\n==================\n')
if len(items) > 2:
new_thinking_history = '\n==================\n'.join(items[:-2])
else:
new_thinking_history = ''
items = section_state['thinking_history'].split('\n==================\n')
if len(items) > 2:
section_state['thinking_history'] = '\n==================\n'.join(items[:-2])
else:
section_state['thinking_history'] = ''
except Exception as e:
print('-----------------------------------')
print(e)
print('-----------------------------------')
print('思考过程发生异常,重置为空')
section_state['thinking_history'] = ''
new_thinking_history = ''
# 重新生成回复
return predict(previous_message, new_history, new_thinking_history, config, section_state)
def predict(message, chat_history, thinking_history, config, section_state):
global local_rag, TTS_LOADED, synthesiser
print(config)
print(f"当前模式:{config['mode_selected']}")
print(f'角色扮演描述:{config["character_description"]}')
print(f"写入角色设定方式:{config['character_setting_mode']}")
print(f"选中LLM:{config['llm_model']}")
print(f"是否使用RAG本地知识库:{config['kb_on']}")
print(f"选中知识库:{config['current_knowledge_base']}")
print(f"是否联网搜索:{config['net_on']}")
print(f"选中TTS模型:{config['tts_model']}")
print(f"是否合成语音:{config['tts_on']}")
print(f"参考音频路径:{config['ref_audio']}")
print(f"参考音频文本:{config['ref_audio_transcribe']}")
context = ''
net_search_res = []
docs = []
if config['kb_on'] and len(config['current_knowledge_base']) > 0:
# 检索相似文档
doc_and_scores = local_rag.vector_db.similarity_search(message, k=local_rag.rag_top_k)
# doc_and_scores = list(filter(lambda x: x[1] <= 0.4, doc_and_scores))
if len(doc_and_scores) > 0:
docs, scores = list(zip(*doc_and_scores))
docs, scores = list(docs), list(scores)
context_local = "【本地知识库】" + "\n".join([concate_metadata(d.metadata) + d.page_content for d in docs])
context = context + context_local
if config['net_on']:
# 检索相似文档
ret = web_search(message, max_results=MAX_RESULTS)
net_search_res = parse_net_search(ret)
context_net = "\n【网络搜索结果】" + ''.join(net_search_res)
context = context + context_net
if config['character_description']:
if config['character_setting_mode'] == 'by system':
if len(section_state['chat_history']) == 0 or section_state['chat_history'][0]['role'] != 'system':
section_state['chat_history'].insert(0, {"role": "system", "content": config["character_description"]})
elif config['character_setting_mode'] == 'by prompt':
if len(section_state['chat_history']) > 0 and section_state['chat_history'][0]['role'] == 'system':
section_state['chat_history'].pop(0)
context = f'【系统核心设定】:{config["character_description"]}\n' if config["character_description"] else '' + context
else:
raise ValueError(f"未知的角色设定模式:{config['character_setting_mode']}")
if len(context) > 0:
prompt = f"""请充分理解以下上下文信息,并结合当前及历史对话产生回复':\n
上下文:{context}
用户当前输入:{message}
回复:
"""
input_message = section_state["chat_history"] + [{"role": "user", "content": prompt}]
else:
input_message = section_state["chat_history"] + [{"role": "user", "content": message}]
# 关闭Qwen3系列默认的思考模式
if config['llm_model'].startswith('qwen3'):
input_message[-1]['content'] += '/no_think'
# input_message[-1]['content'] += '/no_think'
# 添加用户消息到历史
section_state["chat_history"].append({"role": "user", "content": message})
# 计算当前上下文长度,动态调整上下文窗口的长度,规避ollama的限制
try:
tokenizer = load_tokenizer(config['llm_model'])
except Exception as e:
if config['llm_model'] in BASE_MODEL_TABLE:
tokenizer = load_tokenizer(BASE_MODEL_TABLE[config['llm_model']])
else:
raise e
token_cnt = count_tokens_local(input_message, tokenizer)
if token_cnt >= MAX_MODEL_CTX:
gr.Warning("当前对话已经超出模型上下文长度,请开启新会话...")
try:
# 调用模型
response = ollama.chat(
model=config['llm_model'],
messages=input_message,
stream=False,
options={'num_ctx': min(int(token_cnt * 1.2), MAX_MODEL_CTX)}
)
# 解析响应
thinking, response_content = parse_output(response['message']['content'])
# 更新对话历史
chat_history.append({'role': 'user', 'content': message})
if len(context) > 0:
# 构建带折叠结构的消息
formatted_response = f"""
<details class="rag-details">
<summary style='cursor: pointer; color: #666;'>
🔍 检索完成✅(共{len(docs)+len(net_search_res)}条)
</summary>
<div style='margin:10px 0;padding:10px;background:#f5f5f5;border-radius:8px;'>
{
"<br>".join(
["<br>".join(wash_up_content(content if isinstance(content, str) else (content.page_content, scores[idx])))
for idx, content in enumerate(docs + net_search_res)]
)
}
</div>
</details>
<div style="margin-top: 10px;">{response_content}</div> <!-- 增加顶部间距容器 -->
"""
chat_history.append({'role': 'assistant', 'content': formatted_response})
else:
chat_history.append({'role': 'assistant', 'content': response_content})
thinking_history += f"User: {message}\nThinking: {thinking}" + '\n==================\n'
# 添加助手响应到历史
section_state["chat_history"].append({"role": "assistant", "content": response_content})
section_state["thinking_history"] += f"User: {message}\nThinking: {thinking}" + '\n==================\n'
if (not config['tts_on']) or len(response_content) == 0:
audio_output = np.array([0], dtype=np.int16)
if len(response_content) == 0:
print("LLM 回复为空,无法合成语音")
else:
if not TTS_LOADED:
print('TTS模型首次加载...')
gr.Info("初次加载TTS模型,请稍候..", duration=63)
synthesiser = TTSapi(model_name=config['tts_model'])
TTS_LOADED = True
print('加载完毕...')
# 检查当前模型是否是所选
if config['tts_model'] != synthesiser.model_name:
print(f'当前TTS模型{synthesiser.model_name}非所选,重新加载')
synthesiser.reload(model_name=config['tts_model'])
# 如果提供了参考音频,则需把参考音频的文本加在response_content前面作为前缀
if config['ref_audio']:
prompt_text = config['ref_audio_transcribe']
if prompt_text is None:
# prompt_text = ...
raise NotImplementedError('暂时必须提供文本') # TODO:考虑后续加入ASR模型
response_content = prompt_text + response_content
audio_output = synthesiser.forward(response_content, speech_prompt=config['ref_audio'])
except Exception as e:
print('!!!!!!!!')
print(e)
print('!!!!!!!!')
error_msg = f"Error: {str(e)}"
chat_history.append((message, error_msg))
thinking_history += f"Error occurred: {str(e)}" + '\n'
return "", chat_history, thinking_history, (synthesiser.sr if synthesiser else 16000, audio_output)
def init_model(init_llm=False, init_rag=False, init_tts=False):
if init_llm:
print(f'正在加载LLM:{DEFAULT_MODEL_NAME}...')
ollama.chat(model=DEFAULT_MODEL_NAME, messages=[])
if init_rag:
gr.Info("正在加载知识库,请稍候...")
local_rag = LocalRAG(rag_top_k=RAG_TOP_K)
else:
local_rag =None
if init_tts:
print(f'正在加载TTS模型:{DEFAULT_TTS_MODEL_NAME}...')
synthesiser = TTSapi()
TTS_LOADED = True
else:
synthesiser = None
TTS_LOADED = False
return local_rag, synthesiser, TTS_LOADED
if __name__ == "__main__":
import time
st = time.time()
print('********************模型加载中************************')
local_rag, synthesiser, TTS_LOADED = init_model()
print('********************模型加载完成************************')
print('耗时:',time.time() - st)
state = {}
resp, state = log_in(0, state)
cosplayer = CosPlayer(description_file=DEFAULT_COSPLAY_SETTING)
print("===== 初始化开始 =====")
with gr.Blocks(css=CSS, title="LLM Chat Demo", theme=gr.themes.Soft(font=["sans-serif", "Arial"])) as demo:
gr.Markdown("""
# LLM Chat Demo
## 用法介绍
### 用户登录
* 输入用户名,点击Log In按钮。首次登录会自动创建用户目录,聊天记录会保存在下面,如不登录,默认为公共目录'0'
### 模型选择
目前支持Qwen、Deepseek-R1蒸馏系列等部分模型,可下拉菜单选择
### 高级设置
* 模式选择:可以选择角色扮演模式/普通模式
* 角色设定选择:支持加载不同角色设定文件
* 角色配置方式:
* by system: 角色设定将作为system prompt存在于输入首部
* by prompt: 角色设定每次被添加到当前上下文中
* 知识库配置: 支持自由选择、组合知识库
""")
section_state = gr.State(value=state) # 创建会话状态对象
with gr.Row():
uid_input = gr.Textbox(label="Type Your UID:")
response = gr.Textbox(label='', value=resp)
login_button = gr.Button("Log In")
llm_select = gr.Dropdown(label= "模型选择", choices=AVALIABLE_MODELS, value=DEFAULT_MODEL_NAME, visible=True)
gr.Markdown("## 高级设置")
with gr.Accordion("点击展开折叠", open=False, visible=True):
mode_select = gr.Radio(label='模式选择', choices=SUPPORT_MODES, value=DEFAULT_MODE)
coser_select = gr.Dropdown(label= "角色设定选择", choices=cosplayer.get_all_characters(), value=DEFAULT_COSPLAY_SETTING, visible=True)
coser_setting = gr.Radio(label='角色配置方式', choices=CHARACTER_SETTING_MODES, value=DEFAULT_C_SETTING_MODE, visible=True)
kb_select = gr.Dropdown(label= "知识库配置", choices=AVALIABLE_KNOWLEDGE_BASE, value=None, visible=True, multiselect=True)
with gr.Row():
# 页面左侧
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="对话记录", height=500, show_copy_button=True, type='messages')
with gr.Row():
msg = gr.Textbox(label="输入消息", placeholder="请输入您的问题...", scale=7)
with gr.Column(scale=1, min_width=15):
with gr.Row():
rag_switch = gr.Checkbox(label="本地RAG", value=False, info="")
net_switch = gr.Checkbox(label="联网搜索", value=False, info="")
submit_btn = gr.Button("发送", variant="primary", min_width=15)#, , elem_classes=['custom-btn']
with gr.Row():
gr.Examples(
examples=[[example] for example in EXAMPLES],
inputs=msg,
outputs=chatbot,
fn=predict,
visible=True,
cache_examples=False
)
with gr.Row():
save_btn = gr.Button("保存对话")
clear_btn = gr.Button("清空对话")
chat_history_select = gr.Dropdown(label='加载历史对话', choices=state['available_history'], visible=True, interactive=True)
# 页面右侧
with gr.Column(scale=2):
thinking_display = gr.TextArea(label="思考过程",interactive=False,
placeholder="模型思考过程将在此显示..."
)
tts_switch = gr.Checkbox(label="TTS开关", value=False, info="Check me to hear voice")
with gr.Tabs() as audio_tabs:
# 选项卡1:音频播放
with gr.Tab("音频输出", id="audio_output"):
audio_player = gr.Audio(
label="听听我声音~",
type="numpy",
interactive=False
)
# 选项卡2:TTS配置
with gr.Tab("TTS配置", id="tts_config"):
# TTS模型选择
tts_model = gr.Dropdown(
label="选择TTS模型",
choices=["ChatTTS", "HKUSTAudio/LLasa-1B", "HKUSTAudio/LLasa-3B", "HKUSTAudio/LLasa-8B"],
value=DEFAULT_TTS_MODEL_NAME,
interactive=True
)
# 参考音频上传
ref_audio = gr.Audio(
label="上传参考音频",
type="filepath",
interactive=True
)
ref_audio_transcribe = gr.Textbox(label="参考音频对应文本", visible=True)
# ================= 状态管理 =================
current_config = gr.State({
"llm_model": DEFAULT_MODEL_NAME,
"tts_model": DEFAULT_TTS_MODEL_NAME,
"tts_on": False,
"kb_on": False,
"net_on": False,
"ref_audio": None,
"ref_audio_transcribe": None,
"mode_selected": DEFAULT_MODE,
"character_description": cosplayer.get_core_setting(),
"character_setting_mode": DEFAULT_C_SETTING_MODE,
"current_knowledge_base": AVALIABLE_KNOWLEDGE_BASE[0]
})
# 事件处理
login_button.click(log_in, inputs=[uid_input, section_state], outputs=[response, section_state])
gr.on(triggers=[llm_select.change, tts_model.change, ref_audio.change,
ref_audio_transcribe.change, tts_switch.select, rag_switch.select, net_switch.select,
mode_select.change],
fn=lambda model1, model2, audio, text, tts_on, kb_on, net_on, mode, character_setting, kb_select: {"llm_model": model1, "tts_model": model2, "ref_audio": audio,
"ref_audio_transcribe": text, "tts_on": tts_on, "kb_on": kb_on, 'net_on': net_on,
"mode_selected": mode, "character_description": None if mode == '普通模式' else cosplayer.get_core_setting(),
"character_setting_mode": character_setting, "current_knowledge_base": kb_select},
inputs=[llm_select, tts_model, ref_audio, ref_audio_transcribe, tts_switch, rag_switch, net_switch, mode_select, coser_setting, kb_select],
outputs=current_config
)
msg.submit(
predict,
[msg, chatbot, thinking_display, current_config, section_state],
[msg, chatbot, thinking_display, audio_player],
queue=False
)
chatbot.retry(fn=handle_retry,
inputs=[chatbot, thinking_display, current_config, section_state],
outputs=[msg, chatbot, thinking_display, audio_player])
submit_btn.click(
predict,
[msg, chatbot, thinking_display, current_config, section_state],
[msg, chatbot, thinking_display, audio_player],
queue=False
)
def save_chat(state):
from datetime import datetime
now = datetime.now().strftime('%Y%m%d_%H%M%S')
with open(state['user_dir'] / f'chat_history_{now}.json', 'w', encoding='utf-8') as file:
json.dump(state["chat_history"], file, ensure_ascii=False, indent=4)
with open(state['user_dir'] / f'thinking_history_{now}.txt', 'w') as file:
if isinstance(state["thinking_history"], list):
for item in state["thinking_history"]:
file.write(item + '\n')
else:
file.write(state["thinking_history"])
gr.Info("聊天记录已保存!")
state['available_history'].append(f'chat_history_{now}')
return state
def clear_chat(state):
state["chat_history"] = []
state["thinking_history"] = []
prologue = cosplayer.get_prologue()
if prologue:
state['chat_history'].append({'role': 'assistant', 'content': prologue})
chatbot = [{'role': 'assistant', 'content': prologue}]
else:
chatbot = []
return chatbot, [], state
def load_chat(state, chat_file):
# NOTE: 加载历史聊天记录。一般在对话开始之前加载,如果本次对话已经开始,本操作会覆盖当前会话内容
if chat_file:
think_file = chat_file.replace("chat_", "thinking_")
chat_file_path = state['user_dir'] / (chat_file + '.json')
think_file_path = state['user_dir'] / (think_file + '.txt')
if not chat_file_path.exists():
gr.Warning(f'聊天记录文件:{chat_file}.json不存在, 加载失败')
return [], '', state
with open(chat_file_path, 'r', encoding='utf-8') as f:
content = json.load(f)
state['chat_history'] = content
think = ''
if think_file_path.exists():
with open(think_file_path, 'r') as f:
think = f.read()
state['thinking_history'] = think
# 转换成chatbot可以识别的格式
# bot_content = parse_chat_history(content)
# 指定chatbot类型为message后,无需解析
bot_content = content
return bot_content, think, state
return [], '', state
def update_history(state):
return gr.update(choices=state['available_history'])
def update_visible(mode):
if mode != '普通模式':
gr.Warning("当前为角色扮演模式,请确认已配置好该角色的知识库...")
return gr.update(visible=True), gr.update(visible=True)
return gr.update(visible=False), gr.update(visible=False)
def update_cosplay(cos_select, config, chatbot, think_display, state):
cosplayer.update(cos_select)
config['character_description'] = cosplayer.get_core_setting()
# 角色设定发生改变后,自动保存当前聊天记录,之后清空历史记录
if len(state['chat_history']) > 1:
state = save_chat(state)
gr.Warning("我的角色已更换,对话已重置。请检查知识库是否需要更新...")
chatbot, think_display, state = clear_chat(state)
return gr.update(value=cos_select), config, chatbot, think_display, state
def update_character_setting_mode(coser_setting, config):
config['character_setting_mode'] = coser_setting
return gr.update(value=coser_setting), config
def update_knowledge_base(knowledge_base, config):
global local_rag
config['current_knowledge_base'] = knowledge_base
if len(knowledge_base) == 0:
gr.Warning("当前未选中任何知识库,本地RAG将失效。请确认...")
else:
if local_rag is None:
gr.Info("初次加载知识库,请稍候...")
local_rag = LocalRAG(rag_top_k=RAG_TOP_K, doc_dir=knowledge_base)
gr.Info("知识库加载完成!")
else:
gr.Info("重新加载知识库,请稍候...")
local_rag.reload_knowledge_base(knowledge_base)
gr.Info("知识库加载完成!")
return gr.update(value=knowledge_base), config
def init_kb(rag_on, kb_select, config):
global local_rag
if rag_on:
# 初始化本地知识库
if config['mode_selected'] == "角色扮演":
gr.Warning("当前为角色扮演模式,请确认已配置好该角色的知识库...")
if local_rag is None:
gr.Info("初次加载知识库,请稍候...")
local_rag = LocalRAG(rag_top_k=RAG_TOP_K, doc_dir=kb_select)
gr.Info("知识库加载完成!")
return gr.update(value=rag_on)
# 选择非普通模式时(角色扮演),会展示可控选择的角色设定列表
mode_select.change(update_visible,
inputs=mode_select,
outputs=[coser_select, coser_setting])
coser_select.change(update_cosplay,
inputs=[coser_select, current_config, chatbot, thinking_display, section_state],
outputs=[coser_select, current_config, chatbot, thinking_display, section_state])
# TODO: 根据角色变化动态展示示例
# coser_select.change(update_examples,
# inputs=[coser_select],
# outputs=[examples_show])
coser_setting.change(update_character_setting_mode,
inputs=[coser_setting, current_config],
outputs=[coser_setting, current_config])
kb_select.change(update_knowledge_base,
inputs=[kb_select, current_config],
outputs=[kb_select, current_config])
# 勾选本地知识库时,若为角色扮演模式,提醒用户设置知识库目录
rag_switch.select(init_kb, inputs=[rag_switch, kb_select, current_config], outputs=rag_switch)
clear_btn.click(
clear_chat,
inputs=section_state,
outputs=[chatbot, thinking_display, section_state],
queue=False
)
save_btn.click(
save_chat,
inputs=section_state,
outputs=section_state,
queue=False
).then(
fn=update_history,
inputs=section_state,
outputs=chat_history_select
)
chat_history_select.change(load_chat,
inputs=[section_state, chat_history_select],
outputs=[chatbot, thinking_display, section_state])
section_state.change(update_history,
inputs=section_state,
outputs=chat_history_select)
print("===== 初始化完成 =====")
demo.launch(share=True)