olmOcR_grinda / app.py
ibrahim313's picture
Create app.py
120db54 verified
"""
OLM-CLLM OCR – Gradio Space
Upload any PDF ➜ get clean, linearised text.
πŸš€ Model: allenai/olmOCR-7B-0225-preview
πŸ”§ Prompts / render helpers come from the `olmocr` toolkit
"""
import json, base64, tempfile, os, gc
from io import BytesIO
import gradio as gr
import torch
from PIL import Image
from pypdf import PdfReader
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png # page β†’ base64 PNG
from olmocr.prompts.anchor import get_anchor_text # page β†’ anchor text
from olmocr.prompts import build_finetuning_prompt # anchor β†’ final prompt
# ---------- 1. Model & processor (load once, then stay in memory) ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Qwen2VLForConditionalGeneration.from_pretrained(
"allenai/olmOCR-7B-0225-preview",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
).to(device).eval()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
# ---------- 2. Utility ------------------------------------------------------
def _decode_llm_json(raw_str: str) -> str:
"""
olmOCR returns a JSON string like:
{
"primary_language": "...",
...
"natural_text": "THE ACTUAL PAGE TEXT"
}
Pull out the `natural_text` field; fall back to raw string if parsing fails.
"""
try:
page_json = json.loads(raw_str.strip())
return page_json.get("natural_text") or ""
except Exception:
return raw_str.strip()
# ---------- 3. Core pipeline ------------------------------------------------
def pdf_to_text(pdf_file):
"""
β€’ Save uploaded file to a temp path (toolkit needs a real path)
β€’ Iterate over pages
β€’ For each page:
– render page image β†’ base64
– generate anchor text in-page
– build prompt (+ image) and run the model
– collect `natural_text`
β€’ Return merged text
"""
if pdf_file is None:
return "⬆️ Please upload a PDF first."
with tempfile.TemporaryDirectory() as tmpdir:
local_pdf_path = os.path.join(tmpdir, "input.pdf")
with open(local_pdf_path, "wb") as f:
f.write(pdf_file.read())
reader = PdfReader(local_pdf_path)
n_pages = len(reader.pages)
extracted_pages = []
for page_idx in range(1, n_pages + 1): # 1-indexed
# a. Image
img_b64 = render_pdf_to_base64png(
local_pdf_path, page_idx, target_longest_image_dim=1024
)
page_image = Image.open(BytesIO(base64.b64decode(img_b64)))
# b. Anchor text & prompt
anchor = get_anchor_text(
local_pdf_path,
page_idx,
pdf_engine="pdfreport", # uses pypdf / pdfium, no Poppler dependency
target_length=4000,
)
prompt = build_finetuning_prompt(anchor)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}},
],
}
]
# c. Tokenise + generate
text_in = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text_in], images=[page_image], return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
gen = model.generate(
**inputs,
temperature=0.2,
max_new_tokens=512,
do_sample=False,
)
prompt_len = inputs["input_ids"].shape[1]
new_tokens = gen[:, prompt_len:]
raw_out = processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0]
extracted_pages.append(_decode_llm_json(raw_out))
# optional memory clean-up per page
del inputs, gen
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return "\n\n".join(extracted_pages) or "πŸ€” Nothing returned."
# ---------- 4. Gradio UI ----------------------------------------------------
with gr.Blocks(title="olmOCR 7B PDF Extractor") as demo:
gr.Markdown(
"""
# 🧠 **OLM-CLLM OCR**
Upload a PDF → get high-quality, linearised text (tables β†’ Markdown, equations β†’ LaTeX).
Fine-tuned Vision-LLM: **allenai/olmOCR-7B-0225-preview**.
"""
)
with gr.Row():
with gr.Column(scale=1):
up = gr.File(label="πŸ“„ Upload PDF", file_types=[".pdf"])
go = gr.Button("Extract Text", variant="primary", size="lg")
with gr.Column(scale=2):
out = gr.Textbox(
label="πŸ“œ Extracted text",
lines=25,
interactive=False,
show_copy_button=True,
)
go.click(pdf_to_text, inputs=up, outputs=out)
# ---------- 5. Launch locally (Space will ignore this) ----------------------
if __name__ == "__main__":
demo.launch()