Spaces:
Running
Running
import asyncio | |
import functools | |
import uuid | |
from paddleocr import PaddleOCR, draw_ocr | |
from PIL import Image | |
import gradio as gr | |
LANG_CONFIG = { | |
"ch": {"num_workers": 4}, | |
"en": {"num_workers": 4}, | |
"fr": {"num_workers": 1}, | |
"german": {"num_workers": 1}, | |
"korean": {"num_workers": 1}, | |
"japan": {"num_workers": 1}, | |
} | |
CONCURRENCY_LIMIT = 8 | |
class PaddleOCRModelWrapper(object): | |
def __init__(self, model, name=None): | |
super().__init__() | |
self._model = model | |
self._name = name or self._get_random_name() | |
self._state = "IDLE" | |
def name(self): | |
return self._name | |
def state(self): | |
return self._state | |
def state(self, state): | |
self._state = state | |
def infer(self, **kwargs): | |
img_path = kwargs["img"] | |
result = self._model.ocr(**kwargs)[0] | |
image = Image.open(img_path).convert("RGB") | |
boxes = [line[0] for line in result] | |
txts = [line[1][0] for line in result] | |
scores = [line[1][1] for line in result] | |
im_show = draw_ocr(image, boxes, txts, scores, | |
font_path="./simfang.ttf") | |
return im_show | |
def _get_random_name(self): | |
return str(uuid.uuid4()) | |
class PaddleOCRModelManager(object): | |
def __init__(self, | |
num_models, | |
model_factory, | |
*, | |
polling_interval=0.1): | |
super().__init__() | |
self._num_models = num_models | |
self._model_factory = model_factory | |
self._polling_interval = polling_interval | |
self._models = {} | |
self.new_models() | |
def new_models(self): | |
self._models.clear() | |
for _ in range(self._num_models): | |
model = self._new_model() | |
self._models[model.name] = model | |
async def infer(self, **kwargs): | |
while True: | |
model = self._get_available_model() | |
if not model: | |
await asyncio.sleep(self._polling_interval) | |
continue | |
model.state = "RUNNING" | |
# NOTE: I take an optimistic approach here, assuming that the model | |
# is not broken even if inference fails. | |
try: | |
result = await self._new_inference_task(model, **kwargs) | |
finally: | |
model.state = "IDLE" | |
return result | |
def _new_model(self): | |
real_model = self._model_factory() | |
model = PaddleOCRModelWrapper(real_model) | |
return model | |
def _get_available_model(self): | |
if not self._models: | |
raise RuntimeError("No living models") | |
for model in self._models.values(): | |
if model.state == "IDLE": | |
return model | |
return None | |
def _new_inference_task(self, model, | |
**kwargs): | |
return asyncio.get_running_loop().run_in_executor( | |
None, functools.partial(model.infer, **kwargs)) | |
def create_model(lang): | |
return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False) | |
model_managers = {} | |
for lang, config in LANG_CONFIG.items(): | |
model_manager = PaddleOCRModelManager(config["num_workers"], functools.partial(create_model, lang=lang)) | |
model_managers[lang] = model_manager | |
async def inference(img, lang): | |
ocr = model_managers[lang] | |
result = await ocr.infer(img=img, cls=True) | |
return result | |
title = 'PaddleOCR' | |
description = ''' | |
- Gradio demo for PaddleOCR. PaddleOCR demo supports Chinese, English, French, German, Korean and Japanese. | |
- To use it, simply upload your image and choose a language from the dropdown menu, or click one of the examples to load them. Read more at the links below. | |
- [Docs](https://paddlepaddle.github.io/PaddleOCR/), [Github Repository](https://github.com/PaddlePaddle/PaddleOCR). | |
''' | |
examples = [ | |
['en_example.jpg','en'], | |
['cn_example.jpg','ch'], | |
['jp_example.jpg','japan'], | |
] | |
css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}" | |
gr.Interface( | |
inference, | |
[ | |
gr.Image(type='filepath', label='Input'), | |
gr.Dropdown(choices=list(LANG_CONFIG.keys()), value='en', label='language') | |
], | |
gr.Image(type='pil', label='Output'), | |
title=title, | |
description=description, | |
examples=examples, | |
cache_examples=False, | |
css=css, | |
concurrency_limit=CONCURRENCY_LIMIT, | |
).launch(debug=False) | |