ocr-reorder-space / inference.py
Uddipan Basu Bir
Download checkpoint from HF hub in OcrReorderPipeline
4956f20
raw
history blame
2.74 kB
import torch
from transformers import Pipeline
from PIL import Image
import base64
from io import BytesIO
from huggingface_hub import hf_hub_download
# HF model repo containing pytorch_model.bin with 'projection' state
HF_MODEL_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small"
class OcrReorderPipeline(Pipeline):
def __init__(self, model, tokenizer, processor, device=0):
super().__init__(model=model, tokenizer=tokenizer,
feature_extractor=processor, device=device)
# ── Download your fine-tuned checkpoint ───────────────────────────
ckpt_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename="pytorch_model.bin")
ckpt = torch.load(ckpt_path, map_location="cpu")
proj_state= ckpt["projection"]
# ── Rebuild & load your projection head (T5-small hidden size = 512) ─
d_model = 512
self.projection = torch.nn.Sequential(
torch.nn.Linear(768, d_model),
torch.nn.LayerNorm(d_model),
torch.nn.GELU()
)
self.projection.load_state_dict(proj_state)
self.projection.to(self.device)
def _sanitize_parameters(self, **kwargs):
# Extract only the custom args for preprocess; 'inputs' (the image) is passed positionally
words = kwargs.get("words", None)
boxes = kwargs.get("boxes", None)
return {"words": words, "boxes": boxes}, {}, {}
def preprocess(self, image, words, boxes):
# 'image' comes from the positional 'inputs' argument
data = base64.b64decode(image)
img = Image.open(BytesIO(data)).convert("RGB")
return self.feature_extractor(
[img], [words], boxes=[boxes],
return_tensors="pt", padding=True, truncation=True
)
def _forward(self, model_inputs):
pv, ids, mask, bbox = (
model_inputs[k].to(self.device)
for k in ("pixel_values", "input_ids", "attention_mask", "bbox")
)
vision_out = self.model.vision_model(
pixel_values=pv,
input_ids=ids,
attention_mask=mask,
bbox=bbox
)
seq_len = ids.size(1)
text_feats = vision_out.last_hidden_state[:, :seq_len, :]
proj_feats = self.projection(text_feats)
gen_ids = self.model.text_model.generate(
inputs_embeds=proj_feats,
attention_mask=mask,
max_length=512
)
return {"generated_ids": gen_ids}
def postprocess(self, model_outputs):
return self.tokenizer.batch_decode(
model_outputs["generated_ids"],
skip_special_tokens=True
)