File size: 5,406 Bytes
c77dc40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2cda1e
c77dc40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1af9124
c77dc40
 
 
1af9124
 
 
 
 
 
c77dc40
1af9124
 
 
 
c77dc40
 
 
 
 
 
 
1af9124
c77dc40
 
1af9124
 
 
 
 
 
 
c77dc40
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os

# workaround: install old version of pytorch since detectron2 hasn't released packages for pytorch 1.9 (issue: https://github.com/facebookresearch/detectron2/issues/3158)
# os.system('pip install torch==1.8.0+cu101 torchvision==0.9.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html')
os.system('pip install -q torch==1.10.0+cu111 torchvision==0.11+cu111 -f https://download.pytorch.org/whl/torch_stable.html')

# install detectron2 that matches pytorch 1.8
# See https://detectron2.readthedocs.io/tutorials/install.html for instructions
#os.system('pip install -q detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.8/index.html')
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')

import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

import gradio as gr
import re
import string
import torch

from operator import itemgetter
import collections

import pypdf
from pypdf import PdfReader
from pypdf.errors import PdfReadError

import pypdfium2 as pdfium
import langdetect
from langdetect import detect_langs

import pandas as pd
import numpy as np
import random
import tempfile
import itertools

from matplotlib import font_manager
from PIL import Image, ImageDraw, ImageFont
import cv2

import pathlib
from pathlib import Path
import shutil

# Tesseract
print(os.popen(f'cat /etc/debian_version').read())
print(os.popen(f'cat /etc/issue').read())
print(os.popen(f'apt search tesseract').read())
import pytesseract

## Key parameters

# categories colors
label2color = {
    'Caption': 'brown',
    'Footnote': 'orange',
    'Formula': 'gray',
    'List-item': 'yellow',
    'Page-footer': 'red',
    'Page-header': 'red',
    'Picture': 'violet',
    'Section-header': 'orange',
    'Table': 'green',
    'Text': 'blue',
    'Title': 'pink'
    }

# bounding boxes start and end of a sequence
cls_box = [0, 0, 0, 0]
sep_box = [1000, 1000, 1000, 1000]

# model
model_id = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-paragraphlevel-ml512"

# tokenizer
tokenizer_id = "xlm-roberta-base"

# (tokenization) The maximum length of a feature (sequence)
if str(384) in model_id:
  max_length = 384 
elif str(512) in model_id:
  max_length = 512 
else:
  print("Error with max_length of chunks!")

# (tokenization) overlap
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.

# max PDF page images that will be displayed
max_imgboxes = 2

# get files
examples_dir = 'files/'
Path(examples_dir).mkdir(parents=True, exist_ok=True)
from huggingface_hub import hf_hub_download
files = ["example.pdf", "blank.pdf", "blank.png", "languages_iso.csv", "languages_tesseract.csv", "wo_content.png"]
for file_name in files:
    path_to_file = hf_hub_download(
        repo_id = "pierreguillou/Inference-APP-Document-Understanding-at-paragraphlevel-v2",
        filename = "files/" + file_name,
        repo_type = "space"
        )
    shutil.copy(path_to_file,examples_dir)

# path to files
image_wo_content = examples_dir + "wo_content.png" # image without content
pdf_blank = examples_dir + "blank.pdf" # blank PDF
image_blank = examples_dir + "blank.png" # blank image

## get langdetect2Tesseract dictionary
t = "files/languages_tesseract.csv"
l = "files/languages_iso.csv"

df_t = pd.read_csv(t)
df_l = pd.read_csv(l)

langs_t = df_t["Language"].to_list()
langs_t = [lang_t.lower().strip().translate(str.maketrans('', '', string.punctuation)) for lang_t in langs_t]
langs_l = df_l["Language"].to_list()
langs_l = [lang_l.lower().strip().translate(str.maketrans('', '', string.punctuation)) for lang_l in langs_l]
langscode_t = df_t["LangCode"].to_list()
langscode_l = df_l["LangCode"].to_list()

Tesseract2langdetect, langdetect2Tesseract = dict(), dict()
for lang_t, langcode_t in zip(langs_t,langscode_t):
  try:
    if lang_t == "Chinese - Simplified".lower().strip().translate(str.maketrans('', '', string.punctuation)): lang_t = "chinese"
    index = langs_l.index(lang_t)
    langcode_l = langscode_l[index]
    Tesseract2langdetect[langcode_t] = langcode_l
  except: 
    continue

langdetect2Tesseract = {v:k for k,v in Tesseract2langdetect.items()}

## model / feature extractor / tokenizer

# get device
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## model LiLT
import transformers
from transformers import AutoTokenizer, AutoModelForTokenClassification
tokenizer_lilt = AutoTokenizer.from_pretrained(model_id_lilt)
model_lilt = AutoModelForTokenClassification.from_pretrained(model_id_lilt);
model_lilt.to(device);

## model LayoutXLM
from transformers import LayoutLMv2ForTokenClassification # LayoutXLMTokenizerFast, 
model_layoutxlm = LayoutLMv2ForTokenClassification.from_pretrained(model_id_layoutxlm);
model_layoutxlm.to(device);

# feature extractor
from transformers import LayoutLMv2FeatureExtractor
feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)

# tokenizer
from transformers import AutoTokenizer
tokenizer_layoutxlm = AutoTokenizer.from_pretrained(tokenizer_id_layoutxlm)

# get labels
id2label_lilt = model_lilt.config.id2label
label2id_lilt = model_lilt.config.label2id
num_labels_lilt = len(id2label_lilt)

id2label_layoutxlm = model_layoutxlm.config.id2label
label2id_layoutxlm = model_layoutxlm.config.label2id
num_labels_layoutxlm = len(id2label_layoutxlm)