File size: 2,743 Bytes
5b9baff
 
 
 
 
0cfc73f
 
93dce4d
0cfc73f
5b9baff
 
 
 
 
0cfc73f
 
 
 
 
 
93dce4d
 
5b9baff
93dce4d
 
5b9baff
 
 
 
 
 
4956f20
 
 
 
5b9baff
 
4956f20
5b9baff
 
 
 
 
 
 
 
 
 
a63ba49
5b9baff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from transformers import Pipeline
from PIL import Image
import base64
from io import BytesIO
from huggingface_hub import hf_hub_download

# HF model repo containing pytorch_model.bin with 'projection' state
HF_MODEL_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small"

class OcrReorderPipeline(Pipeline):
    def __init__(self, model, tokenizer, processor, device=0):
        super().__init__(model=model, tokenizer=tokenizer,
                         feature_extractor=processor, device=device)

        # ── Download your fine-tuned checkpoint ───────────────────────────
        ckpt_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename="pytorch_model.bin")
        ckpt      = torch.load(ckpt_path, map_location="cpu")
        proj_state= ckpt["projection"]

        # ── Rebuild & load your projection head (T5-small hidden size = 512) ─
        d_model = 512
        self.projection = torch.nn.Sequential(
            torch.nn.Linear(768, d_model),
            torch.nn.LayerNorm(d_model),
            torch.nn.GELU()
        )
        self.projection.load_state_dict(proj_state)
        self.projection.to(self.device)

    def _sanitize_parameters(self, **kwargs):
        # Extract only the custom args for preprocess; 'inputs' (the image) is passed positionally
        words = kwargs.get("words", None)
        boxes = kwargs.get("boxes", None)
        return {"words": words, "boxes": boxes}, {}, {}

    def preprocess(self, image, words, boxes):
        # 'image' comes from the positional 'inputs' argument
        data = base64.b64decode(image)
        img  = Image.open(BytesIO(data)).convert("RGB")
        return self.feature_extractor(
            [img], [words], boxes=[boxes],
            return_tensors="pt", padding=True, truncation=True
        )

    def _forward(self, model_inputs):
        pv, ids, mask, bbox = (
            model_inputs[k].to(self.device)
            for k in ("pixel_values", "input_ids", "attention_mask", "bbox")
        )
        vision_out = self.model.vision_model(
            pixel_values=pv,
            input_ids=ids,
            attention_mask=mask,
            bbox=bbox
        )
        seq_len    = ids.size(1)
        text_feats = vision_out.last_hidden_state[:, :seq_len, :]
        proj_feats = self.projection(text_feats)
        gen_ids = self.model.text_model.generate(
            inputs_embeds=proj_feats,
            attention_mask=mask,
            max_length=512
        )
        return {"generated_ids": gen_ids}

    def postprocess(self, model_outputs):
        return self.tokenizer.batch_decode(
            model_outputs["generated_ids"],
            skip_special_tokens=True
        )