Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from PIL import Image | |
import fitz # PyMuPDF | |
from transformers.utils.import_utils import is_flash_attn_2_available | |
from colpali_engine.models import ColQwen2, ColQwen2Processor | |
# ----------------------------- | |
# Load ColPali Model | |
# ----------------------------- | |
def load_colpali(): | |
model_name = "vidore/colqwen2-v1.0" | |
model = ColQwen2.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="cuda:0" if torch.cuda.is_available() else "cpu", | |
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, | |
).eval() | |
processor = ColQwen2Processor.from_pretrained(model_name) | |
return model, processor | |
colpali_model, colpali_processor = load_colpali() | |
st.title("π Visual PDF Search with ColPali") | |
pdf_file = st.file_uploader("Upload a PDF", type="pdf") | |
# ----------------------------- | |
# Convert PDF to image | |
# ----------------------------- | |
def render_pdf_page_as_image(doc, zoom=2.0): | |
images = [] | |
for page in doc: | |
mat = fitz.Matrix(zoom, zoom) | |
pix = page.get_pixmap(matrix=mat) | |
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
images.append(img) | |
return images | |
# ----------------------------- | |
# Chunk image into pieces | |
# ----------------------------- | |
def chunk_image(image, rows=3, cols=3): | |
width, height = image.size | |
chunk_width = width // cols | |
chunk_height = height // rows | |
chunks = [] | |
for r in range(rows): | |
for c in range(cols): | |
left = c * chunk_width | |
top = r * chunk_height | |
right = left + chunk_width | |
bottom = top + chunk_height | |
chunk = image.crop((left, top, right, bottom)).resize((512, 512)) | |
chunks.append(chunk) | |
return chunks | |
if pdf_file: | |
doc = fitz.open(stream=pdf_file.read(), filetype="pdf") | |
images = render_pdf_page_as_image(doc) | |
if not images: | |
st.warning("Failed to read content from the PDF.") | |
else: | |
all_chunks = [] | |
for image in images: | |
all_chunks.extend(chunk_image(image, rows=2, cols=2)) | |
user_query = st.text_input("What are you looking for in the document?") | |
if user_query: | |
batch_images = colpali_processor.process_images(all_chunks).to(colpali_model.device) | |
batch_queries = colpali_processor.process_queries([user_query]).to(colpali_model.device) | |
with torch.no_grad(): | |
image_embeddings = colpali_model(**batch_images) | |
query_embeddings = colpali_model(**batch_queries) | |
scores = colpali_processor.score_multi_vector(query_embeddings, image_embeddings) | |
best_idx = torch.argmax(scores).item() | |
best_image = all_chunks[best_idx] | |
best_score = scores[0, best_idx].item() | |
st.markdown("### π Most Relevant Image Chunk") | |
st.image(best_image, caption=f"Score: {best_score:.4f}", use_column_width=True) | |