File size: 4,415 Bytes
f21911e
5b9baff
 
 
419d02f
 
 
 
 
 
 
 
 
f21911e
 
 
419d02f
f21911e
 
 
419d02f
f21911e
 
 
 
0d4b0fc
 
f21911e
 
 
 
a01cae7
f21911e
 
 
 
 
 
 
 
a01cae7
 
419d02f
 
 
f21911e
419d02f
f21911e
fabf362
f21911e
 
 
 
 
 
 
fabf362
 
 
f21911e
 
fabf362
2ebc710
f21911e
 
2ebc710
f21911e
 
 
 
ab9088f
fabf362
f21911e
5b9baff
f21911e
 
ab9088f
f21911e
fabf362
f21911e
 
 
 
 
 
 
 
419d02f
f21911e
419d02f
 
 
 
 
 
f21911e
419d02f
f21911e
 
419d02f
 
5a0deb7
f21911e
419d02f
ab9088f
f21911e
 
5b9baff
f21911e
5b9baff
 
 
b701d44
 
5b9baff
 
b701d44
5b9baff
 
 
419d02f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os, json, base64
from io import BytesIO
from PIL import Image
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from transformers import (
    AutoProcessor,
    LayoutLMv3Model,
    T5ForConditionalGeneration,
    AutoTokenizer
)

# ── 1) CONFIG & CHECKPOINT ────────────────────────────────────────────────
HF_REPO     = "Uddipan107/ocr-layoutlmv3-base-t5-small"
CKPT_NAME   = "pytorch_model.bin"

# 1a) Download the checkpoint dict from your Hub
ckpt_path   = hf_hub_download(repo_id=HF_REPO, filename=CKPT_NAME)
ckpt        = torch.load(ckpt_path, map_location="cpu")

# ── 2) BUILD MODELS ───────────────────────────────────────────────────────
# 2a) Processor for LayoutLMv3
processor   = AutoProcessor.from_pretrained(
    "microsoft/layoutlmv3-base", apply_ocr=False
)

# 2b) LayoutLMv3 encoder
layout_model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base")
layout_model.load_state_dict(ckpt["layout_model"], strict=False)
layout_model.eval().to("cpu")

# 2c) T5 decoder + tokenizer
t5_model    = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_model.load_state_dict(ckpt["t5_model"], strict=False)
t5_model.eval().to("cpu")

tokenizer   = AutoTokenizer.from_pretrained("t5-small")

# 2d) Projection head
proj_state  = ckpt["projection"]
projection  = torch.nn.Sequential(
    torch.nn.Linear(768, t5_model.config.d_model),
    torch.nn.LayerNorm(t5_model.config.d_model),
    torch.nn.GELU()
)
projection.load_state_dict(proj_state)
projection.eval().to("cpu")

# 2e) Ensure we have a valid start token for generation
if t5_model.config.decoder_start_token_id is None:
    t5_model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.pad_token_id
if t5_model.config.bos_token_id is None:
    t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id

# ── 3) INFERENCE ─────────────────────────────────────────────────────────
def infer(image_path, json_file):
    img_name = os.path.basename(image_path)

    # 3a) Read the uploaded NDJSON & find the matching record
    entry = None
    with open(json_file.name, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line: 
                continue
            obj = json.loads(line)
            if obj.get("img_name") == img_name:
                entry = obj
                break

    if entry is None:
        return f"❌ No JSON entry for: {img_name}"

    words = entry["src_word_list"]
    boxes = entry["src_wordbox_list"]

    # 3b) Preprocess: image + OCR tokens + boxes
    img = Image.open(image_path).convert("RGB")
    enc = processor([img], [words], boxes=[boxes],
                    return_tensors="pt", padding=True, truncation=True)
    pixel_values   = enc.pixel_values.to("cpu")
    input_ids      = enc.input_ids.to("cpu")
    attention_mask = enc.attention_mask.to("cpu")
    bbox           = enc.bbox.to("cpu")

    # 3c) Forward pass
    with torch.no_grad():
        out = layout_model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            bbox=bbox
        )
        seq_len    = input_ids.size(1)
        text_feats = out.last_hidden_state[:, :seq_len, :]
        proj_feats = projection(text_feats)

        gen_ids = t5_model.generate(
            inputs_embeds=proj_feats,
            attention_mask=attention_mask,
            max_length=512,
            decoder_start_token_id=t5_model.config.decoder_start_token_id
        )

    # 3d) Decode & return
    return tokenizer.decode(gen_ids[0], skip_special_tokens=True)

# ── 4) GRADIO APP ────────────────────────────────────────────────────────
demo = gr.Interface(
    fn=infer,
    inputs=[
        gr.Image(type="filepath", label="Upload Image"),
        gr.File(label="Upload JSON (NDJSON)")
    ],
    outputs="text",
    title="OCR Reorder Pipeline"
)

if __name__ == "__main__":
    demo.launch(share=True)