File size: 3,709 Bytes
0bbf6ef
 
 
e91a768
0380162
649e38b
0940e5a
649e38b
b0e8452
3cadd69
fdedf54
0bbf6ef
 
43d306c
e91a768
 
 
 
43d306c
ec2e6e8
c65777e
 
 
 
 
e91a768
 
 
 
c65777e
 
 
 
e91a768
b296597
649e38b
0940e5a
 
fedf52b
 
 
 
 
649e38b
 
 
0940e5a
 
649e38b
 
 
 
 
 
fedf52b
 
3cadd69
4504622
3cadd69
 
0940e5a
3cadd69
 
 
 
 
c65777e
3cadd69
 
ec2e6e8
4504622
 
e1af4a6
649e38b
0940e5a
fedf52b
649e38b
 
0940e5a
649e38b
 
 
 
0940e5a
649e38b
89415f2
e91a768
0bbf6ef
 
 
e91a768
 
 
6c215ad
e91a768
0bbf6ef
 
 
 
03c8fc6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
from ultralytics import YOLO
import numpy as np
import fitz  # PyMuPDF
import spaces
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import cpu_count
import cv2

# Load the trained model
model_path = 'best.pt'  # Replace with the path to your trained .pt file
model = YOLO(model_path)

# Define the class indices for figures and tables
figure_class_index = 3  # class index for figures
table_class_index = 4   # class index for tables

# Function to perform inference on an image and return bounding boxes for figures and tables
def infer_image_and_get_boxes(image, confidence_threshold=0.6):
    results = model(image)
    boxes = [
        (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]))
        for result in results for box in result.boxes
        if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
    ]
    return boxes

# Function to crop images from the boxes
def crop_images_from_boxes(image, boxes, scale_factor):
    cropped_images = [
        image[int(y1 * scale_factor):int(y2 * scale_factor), int(x1 * scale_factor):int(x2 * scale_factor)]
        for (x1, y1, x2, y2) in boxes
    ]
    return cropped_images

# Function to process a single page's low-resolution image and perform inference
def process_low_res_page(page_num, low_res_pix, scale_factor, doc_path):
    doc = fitz.open(doc_path)
    low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
    
    # Get bounding boxes from low DPI image
    boxes = infer_image_and_get_boxes(low_res_img)
    
    return page_num, boxes

# Function to process a single page's high-resolution image for cropping
def process_high_res_page(page_num, boxes, scale_factor, doc_path):
    doc = fitz.open(doc_path)
    high_res_pix = doc[page_num].get_pixmap(dpi=high_dpi)
    high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
    
    # Crop images at high DPI
    cropped_imgs = crop_images_from_boxes(high_res_img, boxes, scale_factor)
    
    return cropped_imgs

@spaces.GPU
def process_pdf(pdf_file):
    # Open the PDF file
    doc = fitz.open(pdf_file)
    doc_path = pdf_file.name
    all_cropped_images = []

    # Set the DPI for inference and high resolution for cropping
    low_dpi = 50
    high_dpi = 300

    # Calculate the scaling factor
    scale_factor = high_dpi / low_dpi

    # Pre-cache all page pixmaps at low DPI
    low_res_pixmaps = [page.get_pixmap(dpi=low_dpi) for page in doc]
    
    # Prepare arguments for threading
    args_low_res = [(page_num, low_res_pix, scale_factor, doc_path) for page_num, low_res_pix in enumerate(low_res_pixmaps)]

    # Process low-res pages concurrently using threading to get bounding boxes
    with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
        low_res_results = list(executor.map(lambda args: process_low_res_page(*args), args_low_res))
    
    # Sequentially process high-res pages to crop images
    for page_num, boxes in low_res_results:
        if boxes:
            cropped_imgs = process_high_res_page(page_num, boxes, scale_factor, doc_path)
            all_cropped_images.extend(cropped_imgs)
    
    return all_cropped_images

# Create Gradio interface
iface = gr.Interface(
    fn=process_pdf,
    inputs=gr.File(label="Upload a PDF"),
    outputs=gr.Gallery(label="Cropped Figures and Tables from PDF Pages"),
    title="Fast document layout analysis based on YOLOv8",
    description="Upload a PDF file to get cropped figures and tables from each page."
)

# Launch the app
iface.launch()