Spaces:
Sleeping
Sleeping
File size: 4,432 Bytes
2f3144c 3ef82d2 cacc570 2f3144c 3ef82d2 d7c725d 2f3144c 3ef82d2 2f3144c 3ef82d2 2f3144c cacc570 ee73de1 c5b6958 cacc570 2f3144c d7c725d 2f3144c d7c725d 2f3144c d7c725d 2f3144c c5b6958 2f3144c 3ef82d2 2f3144c 3ef82d2 2f3144c 2be66f7 200c7bb cacc570 47c68aa 2be66f7 cacc570 2be66f7 cacc570 2be66f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
import os
import traceback
import spaces
import re
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the Byaldi and Qwen2-VL models
rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali") # Byaldi model
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
).to(device) # Move Qwen2-VL to GPU
# Processor for Qwen2-VL
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
# Global variable to store extracted text
extracted_text = ""
@spaces.GPU(duration=120) # Increased GPU duration to 120 seconds
def ocr_and_extract(image):
global extracted_text
try:
# Save the uploaded image temporarily
temp_image_path = "temp_image.jpg"
image.save(temp_image_path)
# Index the image with Byaldi, and force overwrite of the existing index
rag_model.index(
input_path=temp_image_path,
index_name="image_index", # Reuse the same index
store_collection_with_index=False,
overwrite=True # Overwrite the index for every new image
)
# Perform the search query on the indexed image
results = rag_model.search("", k=1)
# Prepare the input for Qwen2-VL
image_data = Image.open(temp_image_path)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_data},
],
}
]
# Process the message and prepare for Qwen2-VL
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
# Move the image inputs and processor outputs to CUDA
inputs = processor(
text=[text_input],
images=image_inputs,
padding=True,
return_tensors="pt",
).to(device)
# Generate the output with Qwen2-VL
generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
output_text = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# Filter out "You are a helpful assistant" and "assistant" labels
filtered_output = [line for line in output_text[0].split("\n") if not any(kw in line.lower() for kw in ["you are a helpful assistant", "assistant", "user", "system"])]
extracted_text = "\n".join(filtered_output).strip()
# Clean up the temporary file
os.remove(temp_image_path)
return extracted_text
except Exception as e:
error_message = str(e)
traceback.print_exc()
return f"Error: {error_message}"
def search_keywords(keywords):
if not extracted_text:
return "No text extracted yet. Please upload an image."
# Highlight matching keywords in the extracted text
highlighted_text = extracted_text
for keyword in keywords.split():
highlighted_text = re.sub(f"({re.escape(keyword)})", r"<mark>\1</mark>", highlighted_text, flags=re.IGNORECASE)
# Return the highlighted text
return highlighted_text
# Gradio interface for image input and keyword search
with gr.Blocks() as iface:
# Image upload and text extraction section
with gr.Column():
img_input = gr.Image(type="pil", label="Upload an Image")
extracted_output = gr.Textbox(label="Extracted Text", interactive=False)
# Functionality to trigger the OCR and extraction
img_button = gr.Button("Extract Text")
img_button.click(fn=ocr_and_extract, inputs=img_input, outputs=extracted_output)
# Keyword search section
with gr.Column():
search_input = gr.Textbox(label="Enter keywords to search")
search_output = gr.HTML(label="Search Results")
# Functionality to search within the extracted text
search_button = gr.Button("Search")
search_button.click(fn=search_keywords, inputs=search_input, outputs=search_output)
iface.launch()
|