pierreguillou commited on
Commit
a2cda1e
·
1 Parent(s): b3dd1cc

Update files/functions.py

Browse files
Files changed (1) hide show
  1. files/functions.py +1 -724
files/functions.py CHANGED
@@ -25,8 +25,7 @@ import pypdf
25
  from pypdf import PdfReader
26
  from pypdf.errors import PdfReadError
27
 
28
- import pdf2image
29
- from pdf2image import convert_from_path
30
  import langdetect
31
  from langdetect import detect_langs
32
 
@@ -170,725 +169,3 @@ id2label_layoutxlm = model_layoutxlm.config.id2label
170
  label2id_layoutxlm = model_layoutxlm.config.label2id
171
  num_labels_layoutxlm = len(id2label_layoutxlm)
172
 
173
- ## General
174
-
175
- # get text and bounding boxes from an image
176
- # https://stackoverflow.com/questions/61347755/how-can-i-get-line-coordinates-that-readed-by-tesseract
177
- # https://medium.com/geekculture/tesseract-ocr-understanding-the-contents-of-documents-beyond-their-text-a98704b7c655
178
- def get_data_paragraph(results, factor, conf_min=0):
179
-
180
- data = {}
181
- for i in range(len(results['line_num'])):
182
- level = results['level'][i]
183
- block_num = results['block_num'][i]
184
- par_num = results['par_num'][i]
185
- line_num = results['line_num'][i]
186
- top, left = results['top'][i], results['left'][i]
187
- width, height = results['width'][i], results['height'][i]
188
- conf = results['conf'][i]
189
- text = results['text'][i]
190
- if not (text == '' or text.isspace()):
191
- if conf >= conf_min:
192
- tup = (text, left, top, width, height)
193
- if block_num in list(data.keys()):
194
- if par_num in list(data[block_num].keys()):
195
- if line_num in list(data[block_num][par_num].keys()):
196
- data[block_num][par_num][line_num].append(tup)
197
- else:
198
- data[block_num][par_num][line_num] = [tup]
199
- else:
200
- data[block_num][par_num] = {}
201
- data[block_num][par_num][line_num] = [tup]
202
- else:
203
- data[block_num] = {}
204
- data[block_num][par_num] = {}
205
- data[block_num][par_num][line_num] = [tup]
206
-
207
- # get paragraphs dicionnary with list of lines
208
- par_data = {}
209
- par_idx = 1
210
- for _, b in data.items():
211
- for _, p in b.items():
212
- line_data = {}
213
- line_idx = 1
214
- for _, l in p.items():
215
- line_data[line_idx] = l
216
- line_idx += 1
217
- par_data[par_idx] = line_data
218
- par_idx += 1
219
-
220
- # get lines of texts, grouped by paragraph
221
- texts_pars = list()
222
- row_indexes = list()
223
- texts_lines = list()
224
- texts_lines_par = list()
225
- row_index = 0
226
- for _,par in par_data.items():
227
- count_lines = 0
228
- lines_par = list()
229
- for _,line in par.items():
230
- if count_lines == 0: row_indexes.append(row_index)
231
- line_text = ' '.join([item[0] for item in line])
232
- texts_lines.append(line_text)
233
- lines_par.append(line_text)
234
- count_lines += 1
235
- row_index += 1
236
- # lines.append("\n")
237
- row_index += 1
238
- texts_lines_par.append(lines_par)
239
- texts_pars.append(' '.join(lines_par))
240
- # lines = lines[:-1]
241
-
242
- # get paragraphes boxes (par_boxes)
243
- # get lines boxes (line_boxes)
244
- par_boxes = list()
245
- par_idx = 1
246
- line_boxes, lines_par_boxes = list(), list()
247
- line_idx = 1
248
- for _, par in par_data.items():
249
- xmins, ymins, xmaxs, ymaxs = list(), list(), list(), list()
250
- line_boxes_par = list()
251
- count_line_par = 0
252
- for _, line in par.items():
253
- xmin, ymin = line[0][1], line[0][2]
254
- xmax, ymax = (line[-1][1] + line[-1][3]), (line[-1][2] + line[-1][4])
255
- line_boxes.append([int(xmin/factor), int(ymin/factor), int(xmax/factor), int(ymax/factor)])
256
- line_boxes_par.append([int(xmin/factor), int(ymin/factor), int(xmax/factor), int(ymax/factor)])
257
- xmins.append(xmin)
258
- ymins.append(ymin)
259
- xmaxs.append(xmax)
260
- ymaxs.append(ymax)
261
- line_idx += 1
262
- count_line_par += 1
263
- xmin, ymin, xmax, ymax = min(xmins), min(ymins), max(xmaxs), max(ymaxs)
264
- par_bbox = [int(xmin/factor), int(ymin/factor), int(xmax/factor), int(ymax/factor)]
265
- par_boxes.append(par_bbox)
266
- lines_par_boxes.append(line_boxes_par)
267
- par_idx += 1
268
-
269
- return texts_lines, texts_pars, texts_lines_par, row_indexes, par_boxes, line_boxes, lines_par_boxes
270
-
271
- # rescale image to get 300dpi
272
- def set_image_dpi_resize(image):
273
- """
274
- Rescaling image to 300dpi while resizing
275
- :param image: An image
276
- :return: A rescaled image
277
- """
278
- length_x, width_y = image.size
279
- factor = min(1, float(1024.0 / length_x))
280
- size = int(factor * length_x), int(factor * width_y)
281
- # image_resize = image.resize(size, Image.Resampling.LANCZOS)
282
- image_resize = image.resize(size, Image.LANCZOS)
283
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='1.png')
284
- temp_filename = temp_file.name
285
- image_resize.save(temp_filename, dpi=(300, 300))
286
- return factor, temp_filename
287
-
288
- # it is important that each bounding box should be in (upper left, lower right) format.
289
- # source: https://github.com/NielsRogge/Transformers-Tutorials/issues/129
290
- def upperleft_to_lowerright(bbox):
291
- x0, y0, x1, y1 = tuple(bbox)
292
- if bbox[2] < bbox[0]:
293
- x0 = bbox[2]
294
- x1 = bbox[0]
295
- if bbox[3] < bbox[1]:
296
- y0 = bbox[3]
297
- y1 = bbox[1]
298
- return [x0, y0, x1, y1]
299
-
300
- # convert boundings boxes (left, top, width, height) format to (left, top, left+widght, top+height) format.
301
- def convert_box(bbox):
302
- x, y, w, h = tuple(bbox) # the row comes in (left, top, width, height) format
303
- return [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box
304
-
305
- # LiLT model gets 1000x10000 pixels images
306
- def normalize_box(bbox, width, height):
307
- return [
308
- int(1000 * (bbox[0] / width)),
309
- int(1000 * (bbox[1] / height)),
310
- int(1000 * (bbox[2] / width)),
311
- int(1000 * (bbox[3] / height)),
312
- ]
313
-
314
- # LiLT model gets 1000x10000 pixels images
315
- def denormalize_box(bbox, width, height):
316
- return [
317
- int(width * (bbox[0] / 1000)),
318
- int(height * (bbox[1] / 1000)),
319
- int(width* (bbox[2] / 1000)),
320
- int(height * (bbox[3] / 1000)),
321
- ]
322
-
323
- # get back original size
324
- def original_box(box, original_width, original_height, coco_width, coco_height):
325
- return [
326
- int(original_width * (box[0] / coco_width)),
327
- int(original_height * (box[1] / coco_height)),
328
- int(original_width * (box[2] / coco_width)),
329
- int(original_height* (box[3] / coco_height)),
330
- ]
331
-
332
- def get_blocks(bboxes_block, categories, texts):
333
-
334
- # get list of unique block boxes
335
- bbox_block_dict, bboxes_block_list, bbox_block_prec = dict(), list(), list()
336
- for count_block, bbox_block in enumerate(bboxes_block):
337
- if bbox_block != bbox_block_prec:
338
- bbox_block_indexes = [i for i, bbox in enumerate(bboxes_block) if bbox == bbox_block]
339
- bbox_block_dict[count_block] = bbox_block_indexes
340
- bboxes_block_list.append(bbox_block)
341
- bbox_block_prec = bbox_block
342
-
343
- # get list of categories and texts by unique block boxes
344
- category_block_list, text_block_list = list(), list()
345
- for bbox_block in bboxes_block_list:
346
- count_block = bboxes_block.index(bbox_block)
347
- bbox_block_indexes = bbox_block_dict[count_block]
348
- category_block = np.array(categories, dtype=object)[bbox_block_indexes].tolist()[0]
349
- category_block_list.append(category_block)
350
- text_block = np.array(texts, dtype=object)[bbox_block_indexes].tolist()
351
- text_block = [text.replace("\n","").strip() for text in text_block]
352
- if id2label[category_block] == "Text" or id2label[category_block] == "Caption" or id2label[category_block] == "Footnote":
353
- text_block = ' '.join(text_block)
354
- else:
355
- text_block = '\n'.join(text_block)
356
- text_block_list.append(text_block)
357
-
358
- return bboxes_block_list, category_block_list, text_block_list
359
-
360
- # function to sort bounding boxes
361
- def get_sorted_boxes(bboxes):
362
-
363
- # sort by y from page top to bottom
364
- sorted_bboxes = sorted(bboxes, key=itemgetter(1), reverse=False)
365
- y_list = [bbox[1] for bbox in sorted_bboxes]
366
-
367
- # sort by x from page left to right when boxes with same y
368
- if len(list(set(y_list))) != len(y_list):
369
- y_list_duplicates_indexes = dict()
370
- y_list_duplicates = [item for item, count in collections.Counter(y_list).items() if count > 1]
371
- for item in y_list_duplicates:
372
- y_list_duplicates_indexes[item] = [i for i, e in enumerate(y_list) if e == item]
373
- bbox_list_y_duplicates = sorted(np.array(sorted_bboxes, dtype=object)[y_list_duplicates_indexes[item]].tolist(), key=itemgetter(0), reverse=False)
374
- np_array_bboxes = np.array(sorted_bboxes)
375
- np_array_bboxes[y_list_duplicates_indexes[item]] = np.array(bbox_list_y_duplicates)
376
- sorted_bboxes = np_array_bboxes.tolist()
377
-
378
- return sorted_bboxes
379
-
380
- # sort data from y = 0 to end of page (and after, x=0 to end of page when necessary)
381
- def sort_data(bboxes, categories, texts):
382
-
383
- sorted_bboxes = get_sorted_boxes(bboxes)
384
- sorted_bboxes_indexes = [bboxes.index(bbox) for bbox in sorted_bboxes]
385
- sorted_categories = np.array(categories, dtype=object)[sorted_bboxes_indexes].tolist()
386
- sorted_texts = np.array(texts, dtype=object)[sorted_bboxes_indexes].tolist()
387
-
388
- return sorted_bboxes, sorted_categories, sorted_texts
389
-
390
- # sort data from y = 0 to end of page (and after, x=0 to end of page when necessary)
391
- def sort_data_wo_labels(bboxes, texts):
392
-
393
- sorted_bboxes = get_sorted_boxes(bboxes)
394
- sorted_bboxes_indexes = [bboxes.index(bbox) for bbox in sorted_bboxes]
395
- sorted_texts = np.array(texts, dtype=object)[sorted_bboxes_indexes].tolist()
396
-
397
- return sorted_bboxes, sorted_texts
398
-
399
- ## PDF processing
400
-
401
- # get filename and images of PDF pages
402
- def pdf_to_images(uploaded_pdf):
403
-
404
- # Check if None object
405
- if uploaded_pdf is None:
406
- path_to_file = pdf_blank
407
- filename = path_to_file.replace(examples_dir,"")
408
- msg = "Invalid PDF file."
409
- images = [Image.open(image_blank)]
410
- else:
411
- # path to the uploaded PDF
412
- path_to_file = uploaded_pdf.name
413
- filename = path_to_file.replace("/tmp/","")
414
-
415
- try:
416
- PdfReader(path_to_file)
417
- except PdfReadError:
418
- path_to_file = pdf_blank
419
- filename = path_to_file.replace(examples_dir,"")
420
- msg = "Invalid PDF file."
421
- images = [Image.open(image_blank)]
422
- else:
423
- try:
424
- images = convert_from_path(path_to_file, last_page=max_imgboxes)
425
- num_imgs = len(images)
426
- msg = f'The PDF "{filename}" was converted into {num_imgs} images.'
427
- except:
428
- msg = f'Error with the PDF "{filename}": it was not converted into images.'
429
- images = [Image.open(image_wo_content)]
430
-
431
- return filename, msg, images
432
-
433
- # Extraction of image data (text and bounding boxes)
434
- def extraction_data_from_image(images):
435
-
436
- num_imgs = len(images)
437
-
438
- if num_imgs > 0:
439
-
440
- # https://pyimagesearch.com/2021/11/15/tesseract-page-segmentation-modes-psms-explained-how-to-improve-your-ocr-accuracy/
441
- custom_config = r'--oem 3 --psm 3 -l eng' # default config PyTesseract: --oem 3 --psm 3 -l eng+deu+fra+jpn+por+spa+rus+hin+chi_sim
442
- results, texts_lines, texts_pars, texts_lines_par, row_indexes, par_boxes, line_boxes, lines_par_boxes, images_pixels = dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict()
443
- images_ids_list, texts_lines_list, texts_pars_list, texts_lines_par_list, par_boxes_list, line_boxes_list, lines_par_boxes_list, images_list, images_pixels_list, page_no_list, num_pages_list = list(), list(), list(), list(), list(), list(), list(), list(), list(), list(), list()
444
-
445
- try:
446
- for i,image in enumerate(images):
447
- # image preprocessing
448
- # https://docs.opencv.org/3.0-beta/doc/py_tutorials/py_imgproc/py_thresholding/py_thresholding.html
449
- img = image.copy()
450
- factor, path_to_img = set_image_dpi_resize(img) # Rescaling to 300dpi
451
- img = Image.open(path_to_img)
452
- img = np.array(img, dtype='uint8') # convert PIL to cv2
453
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # gray scale image
454
- ret,img = cv2.threshold(img,127,255,cv2.THRESH_BINARY)
455
-
456
- # OCR PyTesseract | get langs of page
457
- txt = pytesseract.image_to_string(img, config=custom_config)
458
- txt = txt.strip().lower()
459
- txt = re.sub(r" +", " ", txt) # multiple space
460
- txt = re.sub(r"(\n\s*)+\n+", "\n", txt) # multiple line
461
- # txt = os.popen(f'tesseract {img_filepath} - {custom_config}').read()
462
- try:
463
- langs = detect_langs(txt)
464
- langs = [langdetect2Tesseract[langs[i].lang] for i in range(len(langs))]
465
- langs_string = '+'.join(langs)
466
- except:
467
- langs_string = "eng"
468
- langs_string += '+osd'
469
- custom_config = f'--oem 3 --psm 3 -l {langs_string}' # default config PyTesseract: --oem 3 --psm 3
470
-
471
- # OCR PyTesseract | get data
472
- results[i] = pytesseract.image_to_data(img, config=custom_config, output_type=pytesseract.Output.DICT)
473
- # results[i] = os.popen(f'tesseract {img_filepath} - {custom_config}').read()
474
-
475
- # get image pixels
476
- images_pixels[i] = feature_extractor(images[i], return_tensors="pt").pixel_values
477
-
478
- texts_lines[i], texts_pars[i], texts_lines_par[i], row_indexes[i], par_boxes[i], line_boxes[i], lines_par_boxes[i] = get_data_paragraph(results[i], factor, conf_min=0)
479
- texts_lines_list.append(texts_lines[i])
480
- texts_pars_list.append(texts_pars[i])
481
- texts_lines_par_list.append(texts_lines_par[i])
482
- par_boxes_list.append(par_boxes[i])
483
- line_boxes_list.append(line_boxes[i])
484
- lines_par_boxes_list.append(lines_par_boxes[i])
485
- images_ids_list.append(i)
486
- images_pixels_list.append(images_pixels[i])
487
- images_list.append(images[i])
488
- page_no_list.append(i)
489
- num_pages_list.append(num_imgs)
490
-
491
- except:
492
- print(f"There was an error within the extraction of PDF text by the OCR!")
493
- else:
494
- from datasets import Dataset
495
- dataset = Dataset.from_dict({"images_ids": images_ids_list, "images": images_list, "images_pixels": images_pixels_list, "page_no": page_no_list, "num_pages": num_pages_list, "texts_line": texts_lines_list, "texts_par": texts_pars_list, "texts_lines_par": texts_lines_par_list, "bboxes_par": par_boxes_list, "bboxes_lines_par":lines_par_boxes_list})
496
-
497
-
498
- # print(f"The text data was successfully extracted by the OCR!")
499
-
500
- return dataset, texts_lines, texts_pars, texts_lines_par, row_indexes, par_boxes, line_boxes, lines_par_boxes
501
-
502
- ## Inference
503
-
504
- def prepare_inference_features_paragraph(example, cls_box = cls_box, sep_box = sep_box):
505
-
506
- images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list, images_pixels_list = list(), list(), list(), list(), list(), list()
507
-
508
- # get batch
509
- # batch_page_hash = example["page_hash"]
510
- batch_images_ids = example["images_ids"]
511
- batch_images = example["images"]
512
- batch_images_pixels = example["images_pixels"]
513
- batch_bboxes_par = example["bboxes_par"]
514
- batch_texts_par = example["texts_par"]
515
- batch_images_size = [image.size for image in batch_images]
516
-
517
- batch_width, batch_height = [image_size[0] for image_size in batch_images_size], [image_size[1] for image_size in batch_images_size]
518
-
519
- # add a dimension if not a batch but only one image
520
- if not isinstance(batch_images_ids, list):
521
- batch_images_ids = [batch_images_ids]
522
- batch_images = [batch_images]
523
- batch_images_pixels = [batch_images_pixels]
524
- batch_bboxes_par = [batch_bboxes_par]
525
- batch_texts_par = [batch_texts_par]
526
- batch_width, batch_height = [batch_width], [batch_height]
527
-
528
- # process all images of the batch
529
- for num_batch, (image_id, image_pixels, boxes, texts_par, width, height) in enumerate(zip(batch_images_ids, batch_images_pixels, batch_bboxes_par, batch_texts_par, batch_width, batch_height)):
530
- tokens_list = []
531
- bboxes_list = []
532
-
533
- # add a dimension if only on image
534
- if not isinstance(texts_par, list):
535
- texts_par, boxes = [texts_par], [boxes]
536
-
537
- # convert boxes to original
538
- normalize_bboxes_par = [normalize_box(upperleft_to_lowerright(box), width, height) for box in boxes]
539
-
540
- # sort boxes with texts
541
- # we want sorted lists from top to bottom of the image
542
- boxes, texts_par = sort_data_wo_labels(normalize_bboxes_par, texts_par)
543
-
544
- count = 0
545
- for box, text_par in zip(boxes, texts_par):
546
- tokens_par = tokenizer.tokenize(text_par)
547
- num_tokens_par = len(tokens_par) # get number of tokens
548
- tokens_list.extend(tokens_par)
549
- bboxes_list.extend([box] * num_tokens_par) # number of boxes must be the same as the number of tokens
550
-
551
- # use of return_overflowing_tokens=True / stride=doc_stride
552
- # to get parts of image with overlap
553
- # source: https://huggingface.co/course/chapter6/3b?fw=tf#handling-long-contexts
554
- encodings = tokenizer(" ".join(texts_par),
555
- truncation=True,
556
- padding="max_length",
557
- max_length=max_length,
558
- stride=doc_stride,
559
- return_overflowing_tokens=True,
560
- return_offsets_mapping=True
561
- )
562
-
563
- otsm = encodings.pop("overflow_to_sample_mapping")
564
- offset_mapping = encodings.pop("offset_mapping")
565
-
566
- # Let's label those examples and get their boxes
567
- sequence_length_prev = 0
568
- for i, offsets in enumerate(offset_mapping):
569
- # truncate tokens, boxes and labels based on length of chunk - 2 (special tokens <s> and </s>)
570
- sequence_length = len(encodings.input_ids[i]) - 2
571
- if i == 0: start = 0
572
- else: start += sequence_length_prev - doc_stride
573
- end = start + sequence_length
574
- sequence_length_prev = sequence_length
575
-
576
- # get tokens, boxes and labels of this image chunk
577
- bb = [cls_box] + bboxes_list[start:end] + [sep_box]
578
-
579
- # as the last chunk can have a length < max_length
580
- # we must to add [tokenizer.pad_token] (tokens), [sep_box] (boxes) and [-100] (labels)
581
- if len(bb) < max_length:
582
- bb = bb + [sep_box] * (max_length - len(bb))
583
-
584
- # append results
585
- input_ids_list.append(encodings["input_ids"][i])
586
- attention_mask_list.append(encodings["attention_mask"][i])
587
- bb_list.append(bb)
588
- images_ids_list.append(image_id)
589
- chunks_ids_list.append(i)
590
- images_pixels_list.append(image_pixels)
591
-
592
- return {
593
- "images_ids": images_ids_list,
594
- "chunk_ids": chunks_ids_list,
595
- "input_ids": input_ids_list,
596
- "attention_mask": attention_mask_list,
597
- "normalized_bboxes": bb_list,
598
- "images_pixels": images_pixels_list
599
- }
600
-
601
- from torch.utils.data import Dataset
602
-
603
- class CustomDataset(Dataset):
604
- def __init__(self, dataset, tokenizer):
605
- self.dataset = dataset
606
- self.tokenizer = tokenizer
607
-
608
- def __len__(self):
609
- return len(self.dataset)
610
-
611
- def __getitem__(self, idx):
612
- # get item
613
- example = self.dataset[idx]
614
- encoding = dict()
615
- encoding["images_ids"] = example["images_ids"]
616
- encoding["chunk_ids"] = example["chunk_ids"]
617
- encoding["input_ids"] = example["input_ids"]
618
- encoding["attention_mask"] = example["attention_mask"]
619
- encoding["bbox"] = example["normalized_bboxes"]
620
- encoding["images_pixels"] = example["images_pixels"]
621
-
622
- return encoding
623
-
624
- import torch.nn.functional as F
625
-
626
- # get predictions at token level
627
- def predictions_token_level(images, custom_encoded_dataset):
628
-
629
- num_imgs = len(images)
630
- if num_imgs > 0:
631
-
632
- chunk_ids, input_ids, bboxes, pixels_values, outputs, token_predictions = dict(), dict(), dict(), dict(), dict(), dict()
633
- images_ids_list = list()
634
-
635
- for i,encoding in enumerate(custom_encoded_dataset):
636
-
637
- # get custom encoded data
638
- image_id = encoding['images_ids']
639
- chunk_id = encoding['chunk_ids']
640
- input_id = torch.tensor(encoding['input_ids'])[None]
641
- attention_mask = torch.tensor(encoding['attention_mask'])[None]
642
- bbox = torch.tensor(encoding['bbox'])[None]
643
- pixel_values = torch.tensor(encoding["images_pixels"])
644
-
645
- # save data in dictionnaries
646
- if image_id not in images_ids_list: images_ids_list.append(image_id)
647
-
648
- if image_id in chunk_ids: chunk_ids[image_id].append(chunk_id)
649
- else: chunk_ids[image_id] = [chunk_id]
650
-
651
- if image_id in input_ids: input_ids[image_id].append(input_id)
652
- else: input_ids[image_id] = [input_id]
653
-
654
- if image_id in bboxes: bboxes[image_id].append(bbox)
655
- else: bboxes[image_id] = [bbox]
656
-
657
- if image_id in pixels_values: pixels_values[image_id].append(pixel_values)
658
- else: pixels_values[image_id] = [pixel_values]
659
-
660
- # get prediction with forward pass
661
- with torch.no_grad():
662
- output = model(
663
- input_ids=input_id.to(device),
664
- attention_mask=attention_mask.to(device),
665
- bbox=bbox.to(device),
666
- image=pixel_values.to(device)
667
- )
668
-
669
- # save probabilities of predictions in dictionnary
670
- if image_id in outputs: outputs[image_id].append(F.softmax(output.logits.squeeze(), dim=-1))
671
- else: outputs[image_id] = [F.softmax(output.logits.squeeze(), dim=-1)]
672
-
673
- return outputs, images_ids_list, chunk_ids, input_ids, bboxes
674
-
675
- else:
676
- print("An error occurred while getting predictions!")
677
-
678
- from functools import reduce
679
-
680
- # Get predictions (line level)
681
- def predictions_paragraph_level(dataset, outputs, images_ids_list, chunk_ids, input_ids, bboxes):
682
-
683
- ten_probs_dict, ten_input_ids_dict, ten_bboxes_dict = dict(), dict(), dict()
684
- bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = dict(), dict(), dict(), dict()
685
-
686
- if len(images_ids_list) > 0:
687
-
688
- for i, image_id in enumerate(images_ids_list):
689
-
690
- # get image information
691
- images_list = dataset.filter(lambda example: example["images_ids"] == image_id)["images"]
692
- image = images_list[0]
693
- width, height = image.size
694
-
695
- # get data
696
- chunk_ids_list = chunk_ids[image_id]
697
- outputs_list = outputs[image_id]
698
- input_ids_list = input_ids[image_id]
699
- bboxes_list = bboxes[image_id]
700
-
701
- # create zeros tensors
702
- ten_probs = torch.zeros((outputs_list[0].shape[0] - 2)*len(outputs_list), outputs_list[0].shape[1])
703
- ten_input_ids = torch.ones(size=(1, (outputs_list[0].shape[0] - 2)*len(outputs_list)), dtype =int)
704
- ten_bboxes = torch.zeros(size=(1, (outputs_list[0].shape[0] - 2)*len(outputs_list), 4), dtype =int)
705
-
706
- if len(outputs_list) > 1:
707
-
708
- for num_output, (output, input_id, bbox) in enumerate(zip(outputs_list, input_ids_list, bboxes_list)):
709
- start = num_output*(max_length - 2) - max(0,num_output)*doc_stride
710
- end = start + (max_length - 2)
711
-
712
- if num_output == 0:
713
- ten_probs[start:end,:] += output[1:-1]
714
- ten_input_ids[:,start:end] = input_id[:,1:-1]
715
- ten_bboxes[:,start:end,:] = bbox[:,1:-1,:]
716
- else:
717
- ten_probs[start:start + doc_stride,:] += output[1:1 + doc_stride]
718
- ten_probs[start:start + doc_stride,:] = ten_probs[start:start + doc_stride,:] * 0.5
719
- ten_probs[start + doc_stride:end,:] += output[1 + doc_stride:-1]
720
-
721
- ten_input_ids[:,start:start + doc_stride] = input_id[:,1:1 + doc_stride]
722
- ten_input_ids[:,start + doc_stride:end] = input_id[:,1 + doc_stride:-1]
723
-
724
- ten_bboxes[:,start:start + doc_stride,:] = bbox[:,1:1 + doc_stride,:]
725
- ten_bboxes[:,start + doc_stride:end,:] = bbox[:,1 + doc_stride:-1,:]
726
-
727
- else:
728
- ten_probs += outputs_list[0][1:-1]
729
- ten_input_ids = input_ids_list[0][:,1:-1]
730
- ten_bboxes = bboxes_list[0][:,1:-1]
731
-
732
- ten_probs_list, ten_input_ids_list, ten_bboxes_list = ten_probs.tolist(), ten_input_ids.tolist()[0], ten_bboxes.tolist()[0]
733
- bboxes_list = list()
734
- input_ids_dict, probs_dict = dict(), dict()
735
- bbox_prev = [-100, -100, -100, -100]
736
- for probs, input_id, bbox in zip(ten_probs_list, ten_input_ids_list, ten_bboxes_list):
737
- bbox = denormalize_box(bbox, width, height)
738
- if bbox != bbox_prev and bbox != cls_box and bbox != sep_box and bbox[0] != bbox[2] and bbox[1] != bbox[3]:
739
- bboxes_list.append(bbox)
740
- input_ids_dict[str(bbox)] = [input_id]
741
- probs_dict[str(bbox)] = [probs]
742
- elif bbox != cls_box and bbox != sep_box and bbox[0] != bbox[2] and bbox[1] != bbox[3]:
743
- input_ids_dict[str(bbox)].append(input_id)
744
- probs_dict[str(bbox)].append(probs)
745
- bbox_prev = bbox
746
-
747
- probs_bbox = dict()
748
- for i,bbox in enumerate(bboxes_list):
749
- probs = probs_dict[str(bbox)]
750
- probs = np.array(probs).T.tolist()
751
-
752
- probs_label = list()
753
- for probs_list in probs:
754
- prob_label = reduce(lambda x, y: x*y, probs_list)
755
- prob_label = prob_label**(1./(len(probs_list))) # normalization
756
- probs_label.append(prob_label)
757
- max_value = max(probs_label)
758
- max_index = probs_label.index(max_value)
759
- probs_bbox[str(bbox)] = max_index
760
-
761
- bboxes_list_dict[image_id] = bboxes_list
762
- input_ids_dict_dict[image_id] = input_ids_dict
763
- probs_dict_dict[image_id] = probs_bbox
764
-
765
- df[image_id] = pd.DataFrame()
766
- df[image_id]["bboxes"] = bboxes_list
767
- df[image_id]["texts"] = [tokenizer.decode(input_ids_dict[str(bbox)]) for bbox in bboxes_list]
768
- df[image_id]["labels"] = [id2label[probs_bbox[str(bbox)]] for bbox in bboxes_list]
769
-
770
- return probs_bbox, bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df
771
-
772
- else:
773
- print("An error occurred while getting predictions!")
774
-
775
- # Get labeled images with lines bounding boxes
776
- def get_labeled_images(dataset, images_ids_list, bboxes_list_dict, probs_dict_dict):
777
-
778
- labeled_images = list()
779
-
780
- for i, image_id in enumerate(images_ids_list):
781
-
782
- # get image
783
- images_list = dataset.filter(lambda example: example["images_ids"] == image_id)["images"]
784
- image = images_list[0]
785
- width, height = image.size
786
-
787
- # get predicted boxes and labels
788
- bboxes_list = bboxes_list_dict[image_id]
789
- probs_bbox = probs_dict_dict[image_id]
790
-
791
- draw = ImageDraw.Draw(image)
792
- # https://stackoverflow.com/questions/66274858/choosing-a-pil-imagefont-by-font-name-rather-than-filename-and-cross-platform-f
793
- font = font_manager.FontProperties(family='sans-serif', weight='bold')
794
- font_file = font_manager.findfont(font)
795
- font_size = 30
796
- font = ImageFont.truetype(font_file, font_size)
797
-
798
- for bbox in bboxes_list:
799
- predicted_label = id2label[probs_bbox[str(bbox)]]
800
- draw.rectangle(bbox, outline=label2color[predicted_label])
801
- draw.text((bbox[0] + 10, bbox[1] - font_size), text=predicted_label, fill=label2color[predicted_label], font=font)
802
-
803
- labeled_images.append(image)
804
-
805
- return labeled_images
806
-
807
- # get data of encoded chunk
808
- def get_encoded_chunk_inference(index_chunk=None):
809
-
810
- # get datasets
811
- example = dataset
812
- encoded_example = encoded_dataset
813
-
814
- # get randomly a document in dataset
815
- if index_chunk == None: index_chunk = random.randint(0, len(encoded_example)-1)
816
- encoded_example = encoded_example[index_chunk]
817
- encoded_image_ids = encoded_example["images_ids"]
818
-
819
- # get the image
820
- example = example.filter(lambda example: example["images_ids"] == encoded_image_ids)[0]
821
- image = example["images"] # original image
822
- width, height = image.size
823
- page_no = example["page_no"]
824
- num_pages = example["num_pages"]
825
-
826
- # get boxes, texts, categories
827
- bboxes, input_ids = encoded_example["normalized_bboxes"][1:-1], encoded_example["input_ids"][1:-1]
828
- bboxes = [denormalize_box(bbox, width, height) for bbox in bboxes]
829
- num_tokens = len(input_ids) + 2
830
-
831
- # get unique bboxes and corresponding labels
832
- bboxes_list, input_ids_list = list(), list()
833
- input_ids_dict = dict()
834
- bbox_prev = [-100, -100, -100, -100]
835
- for i, (bbox, input_id) in enumerate(zip(bboxes, input_ids)):
836
- if bbox != bbox_prev:
837
- bboxes_list.append(bbox)
838
- input_ids_dict[str(bbox)] = [input_id]
839
- else:
840
- input_ids_dict[str(bbox)].append(input_id)
841
-
842
- # start_indexes_list.append(i)
843
- bbox_prev = bbox
844
-
845
- # do not keep "</s><pad><pad>..."
846
- if input_ids_dict[str(bboxes_list[-1])][0] == (tokenizer.convert_tokens_to_ids('</s>')):
847
- del input_ids_dict[str(bboxes_list[-1])]
848
- bboxes_list = bboxes_list[:-1]
849
-
850
- # get texts by line
851
- input_ids_list = input_ids_dict.values()
852
- texts_list = [tokenizer.decode(input_ids) for input_ids in input_ids_list]
853
-
854
- # display DataFrame
855
- df = pd.DataFrame({"texts": texts_list, "input_ids": input_ids_list, "bboxes": bboxes_list})
856
-
857
- return image, df, num_tokens, page_no, num_pages
858
-
859
- # display chunk of PDF image and its data
860
- def display_chunk_paragraphs_inference(index_chunk=None):
861
-
862
- # get image and image data
863
- image, df, num_tokens, page_no, num_pages = get_encoded_chunk_inference(index_chunk=index_chunk)
864
-
865
- # get data from dataframe
866
- input_ids = df["input_ids"]
867
- texts = df["texts"]
868
- bboxes = df["bboxes"]
869
-
870
- print(f'Chunk ({num_tokens} tokens) of the PDF (page: {page_no+1} / {num_pages})\n')
871
-
872
- # display image with bounding boxes
873
- print(">> PDF image with bounding boxes of paragraphs\n")
874
- draw = ImageDraw.Draw(image)
875
-
876
- labels = list()
877
- for box, text in zip(bboxes, texts):
878
- color = "red"
879
- draw.rectangle(box, outline=color)
880
-
881
- # resize image to original
882
- width, height = image.size
883
- image = image.resize((int(0.5*width), int(0.5*height)))
884
-
885
- # convert to cv and display
886
- img = np.array(image, dtype='uint8') # PIL to cv2
887
- cv2_imshow(img)
888
- cv2.waitKey(0)
889
-
890
- # display image dataframe
891
- print("\n>> Dataframe of annotated paragraphs\n")
892
- cols = ["texts", "bboxes"]
893
- df = df[cols]
894
- display(df)
 
25
  from pypdf import PdfReader
26
  from pypdf.errors import PdfReadError
27
 
28
+ import pypdfium2 as pdfium
 
29
  import langdetect
30
  from langdetect import detect_langs
31
 
 
169
  label2id_layoutxlm = model_layoutxlm.config.label2id
170
  num_labels_layoutxlm = len(id2label_layoutxlm)
171