File size: 13,241 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import os
from fairseq import search

from fairseq import scoring, utils, metrics
from fairseq.data import Dictionary, encoders
from fairseq.tasks import LegacyFairseqTask, register_task
from fairseq.tasks.fairseq_task import FairseqTask

try:
    from .data import SROIETextRecognitionDataset, Receipt53KDataset, SyntheticTextRecognitionDataset
    from .data_aug import build_data_aug, OptForDataAugment, DataAugment
except ImportError:
    from data import SROIETextRecognitionDataset, Receipt53KDataset, SyntheticTextRecognitionDataset
    from data_aug import build_data_aug, OptForDataAugment, DataAugment

import logging
import torch


logger = logging.getLogger(__name__)


@register_task('text_recognition')
class TextRecognitionTask(LegacyFairseqTask):
    
    @staticmethod
    def add_args(parser):
        parser.add_argument('data', metavar='DIR',
                            help='the path to the data dir')         

        parser.add_argument('--reset-dictionary', action='store_true',
                            help='if reset dictionary and related parameters')
        parser.add_argument('--adapt-dictionary', action='store_true',
                            help='if adapt dictionary and related parameters')
        parser.add_argument('--adapt-encoder-pos-embed', action='store_true',
                            help='if adapt encoder pos embed')

        parser.add_argument('--add-empty-sample', action='store_true',
                            help='add empty samples to the dataset (for multilingual dataset).')
        parser.add_argument('--preprocess', default='ResizeNormalize', type=str,
                            help='the image preprocess methods (ResizeNormalize|DeiT)')     
        parser.add_argument('--decoder-pretrained', default=None, type=str,
                            help='seted to load the RoBERTa parameters to the decoder.')    
        parser.add_argument('--decoder-pretrained-url', default=None, type=str,
                            help='the ckpt url for decoder pretraining (only unilm for now)')     
        parser.add_argument('--dict-path-or-url', default=None, type=str,
                            help='the local path or url for dictionary file')                          
        parser.add_argument('--input-size', type=int, nargs='+', help='images input size', required=True)
        parser.add_argument('--data-type', type=str, default='SROIE',
                            help='the dataset type used for the task (SROIE or Receipt53K)')        

        # Augmentation parameters
        parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                            help='Color jitter factor (default: 0.4)')
        parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                            help='Use AutoAugment policy. "v0" or "original". " + \
                                "(default: rand-m9-mstd0.5-inc1)'),
        parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
        parser.add_argument('--train-interpolation', type=str, default='bicubic',
                            help='Training interpolation (random, bilinear, bicubic default: "bicubic")')

        parser.add_argument('--repeated-aug', action='store_true')
        parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
        parser.set_defaults(repeated_aug=True)

        # * Random Erase params
        parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                            help='Random erase prob (default: 0.25)')
        parser.add_argument('--remode', type=str, default='pixel',
                            help='Random erase mode (default: "pixel")')
        parser.add_argument('--recount', type=int, default=1,
                            help='Random erase count (default: 1)')
        parser.add_argument('--resplit', action='store_true', default=False,
                            help='Do not random erase first (clean) augmentation split')
                   
    @classmethod
    def setup_task(cls, args, **kwargs):
        import urllib.request
        import io            

        if getattr(args, "dict_path_or_url", None) is not None:
            if args.dict_path_or_url.startswith('http'):
                logger.info('Load dictionary from {}'.format(args.dict_path_or_url))  
                dict_content = urllib.request.urlopen(args.dict_path_or_url).read().decode()
                dict_file_like = io.StringIO(dict_content)
                target_dict = Dictionary.load(dict_file_like)
            else:
                target_dict = Dictionary.load(args.dict_path_or_url)      
        elif getattr(args, "decoder_pretrained", None) is not None:
            if args.decoder_pretrained == 'unilm':            
                url = 'https://layoutlm.blob.core.windows.net/trocr/dictionaries/unilm3.dict.txt'
                logger.info('Load unilm dictionary from {}'.format(url))            
                dict_content = urllib.request.urlopen(url).read().decode()
                dict_file_like = io.StringIO(dict_content)
                target_dict = Dictionary.load(dict_file_like)
            elif args.decoder_pretrained.startswith('roberta'):
                url = 'https://layoutlm.blob.core.windows.net/trocr/dictionaries/gpt2_with_mask.dict.txt'
                logger.info('Load gpt2 dictionary from {}'.format(url))            
                dict_content = urllib.request.urlopen(url).read().decode()
                dict_file_like = io.StringIO(dict_content)
                target_dict = Dictionary.load(dict_file_like)
            else:
                raise ValueError('Unknown decoder_pretrained: {}'.format(args.decoder_pretrained))
        else:
            raise ValueError('Either dict_path_or_url or decoder_pretrained should be set.')
          
        logger.info('[label] load dictionary: {} types'.format(len(target_dict)))

        return cls(args, target_dict)

    def __init__(self, args, target_dict):

        super().__init__(args)
        self.args = args
        self.data_dir = args.data            
        self.target_dict = target_dict
        if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
            torch.distributed.barrier()
        self.bpe = self.build_bpe(args)
        if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] == '0':
            torch.distributed.barrier()

    def load_dataset(self, split, **kwargs):

        input_size = self.args.input_size         
        if isinstance(input_size, list):
            if len(input_size) == 1:
                input_size = (input_size[0], input_size[0])
            else:
                input_size = tuple(input_size)
        elif isinstance(input_size, int):
            input_size = (input_size, input_size)

        logger.info('The input size is {}, the height is {} and the width is {}'.format(input_size, input_size[0], input_size[1]))

        if self.args.preprocess == 'DA2':            
            tfm = build_data_aug(input_size, mode=split)     
        elif self.args.preprocess == 'RandAugment':
            opt = OptForDataAugment(eval= (split != 'train'), isrand_aug=True, imgW=input_size[1], imgH=input_size[0], intact_prob=0.5, augs_num=3, augs_mag=None)
            tfm = DataAugment(opt)
        else:
            raise Exception('Undeined image preprocess method.')        
        
        # load the dataset
        if self.args.data_type == 'SROIE':
            root_dir = os.path.join(self.data_dir, split)
            self.datasets[split] = SROIETextRecognitionDataset(root_dir, tfm, self.bpe, self.target_dict)        
        elif self.args.data_type == 'Receipt53K':
            gt_path = os.path.join(self.data_dir, 'gt_{}.txt'.format(split))            
            self.datasets[split] = Receipt53KDataset(gt_path, tfm, self.bpe, self.target_dict)
        elif self.args.data_type == 'STR':
            gt_path = os.path.join(self.data_dir, 'gt_{}.txt'.format(split))            
            self.datasets[split] = SyntheticTextRecognitionDataset(gt_path, tfm, self.bpe, self.target_dict)
        else:
            raise Exception('Not defined dataset type: ' + self.args.data_type)
    
    @property
    def source_dictionary(self):
        return None

    @property
    def target_dictionary(self):        
        return self.target_dict

    def build_generator(
        self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
    ):

        if getattr(args, "score_reference", False):
            from fairseq.sequence_scorer import SequenceScorer

            return SequenceScorer(
                self.target_dictionary,
                compute_alignment=getattr(args, "print_alignment", False),
            )

        from fairseq.sequence_generator import (
            SequenceGenerator,
            SequenceGeneratorWithAlignment,
        )
        try:
            from .generator import TextRecognitionGenerator
        except:
            from generator import TextRecognitionGenerator

        try:
            from fairseq.fb_sequence_generator import FBSequenceGenerator
        except ModuleNotFoundError:
            pass

        # Choose search strategy. Defaults to Beam Search.
        sampling = getattr(args, "sampling", False)
        sampling_topk = getattr(args, "sampling_topk", -1)
        sampling_topp = getattr(args, "sampling_topp", -1.0)
        diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
        diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
        match_source_len = getattr(args, "match_source_len", False)
        diversity_rate = getattr(args, "diversity_rate", -1)
        constrained = getattr(args, "constraints", False)
        prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
        if (
            sum(
                int(cond)
                for cond in [
                    sampling,
                    diverse_beam_groups > 0,
                    match_source_len,
                    diversity_rate > 0,
                ]
            )
            > 1
        ):
            raise ValueError("Provided Search parameters are mutually exclusive.")
        assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
        assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"

        if sampling:
            search_strategy = search.Sampling(
                self.target_dictionary, sampling_topk, sampling_topp
            )
        elif diverse_beam_groups > 0:
            search_strategy = search.DiverseBeamSearch(
                self.target_dictionary, diverse_beam_groups, diverse_beam_strength
            )
        elif match_source_len:
            # this is useful for tagging applications where the output
            # length should match the input length, so we hardcode the
            # length constraints for simplicity
            search_strategy = search.LengthConstrainedBeamSearch(
                self.target_dictionary,
                min_len_a=1,
                min_len_b=0,
                max_len_a=1,
                max_len_b=0,
            )
        elif diversity_rate > -1:
            search_strategy = search.DiverseSiblingsSearch(
                self.target_dictionary, diversity_rate
            )
        elif constrained:
            search_strategy = search.LexicallyConstrainedBeamSearch(
                self.target_dictionary, args.constraints
            )
        elif prefix_allowed_tokens_fn:
            search_strategy = search.PrefixConstrainedBeamSearch(
                self.target_dictionary, prefix_allowed_tokens_fn
            )
        else:
            search_strategy = search.BeamSearch(self.target_dictionary)

        extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
        if seq_gen_cls is None:
            if getattr(args, "print_alignment", False):
                seq_gen_cls = SequenceGeneratorWithAlignment
                extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
            elif getattr(args, "fb_seq_gen", False):
                seq_gen_cls = FBSequenceGenerator
            else:
                seq_gen_cls = TextRecognitionGenerator

        return seq_gen_cls(
            models,
            self.target_dictionary,
            beam_size=getattr(args, "beam", 5),
            max_len_a=getattr(args, "max_len_a", 0),
            max_len_b=getattr(args, "max_len_b", 200),
            min_len=getattr(args, "min_len", 1),
            normalize_scores=(not getattr(args, "unnormalized", False)),
            len_penalty=getattr(args, "lenpen", 1),
            unk_penalty=getattr(args, "unkpen", 0),
            temperature=getattr(args, "temperature", 1.0),
            match_source_len=getattr(args, "match_source_len", False),
            no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
            search_strategy=search_strategy,
            **extra_gen_cls_kwargs,
        )

    def filter_indices_by_size(
        self, indices, dataset, max_positions=None, ignore_invalid_inputs=False
    ):
        return indices