File size: 6,624 Bytes
9929ce6
 
 
 
 
2474e52
9929ce6
2474e52
9929ce6
2474e52
 
9929ce6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec90bb6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import gradio as gr
from infer.worldmodel import Worldinfer
from PIL import Image
from huggingface_hub import hf_hub_download, snapshot_download
import re
import html 
# 初始化模型
llm_path = hf_hub_download(repo_id="WorldRWKV/RWKV7-0.4B-G1-SigLIP2-ColdStart",filename="rwkv-0.pth",local_dir="./model_weights/")
encoder_path = snapshot_download(repo_id="google/siglip2-base-patch16-384")
# llm_path = "/mnt/B8E84E9EE84E5B30/rwkv-models/world_rwkv/world_weights/rwkv-0"
# encoder_path = "/mnt/B8E84E9EE84E5B30/rwkv-models/world_rwkv/siglip2-base-patch16-384/"
encoder_type = 'siglip'

# 全局变量存储当前上传的图片和模型状态
current_image = None
current_state = None 
first_question = False # 存储模型状态
# 初始化模型
model = Worldinfer(model_path=llm_path, encoder_type=encoder_type, encoder_path=encoder_path)

# 处理用户输入的核心逻辑
def chat_fn(user_input, chat_history, image=None):
    global current_image, current_state, first_question
    
    # 如果上传了新图片,更新当前图片并重置状态
    if image is not None:
        current_image = image
    
    # 如果没有图片,提示用户上传
    if current_image is None:
        bot_response = "请先上传一张图片!"
        chat_history.append((user_input, bot_response))
        return "", chat_history
    
    # 确保图片是PIL Image格式
    if not isinstance(current_image, Image.Image) and current_image != 'none':
        current_image = Image.fromarray(current_image)
    
    # 构造提示文本
    prompt = f'\x16User: {user_input}\x17Assistant:'
    
    # 生成结果,传入当前状态
    try:
        if first_question:
            result, state = model.generate(prompt, current_image, state=None)
        else:
            result, state = model.generate(prompt, 'none', state=current_state)
        
        first_question = False
        bot_response, current_state = result, state
        
        # 解析<think>和</think>标签
        think_pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL)
        think_matches = think_pattern.findall(bot_response)
        
        # 解析<answer></answer>标签
        answer_pattern = re.compile(r'<answer>(.*?)</answer>', re.DOTALL)
        answer_matches = answer_pattern.findall(bot_response)
        
        # 构造最终的输出
        final_response = ""
        for match in think_matches:
            final_response += f"<details><summary>Think 🤔 </summary>{html.escape(match)}</details>"
        
        for match in answer_matches:
            final_response += "Answer 💡"
            final_response += "\n"
            final_response += html.escape(match)
        
        # 转义HTML标签
        bot_response = final_response
        
    except Exception as e:
        bot_response = f"生成回复时出错: {str(e)}"
        current_state = None  # 出错时重置状态
    
    # 更新对话历史
    chat_history.append((user_input, bot_response))
    
    # 返回更新后的组件状态
    return "", chat_history  # 清空输入框,更新聊天记录
# 处理图片上传
def update_image(image):
    global current_image, current_state,first_question
    current_image = image
    current_state = None 
    first_question = True
    # print('1111111111111111111',first_question) # 上传新图片时重置状态
    return "图片已上传成功!可以开始提问了。"

# 清空图片
def clear_image():
    global current_image, current_state
    current_image = None
    current_state = None  # 清空图片时重置状态
    # 返回None给image组件,文本给status组件
    return None, "图片已清除,请上传新图片。"

# 清空历史和图片
def clear_all():
    global current_image, current_state
    current_image = None
    current_state = None  # 清空所有时重置状态
    return [], "", "图片和对话已清空,请重新上传图片。"

# 不使用图片输入的聊天函数
def chat_without_image_update(user_input, chat_history):
    return chat_fn(user_input, chat_history)

# 界面布局组件
with gr.Blocks(title="WORLD RWKV", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# WORLD RWKV")
    gr.Markdown("上传一张图片,然后可以进行多轮提问")
    
    with gr.Row():
        # 左侧图片上传区
        with gr.Column(scale=2):
            image_input = gr.Image(
                type="pil", 
                label="上传图片",
                height=400
            )
            
            # 图片状态和操作
            with gr.Row():
                image_status = gr.Textbox(
                    label="图片状态", 
                    value="请上传图片", 
                    interactive=False
                )
                clear_img_btn = gr.Button("删除图片")
        
        # 右侧对话区
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(
                label="对话记录",
                bubble_full_width=False,
                height=500
            )
    
    # 控制区域
    with gr.Row():
        # 输入组件
        user_input = gr.Textbox(
            placeholder="请输入问题...",
            scale=7,
            container=False,
            label="问题输入"
        )
        
        # 操作按钮
        with gr.Column(scale=1):
            submit_btn = gr.Button("发送", variant="primary")
            clear_btn = gr.Button("清空所有")

    # 事件绑定
    # 图片上传事件
    image_input.change(
        fn=update_image,
        inputs=[image_input],
        outputs=[image_status]
    )
    
    # 删除图片按钮事件 - 修复输出顺序,确保类型匹配
    clear_img_btn.click(
        fn=lambda: (None, "图片已清除,请上传新图片。"),  # 使用lambda直接返回正确类型
        inputs=None,
        outputs=[image_input, image_status]
    )
    
    # 发送按钮事件
    submit_btn.click(
        fn=chat_fn,
        inputs=[user_input, chatbot, image_input],
        outputs=[user_input, chatbot]
    )
    
    # 输入框回车事件 - 使用不需要图片参数的函数
    user_input.submit(
        fn=chat_without_image_update,
        inputs=[user_input, chatbot],
        outputs=[user_input, chatbot]
    )
    
    # 清空按钮事件
    clear_btn.click(
        fn=lambda: ([], "", "图片和对话已清空,请重新上传图片。", None),  # 修复返回值
        inputs=None,
        outputs=[chatbot, user_input, image_status, image_input],
        queue=False
    )

# 启动应用
if __name__ == "__main__":
    demo.launch()