File size: 6,348 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script creates data splits of the Google Text Normalization dataset
of the format mentioned in the `text_normalization doc <https://github.com/NVIDIA/NeMo/blob/main/docs/source/nlp/text_normalization.rst>`.

USAGE Example:
1. Download the Google TN dataset from https://www.kaggle.com/google-nlu/text-normalization
2. Unzip the English subset (e.g., by running `tar zxvf  en_with_types.tgz`). Then there will a folder named `en_with_types`.
3. Run this script
# python data_split.py       \
        --data_dir=en_with_types/           \
        --output_dir=data_split/          \
        --lang=en

In this example, the split files will be stored in the `data_split` folder.
The folder should contain three subfolders `train`, 'dev', and `test` with `.tsv` files.
"""

from argparse import ArgumentParser
from os import listdir, mkdir
from os.path import isdir, isfile, join

from tqdm import tqdm

from nemo.collections.nlp.data.text_normalization import constants

# Local Constants
TEST_SIZE_EN = 100002
TEST_SIZE_RUS = 100007


def read_google_data(data_file: str, lang: str, split: str, add_test_full=False):
    """
    The function can be used to read the raw data files of the Google Text Normalization
    dataset (which can be downloaded from https://www.kaggle.com/google-nlu/text-normalization)

    Args:
        data_file: Path to the data file. Should be of the form output-xxxxx-of-00100
        lang: Selected language.
        split: data split
        add_test_full: do not truncate test data i.e. take the whole test file not #num of lines
    Return:
        data: list of examples
    """
    data = []
    cur_classes, cur_tokens, cur_outputs = [], [], []
    with open(data_file, 'r', encoding='utf-8') as f:
        for linectx, line in tqdm(enumerate(f)):
            es = line.strip().split('\t')
            if split == "test" and not add_test_full:
                # For the results reported in the paper "RNN Approaches to Text Normalization: A Challenge":
                # + For English, the first 100,002 lines of output-00099-of-00100 are used for the test set
                # + For Russian, the first 100,007 lines of output-00099-of-00100 are used for the test set
                if lang == constants.ENGLISH and linectx == TEST_SIZE_EN:
                    break
                if lang == constants.RUSSIAN and linectx == TEST_SIZE_RUS:
                    break
            if len(es) == 2 and es[0] == '<eos>':
                data.append((cur_classes, cur_tokens, cur_outputs))
                # Reset
                cur_classes, cur_tokens, cur_outputs = [], [], []
                continue

            # Remove _trans (for Russian)
            if lang == constants.RUSSIAN:
                es[2] = es[2].replace('_trans', '')
            # Update the current example
            assert len(es) == 3
            cur_classes.append(es[0])
            cur_tokens.append(es[1])
            cur_outputs.append(es[2])
    return data


if __name__ == '__main__':
    parser = ArgumentParser(description='Preprocess Google text normalization dataset')
    parser.add_argument('--data_dir', type=str, required=True, help='Path to folder with data')
    parser.add_argument('--output_dir', type=str, default='preprocessed', help='Path to folder with preprocessed data')
    parser.add_argument(
        '--lang', type=str, default=constants.ENGLISH, choices=constants.SUPPORTED_LANGS, help='Language'
    )
    parser.add_argument(
        '--add_test_full',
        action='store_true',
        help='If True, additional folder test_full will be created without truncation of files',
    )
    args = parser.parse_args()

    # Create the output dir (if not exist)
    if not isdir(args.output_dir):
        mkdir(args.output_dir)
        mkdir(args.output_dir + '/train')
        mkdir(args.output_dir + '/dev')
        mkdir(args.output_dir + '/test')
        if args.add_test_full:
            mkdir(args.output_dir + '/test_full')

    for fn in sorted(listdir(args.data_dir))[::-1]:
        fp = join(args.data_dir, fn)
        if not isfile(fp):
            continue
        if not fn.startswith('output'):
            continue

        # Determine the current split
        split_nb = int(fn.split('-')[1])
        if split_nb < 90:
            cur_split = "train"
        elif split_nb < 95:
            cur_split = "dev"
        elif split_nb == 99:
            cur_split = "test"
        data = read_google_data(data_file=fp, lang=args.lang, split=cur_split)
        # write out
        output_file = join(args.output_dir, f'{cur_split}', f'{fn}.tsv')
        print(fp)
        print(output_file)
        output_f = open(output_file, 'w', encoding='utf-8')
        for inst in data:
            cur_classes, cur_tokens, cur_outputs = inst
            for c, t, o in zip(cur_classes, cur_tokens, cur_outputs):
                output_f.write(f'{c}\t{t}\t{o}\n')
            output_f.write('<eos>\t<eos>\n')

        print(f'{cur_split}_sentences: {len(data)}')

        # additionally generate full test files if needed
        if cur_split == "test" and args.add_test_full:
            data = read_google_data(data_file=fp, lang=args.lang, split=cur_split, add_test_full=True)
            # write out
            output_file = join(args.output_dir, 'test_full', f'{fn}.tsv')
            output_f = open(output_file, 'w', encoding='utf-8')
            for inst in data:
                cur_classes, cur_tokens, cur_outputs = inst
                for c, t, o in zip(cur_classes, cur_tokens, cur_outputs):
                    output_f.write(f'{c}\t{t}\t{o}\n')
                output_f.write('<eos>\t<eos>\n')

            print(f'{cur_split}_sentences: {len(data)}')