File size: 13,515 Bytes
5dbe551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
import gradio as gr
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces
import cv2
import numpy as np
from PIL import Image
import re
from transformers import (
    Qwen2VLForConditionalGeneration,
    AutoProcessor,
    TextIteratorStreamer,
)
from transformers import Qwen2_5_VLForConditionalGeneration

# ---------------------------
# Helper Functions
# ---------------------------
def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
    """
    Returns an HTML snippet for a thin animated progress bar with a label.
    Colors can be customized; default colors are used for Qwen2VL/Aya‑Vision.
    """
    return f'''
<div style="display: flex; align-items: center;">
    <span style="margin-right: 10px; font-size: 14px;">{label}</span>
    <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
        <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
    </div>
</div>
<style>
@keyframes loading {{
    0% {{ transform: translateX(-100%); }}
    100% {{ transform: translateX(100%); }}
}}
</style>
    '''

def downsample_video(video_path):
    """
    Downsamples a video file by extracting 10 evenly spaced frames.
    Returns a list of tuples (PIL.Image, timestamp).
    """
    vidcap = cv2.VideoCapture(video_path)
    total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    frames = []
    if total_frames <= 0 or fps <= 0:
        vidcap.release()
        return frames
    # Determine 10 evenly spaced frame indices.
    frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
    for i in frame_indices:
        vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
        success, image = vidcap.read()
        if success:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(image)
            timestamp = round(i / fps, 2)
            frames.append((pil_image, timestamp))
    vidcap.release()
    return frames

def extract_medicine_names(text):
    """
    Extracts medicine names from OCR text output.
    Uses a combination of pattern matching and formatting to identify medications.
    Returns a formatted list of medicines found.
    """
    # Common medicine patterns (extended to catch more formats)
    lines = text.split('\n')
    medicines = []
    
    # Look for patterns typical in prescriptions
    for line in lines:
        # Clean and standardize the line
        clean_line = line.strip()
        
        # Skip very short lines, headers, or non-relevant text
        if len(clean_line) < 3 or re.search(r'(prescription|rx|patient|name|date|doctor|hospital|clinic|address)', clean_line.lower()):
            continue
            
        # Medicine names often appear at the beginning of lines, with dosage info following
        # Look for tablet/capsule/mg indicators - strong indicators of medication
        if re.search(r'(tab|tablet|cap|capsule|mg|ml|injection|syrup|solution|suspension|ointment|cream|gel|patch|suppository|inhaler|drops)', clean_line.lower()):
            # Extract the likely medicine name - the part before the dosage/form or the entire line if it's short
            medicine_match = re.split(r'(\d+\s*mg|\d+\s*ml|\d+\s*tab|\d+\s*cap)', clean_line, 1)[0].strip()
            if medicine_match and len(medicine_match) > 2:
                medicines.append(medicine_match)
        
        # Check for brand names or generic medication patterns
        elif re.match(r'^[A-Z][a-z]+\s*[A-Z0-9]', clean_line) or re.match(r'^[A-Z][a-z]+', clean_line):
            # Likely a medicine name starting with a capital letter
            medicine_parts = re.split(r'(\d+|\s+\d+\s*times|\s+\d+\s*times\s+daily)', clean_line, 1)
            if medicine_parts and len(medicine_parts[0]) > 2:
                medicines.append(medicine_parts[0].strip())
    
    # Remove duplicates while preserving order
    unique_medicines = []
    for med in medicines:
        if med not in unique_medicines:
            unique_medicines.append(med)
    
    return unique_medicines

# Model and Processor Setup
# Qwen2VL OCR (default branch)
QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"  # [or] prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
    QV_MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.float16
).to("cuda").eval()

# RolmOCR branch (@RolmOCR)
ROLMOCR_MODEL_ID = "reducto/RolmOCR" 
rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    ROLMOCR_MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda").eval()

# Main Inference Function
@spaces.GPU
def model_inference(input_dict, history):
    text = input_dict["text"].strip()
    files = input_dict.get("files", [])
    
    # Check for prescription-specific command
    if text.lower().startswith("@prescription") or text.lower().startswith("@med"):
        # Specific mode for medicine extraction
        if not files:
            yield "Error: Please upload a prescription image to extract medicine names."
            return
            
        # Use RolmOCR for better text extraction from prescriptions
        images = [load_image(image) for image in files[:1]]  # Taking just the first image for processing
        
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": images[0]},
                {"type": "text", "text": "Extract all text from this medical prescription image, focus on medicine names, dosages, and instructions."},
            ],
        }]
        
        prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = rolmocr_processor(
            text=[prompt_full],
            images=images,
            return_tensors="pt",
            padding=True,
        ).to("cuda")
        
        # First, get the complete OCR text
        streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
        thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
        thread.start()
        
        ocr_text = ""
        yield progress_bar_html("Processing Prescription with Medicine Extractor")
        
        for new_text in streamer:
            ocr_text += new_text
            ocr_text = ocr_text.replace("<|im_end|>", "")
            time.sleep(0.01)
        
        # After getting full OCR text, extract medicine names
        medicines = extract_medicine_names(ocr_text)
        
        # Format the results nicely
        result = "## Extracted Medicine Names\n\n"
        if medicines:
            for i, med in enumerate(medicines, 1):
                result += f"{i}. {med}\n"
        else:
            result += "No medicine names detected in the prescription.\n\n"
            
        result += "\n\n## Full OCR Text\n\n```\n" + ocr_text + "\n```"
        yield result
        return

    # RolmOCR Inference (@RolmOCR)
    if text.lower().startswith("@rolmocr"):
        # Remove the tag from the query.
        text_prompt = text[len("@rolmocr"):].strip()
        # Check if a video is provided for inference.
        if files and isinstance(files[0], str) and files[0].lower().endswith((".mp4", ".avi", ".mov")):
            video_path = files[0]
            frames = downsample_video(video_path)
            if not frames:
                yield "Error: Could not extract frames from the video."
                return
            # Build the message: prompt followed by each frame with its timestamp.
            content_list = [{"type": "text", "text": text_prompt}]
            for image, timestamp in frames:
                content_list.append({"type": "text", "text": f"Frame {timestamp}:"})
                content_list.append({"type": "image", "image": image})
            messages = [{"role": "user", "content": content_list}]
            # For video, extract images only.
            video_images = [image for image, _ in frames]
            prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = rolmocr_processor(
                text=[prompt_full],
                images=video_images,
                return_tensors="pt",
                padding=True,
            ).to("cuda")
        else:
            # Assume image(s) or text query.
            if len(files) > 1:
                images = [load_image(image) for image in files]
            elif len(files) == 1:
                images = [load_image(files[0])]
            else:
                images = []
            if text_prompt == "" and not images:
                yield "Error: Please input a text query and/or provide an image for the @RolmOCR feature."
                return
            messages = [{
                "role": "user",
                "content": [
                    *[{"type": "image", "image": image} for image in images],
                    {"type": "text", "text": text_prompt},
                ],
            }]
            prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = rolmocr_processor(
                text=[prompt_full],
                images=images if images else None,
                return_tensors="pt",
                padding=True,
            ).to("cuda")
        streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
        thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
        thread.start()
        buffer = ""
        # Use a different color scheme for RolmOCR (purple-themed).
        yield progress_bar_html("Processing with Qwen2.5VL (RolmOCR)")
        for new_text in streamer:
            buffer += new_text
            buffer = buffer.replace("<|im_end|>", "")
            time.sleep(0.01)
            yield buffer
        return

    # Default Inference: Qwen2VL OCR
    # Process files: support multiple images.
    if len(files) > 1:
        images = [load_image(image) for image in files]
    elif len(files) == 1:
        images = [load_image(files[0])]
    else:
        images = []
    
    if text == "" and not images:
        yield "Error: Please input a text query and optionally image(s)."
        return
    if text == "" and images:
        yield "Error: Please input a text query along with the image(s)."
        return

    messages = [{
        "role": "user",
        "content": [
            *[{"type": "image", "image": image} for image in images],
            {"type": "text", "text": text},
        ],
    }]
    prompt_full = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = qwen_processor(
        text=[prompt_full],
        images=images if images else None,
        return_tensors="pt",
        padding=True,
    ).to("cuda")
    streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
    thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
    thread.start()
    buffer = ""
    yield progress_bar_html("Processing with Qwen2VL OCR")
    for new_text in streamer:
        buffer += new_text
        buffer = buffer.replace("<|im_end|>", "")
        time.sleep(0.01)
        yield buffer

# Gradio Interface
examples = [
    [{"text": "@Prescription Extract medicines from this prescription", "files": ["examples/prescription1.jpg"]}],
    [{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
    [{"text": "@RolmOCR Explain the Ad in Detail", "files": ["examples/videoplayback.mp4"]}],
    [{"text": "@RolmOCR OCR the Image", "files": ["rolm/3.jpeg"]}],
    [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
]

css = """
.gradio-container {
    font-family: 'Roboto', sans-serif;
}
.prescription-header {
    background-color: #4B0082;
    color: white;
    padding: 10px;
    border-radius: 5px;
    margin-bottom: 10px;
}
"""

description = """
# **Multimodal OCR with Medicine Extraction**

## Modes:
- **@Prescription** - Upload a prescription image to extract medicine names
- **@RolmOCR** - Use RolmOCR for general text extraction
- **Default** - Use Qwen2VL OCR for general purposes

Upload your medical prescription images and get the medicine names extracted automatically!
"""

demo = gr.ChatInterface(
    fn=model_inference,
    description=description,
    examples=examples,
    textbox=gr.MultimodalTextbox(
        label="Query Input", 
        file_types=["image", "video"], 
        file_count="multiple", 
        placeholder="Use @Prescription to extract medicines, @RolmOCR for RolmOCR, or leave blank for default Qwen2VL OCR"
    ),
    stop_btn="Stop Generation",
    multimodal=True,
    cache_examples=False,
    css=css
)

demo.launch(debug=True)