shukdevdatta123 commited on
Commit
d352924
·
verified ·
1 Parent(s): c7cc8ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -54
app.py CHANGED
@@ -1,70 +1,146 @@
1
  import gradio as gr
2
  import torch
3
- from PIL import Image
4
- import time
5
- from threading import Thread
6
- from transformers import (
7
- Qwen2VLForConditionalGeneration,
8
- AutoProcessor,
9
- TextIteratorStreamer,
10
- )
11
 
12
- # Load model and processor - CPU version
13
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
14
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
15
- model = Qwen2VLForConditionalGeneration.from_pretrained(
16
- MODEL_ID,
17
- trust_remote_code=True,
18
- torch_dtype=torch.float32 # Using float32 for CPU compatibility
19
- ).to("cpu").eval()
 
20
 
21
- def extract_medicines(image):
22
- """Extract medicine names from prescription images."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  if image is None:
24
- return "Please upload a prescription image."
25
 
26
- # Process the image
27
- text = "Extract ONLY the names of medications/medicines from this prescription image. Format the output as a numbered list of medicine names only, without dosages or instructions."
28
 
29
- messages = [{
30
- "role": "user",
31
- "content": [
32
- {"type": "image", "image": Image.open(image)},
33
- {"type": "text", "text": text},
34
- ],
35
- }]
 
 
 
 
 
 
36
 
37
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
38
  inputs = processor(
39
- text=[prompt_full],
40
- images=[Image.open(image)],
41
- return_tensors="pt",
42
  padding=True,
43
- ).to("cpu")
44
-
45
- # Generate response
46
- with torch.no_grad():
47
- output = model.generate(**inputs, max_new_tokens=512)
48
 
49
- # Decode and return response
50
- response = processor.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
51
 
52
- # Clean up the response to get just the model's answer
53
- if "<|assistant|>" in response:
54
- response = response.split("<|assistant|>")[1].strip()
55
 
56
- return response
57
 
58
- # Create a simple Gradio interface
59
- demo = gr.Interface(
60
- fn=extract_medicines,
61
- inputs=gr.Image(type="filepath", label="Upload Prescription Image"),
62
- outputs=gr.Textbox(label="Extracted Medicine Names"),
63
- title="Medicine Name Extractor",
64
- description="Upload prescription images to extract medicine names",
65
- examples=[["examples/prescription1.jpg"]], # Update with your actual example paths or remove if not available
66
- cache_examples=True,
67
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
69
  if __name__ == "__main__":
70
- demo.launch(debug=True)
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
+ from qwen_vl_utils import process_vision_info
5
+ import re
 
 
 
 
 
6
 
7
+ # Load the model on CPU
8
+ def load_model():
9
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
10
+ "prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
11
+ torch_dtype=torch.float32,
12
+ device_map="cpu"
13
+ )
14
+ processor = AutoProcessor.from_pretrained("prithivMLmods/Qwen2-VL-OCR-2B-Instruct")
15
+ return model, processor
16
 
17
+ # Function to extract medicine names
18
+ def extract_medicine_names(image):
19
+ model, processor = load_model()
20
+
21
+ # Prepare the message with the specific prompt for medicine extraction
22
+ messages = [
23
+ {
24
+ "role": "user",
25
+ "content": [
26
+ {
27
+ "type": "image",
28
+ "image": image,
29
+ },
30
+ {"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."},
31
+ ],
32
+ }
33
+ ]
34
+
35
+ # Prepare for inference
36
+ text = processor.apply_chat_template(
37
+ messages, tokenize=False, add_generation_prompt=True
38
+ )
39
+ image_inputs, video_inputs = process_vision_info(messages)
40
+ inputs = processor(
41
+ text=[text],
42
+ images=image_inputs,
43
+ videos=video_inputs,
44
+ padding=True,
45
+ return_tensors="pt",
46
+ )
47
+
48
+ # Generate output
49
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
50
+ generated_ids_trimmed = [
51
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
52
+ ]
53
+ output_text = processor.batch_decode(
54
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
55
+ )[0]
56
+
57
+ # Remove <|im_end|> and any other special tokens that might appear in the output
58
+ output_text = output_text.replace("<|im_end|>", "").strip()
59
+
60
+ return output_text
61
+
62
+ # Create a singleton model and processor to avoid reloading for each request
63
+ model_instance = None
64
+ processor_instance = None
65
+
66
+ def get_model_and_processor():
67
+ global model_instance, processor_instance
68
+ if model_instance is None or processor_instance is None:
69
+ model_instance, processor_instance = load_model()
70
+ return model_instance, processor_instance
71
+
72
+ # Optimized extraction function that uses the singleton model
73
+ def extract_medicine_names_optimized(image):
74
  if image is None:
75
+ return "Please upload an image."
76
 
77
+ model, processor = get_model_and_processor()
 
78
 
79
+ # Prepare the message with the specific prompt for medicine extraction
80
+ messages = [
81
+ {
82
+ "role": "user",
83
+ "content": [
84
+ {
85
+ "type": "image",
86
+ "image": image,
87
+ },
88
+ {"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."},
89
+ ],
90
+ }
91
+ ]
92
 
93
+ # Prepare for inference
94
+ text = processor.apply_chat_template(
95
+ messages, tokenize=False, add_generation_prompt=True
96
+ )
97
+ image_inputs, video_inputs = process_vision_info(messages)
98
  inputs = processor(
99
+ text=[text],
100
+ images=image_inputs,
101
+ videos=video_inputs,
102
  padding=True,
103
+ return_tensors="pt",
104
+ )
 
 
 
105
 
106
+ # Generate output
107
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
108
+ generated_ids_trimmed = [
109
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
110
+ ]
111
+ output_text = processor.batch_decode(
112
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
113
+ )[0]
114
 
115
+ # Remove <|im_end|> and any other special tokens that might appear in the output
116
+ output_text = output_text.replace("<|im_end|>", "").strip()
 
117
 
118
+ return output_text
119
 
120
+ # Create Gradio interface
121
+ with gr.Blocks(title="Medicine Name Extractor") as app:
122
+ gr.Markdown("# Medicine Name Extractor")
123
+ gr.Markdown("Upload a medical prescription image to extract the names of medicines.")
124
+
125
+ with gr.Row():
126
+ with gr.Column():
127
+ input_image = gr.Image(type="pil", label="Upload Prescription Image")
128
+ extract_btn = gr.Button("Extract Medicine Names", variant="primary")
129
+
130
+ with gr.Column():
131
+ output_text = gr.Textbox(label="Extracted Medicine Names", lines=10)
132
+
133
+ extract_btn.click(
134
+ fn=extract_medicine_names_optimized,
135
+ inputs=input_image,
136
+ outputs=output_text
137
+ )
138
+
139
+ gr.Markdown("### Notes")
140
+ gr.Markdown("- This tool uses the Qwen2-VL-OCR model to extract text from prescription images")
141
+ gr.Markdown("- For best results, ensure the prescription image is clear and readable")
142
+ gr.Markdown("- Processing may take some time as the model runs on CPU")
143
 
144
+ # Launch the app
145
  if __name__ == "__main__":
146
+ app.launch()