# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # USAGE: python get_data.py --data-root= --data-set= --num-workers= # where can be: dev_clean, dev_other, test_clean, # test_other, train_clean_100, train_clean_360, train_other_500 or ALL # You can also put more than one data_set comma-separated: # --data-set=dev_clean,train_clean_100 import argparse import fnmatch import functools import json import multiprocessing import os import subprocess import tarfile import urllib.request from pathlib import Path from tqdm import tqdm parser = argparse.ArgumentParser(description='Download LibriTTS and create manifests') parser.add_argument("--data-root", required=True, type=Path) parser.add_argument("--data-sets", default="dev_clean", type=str) parser.add_argument("--num-workers", default=4, type=int) args = parser.parse_args() URLS = { 'TRAIN_CLEAN_100': "https://www.openslr.org/resources/60/train-clean-100.tar.gz", 'TRAIN_CLEAN_360': "https://www.openslr.org/resources/60/train-clean-360.tar.gz", 'TRAIN_OTHER_500': "https://www.openslr.org/resources/60/train-other-500.tar.gz", 'DEV_CLEAN': "https://www.openslr.org/resources/60/dev-clean.tar.gz", 'DEV_OTHER': "https://www.openslr.org/resources/60/dev-other.tar.gz", 'TEST_CLEAN': "https://www.openslr.org/resources/60/test-clean.tar.gz", 'TEST_OTHER': "https://www.openslr.org/resources/60/test-other.tar.gz", } def __maybe_download_file(source_url, destination_path): if not destination_path.exists(): tmp_file_path = destination_path.with_suffix('.tmp') urllib.request.urlretrieve(source_url, filename=str(tmp_file_path)) tmp_file_path.rename(destination_path) def __extract_file(filepath, data_dir): try: tar = tarfile.open(filepath) tar.extractall(data_dir) tar.close() except Exception: print(f"Error while extracting {filepath}. Already extracted?") def __process_transcript(file_path: str): entries = [] with open(file_path, encoding="utf-8") as fin: text = fin.readlines()[0].strip() # TODO(oktai15): add normalized text via Normalizer/NormalizerWithAudio wav_file = file_path.replace(".normalized.txt", ".wav") speaker_id = file_path.split('/')[-3] assert os.path.exists(wav_file), f"{wav_file} not found!" duration = subprocess.check_output(f"soxi -D {wav_file}", shell=True) entry = { 'audio_filepath': os.path.abspath(wav_file), 'duration': float(duration), 'text': text, 'speaker': int(speaker_id), } entries.append(entry) return entries def __process_data(data_folder, manifest_file, num_workers): files = [] entries = [] for root, dirnames, filenames in os.walk(data_folder): # we will use normalized text provided by the original dataset for filename in fnmatch.filter(filenames, '*.normalized.txt'): files.append(os.path.join(root, filename)) with multiprocessing.Pool(num_workers) as p: processing_func = functools.partial(__process_transcript) results = p.imap(processing_func, files) for result in tqdm(results, total=len(files)): entries.extend(result) with open(manifest_file, 'w') as fout: for m in entries: fout.write(json.dumps(m) + '\n') def main(): data_root = args.data_root data_sets = args.data_sets num_workers = args.num_workers if data_sets == "ALL": data_sets = "dev_clean,dev_other,train_clean_100,train_clean_360,train_other_500,test_clean,test_other" if data_sets == "mini": data_sets = "dev_clean,train_clean_100" for data_set in data_sets.split(','): filepath = data_root / f"{data_set}.tar.gz" print(f"Downloading data for {data_set}...") __maybe_download_file(URLS[data_set.upper()], filepath) print("Extracting...") __extract_file(str(filepath), str(data_root)) print("Processing and building manifest.") __process_data( str(data_root / "LibriTTS" / data_set.replace("_", "-")), str(data_root / "LibriTTS" / f"{data_set}.json"), num_workers=num_workers, ) if __name__ == "__main__": main()