camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# USAGE: python get_data.py --data-root=<where to put data> --data-set=<datasets_to_download> --num-workers=<number of parallel workers>
# where <datasets_to_download> can be: dev_clean, dev_other, test_clean,
# test_other, train_clean_100, train_clean_360, train_other_500 or ALL
# You can also put more than one data_set comma-separated:
# --data-set=dev_clean,train_clean_100
import argparse
import fnmatch
import functools
import json
import multiprocessing
import os
import subprocess
import tarfile
import urllib.request
from pathlib import Path
from tqdm import tqdm
parser = argparse.ArgumentParser(description='Download LibriTTS and create manifests')
parser.add_argument("--data-root", required=True, type=Path)
parser.add_argument("--data-sets", default="dev_clean", type=str)
parser.add_argument("--num-workers", default=4, type=int)
args = parser.parse_args()
URLS = {
'TRAIN_CLEAN_100': "https://www.openslr.org/resources/60/train-clean-100.tar.gz",
'TRAIN_CLEAN_360': "https://www.openslr.org/resources/60/train-clean-360.tar.gz",
'TRAIN_OTHER_500': "https://www.openslr.org/resources/60/train-other-500.tar.gz",
'DEV_CLEAN': "https://www.openslr.org/resources/60/dev-clean.tar.gz",
'DEV_OTHER': "https://www.openslr.org/resources/60/dev-other.tar.gz",
'TEST_CLEAN': "https://www.openslr.org/resources/60/test-clean.tar.gz",
'TEST_OTHER': "https://www.openslr.org/resources/60/test-other.tar.gz",
}
def __maybe_download_file(source_url, destination_path):
if not destination_path.exists():
tmp_file_path = destination_path.with_suffix('.tmp')
urllib.request.urlretrieve(source_url, filename=str(tmp_file_path))
tmp_file_path.rename(destination_path)
def __extract_file(filepath, data_dir):
try:
tar = tarfile.open(filepath)
tar.extractall(data_dir)
tar.close()
except Exception:
print(f"Error while extracting {filepath}. Already extracted?")
def __process_transcript(file_path: str):
entries = []
with open(file_path, encoding="utf-8") as fin:
text = fin.readlines()[0].strip()
# TODO(oktai15): add normalized text via Normalizer/NormalizerWithAudio
wav_file = file_path.replace(".normalized.txt", ".wav")
speaker_id = file_path.split('/')[-3]
assert os.path.exists(wav_file), f"{wav_file} not found!"
duration = subprocess.check_output(f"soxi -D {wav_file}", shell=True)
entry = {
'audio_filepath': os.path.abspath(wav_file),
'duration': float(duration),
'text': text,
'speaker': int(speaker_id),
}
entries.append(entry)
return entries
def __process_data(data_folder, manifest_file, num_workers):
files = []
entries = []
for root, dirnames, filenames in os.walk(data_folder):
# we will use normalized text provided by the original dataset
for filename in fnmatch.filter(filenames, '*.normalized.txt'):
files.append(os.path.join(root, filename))
with multiprocessing.Pool(num_workers) as p:
processing_func = functools.partial(__process_transcript)
results = p.imap(processing_func, files)
for result in tqdm(results, total=len(files)):
entries.extend(result)
with open(manifest_file, 'w') as fout:
for m in entries:
fout.write(json.dumps(m) + '\n')
def main():
data_root = args.data_root
data_sets = args.data_sets
num_workers = args.num_workers
if data_sets == "ALL":
data_sets = "dev_clean,dev_other,train_clean_100,train_clean_360,train_other_500,test_clean,test_other"
if data_sets == "mini":
data_sets = "dev_clean,train_clean_100"
for data_set in data_sets.split(','):
filepath = data_root / f"{data_set}.tar.gz"
print(f"Downloading data for {data_set}...")
__maybe_download_file(URLS[data_set.upper()], filepath)
print("Extracting...")
__extract_file(str(filepath), str(data_root))
print("Processing and building manifest.")
__process_data(
str(data_root / "LibriTTS" / data_set.replace("_", "-")),
str(data_root / "LibriTTS" / f"{data_set}.json"),
num_workers=num_workers,
)
if __name__ == "__main__":
main()