File size: 4,296 Bytes
8eae3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st
import os
import torch
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

# Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Caching the model loading
@st.cache_resource
def load_rag_model():
    return RAGMultiModalModel.from_pretrained("vidore/colpali")

@st.cache_resource
def load_qwen_model():
    return Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-2B-Instruct",
        trust_remote_code=True,
        torch_dtype=torch.bfloat16
    ).to(device).eval()

@st.cache_resource
def load_processor():
    return AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)

# Load models
RAG = load_rag_model()
model = load_qwen_model()
processor = load_processor()

st.title("Multimodal RAG App")

st.warning("⚠️ Disclaimer: This app is currently running on CPU, which may result in slow processing times. For optimal performance, download and run the app locally on a machine with GPU support.")

# Add download link
st.markdown("[📥 Download the app code](https://huggingface.co/spaces/clayton07/colpali-qwen2-ocr/blob/main/app.py)")

# Initialize session state for tracking if index is created
if 'index_created' not in st.session_state:
    st.session_state.index_created = False

# File uploader
image_source = st.radio("Choose image source:", ("Upload an image", "Use example image"))

if image_source == "Upload an image":
    uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
else:
    # Use a pre-defined example image
    example_image_path = "hindi-qp.jpg"
    uploaded_file = example_image_path

if uploaded_file is not None:
    # If using the example image, no need to save it
    if image_source == "Upload an image":
        with open("temp_image.png", "wb") as f:
            f.write(uploaded_file.getvalue())
        image_path = "temp_image.png"
    else:
        image_path = uploaded_file

    if not st.session_state.index_created:
        # Initialize the index for the first image
        RAG.index(
            input_path=image_path,
            index_name="temp_index",
            store_collection_with_index=False,
            overwrite=True
        )
        st.session_state.index_created = True
    else:
        # Add to the existing index for subsequent images
        RAG.add_to_index(
            input_item=image_path,
            store_collection_with_index=False
        )

    st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)

    # Text query input
    text_query = st.text_input("Enter your query about the image:")

    if text_query:
        # Perform RAG search
        results = RAG.search(text_query, k=2)

        # Process with Qwen2VL model
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image_path,
                    },
                    {"type": "text", "text": text_query},
                ],
            }
        ]
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(device)
        generated_ids = model.generate(**inputs, max_new_tokens=100)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        # Display results
        st.subheader("Results:")
        st.write(output_text[0])

    # Clean up temporary file
    if image_source == "Upload an image":
        os.remove("temp_image.png")
else:
    st.write("Please upload an image to get started.")