import os import gradio as gr import omegaconf import torch import numpy import easyocr from PIL import Image from vietocr.model.transformerocr import VietOCR from vietocr.model.vocab import Vocab from vietocr.translate import translate, process_input reader = easyocr.Reader(['vi']) examples_data = os.listdir('examples') examples_data = [os.path.join('examples', line.split('\t')[0]) for line in examples_data] config = omegaconf.OmegaConf.load("vgg-seq2seq.yaml") config = omegaconf.OmegaConf.to_container(config, resolve=True) vocab = Vocab(config['vocab']) model = VietOCR(len(vocab), config['backbone'], config['cnn'], config['transformer'], config['seq_modeling']) model.load_state_dict(torch.load('train_old.pth', map_location=torch.device('cpu'))) def viet_ocr_predict(inp): img = process_input(inp, config['dataset']['image_height'], config['dataset']['image_min_width'], config['dataset']['image_max_width']) out = translate(img, model)[0].tolist() out = vocab.decode(out) return out def predict(filepath): bounds = reader.readtext(filepath) im = Image.open(filepath) inp = numpy.asarray(im) #inp = cv2.imread(filepath) width, height, _ = inp.shape if width>height: height, width, _ = inp.shape texts='' for (bbox, text, prob) in bounds: (tl, tr, br, bl) = bbox tl = (int(tl[0]), int(tl[1])) tr = (int(tr[0]), int(tr[1])) br = (int(br[0]), int(br[1])) bl = (int(bl[0]), int(bl[1])) min_x = min(tl[0], tr[0], br[0], bl[0]) min_x = max(0, min_x) max_x = max(tl[0], tr[0], br[0], bl[0]) max_x = min(width-1, max_x) min_y = min(tl[1], tr[1], br[1], bl[1]) min_y = max(0, min_y) max_y = max(tl[1], tr[1], br[1], bl[1]) max_y = min(height-1, max_y) # crop the region of interest (ROI) try: cropped_image = inp[min_y:max_y,min_x:max_x,:] # crop the image cropped_image = Image.fromarray(cropped_image) out = viet_ocr_predict(cropped_image) except: out = text print(out) texts = texts + '\t' + out return texts gr.Interface(fn=predict, title='Vietnamese Handwriting Recognition', inputs=gr.Image(type='filepath'), outputs=gr.Text(), #examples=examples_data, ).launch()