|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
|
|
|
|
model_id = "Qwen/Qwen1.5-VL-Chat" |
|
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
|
model = AutoModelForVision2Seq.from_pretrained( |
|
model_id, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
**{"disable_exllama": True} |
|
).eval() |
|
|
|
|
|
def chat(image, question): |
|
if image is None or question.strip() == "": |
|
return "請上傳圖片並輸入問題。" |
|
inputs = processor(text=question, images=image, return_tensors="pt").to(model.device) |
|
outputs = model.generate(**inputs, max_new_tokens=512) |
|
answer = processor.batch_decode(outputs, skip_special_tokens=True)[0] |
|
return answer.strip() |
|
|
|
|
|
with gr.Blocks(title="Qwen1.5-VL 圖文問答 Demo") as demo: |
|
gr.Markdown("## 🧠 Qwen1.5-VL 圖文問答 Demo") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type="pil", label="📷 請上傳圖片") |
|
question_input = gr.Textbox(label="請輸入問題", placeholder="例如:這是什麼地方?") |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
clear_btn = gr.Button("Clear") |
|
with gr.Column(): |
|
answer_output = gr.Textbox(label="💬 答案", lines=8) |
|
|
|
submit_btn.click(fn=chat, inputs=[image_input, question_input], outputs=answer_output) |
|
clear_btn.click(lambda: ("", "", ""), outputs=[image_input, question_input, answer_output]) |
|
|
|
|
|
demo.launch(share=True) |