MarkusDressel commited on
Commit
b0247d5
·
1 Parent(s): 8f8ef3f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.system('pip install gradio --upgrade')
3
+ # os.system('pip install git+https://github.com/huggingface/transformers.git --upgrade')
4
+ # os.system('pip install pyyaml==5.1')
5
+ # # workaround: install old version of pytorch since detectron2 hasn't released packages for pytorch 1.9 (issue: https://github.com/facebookresearch/detectron2/issues/3158)
6
+ # os.system('pip install torch==1.8.0+cu101 torchvision==0.9.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html')
7
+ # # install detectron2 that matches pytorch 1.8
8
+ # # See https://detectron2.readthedocs.io/tutorials/install.html for instructions
9
+ # os.system('pip install -q detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.8/index.html')
10
+ # ## install PyTesseract
11
+ # os.system('pip install -q pytesseract')
12
+ import gradio as gr
13
+ import numpy as np
14
+ from transformers import LayoutLMv2Processor, LayoutLMv2ForTokenClassification
15
+ from datasets import load_dataset
16
+ from PIL import Image, ImageDraw, ImageFont
17
+ import PIL
18
+
19
+
20
+ processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
21
+ model = LayoutLMv2ForTokenClassification.from_pretrained("MarkusDressel/cord")
22
+ id2label = model.config.id2label
23
+
24
+ label_ints = np.random.randint(0,len(PIL.ImageColor.colormap.items()),30)
25
+
26
+ label_color_pil = [k for k,_ in PIL.ImageColor.colormap.items()]
27
+
28
+ label_color = [label_color_pil[i] for i in label_ints]
29
+ label2color = {}
30
+ for k,v in id2label.items():
31
+ label2color[v[2:]]=label_color[k]
32
+
33
+
34
+ def unnormalize_box(bbox, width, height):
35
+ return [
36
+ width * (bbox[0] / 1000),
37
+ height * (bbox[1] / 1000),
38
+ width * (bbox[2] / 1000),
39
+ height * (bbox[3] / 1000),
40
+ ]
41
+ def iob_to_label(label):
42
+ label = label[2:]
43
+ if not label:
44
+ return 'other'
45
+ return label
46
+
47
+
48
+ def process_image(image):
49
+ width, height = image.size
50
+ # encode
51
+ encoding = processor(image, truncation=True, return_offsets_mapping=True, return_tensors="pt")
52
+ offset_mapping = encoding.pop('offset_mapping')
53
+ # forward pass
54
+ outputs = model(**encoding)
55
+ # get predictions
56
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
57
+ token_boxes = encoding.bbox.squeeze().tolist()
58
+ # only keep non-subword predictions
59
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
60
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
61
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
62
+ # draw predictions over the image
63
+ draw = ImageDraw.Draw(image)
64
+ font = ImageFont.load_default()
65
+ for prediction, box in zip(true_predictions, true_boxes):
66
+ predicted_label = iob_to_label(prediction).lower()
67
+ draw.rectangle(box, outline=label2color[predicted_label], width=5)
68
+ draw.text((box[0]+10, box[1]-10), text=predicted_label, fill=label2color[predicted_label], font=font)
69
+
70
+ return image
71
+ title = "Cord demo: LayoutLMv2"
72
+ description = "Demo for Microsoft's LayoutLMv2.This particular model is fine-tuned on CORD, a dataset of manually annotated receipts. It annotates the words appearing in the image in up to 30 classes. To use it, simply upload an image or use the example image below and click 'Submit'. Results will show up in a few seconds. If you want to make the output bigger, right-click on it and select 'Open image in new tab'."
73
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2012.14740' target='_blank'>LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding</a> | <a href='https://github.com/microsoft/unilm' target='_blank'>Github Repo</a></p>"
74
+ examples =[['receipt_00189.png']]
75
+ css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
76
+ #css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }"
77
+ # css = ".output_image, .input_image {height: 600px !important}"
78
+ iface = gr.Interface(fn=process_image,
79
+ inputs=gr.inputs.Image(type="pil"),
80
+ outputs=gr.outputs.Image(type="pil", label="annotated image"),
81
+ title=title,
82
+ description=description,
83
+ article=article,
84
+ examples=examples,
85
+ css=css,
86
+ enable_queue=True)
87
+ iface.launch(debug=True)