shukdevdatta123 commited on
Commit
beecb06
·
verified ·
1 Parent(s): 7fb8860

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -120
app.py CHANGED
@@ -1,146 +1,128 @@
1
  import gradio as gr
 
 
 
2
  import torch
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
- from qwen_vl_utils import process_vision_info
5
- import re
 
 
 
6
 
7
- # Load the model on CPU
8
- def load_model():
9
- model = Qwen2VLForConditionalGeneration.from_pretrained(
10
- "prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
11
- torch_dtype=torch.float32,
12
- device_map="cpu"
13
- )
14
- processor = AutoProcessor.from_pretrained("prithivMLmods/Qwen2-VL-OCR-2B-Instruct")
15
- return model, processor
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Function to extract medicine names
18
- def extract_medicine_names(image):
19
- model, processor = load_model()
20
-
21
- # Prepare the message with the specific prompt for medicine extraction
22
- messages = [
23
- {
24
- "role": "user",
25
- "content": [
26
- {
27
- "type": "image",
28
- "image": image,
29
- },
30
- {"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."},
31
- ],
32
- }
33
- ]
34
-
35
- # Prepare for inference
36
- text = processor.apply_chat_template(
37
- messages, tokenize=False, add_generation_prompt=True
38
- )
39
- image_inputs, video_inputs = process_vision_info(messages)
40
- inputs = processor(
41
- text=[text],
42
- images=image_inputs,
43
- videos=video_inputs,
44
- padding=True,
45
- return_tensors="pt",
46
- )
47
-
48
- # Generate output
49
- generated_ids = model.generate(**inputs, max_new_tokens=256)
50
- generated_ids_trimmed = [
51
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
52
- ]
53
- output_text = processor.batch_decode(
54
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
55
- )[0]
56
-
57
- # Remove <|im_end|> and any other special tokens that might appear in the output
58
- output_text = output_text.replace("<|im_end|>", "").strip()
59
-
60
- return output_text
61
 
62
- # Create a singleton model and processor to avoid reloading for each request
63
- model_instance = None
64
- processor_instance = None
65
-
66
- def get_model_and_processor():
67
- global model_instance, processor_instance
68
- if model_instance is None or processor_instance is None:
69
- model_instance, processor_instance = load_model()
70
- return model_instance, processor_instance
71
-
72
- # Optimized extraction function that uses the singleton model
73
- def extract_medicine_names_optimized(image):
74
- if image is None:
75
- return "Please upload an image."
76
 
77
- model, processor = get_model_and_processor()
78
 
79
- # Prepare the message with the specific prompt for medicine extraction
80
- messages = [
81
- {
82
- "role": "user",
83
- "content": [
84
- {
85
- "type": "image",
86
- "image": image,
87
- },
88
- {"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."},
89
- ],
90
- }
91
- ]
92
 
93
- # Prepare for inference
94
- text = processor.apply_chat_template(
95
- messages, tokenize=False, add_generation_prompt=True
96
- )
97
- image_inputs, video_inputs = process_vision_info(messages)
 
 
 
 
98
  inputs = processor(
99
- text=[text],
100
- images=image_inputs,
101
- videos=video_inputs,
102
- padding=True,
103
  return_tensors="pt",
104
- )
 
 
 
 
105
 
106
- # Generate output
107
- generated_ids = model.generate(**inputs, max_new_tokens=256)
108
- generated_ids_trimmed = [
109
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
110
- ]
111
- output_text = processor.batch_decode(
112
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
113
- )[0]
114
 
115
- # Remove <|im_end|> and any other special tokens that might appear in the output
116
- output_text = output_text.replace("<|im_end|>", "").strip()
117
 
118
- return output_text
 
 
 
 
119
 
120
- # Create Gradio interface
121
- with gr.Blocks(title="Medicine Name Extractor") as app:
122
  gr.Markdown("# Medicine Name Extractor")
123
- gr.Markdown("Upload a medical prescription image to extract the names of medicines.")
124
 
125
  with gr.Row():
126
  with gr.Column():
127
- input_image = gr.Image(type="pil", label="Upload Prescription Image")
 
 
 
 
128
  extract_btn = gr.Button("Extract Medicine Names", variant="primary")
129
 
130
  with gr.Column():
131
- output_text = gr.Textbox(label="Extracted Medicine Names", lines=10)
132
 
133
  extract_btn.click(
134
- fn=extract_medicine_names_optimized,
135
- inputs=input_image,
136
- outputs=output_text
 
 
 
 
 
 
 
 
 
 
 
137
  )
138
 
139
- gr.Markdown("### Notes")
140
- gr.Markdown("- This tool uses the Qwen2-VL-OCR model to extract text from prescription images")
141
- gr.Markdown("- For best results, ensure the prescription image is clear and readable")
142
- gr.Markdown("- Processing may take some time as the model runs on CPU")
 
 
143
 
144
- # Launch the app
145
- if __name__ == "__main__":
146
- app.launch()
 
1
  import gradio as gr
2
+ from transformers.image_utils import load_image
3
+ from threading import Thread
4
+ import time
5
  import torch
6
+ from PIL import Image
7
+ from transformers import (
8
+ Qwen2VLForConditionalGeneration,
9
+ AutoProcessor,
10
+ TextIteratorStreamer,
11
+ )
12
 
13
+ # ---------------------------
14
+ # Helper Functions
15
+ # ---------------------------
16
+ def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
17
+ """
18
+ Returns an HTML snippet for a thin animated progress bar with a label.
19
+ """
20
+ return f'''
21
+ <div style="display: flex; align-items: center;">
22
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
23
+ <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
24
+ <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
25
+ </div>
26
+ </div>
27
+ <style>
28
+ @keyframes loading {{
29
+ 0% {{ transform: translateX(-100%); }}
30
+ 100% {{ transform: translateX(100%); }}
31
+ }}
32
+ </style>
33
+ '''
34
 
35
+ # Model and Processor Setup - CPU version
36
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
37
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
38
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
39
+ MODEL_ID,
40
+ trust_remote_code=True,
41
+ torch_dtype=torch.float32 # Using float32 for CPU compatibility
42
+ ).to("cpu").eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # Main Inference Function
45
+ def extract_medicines(image_files):
46
+ """Extract medicine names from prescription images."""
47
+ if not image_files:
48
+ return "Please upload a prescription image."
 
 
 
 
 
 
 
 
 
49
 
50
+ images = [load_image(image) for image in image_files]
51
 
52
+ # Specific prompt to extract only medicine names
53
+ text = "Extract ONLY the names of medications/medicines from this prescription image. Format the output as a numbered list of medicine names only, without dosages or instructions."
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ messages = [{
56
+ "role": "user",
57
+ "content": [
58
+ *[{"type": "image", "image": image} for image in images],
59
+ {"type": "text", "text": text},
60
+ ],
61
+ }]
62
+
63
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
64
  inputs = processor(
65
+ text=[prompt_full],
66
+ images=images,
 
 
67
  return_tensors="pt",
68
+ padding=True,
69
+ ).to("cpu")
70
+
71
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
72
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
73
 
74
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
75
+ thread.start()
 
 
 
 
 
 
76
 
77
+ buffer = ""
78
+ yield progress_bar_html("Extracting Medicine Names")
79
 
80
+ for new_text in streamer:
81
+ buffer += new_text
82
+ buffer = buffer.replace("<|im_end|>", "")
83
+ time.sleep(0.01)
84
+ yield buffer
85
 
86
+ # Gradio Interface
87
+ with gr.Blocks() as demo:
88
  gr.Markdown("# Medicine Name Extractor")
89
+ gr.Markdown("Upload prescription images to extract medicine names")
90
 
91
  with gr.Row():
92
  with gr.Column():
93
+ image_input = gr.File(
94
+ label="Upload Prescription Image(s)",
95
+ file_count="multiple",
96
+ file_types=["image"]
97
+ )
98
  extract_btn = gr.Button("Extract Medicine Names", variant="primary")
99
 
100
  with gr.Column():
101
+ output = gr.Markdown(label="Extracted Medicine Names")
102
 
103
  extract_btn.click(
104
+ fn=extract_medicines,
105
+ inputs=image_input,
106
+ outputs=output
107
+ )
108
+
109
+ gr.Examples(
110
+ examples=[
111
+ ["examples/prescription1.jpg"],
112
+ ["examples/prescription2.jpg"],
113
+ ],
114
+ inputs=image_input,
115
+ outputs=output,
116
+ fn=extract_medicines,
117
+ cache_examples=True,
118
  )
119
 
120
+ gr.Markdown("""
121
+ ### Notes:
122
+ - This app is optimized to run on CPU
123
+ - Upload clear images of prescriptions for best results
124
+ - Only medicine names will be extracted
125
+ """)
126
 
127
+ demo.queue()
128
+ demo.launch(debug=True)