File size: 1,216 Bytes
113bcef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5923d4
113bcef
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
import gradio as gr
import omegaconf
import torch
from vietocr.model.transformerocr import VietOCR
from vietocr.model.vocab import Vocab
from vietocr.translate import translate, process_input

examples_data = os.listdir('examples')
examples_data = [os.path.join('examples', line.split('\t')[0]) for line in examples_data]

config = omegaconf.OmegaConf.load("vgg-seq2seq.yaml")
config = omegaconf.OmegaConf.to_container(config, resolve=True)

vocab = Vocab(config['vocab'])
model = VietOCR(len(vocab),
        config['backbone'],
        config['cnn'], 
        config['transformer'],
        config['seq_modeling'])
model.load_state_dict(torch.load('train_old.pth', map_location=torch.device('cpu')))

def predict(inp):
    img = process_input(inp, config['dataset']['image_height'], 
                    config['dataset']['image_min_width'], config['dataset']['image_max_width'])
    out = translate(img, model)[0].tolist()
    out = vocab.decode(out)
    return out

gr.Interface(fn=predict,
             title='Vietnamese Handwriting Recognition',
             inputs=gr.Image(type='pil'),
             outputs=gr.Text(),
             examples=examples_data,
).launch()