File size: 3,032 Bytes
ef07835
6387027
 
4c8d2d0
6387027
 
ef07835
6387027
4c8d2d0
6387027
 
4c8d2d0
 
 
 
6387027
 
4c8d2d0
6387027
4c8d2d0
 
6387027
4c8d2d0
6387027
4c8d2d0
 
6387027
 
4c8d2d0
6387027
4c8d2d0
6387027
 
4c8d2d0
 
 
 
 
6387027
4c8d2d0
 
 
23c1839
4c8d2d0
 
 
6387027
4c8d2d0
 
 
 
 
 
 
 
 
 
6387027
4c8d2d0
 
 
6387027
4c8d2d0
 
 
 
 
 
6387027
4c8d2d0
6387027
4c8d2d0
 
 
6387027
 
4c8d2d0
 
 
 
 
 
 
6387027
4c8d2d0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import streamlit as st
import torch
from PIL import Image
import fitz  # PyMuPDF
from transformers.utils.import_utils import is_flash_attn_2_available
from colpali_engine.models import ColQwen2, ColQwen2Processor

# -----------------------------
# Load ColPali Model
# -----------------------------
@st.cache_resource
def load_colpali():
    model_name = "vidore/colqwen2-v1.0"
    model = ColQwen2.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="cuda:0" if torch.cuda.is_available() else "cpu",
        attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
    ).eval()
    processor = ColQwen2Processor.from_pretrained(model_name)
    return model, processor

colpali_model, colpali_processor = load_colpali()

st.title("πŸ” Visual PDF Search with ColPali")
pdf_file = st.file_uploader("Upload a PDF", type="pdf")

# -----------------------------
# Convert PDF to image
# -----------------------------
def render_pdf_page_as_image(doc, zoom=2.0):
    images = []
    for page in doc:
        mat = fitz.Matrix(zoom, zoom)
        pix = page.get_pixmap(matrix=mat)
        img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
        images.append(img)
    return images

# -----------------------------
# Chunk image into pieces
# -----------------------------
def chunk_image(image, rows=3, cols=3):
    width, height = image.size
    chunk_width = width // cols
    chunk_height = height // rows

    chunks = []
    for r in range(rows):
        for c in range(cols):
            left = c * chunk_width
            top = r * chunk_height
            right = left + chunk_width
            bottom = top + chunk_height
            chunk = image.crop((left, top, right, bottom)).resize((512, 512))
            chunks.append(chunk)
    return chunks

if pdf_file:
    doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
    images = render_pdf_page_as_image(doc)

    if not images:
        st.warning("Failed to read content from the PDF.")
    else:
        all_chunks = []
        for image in images:
            all_chunks.extend(chunk_image(image, rows=2, cols=2))

        user_query = st.text_input("What are you looking for in the document?")

        if user_query:
            batch_images = colpali_processor.process_images(all_chunks).to(colpali_model.device)
            batch_queries = colpali_processor.process_queries([user_query]).to(colpali_model.device)

            with torch.no_grad():
                image_embeddings = colpali_model(**batch_images)
                query_embeddings = colpali_model(**batch_queries)

            scores = colpali_processor.score_multi_vector(query_embeddings, image_embeddings)
            best_idx = torch.argmax(scores).item()
            best_image = all_chunks[best_idx]
            best_score = scores[0, best_idx].item()

            st.markdown("### πŸ” Most Relevant Image Chunk")
            st.image(best_image, caption=f"Score: {best_score:.4f}", use_column_width=True)