|
import contextlib |
|
import subprocess |
|
import time |
|
from typing import Iterator, Callable |
|
|
|
import gradio as gr |
|
import spaces |
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor |
|
from qwen_vl_utils import process_vision_info |
|
|
|
subprocess.run( |
|
"pip install flash-attn --no-build-isolation", |
|
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, |
|
shell=True, |
|
) |
|
|
|
|
|
MODEL_NAMES = { |
|
"Qwen2.5-VL-7B-Instruct": "Qwen/Qwen2.5-VL-7B-Instruct", |
|
"Qwen2.5-VL-3B-Instruct": "Qwen/Qwen2.5-VL-3B-Instruct", |
|
} |
|
|
|
|
|
@contextlib.contextmanager |
|
def measure_time() -> Iterator[Callable[[], float]]: |
|
""" |
|
A context manager for measuring execution time (in seconds) within its code block. |
|
|
|
usage: |
|
with code_timer() as timer: |
|
# Code snippet to be timed |
|
print(f"Code took: {timer()} seconds") |
|
""" |
|
start_time = end_time = time.perf_counter() |
|
yield lambda: end_time - start_time |
|
end_time = time.perf_counter() |
|
|
|
|
|
@spaces.GPU(duration=300) |
|
def run_inference(model_key, input_type, text, image, video, fps, system_prompt, add_vision_id): |
|
""" |
|
Load the selected Qwen2.5-VL model and run inference on text, image, or video. |
|
""" |
|
model_id = MODEL_NAMES[model_key] |
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
model_id, |
|
torch_dtype="auto", |
|
device_map="auto", |
|
) |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
|
|
if input_type == "text": |
|
inputs = processor( |
|
text=text, |
|
return_tensors="pt", |
|
padding=True |
|
) |
|
inputs = inputs.to(model.device) |
|
outputs = model.generate(**inputs, max_new_tokens=512) |
|
return processor.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
|
|
|
content = [] |
|
if input_type == "image" and image: |
|
for img_path in image: |
|
content.append({"type": "image", "image": img_path}) |
|
elif input_type == "video" and video: |
|
|
|
video_src = video if str(video).startswith("file://") else f"file://{video}" |
|
content.append({"type": "video", "video": video_src, "fps": fps}) |
|
content.append({"type": "text", "text": text or ""}) |
|
msg = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": content} |
|
] |
|
|
|
|
|
text_prompt = processor.apply_chat_template( |
|
msg, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
add_vision_id=add_vision_id |
|
) |
|
image_inputs, video_inputs, video_kwargs = process_vision_info(msg, return_video_kwargs=True) |
|
inputs = processor( |
|
text=[text_prompt], |
|
images=image_inputs, |
|
videos=video_inputs, |
|
padding=True, |
|
return_tensors="pt", |
|
**video_kwargs |
|
) |
|
inputs = inputs.to(model.device) |
|
|
|
with measure_time() as timer: |
|
gen_ids = model.generate(**inputs, max_new_tokens=512) |
|
|
|
trimmed = [out_ids[len(inp_ids):] for inp_ids, out_ids in zip(inputs.input_ids, gen_ids)] |
|
result = processor.batch_decode(trimmed, skip_special_tokens=True)[0] |
|
|
|
gr.Info(f"Finished in {timer():.2f}s", title="Success", duration=5) |
|
return result |
|
|
|
|
|
|
|
demo = gr.Blocks() |
|
with demo: |
|
gr.Markdown("# Qwen2.5-VL Multimodal Demo") |
|
model_select = gr.Dropdown(list(MODEL_NAMES.keys()), label="Select Model") |
|
input_type = gr.Radio(["text", "image", "video"], label="Input Type") |
|
system_prompt_input = gr.Textbox( |
|
lines=2, |
|
placeholder="System prompt…", |
|
value="You are a helpful assistant.", |
|
label="System Prompt" |
|
) |
|
vision_id_checkbox = gr.Checkbox( |
|
label="Add vision ID", |
|
value=False |
|
) |
|
text_input = gr.Textbox( |
|
lines=3, |
|
placeholder="Enter text ...", |
|
visible=True |
|
) |
|
image_input = gr.File( |
|
file_count="multiple", |
|
file_types=["image"], |
|
label="Upload Images", |
|
visible=False |
|
) |
|
video_input = gr.Video(visible=False) |
|
fps_input = gr.Number( |
|
value=2.0, |
|
label="FPS", |
|
visible=False |
|
) |
|
output = gr.Textbox(label="Output") |
|
|
|
|
|
def update_inputs(choice): |
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=(choice == "image")), |
|
gr.update(visible=(choice == "video")), |
|
gr.update(visible=(choice == "video")) |
|
) |
|
|
|
input_type.change(update_inputs, input_type, [text_input, image_input, video_input, fps_input]) |
|
run_btn = gr.Button("Generate") |
|
run_btn.click( |
|
run_inference, |
|
[ |
|
model_select, |
|
input_type, |
|
text_input, |
|
image_input, |
|
video_input, |
|
fps_input, |
|
system_prompt_input, |
|
vision_id_checkbox |
|
], |
|
output |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|