DocTRv2 / vietocr /model /transformerocr.py
hantech's picture
Upload 28 files
0667c13 verified
raw
history blame
1.62 kB
from vietocr.model.backbone.cnn import CNN
from vietocr.model.seqmodel.transformer import LanguageTransformer
from vietocr.model.seqmodel.seq2seq import Seq2Seq
from vietocr.model.seqmodel.convseq2seq import ConvSeq2Seq
from torch import nn
class VietOCR(nn.Module):
def __init__(self, vocab_size,
backbone,
cnn_args,
transformer_args, seq_modeling='transformer'):
super(VietOCR, self).__init__()
self.cnn = CNN(backbone, **cnn_args)
self.seq_modeling = seq_modeling
if seq_modeling == 'transformer':
self.transformer = LanguageTransformer(vocab_size, **transformer_args)
elif seq_modeling == 'seq2seq':
self.transformer = Seq2Seq(vocab_size, **transformer_args)
elif seq_modeling == 'convseq2seq':
self.transformer = ConvSeq2Seq(vocab_size, **transformer_args)
else:
raise('Not Support Seq Model')
def forward(self, img, tgt_input, tgt_key_padding_mask):
"""
Shape:
- img: (N, C, H, W)
- tgt_input: (T, N)
- tgt_key_padding_mask: (N, T)
- output: b t v
"""
src = self.cnn(img)
if self.seq_modeling == 'transformer':
outputs = self.transformer(src, tgt_input, tgt_key_padding_mask=tgt_key_padding_mask)
elif self.seq_modeling == 'seq2seq':
outputs = self.transformer(src, tgt_input)
elif self.seq_modeling == 'convseq2seq':
outputs = self.transformer(src, tgt_input)
return outputs