|
import ollama |
|
import gradio as gr |
|
import numpy as np |
|
import json |
|
from tts_api import TTSapi, DEFAULT_TTS_MODEL_NAME |
|
from config import * |
|
from utils import * |
|
from knowledge_base import LocalRAG, CosPlayer |
|
|
|
|
|
def handle_retry(history, thinking_history, config, section_state, retry_data: gr.RetryData): |
|
|
|
previous_message = history[retry_data.index]['content'] |
|
|
|
new_history = history[:retry_data.index] |
|
section_state['chat_history'] = section_state['chat_history'][:retry_data.index + 1] |
|
|
|
try: |
|
items = thinking_history.split('\n==================\n') |
|
if len(items) > 2: |
|
new_thinking_history = '\n==================\n'.join(items[:-2]) |
|
else: |
|
new_thinking_history = '' |
|
|
|
items = section_state['thinking_history'].split('\n==================\n') |
|
if len(items) > 2: |
|
section_state['thinking_history'] = '\n==================\n'.join(items[:-2]) |
|
else: |
|
section_state['thinking_history'] = '' |
|
except Exception as e: |
|
print('-----------------------------------') |
|
print(e) |
|
print('-----------------------------------') |
|
print('思考过程发生异常,重置为空') |
|
section_state['thinking_history'] = '' |
|
new_thinking_history = '' |
|
|
|
return predict(previous_message, new_history, new_thinking_history, config, section_state) |
|
|
|
|
|
def predict(message, chat_history, thinking_history, config, section_state): |
|
global local_rag, TTS_LOADED, synthesiser |
|
print(config) |
|
print(f"当前模式:{config['mode_selected']}") |
|
print(f'角色扮演描述:{config["character_description"]}') |
|
print(f"写入角色设定方式:{config['character_setting_mode']}") |
|
print(f"选中LLM:{config['llm_model']}") |
|
print(f"是否使用RAG本地知识库:{config['kb_on']}") |
|
print(f"选中知识库:{config['current_knowledge_base']}") |
|
print(f"是否联网搜索:{config['net_on']}") |
|
print(f"选中TTS模型:{config['tts_model']}") |
|
print(f"是否合成语音:{config['tts_on']}") |
|
print(f"参考音频路径:{config['ref_audio']}") |
|
print(f"参考音频文本:{config['ref_audio_transcribe']}") |
|
|
|
context = '' |
|
net_search_res = [] |
|
docs = [] |
|
if config['kb_on'] and len(config['current_knowledge_base']) > 0: |
|
|
|
doc_and_scores = local_rag.vector_db.similarity_search(message, k=local_rag.rag_top_k) |
|
|
|
if len(doc_and_scores) > 0: |
|
docs, scores = list(zip(*doc_and_scores)) |
|
docs, scores = list(docs), list(scores) |
|
context_local = "【本地知识库】" + "\n".join([concate_metadata(d.metadata) + d.page_content for d in docs]) |
|
context = context + context_local |
|
if config['net_on']: |
|
|
|
ret = web_search(message, max_results=MAX_RESULTS) |
|
net_search_res = parse_net_search(ret) |
|
context_net = "\n【网络搜索结果】" + ''.join(net_search_res) |
|
context = context + context_net |
|
|
|
if config['character_description']: |
|
if config['character_setting_mode'] == 'by system': |
|
if len(section_state['chat_history']) == 0 or section_state['chat_history'][0]['role'] != 'system': |
|
section_state['chat_history'].insert(0, {"role": "system", "content": config["character_description"]}) |
|
elif config['character_setting_mode'] == 'by prompt': |
|
if len(section_state['chat_history']) > 0 and section_state['chat_history'][0]['role'] == 'system': |
|
section_state['chat_history'].pop(0) |
|
context = f'【系统核心设定】:{config["character_description"]}\n' if config["character_description"] else '' + context |
|
else: |
|
raise ValueError(f"未知的角色设定模式:{config['character_setting_mode']}") |
|
|
|
if len(context) > 0: |
|
prompt = f"""请充分理解以下上下文信息,并结合当前及历史对话产生回复':\n |
|
上下文:{context} |
|
用户当前输入:{message} |
|
回复: |
|
""" |
|
input_message = section_state["chat_history"] + [{"role": "user", "content": prompt}] |
|
else: |
|
input_message = section_state["chat_history"] + [{"role": "user", "content": message}] |
|
|
|
|
|
if config['llm_model'].startswith('qwen3'): |
|
input_message[-1]['content'] += '/no_think' |
|
|
|
|
|
|
|
section_state["chat_history"].append({"role": "user", "content": message}) |
|
|
|
|
|
try: |
|
tokenizer = load_tokenizer(config['llm_model']) |
|
except Exception as e: |
|
if config['llm_model'] in BASE_MODEL_TABLE: |
|
tokenizer = load_tokenizer(BASE_MODEL_TABLE[config['llm_model']]) |
|
else: |
|
raise e |
|
token_cnt = count_tokens_local(input_message, tokenizer) |
|
if token_cnt >= MAX_MODEL_CTX: |
|
gr.Warning("当前对话已经超出模型上下文长度,请开启新会话...") |
|
try: |
|
|
|
response = ollama.chat( |
|
model=config['llm_model'], |
|
messages=input_message, |
|
stream=False, |
|
options={'num_ctx': min(int(token_cnt * 1.2), MAX_MODEL_CTX)} |
|
) |
|
|
|
|
|
thinking, response_content = parse_output(response['message']['content']) |
|
|
|
|
|
chat_history.append({'role': 'user', 'content': message}) |
|
if len(context) > 0: |
|
|
|
formatted_response = f""" |
|
<details class="rag-details"> |
|
<summary style='cursor: pointer; color: #666;'> |
|
🔍 检索完成✅(共{len(docs)+len(net_search_res)}条) |
|
</summary> |
|
<div style='margin:10px 0;padding:10px;background:#f5f5f5;border-radius:8px;'> |
|
{ |
|
"<br>".join( |
|
["<br>".join(wash_up_content(content if isinstance(content, str) else (content.page_content, scores[idx]))) |
|
for idx, content in enumerate(docs + net_search_res)] |
|
) |
|
} |
|
</div> |
|
</details> |
|
<div style="margin-top: 10px;">{response_content}</div> <!-- 增加顶部间距容器 --> |
|
""" |
|
chat_history.append({'role': 'assistant', 'content': formatted_response}) |
|
else: |
|
chat_history.append({'role': 'assistant', 'content': response_content}) |
|
|
|
thinking_history += f"User: {message}\nThinking: {thinking}" + '\n==================\n' |
|
|
|
section_state["chat_history"].append({"role": "assistant", "content": response_content}) |
|
section_state["thinking_history"] += f"User: {message}\nThinking: {thinking}" + '\n==================\n' |
|
|
|
if (not config['tts_on']) or len(response_content) == 0: |
|
audio_output = np.array([0], dtype=np.int16) |
|
if len(response_content) == 0: |
|
print("LLM 回复为空,无法合成语音") |
|
else: |
|
if not TTS_LOADED: |
|
print('TTS模型首次加载...') |
|
gr.Info("初次加载TTS模型,请稍候..", duration=63) |
|
synthesiser = TTSapi(model_name=config['tts_model']) |
|
TTS_LOADED = True |
|
print('加载完毕...') |
|
|
|
if config['tts_model'] != synthesiser.model_name: |
|
print(f'当前TTS模型{synthesiser.model_name}非所选,重新加载') |
|
synthesiser.reload(model_name=config['tts_model']) |
|
|
|
|
|
if config['ref_audio']: |
|
prompt_text = config['ref_audio_transcribe'] |
|
if prompt_text is None: |
|
|
|
raise NotImplementedError('暂时必须提供文本') |
|
response_content = prompt_text + response_content |
|
|
|
audio_output = synthesiser.forward(response_content, speech_prompt=config['ref_audio']) |
|
|
|
except Exception as e: |
|
print('!!!!!!!!') |
|
print(e) |
|
print('!!!!!!!!') |
|
error_msg = f"Error: {str(e)}" |
|
chat_history.append((message, error_msg)) |
|
thinking_history += f"Error occurred: {str(e)}" + '\n' |
|
|
|
return "", chat_history, thinking_history, (synthesiser.sr if synthesiser else 16000, audio_output) |
|
|
|
|
|
def init_model(init_llm=False, init_rag=False, init_tts=False): |
|
if init_llm: |
|
print(f'正在加载LLM:{DEFAULT_MODEL_NAME}...') |
|
ollama.chat(model=DEFAULT_MODEL_NAME, messages=[]) |
|
|
|
if init_rag: |
|
gr.Info("正在加载知识库,请稍候...") |
|
local_rag = LocalRAG(rag_top_k=RAG_TOP_K) |
|
else: |
|
local_rag =None |
|
|
|
if init_tts: |
|
print(f'正在加载TTS模型:{DEFAULT_TTS_MODEL_NAME}...') |
|
synthesiser = TTSapi() |
|
TTS_LOADED = True |
|
else: |
|
synthesiser = None |
|
TTS_LOADED = False |
|
return local_rag, synthesiser, TTS_LOADED |
|
|
|
|
|
if __name__ == "__main__": |
|
import time |
|
st = time.time() |
|
print('********************模型加载中************************') |
|
local_rag, synthesiser, TTS_LOADED = init_model() |
|
print('********************模型加载完成************************') |
|
print('耗时:',time.time() - st) |
|
|
|
state = {} |
|
resp, state = log_in(0, state) |
|
cosplayer = CosPlayer(description_file=DEFAULT_COSPLAY_SETTING) |
|
print("===== 初始化开始 =====") |
|
with gr.Blocks(css=CSS, title="LLM Chat Demo", theme=gr.themes.Soft(font=["sans-serif", "Arial"])) as demo: |
|
gr.Markdown(""" |
|
# LLM Chat Demo |
|
## 用法介绍 |
|
### 用户登录 |
|
* 输入用户名,点击Log In按钮。首次登录会自动创建用户目录,聊天记录会保存在下面,如不登录,默认为公共目录'0' |
|
### 模型选择 |
|
目前支持Qwen、Deepseek-R1蒸馏系列等部分模型,可下拉菜单选择 |
|
### 高级设置 |
|
* 模式选择:可以选择角色扮演模式/普通模式 |
|
* 角色设定选择:支持加载不同角色设定文件 |
|
* 角色配置方式: |
|
* by system: 角色设定将作为system prompt存在于输入首部 |
|
* by prompt: 角色设定每次被添加到当前上下文中 |
|
* 知识库配置: 支持自由选择、组合知识库 |
|
""") |
|
section_state = gr.State(value=state) |
|
with gr.Row(): |
|
uid_input = gr.Textbox(label="Type Your UID:") |
|
response = gr.Textbox(label='', value=resp) |
|
login_button = gr.Button("Log In") |
|
llm_select = gr.Dropdown(label= "模型选择", choices=AVALIABLE_MODELS, value=DEFAULT_MODEL_NAME, visible=True) |
|
|
|
gr.Markdown("## 高级设置") |
|
with gr.Accordion("点击展开折叠", open=False, visible=True): |
|
mode_select = gr.Radio(label='模式选择', choices=SUPPORT_MODES, value=DEFAULT_MODE) |
|
coser_select = gr.Dropdown(label= "角色设定选择", choices=cosplayer.get_all_characters(), value=DEFAULT_COSPLAY_SETTING, visible=True) |
|
coser_setting = gr.Radio(label='角色配置方式', choices=CHARACTER_SETTING_MODES, value=DEFAULT_C_SETTING_MODE, visible=True) |
|
kb_select = gr.Dropdown(label= "知识库配置", choices=AVALIABLE_KNOWLEDGE_BASE, value=None, visible=True, multiselect=True) |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot(label="对话记录", height=500, show_copy_button=True, type='messages') |
|
with gr.Row(): |
|
msg = gr.Textbox(label="输入消息", placeholder="请输入您的问题...", scale=7) |
|
with gr.Column(scale=1, min_width=15): |
|
with gr.Row(): |
|
rag_switch = gr.Checkbox(label="本地RAG", value=False, info="") |
|
net_switch = gr.Checkbox(label="联网搜索", value=False, info="") |
|
|
|
submit_btn = gr.Button("发送", variant="primary", min_width=15) |
|
with gr.Row(): |
|
gr.Examples( |
|
examples=[[example] for example in EXAMPLES], |
|
inputs=msg, |
|
outputs=chatbot, |
|
fn=predict, |
|
visible=True, |
|
cache_examples=False |
|
) |
|
with gr.Row(): |
|
save_btn = gr.Button("保存对话") |
|
clear_btn = gr.Button("清空对话") |
|
chat_history_select = gr.Dropdown(label='加载历史对话', choices=state['available_history'], visible=True, interactive=True) |
|
|
|
|
|
with gr.Column(scale=2): |
|
thinking_display = gr.TextArea(label="思考过程",interactive=False, |
|
placeholder="模型思考过程将在此显示..." |
|
) |
|
tts_switch = gr.Checkbox(label="TTS开关", value=False, info="Check me to hear voice") |
|
with gr.Tabs() as audio_tabs: |
|
|
|
with gr.Tab("音频输出", id="audio_output"): |
|
audio_player = gr.Audio( |
|
label="听听我声音~", |
|
type="numpy", |
|
interactive=False |
|
) |
|
|
|
|
|
with gr.Tab("TTS配置", id="tts_config"): |
|
|
|
tts_model = gr.Dropdown( |
|
label="选择TTS模型", |
|
choices=["ChatTTS", "HKUSTAudio/LLasa-1B", "HKUSTAudio/LLasa-3B", "HKUSTAudio/LLasa-8B"], |
|
value=DEFAULT_TTS_MODEL_NAME, |
|
interactive=True |
|
) |
|
|
|
|
|
ref_audio = gr.Audio( |
|
label="上传参考音频", |
|
type="filepath", |
|
interactive=True |
|
) |
|
ref_audio_transcribe = gr.Textbox(label="参考音频对应文本", visible=True) |
|
|
|
|
|
|
|
current_config = gr.State({ |
|
"llm_model": DEFAULT_MODEL_NAME, |
|
"tts_model": DEFAULT_TTS_MODEL_NAME, |
|
"tts_on": False, |
|
"kb_on": False, |
|
"net_on": False, |
|
"ref_audio": None, |
|
"ref_audio_transcribe": None, |
|
"mode_selected": DEFAULT_MODE, |
|
"character_description": cosplayer.get_core_setting(), |
|
"character_setting_mode": DEFAULT_C_SETTING_MODE, |
|
"current_knowledge_base": AVALIABLE_KNOWLEDGE_BASE[0] |
|
}) |
|
|
|
|
|
login_button.click(log_in, inputs=[uid_input, section_state], outputs=[response, section_state]) |
|
|
|
gr.on(triggers=[llm_select.change, tts_model.change, ref_audio.change, |
|
ref_audio_transcribe.change, tts_switch.select, rag_switch.select, net_switch.select, |
|
mode_select.change], |
|
fn=lambda model1, model2, audio, text, tts_on, kb_on, net_on, mode, character_setting, kb_select: {"llm_model": model1, "tts_model": model2, "ref_audio": audio, |
|
"ref_audio_transcribe": text, "tts_on": tts_on, "kb_on": kb_on, 'net_on': net_on, |
|
"mode_selected": mode, "character_description": None if mode == '普通模式' else cosplayer.get_core_setting(), |
|
"character_setting_mode": character_setting, "current_knowledge_base": kb_select}, |
|
inputs=[llm_select, tts_model, ref_audio, ref_audio_transcribe, tts_switch, rag_switch, net_switch, mode_select, coser_setting, kb_select], |
|
outputs=current_config |
|
) |
|
msg.submit( |
|
predict, |
|
[msg, chatbot, thinking_display, current_config, section_state], |
|
[msg, chatbot, thinking_display, audio_player], |
|
queue=False |
|
) |
|
chatbot.retry(fn=handle_retry, |
|
inputs=[chatbot, thinking_display, current_config, section_state], |
|
outputs=[msg, chatbot, thinking_display, audio_player]) |
|
|
|
submit_btn.click( |
|
predict, |
|
[msg, chatbot, thinking_display, current_config, section_state], |
|
[msg, chatbot, thinking_display, audio_player], |
|
queue=False |
|
) |
|
|
|
def save_chat(state): |
|
from datetime import datetime |
|
now = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
with open(state['user_dir'] / f'chat_history_{now}.json', 'w', encoding='utf-8') as file: |
|
json.dump(state["chat_history"], file, ensure_ascii=False, indent=4) |
|
with open(state['user_dir'] / f'thinking_history_{now}.txt', 'w') as file: |
|
if isinstance(state["thinking_history"], list): |
|
for item in state["thinking_history"]: |
|
file.write(item + '\n') |
|
else: |
|
file.write(state["thinking_history"]) |
|
|
|
gr.Info("聊天记录已保存!") |
|
state['available_history'].append(f'chat_history_{now}') |
|
return state |
|
|
|
def clear_chat(state): |
|
state["chat_history"] = [] |
|
state["thinking_history"] = [] |
|
prologue = cosplayer.get_prologue() |
|
if prologue: |
|
state['chat_history'].append({'role': 'assistant', 'content': prologue}) |
|
chatbot = [{'role': 'assistant', 'content': prologue}] |
|
else: |
|
chatbot = [] |
|
return chatbot, [], state |
|
|
|
def load_chat(state, chat_file): |
|
|
|
if chat_file: |
|
think_file = chat_file.replace("chat_", "thinking_") |
|
chat_file_path = state['user_dir'] / (chat_file + '.json') |
|
think_file_path = state['user_dir'] / (think_file + '.txt') |
|
|
|
if not chat_file_path.exists(): |
|
gr.Warning(f'聊天记录文件:{chat_file}.json不存在, 加载失败') |
|
return [], '', state |
|
|
|
with open(chat_file_path, 'r', encoding='utf-8') as f: |
|
content = json.load(f) |
|
state['chat_history'] = content |
|
|
|
think = '' |
|
if think_file_path.exists(): |
|
with open(think_file_path, 'r') as f: |
|
think = f.read() |
|
|
|
state['thinking_history'] = think |
|
|
|
|
|
|
|
|
|
bot_content = content |
|
return bot_content, think, state |
|
|
|
return [], '', state |
|
|
|
|
|
def update_history(state): |
|
return gr.update(choices=state['available_history']) |
|
|
|
def update_visible(mode): |
|
if mode != '普通模式': |
|
gr.Warning("当前为角色扮演模式,请确认已配置好该角色的知识库...") |
|
|
|
return gr.update(visible=True), gr.update(visible=True) |
|
return gr.update(visible=False), gr.update(visible=False) |
|
|
|
def update_cosplay(cos_select, config, chatbot, think_display, state): |
|
cosplayer.update(cos_select) |
|
config['character_description'] = cosplayer.get_core_setting() |
|
|
|
if len(state['chat_history']) > 1: |
|
state = save_chat(state) |
|
gr.Warning("我的角色已更换,对话已重置。请检查知识库是否需要更新...") |
|
chatbot, think_display, state = clear_chat(state) |
|
return gr.update(value=cos_select), config, chatbot, think_display, state |
|
|
|
def update_character_setting_mode(coser_setting, config): |
|
config['character_setting_mode'] = coser_setting |
|
return gr.update(value=coser_setting), config |
|
|
|
def update_knowledge_base(knowledge_base, config): |
|
global local_rag |
|
config['current_knowledge_base'] = knowledge_base |
|
if len(knowledge_base) == 0: |
|
gr.Warning("当前未选中任何知识库,本地RAG将失效。请确认...") |
|
else: |
|
if local_rag is None: |
|
gr.Info("初次加载知识库,请稍候...") |
|
local_rag = LocalRAG(rag_top_k=RAG_TOP_K, doc_dir=knowledge_base) |
|
gr.Info("知识库加载完成!") |
|
else: |
|
gr.Info("重新加载知识库,请稍候...") |
|
local_rag.reload_knowledge_base(knowledge_base) |
|
gr.Info("知识库加载完成!") |
|
return gr.update(value=knowledge_base), config |
|
|
|
def init_kb(rag_on, kb_select, config): |
|
global local_rag |
|
if rag_on: |
|
|
|
if config['mode_selected'] == "角色扮演": |
|
gr.Warning("当前为角色扮演模式,请确认已配置好该角色的知识库...") |
|
|
|
if local_rag is None: |
|
gr.Info("初次加载知识库,请稍候...") |
|
local_rag = LocalRAG(rag_top_k=RAG_TOP_K, doc_dir=kb_select) |
|
gr.Info("知识库加载完成!") |
|
return gr.update(value=rag_on) |
|
|
|
|
|
mode_select.change(update_visible, |
|
inputs=mode_select, |
|
outputs=[coser_select, coser_setting]) |
|
|
|
coser_select.change(update_cosplay, |
|
inputs=[coser_select, current_config, chatbot, thinking_display, section_state], |
|
outputs=[coser_select, current_config, chatbot, thinking_display, section_state]) |
|
|
|
|
|
|
|
|
|
|
|
coser_setting.change(update_character_setting_mode, |
|
inputs=[coser_setting, current_config], |
|
outputs=[coser_setting, current_config]) |
|
|
|
kb_select.change(update_knowledge_base, |
|
inputs=[kb_select, current_config], |
|
outputs=[kb_select, current_config]) |
|
|
|
|
|
rag_switch.select(init_kb, inputs=[rag_switch, kb_select, current_config], outputs=rag_switch) |
|
|
|
clear_btn.click( |
|
clear_chat, |
|
inputs=section_state, |
|
outputs=[chatbot, thinking_display, section_state], |
|
queue=False |
|
) |
|
|
|
save_btn.click( |
|
save_chat, |
|
inputs=section_state, |
|
outputs=section_state, |
|
queue=False |
|
).then( |
|
fn=update_history, |
|
inputs=section_state, |
|
outputs=chat_history_select |
|
) |
|
|
|
chat_history_select.change(load_chat, |
|
inputs=[section_state, chat_history_select], |
|
outputs=[chatbot, thinking_display, section_state]) |
|
|
|
section_state.change(update_history, |
|
inputs=section_state, |
|
outputs=chat_history_select) |
|
print("===== 初始化完成 =====") |
|
demo.launch(share=True) |
|
|