financial_agent / app.py
ignaciaginting's picture
testing 3 x 3 chunk
23c1839 verified
import streamlit as st
import torch
from PIL import Image
import fitz # PyMuPDF
from transformers.utils.import_utils import is_flash_attn_2_available
from colpali_engine.models import ColQwen2, ColQwen2Processor
# -----------------------------
# Load ColPali Model
# -----------------------------
@st.cache_resource
def load_colpali():
model_name = "vidore/colqwen2-v1.0"
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="cuda:0" if torch.cuda.is_available() else "cpu",
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
return model, processor
colpali_model, colpali_processor = load_colpali()
st.title("πŸ” Visual PDF Search with ColPali")
pdf_file = st.file_uploader("Upload a PDF", type="pdf")
# -----------------------------
# Convert PDF to image
# -----------------------------
def render_pdf_page_as_image(doc, zoom=2.0):
images = []
for page in doc:
mat = fitz.Matrix(zoom, zoom)
pix = page.get_pixmap(matrix=mat)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
images.append(img)
return images
# -----------------------------
# Chunk image into pieces
# -----------------------------
def chunk_image(image, rows=3, cols=3):
width, height = image.size
chunk_width = width // cols
chunk_height = height // rows
chunks = []
for r in range(rows):
for c in range(cols):
left = c * chunk_width
top = r * chunk_height
right = left + chunk_width
bottom = top + chunk_height
chunk = image.crop((left, top, right, bottom)).resize((512, 512))
chunks.append(chunk)
return chunks
if pdf_file:
doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
images = render_pdf_page_as_image(doc)
if not images:
st.warning("Failed to read content from the PDF.")
else:
all_chunks = []
for image in images:
all_chunks.extend(chunk_image(image, rows=2, cols=2))
user_query = st.text_input("What are you looking for in the document?")
if user_query:
batch_images = colpali_processor.process_images(all_chunks).to(colpali_model.device)
batch_queries = colpali_processor.process_queries([user_query]).to(colpali_model.device)
with torch.no_grad():
image_embeddings = colpali_model(**batch_images)
query_embeddings = colpali_model(**batch_queries)
scores = colpali_processor.score_multi_vector(query_embeddings, image_embeddings)
best_idx = torch.argmax(scores).item()
best_image = all_chunks[best_idx]
best_score = scores[0, best_idx].item()
st.markdown("### πŸ” Most Relevant Image Chunk")
st.image(best_image, caption=f"Score: {best_score:.4f}", use_column_width=True)