VietOCR / vietocr /translate.py
nhay103's picture
update app
a26e5f8
raw
history blame contribute delete
2.01 kB
import torch
import numpy as np
import math
from PIL import Image
from torch.nn.functional import softmax
def translate(img, model, max_seq_length=128, sos_token=1, eos_token=2):
"data: BxCXHxW"
model.eval()
with torch.no_grad():
src = model.cnn(img)
memory = model.transformer.forward_encoder(src)
translated_sentence = [[sos_token]*len(img)]
max_length = 0
while max_length <= max_seq_length and not all(np.any(np.asarray(translated_sentence).T==eos_token, axis=1)):
tgt_inp = torch.LongTensor(translated_sentence)
output, memory = model.transformer.forward_decoder(tgt_inp, memory)
output = softmax(output, dim=-1)
_, indices = torch.topk(output, 5)
indices = indices[:, -1, 0]
indices = indices.tolist()
translated_sentence.append(indices)
max_length += 1
translated_sentence = np.asarray(translated_sentence).T
return translated_sentence
def resize(w, h, expected_height, image_min_width, image_max_width):
new_w = int(expected_height * float(w) / float(h))
round_to = 10
new_w = math.ceil(new_w/round_to)*round_to
new_w = max(new_w, image_min_width)
new_w = min(new_w, image_max_width)
return new_w, expected_height
def process_image(image, image_height, image_min_width, image_max_width):
img = image.convert('RGB')
w, h = img.size
new_w, image_height = resize(w, h, image_height, image_min_width, image_max_width)
img = img.resize((new_w, image_height), Image.Resampling.LANCZOS)
img = np.asarray(img).transpose(2,0, 1)
img = img/255
return img
def process_input(image, image_height, image_min_width, image_max_width):
img = process_image(image, image_height, image_min_width, image_max_width)
img = img[np.newaxis, ...]
img = torch.FloatTensor(img)
return img