File size: 4,432 Bytes
2f3144c
 
 
 
 
 
 
 
3ef82d2
cacc570
2f3144c
3ef82d2
 
 
 
 
d7c725d
2f3144c
3ef82d2
 
2f3144c
 
3ef82d2
2f3144c
cacc570
 
 
ee73de1
c5b6958
cacc570
2f3144c
 
 
 
 
d7c725d
2f3144c
 
d7c725d
2f3144c
d7c725d
2f3144c
 
 
c5b6958
2f3144c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ef82d2
2f3144c
 
 
 
 
3ef82d2
2f3144c
 
 
 
2be66f7
 
 
200c7bb
 
cacc570
47c68aa
2be66f7
 
 
cacc570
2be66f7
 
 
 
 
 
cacc570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2be66f7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
import os
import traceback
import spaces
import re

# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the Byaldi and Qwen2-VL models
rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali")  # Byaldi model
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
).to(device)  # Move Qwen2-VL to GPU

# Processor for Qwen2-VL
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)

# Global variable to store extracted text
extracted_text = ""

@spaces.GPU(duration=120)  # Increased GPU duration to 120 seconds
def ocr_and_extract(image):
    global extracted_text
    try:
        # Save the uploaded image temporarily
        temp_image_path = "temp_image.jpg"
        image.save(temp_image_path)

        # Index the image with Byaldi, and force overwrite of the existing index
        rag_model.index(
            input_path=temp_image_path,
            index_name="image_index",  # Reuse the same index
            store_collection_with_index=False,
            overwrite=True  # Overwrite the index for every new image
        )

        # Perform the search query on the indexed image
        results = rag_model.search("", k=1)

        # Prepare the input for Qwen2-VL
        image_data = Image.open(temp_image_path)

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_data},
                ],
            }
        ]

        # Process the message and prepare for Qwen2-VL
        text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, _ = process_vision_info(messages)

        # Move the image inputs and processor outputs to CUDA
        inputs = processor(
            text=[text_input],
            images=image_inputs,
            padding=True,
            return_tensors="pt",
        ).to(device)

        # Generate the output with Qwen2-VL
        generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
        output_text = processor.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        # Filter out "You are a helpful assistant" and "assistant" labels
        filtered_output = [line for line in output_text[0].split("\n") if not any(kw in line.lower() for kw in ["you are a helpful assistant", "assistant", "user", "system"])]
        extracted_text = "\n".join(filtered_output).strip()

        # Clean up the temporary file
        os.remove(temp_image_path)

        return extracted_text

    except Exception as e:
        error_message = str(e)
        traceback.print_exc()
        return f"Error: {error_message}"

def search_keywords(keywords):
    if not extracted_text:
        return "No text extracted yet. Please upload an image."

    # Highlight matching keywords in the extracted text
    highlighted_text = extracted_text
    for keyword in keywords.split():
        highlighted_text = re.sub(f"({re.escape(keyword)})", r"<mark>\1</mark>", highlighted_text, flags=re.IGNORECASE)

    # Return the highlighted text
    return highlighted_text

# Gradio interface for image input and keyword search
with gr.Blocks() as iface:
    # Image upload and text extraction section
    with gr.Column():
        img_input = gr.Image(type="pil", label="Upload an Image")
        extracted_output = gr.Textbox(label="Extracted Text", interactive=False)

        # Functionality to trigger the OCR and extraction
        img_button = gr.Button("Extract Text")
        img_button.click(fn=ocr_and_extract, inputs=img_input, outputs=extracted_output)

    # Keyword search section
    with gr.Column():
        search_input = gr.Textbox(label="Enter keywords to search")
        search_output = gr.HTML(label="Search Results")

        # Functionality to search within the extracted text
        search_button = gr.Button("Search")
        search_button.click(fn=search_keywords, inputs=search_input, outputs=search_output)

iface.launch()