gauri-sharan commited on
Commit
7508b05
·
verified ·
1 Parent(s): ee73de1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -84
app.py CHANGED
@@ -7,117 +7,90 @@ from PIL import Image
7
  import os
8
  import traceback
9
  import spaces
10
- import re
11
 
12
- # Check if CUDA is available
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- print(f"Using device: {device}")
15
-
16
- # Load the Byaldi and Qwen2-VL models
17
- rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali") # Byaldi model
18
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
19
  "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
20
- ).to(device) # Move Qwen2-VL to GPU
21
-
22
- # Processor for Qwen2-VL
23
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
24
 
25
  # Global variable to store extracted text
26
  extracted_text = ""
27
 
28
- @spaces.GPU(duration=120) # Increased GPU duration to 120 seconds
29
- def ocr_and_extract(image):
30
  global extracted_text
31
  try:
32
- # Save the uploaded image temporarily
33
  temp_image_path = "temp_image.jpg"
34
  image.save(temp_image_path)
35
 
36
- # Index the image with Byaldi, and force overwrite of the existing index
37
- rag_model.index(
38
- input_path=temp_image_path,
39
- index_name="image_index", # Reuse the same index
40
- store_collection_with_index=False,
41
- overwrite=True # Overwrite the index for every new image
42
- )
43
 
44
- # Perform the search query on the indexed image
45
- results = rag_model.search("", k=1)
46
-
47
- # Prepare the input for Qwen2-VL
48
  image_data = Image.open(temp_image_path)
49
-
50
  messages = [
51
- {
52
- "role": "user",
53
- "content": [
54
- {"type": "image", "image": image_data},
55
- ],
56
- }
57
  ]
58
 
59
- # Process the message and prepare for Qwen2-VL
60
  text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
  image_inputs, _ = process_vision_info(messages)
 
62
 
63
- # Move the image inputs and processor outputs to CUDA
64
- inputs = processor(
65
- text=[text_input],
66
- images=image_inputs,
67
- padding=True,
68
- return_tensors="pt",
69
- ).to(device)
70
-
71
- # Generate the output with Qwen2-VL
72
  generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
73
- output_text = processor.batch_decode(
74
- generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
75
- )
76
-
77
- # Filter out "You are a helpful assistant" and "assistant" labels
78
- filtered_output = [line for line in output_text[0].split("\n") if not any(kw in line.lower() for kw in ["you are a helpful assistant", "assistant", "user", "system"])]
79
- extracted_text = "\n".join(filtered_output).strip()
80
-
81
- # Clean up the temporary file
82
  os.remove(temp_image_path)
83
 
84
  return extracted_text
85
 
86
  except Exception as e:
87
- error_message = str(e)
88
  traceback.print_exc()
89
- return f"Error: {error_message}"
90
-
91
- def search_keywords(keywords):
92
- if not extracted_text:
 
 
 
 
 
 
93
  return "No text extracted yet. Please upload an image."
94
 
95
- # Highlight matching keywords in the extracted text
96
- highlighted_text = extracted_text
97
- for keyword in keywords.split():
98
- highlighted_text = re.sub(f"({re.escape(keyword)})", r"<mark>\1</mark>", highlighted_text, flags=re.IGNORECASE)
99
-
100
- # Return the highlighted text
101
- return highlighted_text
102
-
103
- # Gradio interface for image input and keyword search
104
- with gr.Blocks() as iface:
105
- # Image upload and text extraction section
106
- with gr.Column():
107
- img_input = gr.Image(type="pil", label="Upload an Image")
108
- extracted_output = gr.Textbox(label="Extracted Text", interactive=False)
109
-
110
- # Functionality to trigger the OCR and extraction
111
- img_button = gr.Button("Extract Text")
112
- img_button.click(fn=ocr_and_extract, inputs=img_input, outputs=extracted_output)
113
-
114
- # Keyword search section
115
- with gr.Column():
116
- search_input = gr.Textbox(label="Enter keywords to search")
117
- search_output = gr.HTML(label="Search Results")
118
-
119
- # Functionality to search within the extracted text
120
- search_button = gr.Button("Search")
121
- search_button.click(fn=search_keywords, inputs=search_input, outputs=search_output)
122
-
123
- iface.launch()
 
 
 
 
 
7
  import os
8
  import traceback
9
  import spaces
 
10
 
11
+ # Load the models
12
+ rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali")
 
 
 
 
13
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
14
  "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
15
+ )
 
 
16
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
17
 
18
  # Global variable to store extracted text
19
  extracted_text = ""
20
 
21
+ @spaces.GPU(duration=120)
22
+ def ocr_and_extract(image, text_query):
23
  global extracted_text
24
  try:
 
25
  temp_image_path = "temp_image.jpg"
26
  image.save(temp_image_path)
27
 
28
+ rag_model.index(input_path=temp_image_path, index_name="image_index", store_collection_with_index=False, overwrite=True)
29
+ results = rag_model.search(text_query, k=1)
 
 
 
 
 
30
 
 
 
 
 
31
  image_data = Image.open(temp_image_path)
 
32
  messages = [
33
+ {"role": "user", "content": [{"type": "image", "image": image_data}, {"type": "text", "text": text_query}]}
 
 
 
 
 
34
  ]
35
 
 
36
  text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
37
  image_inputs, _ = process_vision_info(messages)
38
+ inputs = processor(text=[text_input], images=image_inputs, padding=True, return_tensors="pt")
39
 
40
+ qwen_model.to("cuda")
41
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
 
 
 
 
 
 
 
42
  generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
43
+ output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
44
+
45
+ extracted_text = output_text[0]
 
 
 
 
 
 
46
  os.remove(temp_image_path)
47
 
48
  return extracted_text
49
 
50
  except Exception as e:
 
51
  traceback.print_exc()
52
+ return f"Error: {str(e)}"
53
+
54
+ def keyword_search(keywords):
55
+ if extracted_text:
56
+ found_keywords = [word for word in keywords.split() if word in extracted_text]
57
+ if found_keywords:
58
+ return f"Keywords found: {', '.join(found_keywords)}"
59
+ else:
60
+ return "No matching keywords found."
61
+ else:
62
  return "No text extracted yet. Please upload an image."
63
 
64
+ # Interface Layout
65
+ extract_text_button = gr.Button("Extract Text")
66
+ extracted_text_box = gr.Textbox(label="Extracted Text", placeholder="Text will appear here...", interactive=False)
67
+ keyword_search_box = gr.Textbox(label="Enter keywords to search", placeholder="Type keywords here...")
68
+ search_results = gr.Textbox(label="Search Results", interactive=False)
69
+
70
+ # Re-order the components: Extract Text button goes above Extracted Text box
71
+ iface = gr.Interface(
72
+ fn=ocr_and_extract,
73
+ inputs=[gr.Image(type="pil"), gr.Textbox(label="Enter your query (optional)")],
74
+ outputs=[extracted_text_box],
75
+ title="Image OCR with Byaldi + Qwen2-VL",
76
+ description="Upload an image (JPEG/PNG) containing Hindi and English text for OCR."
77
+ )
78
+
79
+ # Layout for keyword search
80
+ search_interface = gr.Interface(
81
+ fn=keyword_search,
82
+ inputs=[keyword_search_box],
83
+ outputs=[search_results],
84
+ title="Keyword Search within Extracted Text",
85
+ description="Enter keywords to search within the extracted text."
86
+ )
87
+
88
+ # Combining both interfaces with keyword search on the same page
89
+ combined_interface = gr.Blocks()
90
+ with combined_interface:
91
+ extract_text_button.render()
92
+ extracted_text_box.render()
93
+ keyword_search_box.render()
94
+ search_results.render()
95
+
96
+ combined_interface.launch()