ignaciaginting's picture
Upload 396 files
230c9a6 verified
import os
import json
import random
from PIL import Image, ImageDraw
from pdf_extract_kit.registry.registry import TASK_REGISTRY
from pdf_extract_kit.utils.data_preprocess import load_pdf
from pdf_extract_kit.tasks.base_task import BaseTask
@TASK_REGISTRY.register("ocr")
class OCRTask(BaseTask):
def __init__(self, model):
"""init the task based on the given model.
Args:
model: task model, must contains predict function.
"""
super().__init__(model)
def predict_image(self, image):
"""predict on one image, reture text detection and recognition results.
Args:
image: PIL.Image.Image, (if the model.predict function support other types, remenber add change-format-function in model.predict)
Returns:
List[dict]: list of text bbox with it's content
Return example:
[
{
"category_type": "text",
"poly": [
380.6792698635707,
159.85058512958923,
765.1419999999998,
159.85058512958923,
765.1419999999998,
192.51073013642917,
380.6792698635707,
192.51073013642917
],
"text": "this is an example text",
"score": 0.97
},
...
]
"""
return self.model.predict(image)
def prepare_input_files(self, input_path):
if os.path.isdir(input_path):
file_list = [os.path.join(input_path, fname) for fname in os.listdir(input_path)]
else:
file_list = [input_path]
return file_list
def process(self, input_path, save_dir=None, visualize=False):
file_list = self.prepare_input_files(input_path)
res_list = []
for fpath in file_list:
basename = os.path.basename(fpath)[:-4]
if fpath.endswith(".pdf") or fpath.endswith(".PDF"):
images = load_pdf(fpath)
pdf_res = []
for page, img in enumerate(images):
page_res = self.predict_image(img)
pdf_res.append(page_res)
if save_dir:
os.makedirs(os.path.join(save_dir, basename), exist_ok=True)
self.save_json_result(page_res, os.path.join(save_dir, basename, f"page_{page+1}.json"))
if visualize:
self.visualize_image(img, page_res, os.path.join(save_dir, basename, f"page_{page+1}.jpg"))
res_list.append(pdf_res)
else:
image = Image.open(fpath)
img_res = self.predict_image(image)
res_list.append(img_res)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
self.save_json_result(img_res, os.path.join(save_dir, f"{basename}.json"))
if visualize:
self.visualize_image(image, img_res, os.path.join(save_dir, f"{basename}.png"))
return res_list
def visualize_image(self, image, ocr_res, save_path="", cate2color={}):
"""plot each result's bbox and category on image.
Args:
image: PIL.Image.Image
ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function
save_path: path to save visualized image
"""
draw = ImageDraw.Draw(image)
for res in ocr_res:
box_color = cate2color.get(res['category_type'], (0, 255, 0))
x_min, y_min = int(res['poly'][0]), int(res['poly'][1])
x_max, y_max = int(res['poly'][4]), int(res['poly'][5])
draw.rectangle([x_min, y_min, x_max, y_max], fill=None, outline=box_color, width=1)
draw.text((x_min, y_min), res['category_type'], (255, 0, 0))
if save_path:
image.save(save_path)
def save_json_result(self, ocr_res, save_path):
"""save results to a json file.
Args:
ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function
save_path: path to save visualized image
"""
with open(save_path, "w", encoding="utf-8") as f:
f.write(json.dumps(ocr_res, indent=2, ensure_ascii=False))