import streamlit as st import torch from PIL import Image import fitz # PyMuPDF from transformers.utils.import_utils import is_flash_attn_2_available from colpali_engine.models import ColQwen2, ColQwen2Processor # ----------------------------- # Load ColPali Model # ----------------------------- @st.cache_resource def load_colpali(): model_name = "vidore/colqwen2-v1.0" model = ColQwen2.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cuda:0" if torch.cuda.is_available() else "cpu", attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, ).eval() processor = ColQwen2Processor.from_pretrained(model_name) return model, processor colpali_model, colpali_processor = load_colpali() st.title("🔍 Visual PDF Search with ColPali") pdf_file = st.file_uploader("Upload a PDF", type="pdf") # ----------------------------- # Convert PDF to image # ----------------------------- def render_pdf_page_as_image(doc, zoom=2.0): images = [] for page in doc: mat = fitz.Matrix(zoom, zoom) pix = page.get_pixmap(matrix=mat) img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) images.append(img) return images # ----------------------------- # Chunk image into pieces # ----------------------------- def chunk_image(image, rows=3, cols=3): width, height = image.size chunk_width = width // cols chunk_height = height // rows chunks = [] for r in range(rows): for c in range(cols): left = c * chunk_width top = r * chunk_height right = left + chunk_width bottom = top + chunk_height chunk = image.crop((left, top, right, bottom)).resize((512, 512)) chunks.append(chunk) return chunks if pdf_file: doc = fitz.open(stream=pdf_file.read(), filetype="pdf") images = render_pdf_page_as_image(doc) if not images: st.warning("Failed to read content from the PDF.") else: all_chunks = [] for image in images: all_chunks.extend(chunk_image(image, rows=2, cols=2)) user_query = st.text_input("What are you looking for in the document?") if user_query: batch_images = colpali_processor.process_images(all_chunks).to(colpali_model.device) batch_queries = colpali_processor.process_queries([user_query]).to(colpali_model.device) with torch.no_grad(): image_embeddings = colpali_model(**batch_images) query_embeddings = colpali_model(**batch_queries) scores = colpali_processor.score_multi_vector(query_embeddings, image_embeddings) best_idx = torch.argmax(scores).item() best_image = all_chunks[best_idx] best_score = scores[0, best_idx].item() st.markdown("### 🔍 Most Relevant Image Chunk") st.image(best_image, caption=f"Score: {best_score:.4f}", use_column_width=True)