Spaces:
Running
Running
File size: 4,415 Bytes
f21911e 5b9baff 419d02f f21911e 419d02f f21911e 419d02f f21911e 0d4b0fc f21911e a01cae7 f21911e a01cae7 419d02f f21911e 419d02f f21911e fabf362 f21911e fabf362 f21911e fabf362 2ebc710 f21911e 2ebc710 f21911e ab9088f fabf362 f21911e 5b9baff f21911e ab9088f f21911e fabf362 f21911e 419d02f f21911e 419d02f f21911e 419d02f f21911e 419d02f 5a0deb7 f21911e 419d02f ab9088f f21911e 5b9baff f21911e 5b9baff b701d44 5b9baff b701d44 5b9baff 419d02f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os, json, base64
from io import BytesIO
from PIL import Image
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from transformers import (
AutoProcessor,
LayoutLMv3Model,
T5ForConditionalGeneration,
AutoTokenizer
)
# ββ 1) CONFIG & CHECKPOINT ββββββββββββββββββββββββββββββββββββββββββββββββ
HF_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small"
CKPT_NAME = "pytorch_model.bin"
# 1a) Download the checkpoint dict from your Hub
ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CKPT_NAME)
ckpt = torch.load(ckpt_path, map_location="cpu")
# ββ 2) BUILD MODELS βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2a) Processor for LayoutLMv3
processor = AutoProcessor.from_pretrained(
"microsoft/layoutlmv3-base", apply_ocr=False
)
# 2b) LayoutLMv3 encoder
layout_model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base")
layout_model.load_state_dict(ckpt["layout_model"], strict=False)
layout_model.eval().to("cpu")
# 2c) T5 decoder + tokenizer
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_model.load_state_dict(ckpt["t5_model"], strict=False)
t5_model.eval().to("cpu")
tokenizer = AutoTokenizer.from_pretrained("t5-small")
# 2d) Projection head
proj_state = ckpt["projection"]
projection = torch.nn.Sequential(
torch.nn.Linear(768, t5_model.config.d_model),
torch.nn.LayerNorm(t5_model.config.d_model),
torch.nn.GELU()
)
projection.load_state_dict(proj_state)
projection.eval().to("cpu")
# 2e) Ensure we have a valid start token for generation
if t5_model.config.decoder_start_token_id is None:
t5_model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.pad_token_id
if t5_model.config.bos_token_id is None:
t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id
# ββ 3) INFERENCE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def infer(image_path, json_file):
img_name = os.path.basename(image_path)
# 3a) Read the uploaded NDJSON & find the matching record
entry = None
with open(json_file.name, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
obj = json.loads(line)
if obj.get("img_name") == img_name:
entry = obj
break
if entry is None:
return f"β No JSON entry for: {img_name}"
words = entry["src_word_list"]
boxes = entry["src_wordbox_list"]
# 3b) Preprocess: image + OCR tokens + boxes
img = Image.open(image_path).convert("RGB")
enc = processor([img], [words], boxes=[boxes],
return_tensors="pt", padding=True, truncation=True)
pixel_values = enc.pixel_values.to("cpu")
input_ids = enc.input_ids.to("cpu")
attention_mask = enc.attention_mask.to("cpu")
bbox = enc.bbox.to("cpu")
# 3c) Forward pass
with torch.no_grad():
out = layout_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
bbox=bbox
)
seq_len = input_ids.size(1)
text_feats = out.last_hidden_state[:, :seq_len, :]
proj_feats = projection(text_feats)
gen_ids = t5_model.generate(
inputs_embeds=proj_feats,
attention_mask=attention_mask,
max_length=512,
decoder_start_token_id=t5_model.config.decoder_start_token_id
)
# 3d) Decode & return
return tokenizer.decode(gen_ids[0], skip_special_tokens=True)
# ββ 4) GRADIO APP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=infer,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.File(label="Upload JSON (NDJSON)")
],
outputs="text",
title="OCR Reorder Pipeline"
)
if __name__ == "__main__":
demo.launch(share=True)
|