|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script creates data splits of the Google Text Normalization dataset |
|
of the format mentioned in the `text_normalization doc <https://github.com/NVIDIA/NeMo/blob/main/docs/source/nlp/text_normalization.rst>`. |
|
|
|
USAGE Example: |
|
1. Download the Google TN dataset from https://www.kaggle.com/google-nlu/text-normalization |
|
2. Unzip the English subset (e.g., by running `tar zxvf en_with_types.tgz`). Then there will a folder named `en_with_types`. |
|
3. Run this script |
|
# python data_split.py \ |
|
--data_dir=en_with_types/ \ |
|
--output_dir=data_split/ \ |
|
--lang=en |
|
|
|
In this example, the split files will be stored in the `data_split` folder. |
|
The folder should contain three subfolders `train`, 'dev', and `test` with `.tsv` files. |
|
""" |
|
|
|
from argparse import ArgumentParser |
|
from os import listdir, mkdir |
|
from os.path import isdir, isfile, join |
|
|
|
from tqdm import tqdm |
|
|
|
from nemo.collections.nlp.data.text_normalization import constants |
|
|
|
|
|
TEST_SIZE_EN = 100002 |
|
TEST_SIZE_RUS = 100007 |
|
|
|
|
|
def read_google_data(data_file: str, lang: str, split: str, add_test_full=False): |
|
""" |
|
The function can be used to read the raw data files of the Google Text Normalization |
|
dataset (which can be downloaded from https://www.kaggle.com/google-nlu/text-normalization) |
|
|
|
Args: |
|
data_file: Path to the data file. Should be of the form output-xxxxx-of-00100 |
|
lang: Selected language. |
|
split: data split |
|
add_test_full: do not truncate test data i.e. take the whole test file not #num of lines |
|
Return: |
|
data: list of examples |
|
""" |
|
data = [] |
|
cur_classes, cur_tokens, cur_outputs = [], [], [] |
|
with open(data_file, 'r', encoding='utf-8') as f: |
|
for linectx, line in tqdm(enumerate(f)): |
|
es = line.strip().split('\t') |
|
if split == "test" and not add_test_full: |
|
|
|
|
|
|
|
if lang == constants.ENGLISH and linectx == TEST_SIZE_EN: |
|
break |
|
if lang == constants.RUSSIAN and linectx == TEST_SIZE_RUS: |
|
break |
|
if len(es) == 2 and es[0] == '<eos>': |
|
data.append((cur_classes, cur_tokens, cur_outputs)) |
|
|
|
cur_classes, cur_tokens, cur_outputs = [], [], [] |
|
continue |
|
|
|
|
|
if lang == constants.RUSSIAN: |
|
es[2] = es[2].replace('_trans', '') |
|
|
|
assert len(es) == 3 |
|
cur_classes.append(es[0]) |
|
cur_tokens.append(es[1]) |
|
cur_outputs.append(es[2]) |
|
return data |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = ArgumentParser(description='Preprocess Google text normalization dataset') |
|
parser.add_argument('--data_dir', type=str, required=True, help='Path to folder with data') |
|
parser.add_argument('--output_dir', type=str, default='preprocessed', help='Path to folder with preprocessed data') |
|
parser.add_argument( |
|
'--lang', type=str, default=constants.ENGLISH, choices=constants.SUPPORTED_LANGS, help='Language' |
|
) |
|
parser.add_argument( |
|
'--add_test_full', |
|
action='store_true', |
|
help='If True, additional folder test_full will be created without truncation of files', |
|
) |
|
args = parser.parse_args() |
|
|
|
|
|
if not isdir(args.output_dir): |
|
mkdir(args.output_dir) |
|
mkdir(args.output_dir + '/train') |
|
mkdir(args.output_dir + '/dev') |
|
mkdir(args.output_dir + '/test') |
|
if args.add_test_full: |
|
mkdir(args.output_dir + '/test_full') |
|
|
|
for fn in sorted(listdir(args.data_dir))[::-1]: |
|
fp = join(args.data_dir, fn) |
|
if not isfile(fp): |
|
continue |
|
if not fn.startswith('output'): |
|
continue |
|
|
|
|
|
split_nb = int(fn.split('-')[1]) |
|
if split_nb < 90: |
|
cur_split = "train" |
|
elif split_nb < 95: |
|
cur_split = "dev" |
|
elif split_nb == 99: |
|
cur_split = "test" |
|
data = read_google_data(data_file=fp, lang=args.lang, split=cur_split) |
|
|
|
output_file = join(args.output_dir, f'{cur_split}', f'{fn}.tsv') |
|
print(fp) |
|
print(output_file) |
|
output_f = open(output_file, 'w', encoding='utf-8') |
|
for inst in data: |
|
cur_classes, cur_tokens, cur_outputs = inst |
|
for c, t, o in zip(cur_classes, cur_tokens, cur_outputs): |
|
output_f.write(f'{c}\t{t}\t{o}\n') |
|
output_f.write('<eos>\t<eos>\n') |
|
|
|
print(f'{cur_split}_sentences: {len(data)}') |
|
|
|
|
|
if cur_split == "test" and args.add_test_full: |
|
data = read_google_data(data_file=fp, lang=args.lang, split=cur_split, add_test_full=True) |
|
|
|
output_file = join(args.output_dir, 'test_full', f'{fn}.tsv') |
|
output_f = open(output_file, 'w', encoding='utf-8') |
|
for inst in data: |
|
cur_classes, cur_tokens, cur_outputs = inst |
|
for c, t, o in zip(cur_classes, cur_tokens, cur_outputs): |
|
output_f.write(f'{c}\t{t}\t{o}\n') |
|
output_f.write('<eos>\t<eos>\n') |
|
|
|
print(f'{cur_split}_sentences: {len(data)}') |
|
|