Uddipan Basu Bir commited on
Commit
a01cae7
Β·
1 Parent(s): 0d4b0fc

Download checkpoint from HF hub in OcrReorderPipeline

Browse files
Files changed (1) hide show
  1. app.py +31 -77
app.py CHANGED
@@ -16,89 +16,43 @@ from transformers import (
16
  # ── 1) MODEL SETUP ─────────────────────────────────────────────────────
17
  repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"
18
 
19
- # Processor for LayoutLMv3
20
  processor = AutoProcessor.from_pretrained(
21
  repo,
22
  subfolder="preprocessor",
23
  apply_ocr=False
24
  )
25
 
26
- # LayoutLMv3 encoder
27
- layout_model = LayoutLMv3Model.from_pretrained(repo)
28
- layout_model.eval()
29
-
30
- # T5 decoder & tokenizer
31
- t5_model = T5ForConditionalGeneration.from_pretrained(repo)
32
- t5_model.eval()
33
- tokenizer = AutoTokenizer.from_pretrained(
34
  repo, subfolder="preprocessor"
35
  )
36
 
37
- # Ensure decoder_start_token_id is set
38
  if t5_model.config.decoder_start_token_id is None:
39
- # Fallback to bos_token_id if present
40
- t5_model.config.decoder_start_token_id = tokenizer.bos_token_id
41
-
42
- # Projection head: load from checkpoint
43
- ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
44
- ckpt = torch.load(ckpt_file, map_location="cpu")
45
- proj_state= ckpt["projection"]
46
- projection = torch.nn.Sequential(
47
- torch.nn.Linear(768, t5_model.config.d_model),
48
- torch.nn.LayerNorm(t5_model.config.d_model),
49
- torch.nn.GELU()
50
- )
51
- projection.load_state_dict(proj_state)
52
- projection.eval()
53
-
54
- # Move models to CPU (Spaces are CPU-only)
55
- device = torch.device("cpu")
56
- layout_model.to(device)
57
- t5_model.to(device)
58
- projection.to(device)
59
- repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"
60
-
61
- # Processor for LayoutLMv3
62
- processor = AutoProcessor.from_pretrained(
63
- repo,
64
- subfolder="preprocessor",
65
- apply_ocr=False
66
- )
67
-
68
- # LayoutLMv3 encoder
69
- layout_model = LayoutLMv3Model.from_pretrained(repo)
70
- layout_model.eval()
71
-
72
- # T5 decoder & tokenizer
73
- t5_model = T5ForConditionalGeneration.from_pretrained(repo)
74
- t5_model.eval()
75
- tokenizer = AutoTokenizer.from_pretrained(
76
- repo, subfolder="preprocessor"
77
- )
78
-
79
- # Projection head: load from checkpoint
80
- ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
81
- ckpt = torch.load(ckpt_file, map_location="cpu")
82
- proj_state= ckpt["projection"]
83
- projection = torch.nn.Sequential(
84
  torch.nn.Linear(768, t5_model.config.d_model),
85
  torch.nn.LayerNorm(t5_model.config.d_model),
86
  torch.nn.GELU()
87
- )
88
  projection.load_state_dict(proj_state)
89
- projection.eval()
90
-
91
- # Move models to CPU (Spaces are CPU-only)
92
- device = torch.device("cpu")
93
- layout_model.to(device)
94
- t5_model.to(device)
95
- projection.to(device)
96
 
97
  # ── 2) INFERENCE FUNCTION ─────────────────────────────────────────────
98
  def infer(image_path, json_file):
99
  img_name = os.path.basename(image_path)
100
 
101
- # 2.a) Load NDJSON file (one JSON object per line)
102
  data = []
103
  with open(json_file.name, "r", encoding="utf-8") as f:
104
  for line in f:
@@ -106,7 +60,6 @@ def infer(image_path, json_file):
106
  continue
107
  data.append(json.loads(line))
108
 
109
- # 2.b) Find entry matching uploaded image
110
  entry = next((e for e in data if e.get("img_name") == img_name), None)
111
  if entry is None:
112
  return f"❌ No JSON entry found for image '{img_name}'"
@@ -114,21 +67,21 @@ def infer(image_path, json_file):
114
  words = entry.get("src_word_list", [])
115
  boxes = entry.get("src_wordbox_list", [])
116
 
117
- # 2.c) Open and preprocess the image + tokens + boxes
118
  img = Image.open(image_path).convert("RGB")
119
  encoding = processor(
120
  [img], [words], boxes=[boxes],
121
  return_tensors="pt", padding=True, truncation=True
122
  )
123
- pixel_values = encoding.pixel_values.to(device)
124
- input_ids = encoding.input_ids.to(device)
125
- attention_mask = encoding.attention_mask.to(device)
126
- bbox = encoding.bbox.to(device)
127
 
128
- # 2.d) Forward pass
129
  with torch.no_grad():
130
  # LayoutLMv3 encoding
131
- lm_out = layout_model(
132
  pixel_values=pixel_values,
133
  input_ids=input_ids,
134
  attention_mask=attention_mask,
@@ -137,22 +90,23 @@ def infer(image_path, json_file):
137
  seq_len = input_ids.size(1)
138
  text_feats = lm_out.last_hidden_state[:, :seq_len, :]
139
 
140
- # Projection β†’ T5 decoding
141
  proj_feats = projection(text_feats)
142
- gen_ids = t5_model.generate(
143
  inputs_embeds=proj_feats,
144
  attention_mask=attention_mask,
145
  max_length=512,
146
- decoder_start_token_id=t5_model.config.decoder_start_token_id
 
147
  )
148
 
149
- # Decode to text
150
  result = tokenizer.batch_decode(
151
  gen_ids, skip_special_tokens=True
152
  )[0]
153
  return result
154
 
155
- # ── 3) GRADIO UI ───────────────────────────────────────────────────────
156
  demo = gr.Interface(
157
  fn=infer,
158
  inputs=[
 
16
  # ── 1) MODEL SETUP ─────────────────────────────────────────────────────
17
  repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"
18
 
19
+ # Processor
20
  processor = AutoProcessor.from_pretrained(
21
  repo,
22
  subfolder="preprocessor",
23
  apply_ocr=False
24
  )
25
 
26
+ # Encoder & Decoder
27
+ layout_model = LayoutLMv3Model.from_pretrained(repo).to("cpu").eval()
28
+ t5_model = T5ForConditionalGeneration.from_pretrained(repo).to("cpu").eval()
29
+ tokenizer = AutoTokenizer.from_pretrained(
 
 
 
 
30
  repo, subfolder="preprocessor"
31
  )
32
 
33
+ # Ensure decoder_start_token_id and bos_token_id are set
34
  if t5_model.config.decoder_start_token_id is None:
35
+ fallback = tokenizer.bos_token_id or tokenizer.eos_token_id
36
+ t5_model.config.decoder_start_token_id = fallback
37
+ if t5_model.config.bos_token_id is None:
38
+ t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id
39
+
40
+ # Projection head
41
+ ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
42
+ ckpt = torch.load(ckpt_file, map_location="cpu")
43
+ proj_state = ckpt["projection"]
44
+ projection = torch.nn.Sequential(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  torch.nn.Linear(768, t5_model.config.d_model),
46
  torch.nn.LayerNorm(t5_model.config.d_model),
47
  torch.nn.GELU()
48
+ ).to("cpu")
49
  projection.load_state_dict(proj_state)
 
 
 
 
 
 
 
50
 
51
  # ── 2) INFERENCE FUNCTION ─────────────────────────────────────────────
52
  def infer(image_path, json_file):
53
  img_name = os.path.basename(image_path)
54
 
55
+ # Load NDJSON
56
  data = []
57
  with open(json_file.name, "r", encoding="utf-8") as f:
58
  for line in f:
 
60
  continue
61
  data.append(json.loads(line))
62
 
 
63
  entry = next((e for e in data if e.get("img_name") == img_name), None)
64
  if entry is None:
65
  return f"❌ No JSON entry found for image '{img_name}'"
 
67
  words = entry.get("src_word_list", [])
68
  boxes = entry.get("src_wordbox_list", [])
69
 
70
+ # Preprocess image + tokens
71
  img = Image.open(image_path).convert("RGB")
72
  encoding = processor(
73
  [img], [words], boxes=[boxes],
74
  return_tensors="pt", padding=True, truncation=True
75
  )
76
+ pixel_values = encoding.pixel_values.to("cpu")
77
+ input_ids = encoding.input_ids.to("cpu")
78
+ attention_mask = encoding.attention_mask.to("cpu")
79
+ bbox = encoding.bbox.to("cpu")
80
 
81
+ # Forward pass
82
  with torch.no_grad():
83
  # LayoutLMv3 encoding
84
+ lm_out = layout_model(
85
  pixel_values=pixel_values,
86
  input_ids=input_ids,
87
  attention_mask=attention_mask,
 
90
  seq_len = input_ids.size(1)
91
  text_feats = lm_out.last_hidden_state[:, :seq_len, :]
92
 
93
+ # Projection + T5 decoding
94
  proj_feats = projection(text_feats)
95
+ gen_ids = t5_model.generate(
96
  inputs_embeds=proj_feats,
97
  attention_mask=attention_mask,
98
  max_length=512,
99
+ decoder_start_token_id=t5_model.config.decoder_start_token_id,
100
+ bos_token_id=t5_model.config.bos_token_id
101
  )
102
 
103
+ # Decode and return
104
  result = tokenizer.batch_decode(
105
  gen_ids, skip_special_tokens=True
106
  )[0]
107
  return result
108
 
109
+ # ── 3) GRADIO INTERFACE ────────────────────────────────────────────────
110
  demo = gr.Interface(
111
  fn=infer,
112
  inputs=[