File size: 2,307 Bytes
b32257c
 
 
 
9488f66
 
b32257c
 
 
 
 
 
 
cf95585
b32257c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9488f66
 
 
1dc65c4
 
 
b32257c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa8925c
cb848ef
b32257c
 
 
 
 
 
 
 
 
 
 
113bcef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import gradio as gr
import omegaconf
import torch
import cv2

import easyocr
from PIL import Image

from vietocr.model.transformerocr import VietOCR
from vietocr.model.vocab import Vocab
from vietocr.translate import translate, process_input

reader = easyocr.Reader(['vi'])

examples_data = os.listdir('examples')
examples_data = [os.path.join('examples', line.split('\t')[0]) for line in examples_data]

config = omegaconf.OmegaConf.load("vgg-seq2seq.yaml")
config = omegaconf.OmegaConf.to_container(config, resolve=True)

vocab = Vocab(config['vocab'])
model = VietOCR(len(vocab),
        config['backbone'],
        config['cnn'], 
        config['transformer'],
        config['seq_modeling'])
model.load_state_dict(torch.load('train_old.pth', map_location=torch.device('cpu')))
def viet_ocr_predict(inp):
    img = process_input(inp, config['dataset']['image_height'], 
                    config['dataset']['image_min_width'], config['dataset']['image_max_width'])
    out = translate(img, model)[0].tolist()
    out = vocab.decode(out)
    return out
def predict(filepath):
    bounds = reader.readtext(filepath)
    inp = cv2.imread(filepath)

    width, height, _ = inp.shape
    if width>height:
        height, width, _ = img.shape

    texts=''
    for (bbox, text, prob) in bounds:
        (tl, tr, br, bl) = bbox
        tl = (int(tl[0]), int(tl[1]))
        tr = (int(tr[0]), int(tr[1]))
        br = (int(br[0]), int(br[1]))
        bl = (int(bl[0]), int(bl[1]))

        min_x = min(tl[0], tr[0], br[0], bl[0])
        min_x = max(0, min_x)
        max_x = max(tl[0], tr[0], br[0], bl[0])
        max_x = min(width-1, max_x)
        min_y = min(tl[1], tr[1], br[1], bl[1])
        min_y = max(0, min_y)
        max_y = max(tl[1], tr[1], br[1], bl[1])
        max_y = min(height-1, max_y)
        # crop the region of interest (ROI)
        cropped_image = inp[min_y:max_y,min_x:max_x] # crop the image
        #cropped_image = Image.fromarray(cropped_image)
        out = viet_ocr_predict(cropped_image)
        
        texts = texts + '\t' + out
    
    return texts

gr.Interface(fn=predict,
             title='Vietnamese Handwriting Recognition',
             inputs=gr.Image(type='filepath'),
             outputs=gr.Text(),
             examples=examples_data,
).launch()