File size: 2,247 Bytes
b32257c 9488f66 b32257c cf95585 b32257c 9488f66 b32257c fa8925c b32257c 113bcef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
import gradio as gr
import omegaconf
import torch
import cv2
import easyocr
from PIL import Image
from vietocr.model.transformerocr import VietOCR
from vietocr.model.vocab import Vocab
from vietocr.translate import translate, process_input
reader = easyocr.Reader(['vi'])
examples_data = os.listdir('examples')
examples_data = [os.path.join('examples', line.split('\t')[0]) for line in examples_data]
config = omegaconf.OmegaConf.load("vgg-seq2seq.yaml")
config = omegaconf.OmegaConf.to_container(config, resolve=True)
vocab = Vocab(config['vocab'])
model = VietOCR(len(vocab),
config['backbone'],
config['cnn'],
config['transformer'],
config['seq_modeling'])
model.load_state_dict(torch.load('train_old.pth', map_location=torch.device('cpu')))
def viet_ocr_predict(inp):
img = process_input(inp, config['dataset']['image_height'],
config['dataset']['image_min_width'], config['dataset']['image_max_width'])
out = translate(img, model)[0].tolist()
out = vocab.decode(out)
return out
def predict(filepath):
bounds = reader.readtext(filepath)
inp = cv2.imread(filepath)
width, height, _ = inp.shape
texts=''
for (bbox, text, prob) in bounds:
(tl, tr, br, bl) = bbox
tl = (int(tl[0]), int(tl[1]))
tr = (int(tr[0]), int(tr[1]))
br = (int(br[0]), int(br[1]))
bl = (int(bl[0]), int(bl[1]))
min_x = min(tl[0], tr[0], br[0], bl[0])
min_x = max(0, min_x)
max_x = max(tl[0], tr[0], br[0], bl[0])
max_x = min(width-1, max_x)
min_y = min(tl[1], tr[1], br[1], bl[1])
min_y = max(0, min_y)
max_y = max(tl[1], tr[1], br[1], bl[1])
max_y = min(height-1, max_y)
# crop the region of interest (ROI)
cropped_image = inp[min_y:max_y,min_x:max_x] # crop the image
cropped_image = Image.fromarray(cropped_image)
out = viet_ocr_predict(cropped_image)
texts = texts + '\t' + out
return texts
gr.Interface(fn=predict,
title='Vietnamese Handwriting Recognition',
inputs=gr.Image(type='filepath'),
outputs=gr.Text(),
examples=examples_data,
).launch() |