File size: 8,446 Bytes
ed275c9
5d63d59
ed275c9
5d63d59
 
fc95e60
6401487
 
92e002a
6401487
 
 
 
 
92e002a
ed275c9
92e002a
 
 
d2b791d
92e002a
 
 
 
7342b9f
 
 
92e002a
 
7342b9f
 
 
c8cd2f3
 
 
7342b9f
 
 
 
6401487
92e002a
 
 
 
6401487
 
 
 
92e002a
 
 
 
6401487
 
 
 
 
 
 
 
 
 
 
 
92e002a
 
d2b791d
3f6a788
 
 
91cda81
 
 
ed275c9
92e002a
 
 
 
 
 
 
 
 
6401487
ed275c9
9522057
3f6a788
 
92e002a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64f9a07
92e002a
 
 
 
 
 
 
 
 
 
64f9a07
239e8eb
 
92e002a
64f9a07
 
 
92e002a
 
 
 
 
64f9a07
92e002a
 
 
 
 
 
 
d2b791d
92e002a
 
 
 
 
 
d2b791d
92e002a
 
fc95e60
 
 
 
 
 
3f6a788
fc95e60
92e002a
5d63d59
fc95e60
3f6a788
5d63d59
 
3f6a788
 
 
 
 
 
 
92e002a
3f6a788
92e002a
fc95e60
5633a75
fe53594
ed275c9
3f6a788
ed275c9
3f6a788
ed275c9
 
7342b9f
ed275c9
 
0de5083
5d63d59
ed275c9
 
92e002a
5d63d59
92e002a
5afd124
1ed1b2f
64f9a07
9522057
91cda81
 
9522057
c2fa869
91cda81
7342b9f
 
6401487
7342b9f
d2b791d
7342b9f
91cda81
 
 
 
 
fc95e60
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import gradio as gr
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces
import cv2
import numpy as np
from PIL import Image
from transformers import (
    Qwen2VLForConditionalGeneration,
    AutoProcessor,
    TextIteratorStreamer,
)
from transformers import Qwen2_5_VLForConditionalGeneration

# ---------------------------
# Helper Functions
# ---------------------------
def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
    """
    Returns an HTML snippet for a thin animated progress bar with a label.
    Colors can be customized; default colors are used for Qwen2VL/Aya‑Vision.
    """
    return f'''
<div style="display: flex; align-items: center;">
    <span style="margin-right: 10px; font-size: 14px;">{label}</span>
    <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
        <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
    </div>
</div>
<style>
@keyframes loading {{
    0% {{ transform: translateX(-100%); }}
    100% {{ transform: translateX(100%); }}
}}
</style>
    '''

def downsample_video(video_path):
    """
    Downsamples a video file by extracting 10 evenly spaced frames.
    Returns a list of tuples (PIL.Image, timestamp).
    """
    vidcap = cv2.VideoCapture(video_path)
    total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    frames = []
    if total_frames <= 0 or fps <= 0:
        vidcap.release()
        return frames
    # Determine 10 evenly spaced frame indices.
    frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
    for i in frame_indices:
        vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
        success, image = vidcap.read()
        if success:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(image)
            timestamp = round(i / fps, 2)
            frames.append((pil_image, timestamp))
    vidcap.release()
    return frames

# Model and Processor Setup
# Qwen2VL OCR (default branch)
QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"  # [or] prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
    QV_MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.float16
).to("cuda").eval()

# RolmOCR branch (@RolmOCR)
ROLMOCR_MODEL_ID = "reducto/RolmOCR" 
rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    ROLMOCR_MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda").eval()

# Main Inference Function
@spaces.GPU
def model_inference(input_dict, history):
    text = input_dict["text"].strip()
    files = input_dict.get("files", [])

    # RolmOCR Inference (@RolmOCR)
    if text.lower().startswith("@rolmocr"):
        # Remove the tag from the query.
        text_prompt = text[len("@rolmocr"):].strip()
        # Check if a video is provided for inference.
        if files and isinstance(files[0], str) and files[0].lower().endswith((".mp4", ".avi", ".mov")):
            video_path = files[0]
            frames = downsample_video(video_path)
            if not frames:
                yield "Error: Could not extract frames from the video."
                return
            # Build the message: prompt followed by each frame with its timestamp.
            content_list = [{"type": "text", "text": text_prompt}]
            for image, timestamp in frames:
                content_list.append({"type": "text", "text": f"Frame {timestamp}:"})
                content_list.append({"type": "image", "image": image})
            messages = [{"role": "user", "content": content_list}]
            # For video, extract images only.
            video_images = [image for image, _ in frames]
            prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = rolmocr_processor(
                text=[prompt_full],
                images=video_images,
                return_tensors="pt",
                padding=True,
            ).to("cuda")
        else:
            # Assume image(s) or text query.
            if len(files) > 1:
                images = [load_image(image) for image in files]
            elif len(files) == 1:
                images = [load_image(files[0])]
            else:
                images = []
            if text_prompt == "" and not images:
                yield "Error: Please input a text query and/or provide an image for the @RolmOCR feature."
                return
            messages = [{
                "role": "user",
                "content": [
                    *[{"type": "image", "image": image} for image in images],
                    {"type": "text", "text": text_prompt},
                ],
            }]
            prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = rolmocr_processor(
                text=[prompt_full],
                images=images if images else None,
                return_tensors="pt",
                padding=True,
            ).to("cuda")
        streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
        thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
        thread.start()
        buffer = ""
        # Use a different color scheme for RolmOCR (purple-themed).
        yield progress_bar_html("Processing with Qwen2.5VL (RolmOCR)")
        for new_text in streamer:
            buffer += new_text
            buffer = buffer.replace("<|im_end|>", "")
            time.sleep(0.01)
            yield buffer
        return

    # Default Inference: Qwen2VL OCR
    # Process files: support multiple images.
    if len(files) > 1:
        images = [load_image(image) for image in files]
    elif len(files) == 1:
        images = [load_image(files[0])]
    else:
        images = []
    
    if text == "" and not images:
        yield "Error: Please input a text query and optionally image(s)."
        return
    if text == "" and images:
        yield "Error: Please input a text query along with the image(s)."
        return

    messages = [{
        "role": "user",
        "content": [
            *[{"type": "image", "image": image} for image in images],
            {"type": "text", "text": text},
        ],
    }]
    prompt_full = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = qwen_processor(
        text=[prompt_full],
        images=images if images else None,
        return_tensors="pt",
        padding=True,
    ).to("cuda")
    streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
    thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
    thread.start()
    buffer = ""
    yield progress_bar_html("Processing with Qwen2VL OCR")
    for new_text in streamer:
        buffer += new_text
        buffer = buffer.replace("<|im_end|>", "")
        time.sleep(0.01)
        yield buffer

# Gradio Interface
examples = [
    [{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
    [{"text": "@RolmOCR Explain the Ad in Detail", "files": ["examples/videoplayback.mp4"]}],
    [{"text": "@RolmOCR OCR the Image", "files": ["rolm/3.jpeg"]}],
    [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
]

demo = gr.ChatInterface(
    fn=model_inference,
    description="# **Multimodal OCR `@RolmOCR and Default Qwen2VL OCR`**",
    examples=examples,
    textbox=gr.MultimodalTextbox(
        label="Query Input", 
        file_types=["image", "video"], 
        file_count="multiple", 
        placeholder="Use tag @RolmOCR for RolmOCR, or leave blank for default Qwen2VL OCR"
    ),
    stop_btn="Stop Generation",
    multimodal=True,
    cache_examples=False,
)

demo.launch(debug=True)