gauri-sharan commited on
Commit
d85fa29
·
verified ·
1 Parent(s): 7508b05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -43
app.py CHANGED
@@ -6,91 +6,128 @@ import torch
6
  from PIL import Image
7
  import os
8
  import traceback
9
- import spaces
10
 
11
- # Load the models
12
  rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali")
13
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
14
  "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
15
  )
16
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
17
 
18
- # Global variable to store extracted text
19
- extracted_text = ""
20
 
21
- @spaces.GPU(duration=120)
22
- def ocr_and_extract(image, text_query):
23
  global extracted_text
24
  try:
 
25
  temp_image_path = "temp_image.jpg"
26
  image.save(temp_image_path)
27
 
28
- rag_model.index(input_path=temp_image_path, index_name="image_index", store_collection_with_index=False, overwrite=True)
 
 
 
 
 
 
 
 
29
  results = rag_model.search(text_query, k=1)
30
 
 
31
  image_data = Image.open(temp_image_path)
 
32
  messages = [
33
- {"role": "user", "content": [{"type": "image", "image": image_data}, {"type": "text", "text": text_query}]}
 
 
 
 
 
 
34
  ]
35
 
 
36
  text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
37
  image_inputs, _ = process_vision_info(messages)
38
- inputs = processor(text=[text_input], images=image_inputs, padding=True, return_tensors="pt")
 
 
 
 
 
 
39
 
40
  qwen_model.to("cuda")
41
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
 
 
42
  generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
43
  output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
44
 
 
45
  extracted_text = output_text[0]
46
  os.remove(temp_image_path)
47
 
48
  return extracted_text
49
 
50
  except Exception as e:
 
51
  traceback.print_exc()
52
- return f"Error: {str(e)}"
53
-
54
- def keyword_search(keywords):
55
- if extracted_text:
56
- found_keywords = [word for word in keywords.split() if word in extracted_text]
57
- if found_keywords:
58
- return f"Keywords found: {', '.join(found_keywords)}"
59
- else:
60
- return "No matching keywords found."
61
- else:
62
  return "No text extracted yet. Please upload an image."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- # Interface Layout
65
- extract_text_button = gr.Button("Extract Text")
66
- extracted_text_box = gr.Textbox(label="Extracted Text", placeholder="Text will appear here...", interactive=False)
67
- keyword_search_box = gr.Textbox(label="Enter keywords to search", placeholder="Type keywords here...")
68
- search_results = gr.Textbox(label="Search Results", interactive=False)
69
 
70
- # Re-order the components: Extract Text button goes above Extracted Text box
71
  iface = gr.Interface(
72
  fn=ocr_and_extract,
73
- inputs=[gr.Image(type="pil"), gr.Textbox(label="Enter your query (optional)")],
74
- outputs=[extracted_text_box],
75
  title="Image OCR with Byaldi + Qwen2-VL",
76
- description="Upload an image (JPEG/PNG) containing Hindi and English text for OCR."
77
  )
78
 
79
- # Layout for keyword search
80
- search_interface = gr.Interface(
81
- fn=keyword_search,
82
- inputs=[keyword_search_box],
83
- outputs=[search_results],
84
- title="Keyword Search within Extracted Text",
85
- description="Enter keywords to search within the extracted text."
86
  )
87
 
88
- # Combining both interfaces with keyword search on the same page
89
- combined_interface = gr.Blocks()
90
- with combined_interface:
91
- extract_text_button.render()
92
- extracted_text_box.render()
93
- keyword_search_box.render()
94
- search_results.render()
 
 
 
 
 
 
 
95
 
96
- combined_interface.launch()
 
 
6
  from PIL import Image
7
  import os
8
  import traceback
9
+ import re
10
 
11
+ # Load models
12
  rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali")
13
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
14
  "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
15
  )
16
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
17
 
18
+ extracted_text = "" # Store the extracted text globally for keyword search
 
19
 
20
+ def ocr_and_extract(image, text_query=None):
 
21
  global extracted_text
22
  try:
23
+ # Save the uploaded image temporarily
24
  temp_image_path = "temp_image.jpg"
25
  image.save(temp_image_path)
26
 
27
+ # Index the image with Byaldi
28
+ rag_model.index(
29
+ input_path=temp_image_path,
30
+ index_name="image_index",
31
+ store_collection_with_index=False,
32
+ overwrite=True
33
+ )
34
+
35
+ # Perform the search query on the indexed image
36
  results = rag_model.search(text_query, k=1)
37
 
38
+ # Prepare the input for Qwen2-VL
39
  image_data = Image.open(temp_image_path)
40
+
41
  messages = [
42
+ {
43
+ "role": "user",
44
+ "content": [
45
+ {"type": "image", "image": image_data},
46
+ {"type": "text", "text": text_query},
47
+ ],
48
+ }
49
  ]
50
 
51
+ # Process input for Qwen2-VL
52
  text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
53
  image_inputs, _ = process_vision_info(messages)
54
+
55
+ inputs = processor(
56
+ text=[text_input],
57
+ images=image_inputs,
58
+ padding=True,
59
+ return_tensors="pt",
60
+ )
61
 
62
  qwen_model.to("cuda")
63
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
64
+
65
+ # Generate the output with Qwen2-VL
66
  generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
67
  output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
68
 
69
+ # Store the extracted text for keyword search
70
  extracted_text = output_text[0]
71
  os.remove(temp_image_path)
72
 
73
  return extracted_text
74
 
75
  except Exception as e:
76
+ error_message = str(e)
77
  traceback.print_exc()
78
+ return f"Error: {error_message}"
79
+
80
+ def search_keywords(keyword):
81
+ global extracted_text
82
+ if not extracted_text:
 
 
 
 
 
83
  return "No text extracted yet. Please upload an image."
84
+
85
+ # Perform basic keyword search within the extracted text
86
+ if re.search(rf"\b{re.escape(keyword)}\b", extracted_text, re.IGNORECASE):
87
+ highlighted_text = re.sub(rf"({re.escape(keyword)})", r"<mark>\1</mark>", extracted_text, flags=re.IGNORECASE)
88
+ return f"Keyword found! {highlighted_text}"
89
+ else:
90
+ return "Keyword not found in the extracted text."
91
+
92
+ # Gradio interface
93
+ image_input = gr.Image(type="pil")
94
+ text_output = gr.Textbox(label="Extracted Text", interactive=True)
95
+ keyword_search = gr.Textbox(label="Enter keywords to search")
96
+ search_button = gr.Button("Search Keywords")
97
+ search_output = gr.HTML()
98
 
99
+ extract_button = gr.Button("Extract Text")
 
 
 
 
100
 
101
+ # Layout update
102
  iface = gr.Interface(
103
  fn=ocr_and_extract,
104
+ inputs=[image_input],
105
+ outputs=[text_output],
106
  title="Image OCR with Byaldi + Qwen2-VL",
107
+ description="Upload an image containing Hindi and English text for OCR. Then, search for specific keywords.",
108
  )
109
 
110
+ # Keyword search layout
111
+ iface_search = gr.Interface(
112
+ fn=search_keywords,
113
+ inputs=[keyword_search],
114
+ outputs=[search_output],
 
 
115
  )
116
 
117
+ # Move extract button above the text output
118
+ def combined_interface(image, keyword):
119
+ ocr_text = ocr_and_extract(image)
120
+ search_result = search_keywords(keyword)
121
+ return ocr_text, search_result
122
+
123
+ combined_iface = gr.Interface(
124
+ fn=combined_interface,
125
+ inputs=[image_input, keyword_search],
126
+ outputs=[text_output, search_output],
127
+ live=True,
128
+ title="Image OCR & Keyword Search",
129
+ description="Extract text from the image and search for specific keywords."
130
+ )
131
 
132
+ # Launch the app
133
+ combined_iface.launch()