Uddipan Basu Bir commited on
Commit
5b9baff
·
1 Parent(s): 124a92f

Add custom OCR reorder pipeline + Gradio UI

Browse files
Files changed (3) hide show
  1. app.py +41 -0
  2. inference.py +56 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, base64
2
+ from io import BytesIO
3
+ from PIL import Image
4
+ import gradio as gr
5
+ from inference import OcrReorderPipeline
6
+ from transformers import (
7
+ AutoProcessor,
8
+ LayoutLMv3Model,
9
+ T5ForConditionalGeneration,
10
+ AutoTokenizer
11
+ )
12
+ import torch
13
+
14
+ # Load from your model repo
15
+ repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"
16
+ model = LayoutLMv3Model.from_pretrained(repo)
17
+ tokenizer = AutoTokenizer.from_pretrained(repo)
18
+ processor = AutoProcessor.from_pretrained(repo, apply_ocr=False)
19
+ pipe = OcrReorderPipeline(model, tokenizer, processor, device=0)
20
+
21
+ def infer(image, words_json, boxes_json):
22
+ words = json.loads(words_json)
23
+ boxes = json.loads(boxes_json)
24
+ buf = BytesIO(); image.save(buf, "PNG")
25
+ b64 = base64.b64encode(buf.getvalue()).decode()
26
+ # returns a list of strings; take first
27
+ return pipe(b64, words, boxes)[0]
28
+
29
+ demo = gr.Interface(
30
+ fn=infer,
31
+ inputs=[
32
+ gr.Image(type="pil", label="Image"),
33
+ gr.Textbox(label="Words (JSON list)"),
34
+ gr.Textbox(label="Boxes (JSON list)")
35
+ ],
36
+ outputs="text",
37
+ title="OCR Reorder Pipeline"
38
+ )
39
+
40
+ if __name__ == "__main__":
41
+ demo.launch()
inference.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Pipeline
3
+ from PIL import Image
4
+ import base64
5
+ from io import BytesIO
6
+
7
+ class OcrReorderPipeline(Pipeline):
8
+ def __init__(self, model, tokenizer, processor, device=0):
9
+ super().__init__(model=model, tokenizer=tokenizer,
10
+ feature_extractor=processor, device=device)
11
+ proj_state = torch.load("pytorch_model.bin", map_location="cpu")["projection"]
12
+ self.projection = torch.nn.Sequential(
13
+ torch.nn.Linear(768, model.config.d_model),
14
+ torch.nn.LayerNorm(model.config.d_model),
15
+ torch.nn.GELU()
16
+ )
17
+ self.projection.load_state_dict(proj_state)
18
+ self.projection.to(self.device)
19
+
20
+ def _sanitize_parameters(self, **kwargs):
21
+ return {}, {}, {}
22
+
23
+ def preprocess(self, image, words, boxes):
24
+ data = base64.b64decode(image)
25
+ img = Image.open(BytesIO(data)).convert("RGB")
26
+ return self.feature_extractor(
27
+ [img], [words], boxes=[boxes],
28
+ return_tensors="pt", padding=True, truncation=True
29
+ )
30
+
31
+ def _forward(self, model_inputs):
32
+ pv, ids, mask, bbox = (
33
+ model_inputs[k].to(self.device)
34
+ for k in ("pixel_values","input_ids","attention_mask","bbox")
35
+ )
36
+ vision_out = self.model.vision_model(
37
+ pixel_values=pv,
38
+ input_ids=ids,
39
+ attention_mask=mask,
40
+ bbox=bbox
41
+ )
42
+ seq_len = ids.size(1)
43
+ text_feats = vision_out.last_hidden_state[:, :seq_len, :]
44
+ proj_feats = self.projection(text_feats)
45
+ gen_ids = self.model.text_model.generate(
46
+ inputs_embeds=proj_feats,
47
+ attention_mask=mask,
48
+ max_length=512
49
+ )
50
+ return {"generated_ids": gen_ids}
51
+
52
+ def postprocess(self, model_outputs):
53
+ return self.tokenizer.batch_decode(
54
+ model_outputs["generated_ids"],
55
+ skip_special_tokens=True
56
+ )
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ Pillow
4
+ gradio