ocr-reorder-space / inference.py
Uddipan Basu Bir
Download checkpoint from HF hub in OcrReorderPipeline
0cfc73f
raw
history blame
2.49 kB
import torch
from transformers import Pipeline
from PIL import Image
import base64
from io import BytesIO
from huggingface_hub import hf_hub_download
# point at your HF model repo
HF_MODEL_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small"
class OcrReorderPipeline(Pipeline):
def __init__(self, model, tokenizer, processor, device=0):
super().__init__(model=model, tokenizer=tokenizer,
feature_extractor=processor, device=device)
# ── Download your fine-tuned checkpoint ───────────────────────────
ckpt_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename="pytorch_model.bin")
ckpt = torch.load(ckpt_path, map_location="cpu")
proj_state= ckpt["projection"]
# ── Rebuild & load your projection head ────────────────────────────
self.projection = torch.nn.Sequential(
torch.nn.Linear(768, model.config.d_model),
torch.nn.LayerNorm(model.config.d_model),
torch.nn.GELU()
)
self.projection.load_state_dict(proj_state)
self.projection.to(self.device)
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def preprocess(self, image, words, boxes):
data = base64.b64decode(image)
img = Image.open(BytesIO(data)).convert("RGB")
return self.feature_extractor(
[img], [words], boxes=[boxes],
return_tensors="pt", padding=True, truncation=True
)
def _forward(self, model_inputs):
pv, ids, mask, bbox = (
model_inputs[k].to(self.device)
for k in ("pixel_values", "input_ids", "attention_mask", "bbox")
)
vision_out = self.model.vision_model(
pixel_values=pv,
input_ids=ids,
attention_mask=mask,
bbox=bbox
)
seq_len = ids.size(1)
text_feats = vision_out.last_hidden_state[:, :seq_len, :]
proj_feats = self.projection(text_feats)
gen_ids = self.model.text_model.generate(
inputs_embeds=proj_feats,
attention_mask=mask,
max_length=512
)
return {"generated_ids": gen_ids}
def postprocess(self, model_outputs):
return self.tokenizer.batch_decode(
model_outputs["generated_ids"],
skip_special_tokens=True
)