File size: 4,792 Bytes
230c9a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import json
import random
from PIL import Image, ImageDraw
from pdf_extract_kit.registry.registry import TASK_REGISTRY
from pdf_extract_kit.utils.data_preprocess import load_pdf
from pdf_extract_kit.tasks.base_task import BaseTask


@TASK_REGISTRY.register("ocr")
class OCRTask(BaseTask):
    def __init__(self, model):
        """init the task based on the given model.

        

        Args:

            model: task model, must contains predict function.

        """
        super().__init__(model)

    def predict_image(self, image):
        """predict on one image, reture text detection and recognition results.

        

        Args:

            image: PIL.Image.Image, (if the model.predict function support other types, remenber add change-format-function in model.predict)

            

        Returns:

            List[dict]: list of text bbox with it's content

            

        Return example:

            [

                {

                    "category_type": "text",

                    "poly": [

                        380.6792698635707,

                        159.85058512958923,

                        765.1419999999998,

                        159.85058512958923,

                        765.1419999999998,

                        192.51073013642917,

                        380.6792698635707,

                        192.51073013642917

                    ],

                    "text": "this is an example text",

                    "score": 0.97

                },

                ...

            ]

        """
        return self.model.predict(image)
        
    def prepare_input_files(self, input_path):
        if os.path.isdir(input_path):
            file_list = [os.path.join(input_path, fname) for fname in os.listdir(input_path)]
        else:
            file_list = [input_path]
        return file_list
            
    def process(self, input_path, save_dir=None, visualize=False):
        file_list = self.prepare_input_files(input_path)
        res_list = []
        for fpath in file_list:
            basename = os.path.basename(fpath)[:-4]
            if fpath.endswith(".pdf") or fpath.endswith(".PDF"):
                images = load_pdf(fpath)
                pdf_res = []
                for page, img in enumerate(images):
                    page_res = self.predict_image(img)
                    pdf_res.append(page_res)
                    if save_dir:
                        os.makedirs(os.path.join(save_dir, basename), exist_ok=True)
                        self.save_json_result(page_res, os.path.join(save_dir, basename, f"page_{page+1}.json"))
                        if visualize:
                            self.visualize_image(img, page_res, os.path.join(save_dir, basename, f"page_{page+1}.jpg"))
                        
                res_list.append(pdf_res)
            else:
                image = Image.open(fpath)
                img_res = self.predict_image(image)
                res_list.append(img_res)
                if save_dir:
                    os.makedirs(save_dir, exist_ok=True)
                    self.save_json_result(img_res, os.path.join(save_dir, f"{basename}.json"))
                    if visualize:
                        self.visualize_image(image, img_res, os.path.join(save_dir, f"{basename}.png"))
                
        return res_list
    
    def visualize_image(self, image, ocr_res, save_path="", cate2color={}):
        """plot each result's bbox and category on image.

        

        Args:

            image: PIL.Image.Image

            ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function

            save_path: path to save visualized image

        """
        draw = ImageDraw.Draw(image)
        for res in ocr_res:
            box_color = cate2color.get(res['category_type'], (0, 255, 0))
            x_min, y_min = int(res['poly'][0]), int(res['poly'][1])
            x_max, y_max = int(res['poly'][4]), int(res['poly'][5])
            draw.rectangle([x_min, y_min, x_max, y_max], fill=None, outline=box_color, width=1)
            draw.text((x_min, y_min), res['category_type'], (255, 0, 0))
        if save_path:
            image.save(save_path)
        
    def save_json_result(self, ocr_res, save_path):
        """save results to a json file.

        

        Args:

            ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function

            save_path: path to save visualized image

        """
        with open(save_path, "w", encoding="utf-8") as f:
            f.write(json.dumps(ocr_res, indent=2, ensure_ascii=False))