File size: 6,624 Bytes
9929ce6 2474e52 9929ce6 2474e52 9929ce6 2474e52 9929ce6 ec90bb6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import gradio as gr
from infer.worldmodel import Worldinfer
from PIL import Image
from huggingface_hub import hf_hub_download, snapshot_download
import re
import html
# 初始化模型
llm_path = hf_hub_download(repo_id="WorldRWKV/RWKV7-0.4B-G1-SigLIP2-ColdStart",filename="rwkv-0.pth",local_dir="./model_weights/")
encoder_path = snapshot_download(repo_id="google/siglip2-base-patch16-384")
# llm_path = "/mnt/B8E84E9EE84E5B30/rwkv-models/world_rwkv/world_weights/rwkv-0"
# encoder_path = "/mnt/B8E84E9EE84E5B30/rwkv-models/world_rwkv/siglip2-base-patch16-384/"
encoder_type = 'siglip'
# 全局变量存储当前上传的图片和模型状态
current_image = None
current_state = None
first_question = False # 存储模型状态
# 初始化模型
model = Worldinfer(model_path=llm_path, encoder_type=encoder_type, encoder_path=encoder_path)
# 处理用户输入的核心逻辑
def chat_fn(user_input, chat_history, image=None):
global current_image, current_state, first_question
# 如果上传了新图片,更新当前图片并重置状态
if image is not None:
current_image = image
# 如果没有图片,提示用户上传
if current_image is None:
bot_response = "请先上传一张图片!"
chat_history.append((user_input, bot_response))
return "", chat_history
# 确保图片是PIL Image格式
if not isinstance(current_image, Image.Image) and current_image != 'none':
current_image = Image.fromarray(current_image)
# 构造提示文本
prompt = f'\x16User: {user_input}\x17Assistant:'
# 生成结果,传入当前状态
try:
if first_question:
result, state = model.generate(prompt, current_image, state=None)
else:
result, state = model.generate(prompt, 'none', state=current_state)
first_question = False
bot_response, current_state = result, state
# 解析<think>和</think>标签
think_pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL)
think_matches = think_pattern.findall(bot_response)
# 解析<answer></answer>标签
answer_pattern = re.compile(r'<answer>(.*?)</answer>', re.DOTALL)
answer_matches = answer_pattern.findall(bot_response)
# 构造最终的输出
final_response = ""
for match in think_matches:
final_response += f"<details><summary>Think 🤔 </summary>{html.escape(match)}</details>"
for match in answer_matches:
final_response += "Answer 💡"
final_response += "\n"
final_response += html.escape(match)
# 转义HTML标签
bot_response = final_response
except Exception as e:
bot_response = f"生成回复时出错: {str(e)}"
current_state = None # 出错时重置状态
# 更新对话历史
chat_history.append((user_input, bot_response))
# 返回更新后的组件状态
return "", chat_history # 清空输入框,更新聊天记录
# 处理图片上传
def update_image(image):
global current_image, current_state,first_question
current_image = image
current_state = None
first_question = True
# print('1111111111111111111',first_question) # 上传新图片时重置状态
return "图片已上传成功!可以开始提问了。"
# 清空图片
def clear_image():
global current_image, current_state
current_image = None
current_state = None # 清空图片时重置状态
# 返回None给image组件,文本给status组件
return None, "图片已清除,请上传新图片。"
# 清空历史和图片
def clear_all():
global current_image, current_state
current_image = None
current_state = None # 清空所有时重置状态
return [], "", "图片和对话已清空,请重新上传图片。"
# 不使用图片输入的聊天函数
def chat_without_image_update(user_input, chat_history):
return chat_fn(user_input, chat_history)
# 界面布局组件
with gr.Blocks(title="WORLD RWKV", theme=gr.themes.Soft()) as demo:
gr.Markdown("# WORLD RWKV")
gr.Markdown("上传一张图片,然后可以进行多轮提问")
with gr.Row():
# 左侧图片上传区
with gr.Column(scale=2):
image_input = gr.Image(
type="pil",
label="上传图片",
height=400
)
# 图片状态和操作
with gr.Row():
image_status = gr.Textbox(
label="图片状态",
value="请上传图片",
interactive=False
)
clear_img_btn = gr.Button("删除图片")
# 右侧对话区
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="对话记录",
bubble_full_width=False,
height=500
)
# 控制区域
with gr.Row():
# 输入组件
user_input = gr.Textbox(
placeholder="请输入问题...",
scale=7,
container=False,
label="问题输入"
)
# 操作按钮
with gr.Column(scale=1):
submit_btn = gr.Button("发送", variant="primary")
clear_btn = gr.Button("清空所有")
# 事件绑定
# 图片上传事件
image_input.change(
fn=update_image,
inputs=[image_input],
outputs=[image_status]
)
# 删除图片按钮事件 - 修复输出顺序,确保类型匹配
clear_img_btn.click(
fn=lambda: (None, "图片已清除,请上传新图片。"), # 使用lambda直接返回正确类型
inputs=None,
outputs=[image_input, image_status]
)
# 发送按钮事件
submit_btn.click(
fn=chat_fn,
inputs=[user_input, chatbot, image_input],
outputs=[user_input, chatbot]
)
# 输入框回车事件 - 使用不需要图片参数的函数
user_input.submit(
fn=chat_without_image_update,
inputs=[user_input, chatbot],
outputs=[user_input, chatbot]
)
# 清空按钮事件
clear_btn.click(
fn=lambda: ([], "", "图片和对话已清空,请重新上传图片。", None), # 修复返回值
inputs=None,
outputs=[chatbot, user_input, image_status, image_input],
queue=False
)
# 启动应用
if __name__ == "__main__":
demo.launch() |