shukdevdatta123 commited on
Commit
5dbe551
·
verified ·
1 Parent(s): ea794b6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +344 -0
app.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers.image_utils import load_image
3
+ from threading import Thread
4
+ import time
5
+ import torch
6
+ import spaces
7
+ import cv2
8
+ import numpy as np
9
+ from PIL import Image
10
+ import re
11
+ from transformers import (
12
+ Qwen2VLForConditionalGeneration,
13
+ AutoProcessor,
14
+ TextIteratorStreamer,
15
+ )
16
+ from transformers import Qwen2_5_VLForConditionalGeneration
17
+
18
+ # ---------------------------
19
+ # Helper Functions
20
+ # ---------------------------
21
+ def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
22
+ """
23
+ Returns an HTML snippet for a thin animated progress bar with a label.
24
+ Colors can be customized; default colors are used for Qwen2VL/Aya‑Vision.
25
+ """
26
+ return f'''
27
+ <div style="display: flex; align-items: center;">
28
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
29
+ <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
30
+ <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
31
+ </div>
32
+ </div>
33
+ <style>
34
+ @keyframes loading {{
35
+ 0% {{ transform: translateX(-100%); }}
36
+ 100% {{ transform: translateX(100%); }}
37
+ }}
38
+ </style>
39
+ '''
40
+
41
+ def downsample_video(video_path):
42
+ """
43
+ Downsamples a video file by extracting 10 evenly spaced frames.
44
+ Returns a list of tuples (PIL.Image, timestamp).
45
+ """
46
+ vidcap = cv2.VideoCapture(video_path)
47
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
48
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
49
+ frames = []
50
+ if total_frames <= 0 or fps <= 0:
51
+ vidcap.release()
52
+ return frames
53
+ # Determine 10 evenly spaced frame indices.
54
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
55
+ for i in frame_indices:
56
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
57
+ success, image = vidcap.read()
58
+ if success:
59
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
60
+ pil_image = Image.fromarray(image)
61
+ timestamp = round(i / fps, 2)
62
+ frames.append((pil_image, timestamp))
63
+ vidcap.release()
64
+ return frames
65
+
66
+ def extract_medicine_names(text):
67
+ """
68
+ Extracts medicine names from OCR text output.
69
+ Uses a combination of pattern matching and formatting to identify medications.
70
+ Returns a formatted list of medicines found.
71
+ """
72
+ # Common medicine patterns (extended to catch more formats)
73
+ lines = text.split('\n')
74
+ medicines = []
75
+
76
+ # Look for patterns typical in prescriptions
77
+ for line in lines:
78
+ # Clean and standardize the line
79
+ clean_line = line.strip()
80
+
81
+ # Skip very short lines, headers, or non-relevant text
82
+ if len(clean_line) < 3 or re.search(r'(prescription|rx|patient|name|date|doctor|hospital|clinic|address)', clean_line.lower()):
83
+ continue
84
+
85
+ # Medicine names often appear at the beginning of lines, with dosage info following
86
+ # Look for tablet/capsule/mg indicators - strong indicators of medication
87
+ if re.search(r'(tab|tablet|cap|capsule|mg|ml|injection|syrup|solution|suspension|ointment|cream|gel|patch|suppository|inhaler|drops)', clean_line.lower()):
88
+ # Extract the likely medicine name - the part before the dosage/form or the entire line if it's short
89
+ medicine_match = re.split(r'(\d+\s*mg|\d+\s*ml|\d+\s*tab|\d+\s*cap)', clean_line, 1)[0].strip()
90
+ if medicine_match and len(medicine_match) > 2:
91
+ medicines.append(medicine_match)
92
+
93
+ # Check for brand names or generic medication patterns
94
+ elif re.match(r'^[A-Z][a-z]+\s*[A-Z0-9]', clean_line) or re.match(r'^[A-Z][a-z]+', clean_line):
95
+ # Likely a medicine name starting with a capital letter
96
+ medicine_parts = re.split(r'(\d+|\s+\d+\s*times|\s+\d+\s*times\s+daily)', clean_line, 1)
97
+ if medicine_parts and len(medicine_parts[0]) > 2:
98
+ medicines.append(medicine_parts[0].strip())
99
+
100
+ # Remove duplicates while preserving order
101
+ unique_medicines = []
102
+ for med in medicines:
103
+ if med not in unique_medicines:
104
+ unique_medicines.append(med)
105
+
106
+ return unique_medicines
107
+
108
+ # Model and Processor Setup
109
+ # Qwen2VL OCR (default branch)
110
+ QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # [or] prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
111
+ qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
112
+ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
113
+ QV_MODEL_ID,
114
+ trust_remote_code=True,
115
+ torch_dtype=torch.float16
116
+ ).to("cuda").eval()
117
+
118
+ # RolmOCR branch (@RolmOCR)
119
+ ROLMOCR_MODEL_ID = "reducto/RolmOCR"
120
+ rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
121
+ rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
122
+ ROLMOCR_MODEL_ID,
123
+ trust_remote_code=True,
124
+ torch_dtype=torch.bfloat16
125
+ ).to("cuda").eval()
126
+
127
+ # Main Inference Function
128
+ @spaces.GPU
129
+ def model_inference(input_dict, history):
130
+ text = input_dict["text"].strip()
131
+ files = input_dict.get("files", [])
132
+
133
+ # Check for prescription-specific command
134
+ if text.lower().startswith("@prescription") or text.lower().startswith("@med"):
135
+ # Specific mode for medicine extraction
136
+ if not files:
137
+ yield "Error: Please upload a prescription image to extract medicine names."
138
+ return
139
+
140
+ # Use RolmOCR for better text extraction from prescriptions
141
+ images = [load_image(image) for image in files[:1]] # Taking just the first image for processing
142
+
143
+ messages = [{
144
+ "role": "user",
145
+ "content": [
146
+ {"type": "image", "image": images[0]},
147
+ {"type": "text", "text": "Extract all text from this medical prescription image, focus on medicine names, dosages, and instructions."},
148
+ ],
149
+ }]
150
+
151
+ prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
152
+ inputs = rolmocr_processor(
153
+ text=[prompt_full],
154
+ images=images,
155
+ return_tensors="pt",
156
+ padding=True,
157
+ ).to("cuda")
158
+
159
+ # First, get the complete OCR text
160
+ streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
161
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
162
+ thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
163
+ thread.start()
164
+
165
+ ocr_text = ""
166
+ yield progress_bar_html("Processing Prescription with Medicine Extractor")
167
+
168
+ for new_text in streamer:
169
+ ocr_text += new_text
170
+ ocr_text = ocr_text.replace("<|im_end|>", "")
171
+ time.sleep(0.01)
172
+
173
+ # After getting full OCR text, extract medicine names
174
+ medicines = extract_medicine_names(ocr_text)
175
+
176
+ # Format the results nicely
177
+ result = "## Extracted Medicine Names\n\n"
178
+ if medicines:
179
+ for i, med in enumerate(medicines, 1):
180
+ result += f"{i}. {med}\n"
181
+ else:
182
+ result += "No medicine names detected in the prescription.\n\n"
183
+
184
+ result += "\n\n## Full OCR Text\n\n```\n" + ocr_text + "\n```"
185
+ yield result
186
+ return
187
+
188
+ # RolmOCR Inference (@RolmOCR)
189
+ if text.lower().startswith("@rolmocr"):
190
+ # Remove the tag from the query.
191
+ text_prompt = text[len("@rolmocr"):].strip()
192
+ # Check if a video is provided for inference.
193
+ if files and isinstance(files[0], str) and files[0].lower().endswith((".mp4", ".avi", ".mov")):
194
+ video_path = files[0]
195
+ frames = downsample_video(video_path)
196
+ if not frames:
197
+ yield "Error: Could not extract frames from the video."
198
+ return
199
+ # Build the message: prompt followed by each frame with its timestamp.
200
+ content_list = [{"type": "text", "text": text_prompt}]
201
+ for image, timestamp in frames:
202
+ content_list.append({"type": "text", "text": f"Frame {timestamp}:"})
203
+ content_list.append({"type": "image", "image": image})
204
+ messages = [{"role": "user", "content": content_list}]
205
+ # For video, extract images only.
206
+ video_images = [image for image, _ in frames]
207
+ prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
208
+ inputs = rolmocr_processor(
209
+ text=[prompt_full],
210
+ images=video_images,
211
+ return_tensors="pt",
212
+ padding=True,
213
+ ).to("cuda")
214
+ else:
215
+ # Assume image(s) or text query.
216
+ if len(files) > 1:
217
+ images = [load_image(image) for image in files]
218
+ elif len(files) == 1:
219
+ images = [load_image(files[0])]
220
+ else:
221
+ images = []
222
+ if text_prompt == "" and not images:
223
+ yield "Error: Please input a text query and/or provide an image for the @RolmOCR feature."
224
+ return
225
+ messages = [{
226
+ "role": "user",
227
+ "content": [
228
+ *[{"type": "image", "image": image} for image in images],
229
+ {"type": "text", "text": text_prompt},
230
+ ],
231
+ }]
232
+ prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
233
+ inputs = rolmocr_processor(
234
+ text=[prompt_full],
235
+ images=images if images else None,
236
+ return_tensors="pt",
237
+ padding=True,
238
+ ).to("cuda")
239
+ streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
240
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
241
+ thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
242
+ thread.start()
243
+ buffer = ""
244
+ # Use a different color scheme for RolmOCR (purple-themed).
245
+ yield progress_bar_html("Processing with Qwen2.5VL (RolmOCR)")
246
+ for new_text in streamer:
247
+ buffer += new_text
248
+ buffer = buffer.replace("<|im_end|>", "")
249
+ time.sleep(0.01)
250
+ yield buffer
251
+ return
252
+
253
+ # Default Inference: Qwen2VL OCR
254
+ # Process files: support multiple images.
255
+ if len(files) > 1:
256
+ images = [load_image(image) for image in files]
257
+ elif len(files) == 1:
258
+ images = [load_image(files[0])]
259
+ else:
260
+ images = []
261
+
262
+ if text == "" and not images:
263
+ yield "Error: Please input a text query and optionally image(s)."
264
+ return
265
+ if text == "" and images:
266
+ yield "Error: Please input a text query along with the image(s)."
267
+ return
268
+
269
+ messages = [{
270
+ "role": "user",
271
+ "content": [
272
+ *[{"type": "image", "image": image} for image in images],
273
+ {"type": "text", "text": text},
274
+ ],
275
+ }]
276
+ prompt_full = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
277
+ inputs = qwen_processor(
278
+ text=[prompt_full],
279
+ images=images if images else None,
280
+ return_tensors="pt",
281
+ padding=True,
282
+ ).to("cuda")
283
+ streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
284
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
285
+ thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
286
+ thread.start()
287
+ buffer = ""
288
+ yield progress_bar_html("Processing with Qwen2VL OCR")
289
+ for new_text in streamer:
290
+ buffer += new_text
291
+ buffer = buffer.replace("<|im_end|>", "")
292
+ time.sleep(0.01)
293
+ yield buffer
294
+
295
+ # Gradio Interface
296
+ examples = [
297
+ [{"text": "@Prescription Extract medicines from this prescription", "files": ["examples/prescription1.jpg"]}],
298
+ [{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
299
+ [{"text": "@RolmOCR Explain the Ad in Detail", "files": ["examples/videoplayback.mp4"]}],
300
+ [{"text": "@RolmOCR OCR the Image", "files": ["rolm/3.jpeg"]}],
301
+ [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
302
+ ]
303
+
304
+ css = """
305
+ .gradio-container {
306
+ font-family: 'Roboto', sans-serif;
307
+ }
308
+ .prescription-header {
309
+ background-color: #4B0082;
310
+ color: white;
311
+ padding: 10px;
312
+ border-radius: 5px;
313
+ margin-bottom: 10px;
314
+ }
315
+ """
316
+
317
+ description = """
318
+ # **Multimodal OCR with Medicine Extraction**
319
+
320
+ ## Modes:
321
+ - **@Prescription** - Upload a prescription image to extract medicine names
322
+ - **@RolmOCR** - Use RolmOCR for general text extraction
323
+ - **Default** - Use Qwen2VL OCR for general purposes
324
+
325
+ Upload your medical prescription images and get the medicine names extracted automatically!
326
+ """
327
+
328
+ demo = gr.ChatInterface(
329
+ fn=model_inference,
330
+ description=description,
331
+ examples=examples,
332
+ textbox=gr.MultimodalTextbox(
333
+ label="Query Input",
334
+ file_types=["image", "video"],
335
+ file_count="multiple",
336
+ placeholder="Use @Prescription to extract medicines, @RolmOCR for RolmOCR, or leave blank for default Qwen2VL OCR"
337
+ ),
338
+ stop_btn="Stop Generation",
339
+ multimodal=True,
340
+ cache_examples=False,
341
+ css=css
342
+ )
343
+
344
+ demo.launch(debug=True)