File size: 5,226 Bytes
f17ef4c
2e3ddd8
f17ef4c
 
 
d317ae4
f0c7145
2e3ddd8
 
f0c7145
2e3ddd8
cece0ec
2e3ddd8
 
 
d317ae4
2e3ddd8
 
 
 
 
 
 
f17ef4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e3ddd8
b3d5d95
2e3ddd8
 
 
 
 
 
 
c5c055b
2e3ddd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed6c98
 
2e3ddd8
 
 
 
 
b3d5d95
 
 
 
2e3ddd8
 
 
b3d5d95
 
 
 
2e3ddd8
 
 
 
 
 
 
 
 
 
 
 
f17ef4c
 
 
 
 
 
 
 
2e3ddd8
 
 
 
 
 
 
 
b3d5d95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd47744
b3d5d95
 
 
 
 
2e3ddd8
 
 
 
 
b3d5d95
2e3ddd8
 
 
 
 
 
 
 
 
b3d5d95
 
 
 
 
 
 
 
 
 
2e3ddd8
 
 
 
 
 
f0c7145
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import contextlib
import subprocess
import time
from typing import Iterator, Callable

import gradio as gr
import spaces
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)

# Mapping user-friendly names to HF model IDs
MODEL_NAMES = {
    "Qwen2.5-VL-7B-Instruct": "Qwen/Qwen2.5-VL-7B-Instruct",
    "Qwen2.5-VL-3B-Instruct": "Qwen/Qwen2.5-VL-3B-Instruct",
}


@contextlib.contextmanager
def measure_time() -> Iterator[Callable[[], float]]:
    """
    A context manager for measuring execution time (in seconds) within its code block.

    usage:
        with code_timer() as timer:
            # Code snippet to be timed
        print(f"Code took: {timer()} seconds")
    """
    start_time = end_time = time.perf_counter()
    yield lambda: end_time - start_time
    end_time = time.perf_counter()


@spaces.GPU(duration=300)
def run_inference(model_key, input_type, text, image, video, fps, system_prompt, add_vision_id):
    """
    Load the selected Qwen2.5-VL model and run inference on text, image, or video.
    """
    model_id = MODEL_NAMES[model_key]
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype="auto",
        device_map="auto",
    )
    processor = AutoProcessor.from_pretrained(model_id)

    # Text-only inference
    if input_type == "text":
        inputs = processor(
            text=text,
            return_tensors="pt",
            padding=True
        )
        inputs = inputs.to(model.device)
        outputs = model.generate(**inputs, max_new_tokens=512)
        return processor.batch_decode(outputs, skip_special_tokens=True)[0]

    # Multimodal inference (image or video)
    content = []
    if input_type == "image" and image:
        for img_path in image:
            content.append({"type": "image", "image": img_path})
    elif input_type == "video" and video:
        # Ensure file URI for local files
        video_src = video if str(video).startswith("file://") else f"file://{video}"
        content.append({"type": "video", "video": video_src, "fps": fps})
    content.append({"type": "text", "text": text or ""})
    msg = [
        {"role": "system", "content": system_prompt},
        {"role": "user",   "content": content}
    ]

    # Prepare inputs for model with video kwargs
    text_prompt = processor.apply_chat_template(
        msg,
        tokenize=False,
        add_generation_prompt=True,
        add_vision_id=add_vision_id
    )
    image_inputs, video_inputs, video_kwargs = process_vision_info(msg, return_video_kwargs=True)
    inputs = processor(
        text=[text_prompt],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
        **video_kwargs
    )
    inputs = inputs.to(model.device)

    with measure_time() as timer:
        gen_ids = model.generate(**inputs, max_new_tokens=512)
        # Trim the prompt tokens
        trimmed = [out_ids[len(inp_ids):] for inp_ids, out_ids in zip(inputs.input_ids, gen_ids)]
        result = processor.batch_decode(trimmed, skip_special_tokens=True)[0]

    gr.Info(f"Finished in {timer():.2f}s", title="Success", duration=5)  # green-style info toast :contentReference[oaicite:0]{index=0}
    return result


# Build Gradio interface
demo = gr.Blocks()
with demo:
    gr.Markdown("# Qwen2.5-VL Multimodal Demo")
    model_select = gr.Dropdown(list(MODEL_NAMES.keys()), label="Select Model")
    input_type = gr.Radio(["text", "image", "video"], label="Input Type")
    system_prompt_input = gr.Textbox(
        lines=2,
        placeholder="System prompt…",
        value="You are a helpful assistant.",
        label="System Prompt"
   )
    vision_id_checkbox = gr.Checkbox(
        label="Add vision ID",
        value=False
   )
    text_input = gr.Textbox(
        lines=3,
        placeholder="Enter text ...",
        visible=True
    )
    image_input = gr.File(
        file_count="multiple",
        file_types=["image"],
        label="Upload Images",
        visible=False
    )
    video_input = gr.Video(visible=False)
    fps_input = gr.Number(
      value=2.0,
      label="FPS",
      visible=False
    )
    output = gr.Textbox(label="Output")

    # Show/hide inputs based on selection
    def update_inputs(choice):
        return (
            gr.update(visible=True),
            gr.update(visible=(choice == "image")),
            gr.update(visible=(choice == "video")),
            gr.update(visible=(choice == "video"))
        )

    input_type.change(update_inputs, input_type, [text_input, image_input, video_input, fps_input])
    run_btn = gr.Button("Generate")
    run_btn.click(
        run_inference,
        [
                model_select,
                input_type,
                text_input,
                image_input,
                video_input,
                fps_input,
                system_prompt_input,
                vision_id_checkbox
            ],
        output
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()