File size: 5,458 Bytes
120db54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
OLM-CLLM OCR – Gradio Space
Upload any PDF ➜ get clean, linearised text.

πŸš€  Model: allenai/olmOCR-7B-0225-preview
πŸ”§  Prompts / render helpers come from the `olmocr` toolkit
"""

import json, base64, tempfile, os, gc
from io import BytesIO

import gradio as gr
import torch
from PIL import Image
from pypdf import PdfReader

from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png          # page β†’ base64 PNG
from olmocr.prompts.anchor import get_anchor_text                  # page β†’ anchor text
from olmocr.prompts import build_finetuning_prompt                 # anchor β†’ final prompt

# ---------- 1. Model & processor (load once, then stay in memory) ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Qwen2VLForConditionalGeneration.from_pretrained(
    "allenai/olmOCR-7B-0225-preview",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
).to(device).eval()

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

# ---------- 2. Utility ------------------------------------------------------
def _decode_llm_json(raw_str: str) -> str:
    """
    olmOCR returns a JSON string like:
    {
        "primary_language": "...",
        ...
        "natural_text": "THE ACTUAL PAGE TEXT"
    }
    Pull out the `natural_text` field; fall back to raw string if parsing fails.
    """
    try:
        page_json = json.loads(raw_str.strip())
        return page_json.get("natural_text") or ""
    except Exception:
        return raw_str.strip()

# ---------- 3. Core pipeline ------------------------------------------------
def pdf_to_text(pdf_file):
    """
    β€’ Save uploaded file to a temp path (toolkit needs a real path)
    β€’ Iterate over pages
    β€’ For each page:
        – render page image β†’ base64
        – generate anchor text in-page
        – build prompt (+ image) and run the model
        – collect `natural_text`
    β€’ Return merged text
    """

    if pdf_file is None:
        return "⬆️ Please upload a PDF first."

    with tempfile.TemporaryDirectory() as tmpdir:
        local_pdf_path = os.path.join(tmpdir, "input.pdf")
        with open(local_pdf_path, "wb") as f:
            f.write(pdf_file.read())

        reader = PdfReader(local_pdf_path)
        n_pages = len(reader.pages)

        extracted_pages = []

        for page_idx in range(1, n_pages + 1):          # 1-indexed
            # a. Image
            img_b64 = render_pdf_to_base64png(
                local_pdf_path, page_idx, target_longest_image_dim=1024
            )
            page_image = Image.open(BytesIO(base64.b64decode(img_b64)))

            # b. Anchor text & prompt
            anchor = get_anchor_text(
                local_pdf_path,
                page_idx,
                pdf_engine="pdfreport",                 # uses pypdf / pdfium, no Poppler dependency
                target_length=4000,
            )
            prompt = build_finetuning_prompt(anchor)

            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text",  "text": prompt},
                        {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}},
                    ],
                }
            ]

            # c. Tokenise + generate
            text_in = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = processor(text=[text_in], images=[page_image], return_tensors="pt", padding=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            with torch.no_grad():
                gen = model.generate(
                    **inputs,
                    temperature=0.2,
                    max_new_tokens=512,
                    do_sample=False,
                )

            prompt_len = inputs["input_ids"].shape[1]
            new_tokens = gen[:, prompt_len:]
            raw_out   = processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0]

            extracted_pages.append(_decode_llm_json(raw_out))

            # optional memory clean-up per page
            del inputs, gen
            gc.collect()
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

    return "\n\n".join(extracted_pages) or "πŸ€” Nothing returned."

# ---------- 4. Gradio UI ----------------------------------------------------
with gr.Blocks(title="olmOCR 7B PDF Extractor") as demo:
    gr.Markdown(
        """
        # 🧠 **OLM-CLLM OCR**  
        Upload a PDF → get high-quality, linearised text (tables β†’ Markdown, equations β†’ LaTeX).  
        Fine-tuned Vision-LLM: **allenai/olmOCR-7B-0225-preview**.
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            up = gr.File(label="πŸ“„ Upload PDF", file_types=[".pdf"])
            go = gr.Button("Extract Text", variant="primary", size="lg")
        with gr.Column(scale=2):
            out = gr.Textbox(
                label="πŸ“œ Extracted text",
                lines=25,
                interactive=False,
                show_copy_button=True,
            )

    go.click(pdf_to_text, inputs=up, outputs=out)

# ---------- 5. Launch locally (Space will ignore this) ----------------------
if __name__ == "__main__":
    demo.launch()