Uddipan Basu Bir commited on
Commit
93dce4d
Β·
1 Parent(s): 0cfc73f

Download checkpoint from HF hub in OcrReorderPipeline

Browse files
Files changed (1) hide show
  1. inference.py +6 -8
inference.py CHANGED
@@ -5,7 +5,7 @@ import base64
5
  from io import BytesIO
6
  from huggingface_hub import hf_hub_download
7
 
8
- # point at your HF model repo
9
  HF_MODEL_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small"
10
 
11
  class OcrReorderPipeline(Pipeline):
@@ -18,10 +18,11 @@ class OcrReorderPipeline(Pipeline):
18
  ckpt = torch.load(ckpt_path, map_location="cpu")
19
  proj_state= ckpt["projection"]
20
 
21
- # ── Rebuild & load your projection head ────────────────────────────
 
22
  self.projection = torch.nn.Sequential(
23
- torch.nn.Linear(768, model.config.d_model),
24
- torch.nn.LayerNorm(model.config.d_model),
25
  torch.nn.GELU()
26
  )
27
  self.projection.load_state_dict(proj_state)
@@ -41,20 +42,17 @@ class OcrReorderPipeline(Pipeline):
41
  def _forward(self, model_inputs):
42
  pv, ids, mask, bbox = (
43
  model_inputs[k].to(self.device)
44
- for k in ("pixel_values", "input_ids", "attention_mask", "bbox")
45
  )
46
-
47
  vision_out = self.model.vision_model(
48
  pixel_values=pv,
49
  input_ids=ids,
50
  attention_mask=mask,
51
  bbox=bbox
52
  )
53
-
54
  seq_len = ids.size(1)
55
  text_feats = vision_out.last_hidden_state[:, :seq_len, :]
56
  proj_feats = self.projection(text_feats)
57
-
58
  gen_ids = self.model.text_model.generate(
59
  inputs_embeds=proj_feats,
60
  attention_mask=mask,
 
5
  from io import BytesIO
6
  from huggingface_hub import hf_hub_download
7
 
8
+ # HF model repo containing pytorch_model.bin with 'projection' state
9
  HF_MODEL_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small"
10
 
11
  class OcrReorderPipeline(Pipeline):
 
18
  ckpt = torch.load(ckpt_path, map_location="cpu")
19
  proj_state= ckpt["projection"]
20
 
21
+ # ── Rebuild & load your projection head (T5-small hidden size = 512) ─
22
+ d_model = 512
23
  self.projection = torch.nn.Sequential(
24
+ torch.nn.Linear(768, d_model),
25
+ torch.nn.LayerNorm(d_model),
26
  torch.nn.GELU()
27
  )
28
  self.projection.load_state_dict(proj_state)
 
42
  def _forward(self, model_inputs):
43
  pv, ids, mask, bbox = (
44
  model_inputs[k].to(self.device)
45
+ for k in ("pixel_values","input_ids","attention_mask","bbox")
46
  )
 
47
  vision_out = self.model.vision_model(
48
  pixel_values=pv,
49
  input_ids=ids,
50
  attention_mask=mask,
51
  bbox=bbox
52
  )
 
53
  seq_len = ids.size(1)
54
  text_feats = vision_out.last_hidden_state[:, :seq_len, :]
55
  proj_feats = self.projection(text_feats)
 
56
  gen_ids = self.model.text_model.generate(
57
  inputs_embeds=proj_feats,
58
  attention_mask=mask,