|
import gradio as gr |
|
from infer.worldmodel import Worldinfer |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
import re |
|
import html |
|
|
|
llm_path = hf_hub_download(repo_id="WorldRWKV/RWKV7-0.4B-G1-SigLIP2-ColdStart",filename="rwkv-0.pth",local_dir="./model_weights/") |
|
encoder_path = snapshot_download(repo_id="google/siglip2-base-patch16-384") |
|
|
|
|
|
encoder_type = 'siglip' |
|
|
|
|
|
current_image = None |
|
current_state = None |
|
first_question = False |
|
|
|
model = Worldinfer(model_path=llm_path, encoder_type=encoder_type, encoder_path=encoder_path) |
|
|
|
|
|
def chat_fn(user_input, chat_history, image=None): |
|
global current_image, current_state, first_question |
|
|
|
|
|
if image is not None: |
|
current_image = image |
|
|
|
|
|
if current_image is None: |
|
bot_response = "请先上传一张图片!" |
|
chat_history.append((user_input, bot_response)) |
|
return "", chat_history |
|
|
|
|
|
if not isinstance(current_image, Image.Image) and current_image != 'none': |
|
current_image = Image.fromarray(current_image) |
|
|
|
|
|
prompt = f'\x16User: {user_input}\x17Assistant:' |
|
|
|
|
|
try: |
|
if first_question: |
|
result, state = model.generate(prompt, current_image, state=None) |
|
else: |
|
result, state = model.generate(prompt, 'none', state=current_state) |
|
|
|
first_question = False |
|
bot_response, current_state = result, state |
|
|
|
|
|
think_pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL) |
|
think_matches = think_pattern.findall(bot_response) |
|
|
|
|
|
answer_pattern = re.compile(r'<answer>(.*?)</answer>', re.DOTALL) |
|
answer_matches = answer_pattern.findall(bot_response) |
|
|
|
|
|
final_response = "" |
|
for match in think_matches: |
|
final_response += f"<details><summary>Think 🤔 </summary>{html.escape(match)}</details>" |
|
|
|
for match in answer_matches: |
|
final_response += "Answer 💡" |
|
final_response += "\n" |
|
final_response += html.escape(match) |
|
|
|
|
|
bot_response = final_response |
|
|
|
except Exception as e: |
|
bot_response = f"生成回复时出错: {str(e)}" |
|
current_state = None |
|
|
|
|
|
chat_history.append((user_input, bot_response)) |
|
|
|
|
|
return "", chat_history |
|
|
|
def update_image(image): |
|
global current_image, current_state,first_question |
|
current_image = image |
|
current_state = None |
|
first_question = True |
|
|
|
return "图片已上传成功!可以开始提问了。" |
|
|
|
|
|
def clear_image(): |
|
global current_image, current_state |
|
current_image = None |
|
current_state = None |
|
|
|
return None, "图片已清除,请上传新图片。" |
|
|
|
|
|
def clear_all(): |
|
global current_image, current_state |
|
current_image = None |
|
current_state = None |
|
return [], "", "图片和对话已清空,请重新上传图片。" |
|
|
|
|
|
def chat_without_image_update(user_input, chat_history): |
|
return chat_fn(user_input, chat_history) |
|
|
|
|
|
with gr.Blocks(title="WORLD RWKV", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# WORLD RWKV") |
|
gr.Markdown("上传一张图片,然后可以进行多轮提问") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=2): |
|
image_input = gr.Image( |
|
type="pil", |
|
label="上传图片", |
|
height=400 |
|
) |
|
|
|
|
|
with gr.Row(): |
|
image_status = gr.Textbox( |
|
label="图片状态", |
|
value="请上传图片", |
|
interactive=False |
|
) |
|
clear_img_btn = gr.Button("删除图片") |
|
|
|
|
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot( |
|
label="对话记录", |
|
bubble_full_width=False, |
|
height=500 |
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
|
user_input = gr.Textbox( |
|
placeholder="请输入问题...", |
|
scale=7, |
|
container=False, |
|
label="问题输入" |
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
submit_btn = gr.Button("发送", variant="primary") |
|
clear_btn = gr.Button("清空所有") |
|
|
|
|
|
|
|
image_input.change( |
|
fn=update_image, |
|
inputs=[image_input], |
|
outputs=[image_status] |
|
) |
|
|
|
|
|
clear_img_btn.click( |
|
fn=lambda: (None, "图片已清除,请上传新图片。"), |
|
inputs=None, |
|
outputs=[image_input, image_status] |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=chat_fn, |
|
inputs=[user_input, chatbot, image_input], |
|
outputs=[user_input, chatbot] |
|
) |
|
|
|
|
|
user_input.submit( |
|
fn=chat_without_image_update, |
|
inputs=[user_input, chatbot], |
|
outputs=[user_input, chatbot] |
|
) |
|
|
|
|
|
clear_btn.click( |
|
fn=lambda: ([], "", "图片和对话已清空,请重新上传图片。", None), |
|
inputs=None, |
|
outputs=[chatbot, user_input, image_status, image_input], |
|
queue=False |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |