File size: 4,872 Bytes
21db53c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from time import time

import numpy as np
import torch
from PIL import Image
from loguru import logger

from app.Services.lifespan_service import LifespanService
from app.config import config


class OCRService(LifespanService):
    def __init__(self):
        self._device = config.device
        if self._device == "auto":
            self._device = "cuda" if torch.cuda.is_available() else "cpu"

    @staticmethod
    def _image_preprocess(img: Image.Image) -> Image.Image:
        if img.mode != 'RGB':
            img = img.convert('RGB')
        if img.size[0] > 1024 or img.size[1] > 1024:
            img.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
        new_img = Image.new('RGB', (1024, 1024), (0, 0, 0))
        new_img.paste(img, ((1024 - img.size[0]) // 2, (1024 - img.size[1]) // 2))
        return new_img

    def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
        pass


class EasyPaddleOCRService(OCRService):
    def __init__(self):
        super().__init__()
        from easypaddleocr import EasyPaddleOCR
        self._paddle_ocr_module = EasyPaddleOCR(use_angle_cls=True,
                                                needWarmUp=True,
                                                devices=self._device,
                                                warmup_size=(960, 960),
                                                model_local_dir=config.model.easypaddleocr if
                                                config.model.easypaddleocr else None)
        logger.success("EasyPaddleOCR loaded successfully")

    @staticmethod
    def _image_preprocess(img: Image.Image) -> Image.Image:
        # Optimized `easypaddleocr` doesn't require scaling preprocess
        if img.mode != 'RGB':
            img = img.convert('RGB')
        return img

    def _easy_paddleocr_process(self, img: Image.Image) -> str:
        _, ocr_result, _ = self._paddle_ocr_module.ocr(np.array(img))
        if ocr_result:
            return "".join(itm[0] for itm in ocr_result if float(itm[1]) > config.ocr_search.ocr_min_confidence)
        return ""

    def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
        start_time = time()
        logger.info("Processing text with EasyPaddleOCR...")
        res = self._easy_paddleocr_process(self._image_preprocess(img) if need_preprocess else img)
        logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time)
        return res


class EasyOCRService(OCRService):
    def __init__(self):
        super().__init__()
        # noinspection PyPackageRequirements
        import easyocr  # pylint: disable=import-error
        self._easy_ocr_module = easyocr.Reader(config.ocr_search.ocr_language,
                                               gpu=self._device == "cuda")
        logger.success("easyOCR loaded successfully")

    def _easyocr_process(self, img: Image.Image) -> str:
        ocr_result = self._easy_ocr_module.readtext(np.array(img))
        return " ".join(itm[1] for itm in ocr_result if itm[2] > config.ocr_search.ocr_min_confidence)

    def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
        start_time = time()
        logger.info("Processing text with easyOCR...")
        res = self._easyocr_process(self._image_preprocess(img) if need_preprocess else img)
        logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time)
        return res


class PaddleOCRService(OCRService):
    def __init__(self):
        super().__init__()
        # noinspection PyPackageRequirements
        import paddleocr  # pylint: disable=import-error
        self._paddle_ocr_module = paddleocr.PaddleOCR(lang="ch", use_angle_cls=True,
                                                      use_gpu=self._device == "cuda")
        logger.success("PaddleOCR loaded successfully")

    def _paddleocr_process(self, img: Image.Image) -> str:
        ocr_result = self._paddle_ocr_module.ocr(np.array(img), cls=True)
        if ocr_result[0]:
            return "".join(itm[1][0] for itm in ocr_result[0] if itm[1][1] > config.ocr_search.ocr_min_confidence)
        return ""

    def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
        start_time = time()
        logger.info("Processing text with PaddleOCR...")
        res = self._paddleocr_process(self._image_preprocess(img) if need_preprocess else img)
        logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time)
        return res


class DisabledOCRService(OCRService):
    def __init__(self):
        super().__init__()
        logger.warning("OCR search is disabled. Skipping OCR model loading.")

    def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
        raise NotImplementedError("OCR module is disabled. Consider enable it in config.")