itsanurag commited on
Commit
f5a475a
·
verified ·
1 Parent(s): 8c2b304

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +387 -0
app.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+
5
+ from PIL import Image, ImageDraw
6
+ import traceback
7
+
8
+ import gradio as gr
9
+
10
+ import torch
11
+ from docquery import pipeline
12
+ from docquery.document import load_document, ImageDocument
13
+ from docquery.ocr_reader import get_ocr_reader
14
+
15
+
16
+
17
+ def ensure_list(x):
18
+ if isinstance(x, list):
19
+ return x
20
+ else:
21
+ return [x]
22
+
23
+
24
+ CHECKPOINTS = {
25
+ "LayoutLMv1": "impira/layoutlm-document-qa",
26
+ "LayoutLMv1 for Invoices": "impira/layoutlm-invoices",
27
+ "Donut": "naver-clova-ix/donut-base-finetuned-docvqa",
28
+ }
29
+
30
+
31
+
32
+ PIPELINES = {}
33
+
34
+
35
+
36
+ def construct_pipeline(task, model):
37
+ global PIPELINES
38
+ if model in PIPELINES:
39
+ return PIPELINES[model]
40
+
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
43
+ PIPELINES[model] = ret
44
+ return ret
45
+
46
+
47
+ def run_pipeline(model, question, document, top_k):
48
+ pipeline = construct_pipeline("document-question-answering", model)
49
+ return pipeline(question=question, **document.context, top_k=top_k)
50
+
51
+
52
+ # TODO: Move into docquery
53
+ # TODO: Support words past the first page (or window?)
54
+ def lift_word_boxes(document, page):
55
+ return document.context["image"][page][1]
56
+
57
+
58
+ def expand_bbox(word_boxes):
59
+ if len(word_boxes) == 0:
60
+ return None
61
+
62
+ min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
63
+ min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
64
+ return [min_x, min_y, max_x, max_y]
65
+
66
+
67
+ # LayoutLM boxes are normalized to 0, 1000
68
+ def normalize_bbox(box, width, height, padding=0.005):
69
+ min_x, min_y, max_x, max_y = [c / 1000 for c in box]
70
+ if padding != 0:
71
+ min_x = max(0, min_x - padding)
72
+ min_y = max(0, min_y - padding)
73
+ max_x = min(max_x + padding, 1)
74
+ max_y = min(max_y + padding, 1)
75
+ return [min_x * width, min_y * height, max_x * width, max_y * height]
76
+
77
+
78
+ def process_path(path):
79
+ error = None
80
+ if path:
81
+ try:
82
+ document = load_document(path)
83
+ return (
84
+ document,
85
+ gr.update(visible=True, value=document.preview),
86
+ gr.update(visible=True),
87
+ gr.update(visible=False, value=None),
88
+ gr.update(visible=False, value=None),
89
+ None,
90
+ )
91
+ except Exception as e:
92
+ traceback.print_exc()
93
+ error = str(e)
94
+ return (
95
+ None,
96
+ gr.update(visible=False, value=None),
97
+ gr.update(visible=False),
98
+ gr.update(visible=False, value=None),
99
+ gr.update(visible=False, value=None),
100
+ gr.update(visible=True, value=error) if error is not None else None,
101
+ None,
102
+ )
103
+
104
+
105
+ def process_upload(file):
106
+ if file:
107
+ return process_path(file.name)
108
+ else:
109
+ return (
110
+ None,
111
+ gr.update(visible=False, value=None),
112
+ gr.update(visible=False),
113
+ gr.update(visible=False, value=None),
114
+ gr.update(visible=False, value=None),
115
+ None,
116
+ )
117
+
118
+
119
+ colors = ["#64A087", "black", "black"]
120
+
121
+
122
+ def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
123
+ if not question or document is None:
124
+ return None, None, None
125
+
126
+ text_value = None
127
+ predictions = run_pipeline(model, question, document, 3)
128
+ pages = [x.copy().convert("RGB") for x in document.preview]
129
+ for i, p in enumerate(ensure_list(predictions)):
130
+ if i == 0:
131
+ text_value = p["answer"]
132
+ else:
133
+ # Keep the code around to produce multiple boxes, but only show the top
134
+ # prediction for now
135
+ break
136
+
137
+ if "word_ids" in p:
138
+ image = pages[p["page"]]
139
+ draw = ImageDraw.Draw(image, "RGBA")
140
+ word_boxes = lift_word_boxes(document, p["page"])
141
+ x1, y1, x2, y2 = normalize_bbox(
142
+ expand_bbox([word_boxes[i] for i in p["word_ids"]]),
143
+ image.width,
144
+ image.height,
145
+ )
146
+ draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
147
+
148
+ return (
149
+ gr.update(visible=True, value=pages),
150
+ gr.update(visible=True, value=predictions),
151
+ gr.update(
152
+ visible=True,
153
+ value=text_value,
154
+ ),
155
+ )
156
+
157
+
158
+ def load_example_document(img, question, model):
159
+ if img is not None:
160
+ if question in question_files:
161
+ document = load_document(question_files[question])
162
+ else:
163
+ document = ImageDocument(Image.fromarray(img), get_ocr_reader())
164
+ preview, answer, answer_text = process_question(question, document, model)
165
+ return document, question, preview, gr.update(visible=True), answer, answer_text
166
+ else:
167
+ return None, None, None, gr.update(visible=False), None, None
168
+
169
+
170
+ CSS = """
171
+ #question input {
172
+ font-size: 16px;
173
+ }
174
+ #url-textbox {
175
+ padding: 0 !important;
176
+ }
177
+ #short-upload-box .w-full {
178
+ min-height: 10rem !important;
179
+ }
180
+ /* I think something like this can be used to re-shape
181
+ * the table
182
+ */
183
+ /*
184
+ .gr-samples-table tr {
185
+ display: inline;
186
+ }
187
+ .gr-samples-table .p-2 {
188
+ width: 100px;
189
+ }
190
+ */
191
+ #select-a-file {
192
+ width: 100%;
193
+ }
194
+ #file-clear {
195
+ padding-top: 2px !important;
196
+ padding-bottom: 2px !important;
197
+ padding-left: 8px !important;
198
+ padding-right: 8px !important;
199
+ margin-top: 10px;
200
+ }
201
+ .gradio-container .gr-button-primary {
202
+ background: linear-gradient(180deg, #FAED27 0%, #FAED27 100%);
203
+ border: 1px solid #000000;
204
+ border-radius: 8px;
205
+ color: #000000;
206
+ }
207
+ .gradio-container.dark button#submit-button {
208
+ background: linear-gradient(180deg, #FAED27 0%, #FAED27 100%);
209
+ border: 1px solid #000000;
210
+ border-radius: 8px;
211
+ color: #000000
212
+ }
213
+ table.gr-samples-table tr td {
214
+ border: none;
215
+ outline: none;
216
+ }
217
+ table.gr-samples-table tr td:first-of-type {
218
+ width: 0%;
219
+ }
220
+ div#short-upload-box div.absolute {
221
+ display: none !important;
222
+ }
223
+ gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
224
+ gap: 0px 2%;
225
+ }
226
+ gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
227
+ gap: 0px;
228
+ }
229
+ gradio-app h2, .gradio-app h2 {
230
+ padding-top: 10px;
231
+ }
232
+ #answer {
233
+ overflow-y: scroll;
234
+ color: white;
235
+ background: #666;
236
+ border-color: #666;
237
+ font-size: 20px;
238
+ font-weight: bold;
239
+ }
240
+ #answer span {
241
+ color: white;
242
+ }
243
+ #answer textarea {
244
+ color:white;
245
+ background: #777;
246
+ border-color: #777;
247
+ font-size: 18px;
248
+ }
249
+ #url-error input {
250
+ color: red;
251
+ }
252
+ """
253
+
254
+ with gr.Blocks(css=CSS) as demo:
255
+ gr.Markdown()
256
+ gr.Markdown(
257
+
258
+ )
259
+
260
+ document = gr.Variable()
261
+ example_question = gr.Textbox(visible=False)
262
+ example_image = gr.Image(visible=False)
263
+
264
+ with gr.Row(equal_height=True):
265
+ with gr.Column():
266
+ with gr.Row():
267
+ gr.Markdown("## 1. Select a file", elem_id="select-a-file")
268
+ img_clear_button = gr.Button(
269
+ "Clear", variant="secondary", elem_id="file-clear", visible=False
270
+ )
271
+ image = gr.Gallery(visible=False)
272
+ with gr.Row(equal_height=True):
273
+ with gr.Column():
274
+ with gr.Row():
275
+ url = gr.Textbox(
276
+ show_label=False,
277
+ placeholder="URL",
278
+ lines=1,
279
+ max_lines=1,
280
+ elem_id="url-textbox",
281
+ )
282
+ submit = gr.Button("Get")
283
+ url_error = gr.Textbox(
284
+ visible=False,
285
+ elem_id="url-error",
286
+ max_lines=1,
287
+ interactive=False,
288
+ label="Error",
289
+ )
290
+ gr.Markdown("— or —")
291
+ upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
292
+ gr.Examples(
293
+ examples=examples,
294
+ inputs=[example_image, example_question],
295
+ )
296
+
297
+ with gr.Column() as col:
298
+ gr.Markdown("## 2. Ask a question")
299
+ question = gr.Textbox(
300
+ label="Question",
301
+ placeholder="e.g. What is the invoice number?",
302
+ lines=1,
303
+ max_lines=1,
304
+ )
305
+ model = gr.Radio(
306
+ choices=list(CHECKPOINTS.keys()),
307
+ value=list(CHECKPOINTS.keys())[0],
308
+ label="Model",
309
+ )
310
+
311
+ with gr.Row():
312
+ clear_button = gr.Button("Clear", variant="secondary")
313
+ submit_button = gr.Button(
314
+ "Submit", variant="primary", elem_id="submit-button"
315
+ )
316
+ with gr.Column():
317
+ output_text = gr.Textbox(
318
+ label="Top Answer", visible=False, elem_id="answer"
319
+ )
320
+ output = gr.JSON(label="Output", visible=False)
321
+
322
+ for cb in [img_clear_button, clear_button]:
323
+ cb.click(
324
+ lambda _: (
325
+ gr.update(visible=False, value=None),
326
+ None,
327
+ gr.update(visible=False, value=None),
328
+ gr.update(visible=False, value=None),
329
+ gr.update(visible=False),
330
+ None,
331
+ None,
332
+ None,
333
+ gr.update(visible=False, value=None),
334
+ None,
335
+ ),
336
+ inputs=clear_button,
337
+ outputs=[
338
+ image,
339
+ document,
340
+ output,
341
+ output_text,
342
+ img_clear_button,
343
+ example_image,
344
+ upload,
345
+ url,
346
+ url_error,
347
+ question,
348
+ ],
349
+ )
350
+
351
+ upload.change(
352
+ fn=process_upload,
353
+ inputs=[upload],
354
+ outputs=[document, image, img_clear_button, output, output_text, url_error],
355
+ )
356
+ submit.click(
357
+ fn=process_path,
358
+ inputs=[url],
359
+ outputs=[document, image, img_clear_button, output, output_text, url_error],
360
+ )
361
+
362
+ question.submit(
363
+ fn=process_question,
364
+ inputs=[question, document, model],
365
+ outputs=[image, output, output_text],
366
+ )
367
+
368
+ submit_button.click(
369
+ process_question,
370
+ inputs=[question, document, model],
371
+ outputs=[image, output, output_text],
372
+ )
373
+
374
+ model.change(
375
+ process_question,
376
+ inputs=[question, document, model],
377
+ outputs=[image, output, output_text],
378
+ )
379
+
380
+ example_image.change(
381
+ fn=load_example_document,
382
+ inputs=[example_image, example_question, model],
383
+ outputs=[document, question, image, img_clear_button, output, output_text],
384
+ )
385
+
386
+ if __name__ == "__main__":
387
+ demo.launch(enable_queue=False)