|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
# This script converts an existing audio dataset with a manifest to |
|
# a tarred and sharded audio dataset that can be read by the |
|
# TarredAudioToTextDataLayer. |
|
|
|
# Please make sure your audio_filepath DOES NOT CONTAIN '-sub'! |
|
# Because we will use it to handle files which have duplicate filenames but with different offsets |
|
# (see function create_shard for details) |
|
|
|
|
|
# Bucketing can help to improve the training speed. You may use --buckets_num to specify the number of buckets. |
|
# It creates multiple tarred datasets, one per bucket, based on the audio durations. |
|
# The range of [min_duration, max_duration) is split into equal sized buckets. |
|
# Recommend to use --sort_in_shards to speedup the training by reducing the paddings in the batches |
|
# More info on how to use bucketing feature: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/datasets.html |
|
|
|
# If valid NVIDIA DALI version is installed, will also generate the corresponding DALI index files that need to be |
|
# supplied to the config in order to utilize webdataset for efficient large dataset handling. |
|
# NOTE: DALI + Webdataset is NOT compatible with Bucketing support ! |
|
|
|
# Usage: |
|
1) Creating a new tarfile dataset |
|
|
|
python convert_to_tarred_audio_dataset.py \ |
|
--manifest_path=<path to the manifest file> \ |
|
--target_dir=<path to output directory> \ |
|
--num_shards=<number of tarfiles that will contain the audio> \ |
|
--max_duration=<float representing maximum duration of audio samples> \ |
|
--min_duration=<float representing minimum duration of audio samples> \ |
|
--shuffle --shuffle_seed=1 \ |
|
--sort_in_shards \ |
|
--workers=-1 |
|
|
|
|
|
2) Concatenating more tarfiles to a pre-existing tarred dataset |
|
|
|
python convert_to_tarred_audio_dataset.py \ |
|
--manifest_path=<path to the tarred manifest file> \ |
|
--metadata_path=<path to the metadata.yaml (or metadata_version_{X}.yaml) file> \ |
|
--target_dir=<path to output directory where the original tarfiles are contained> \ |
|
--max_duration=<float representing maximum duration of audio samples> \ |
|
--min_duration=<float representing minimum duration of audio samples> \ |
|
--shuffle --shuffle_seed=1 \ |
|
--sort_in_shards \ |
|
--workers=-1 \ |
|
--concat_manifest_paths \ |
|
<space separated paths to 1 or more manifest files to concatenate into the original tarred dataset> |
|
|
|
3) Writing an empty metadata file |
|
|
|
python convert_to_tarred_audio_dataset.py \ |
|
--target_dir=<path to output directory> \ |
|
# any other optional argument |
|
--num_shards=8 \ |
|
--max_duration=16.7 \ |
|
--min_duration=0.01 \ |
|
--shuffle \ |
|
--workers=-1 \ |
|
--sort_in_shards \ |
|
--shuffle_seed=1 \ |
|
--write_metadata |
|
|
|
""" |
|
import argparse |
|
import copy |
|
import json |
|
import os |
|
import random |
|
import tarfile |
|
from collections import defaultdict |
|
from dataclasses import dataclass, field |
|
from datetime import datetime |
|
from typing import Any, List, Optional |
|
|
|
from joblib import Parallel, delayed |
|
from omegaconf import DictConfig, OmegaConf, open_dict |
|
|
|
try: |
|
import create_dali_tarred_dataset_index as dali_index |
|
|
|
DALI_INDEX_SCRIPT_AVAILABLE = True |
|
except (ImportError, ModuleNotFoundError, FileNotFoundError): |
|
DALI_INDEX_SCRIPT_AVAILABLE = False |
|
|
|
parser = argparse.ArgumentParser( |
|
description="Convert an existing ASR dataset to tarballs compatible with TarredAudioToTextDataLayer." |
|
) |
|
parser.add_argument( |
|
"--manifest_path", default=None, type=str, required=False, help="Path to the existing dataset's manifest." |
|
) |
|
|
|
parser.add_argument( |
|
'--concat_manifest_paths', |
|
nargs='+', |
|
default=None, |
|
type=str, |
|
required=False, |
|
help="Path to the additional dataset's manifests that will be concatenated with base dataset.", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--target_dir", |
|
default='./tarred', |
|
type=str, |
|
help="Target directory for resulting tarballs and manifest. Defaults to `./tarred`. Creates the path if necessary.", |
|
) |
|
|
|
parser.add_argument( |
|
"--metadata_path", required=False, default=None, type=str, help="Path to metadata file for the dataset.", |
|
) |
|
|
|
parser.add_argument( |
|
"--num_shards", |
|
default=-1, |
|
type=int, |
|
help="Number of shards (tarballs) to create. Used for partitioning data among workers.", |
|
) |
|
parser.add_argument( |
|
'--max_duration', |
|
default=None, |
|
required=True, |
|
type=float, |
|
help='Maximum duration of audio clip in the dataset. By default, it is None and is required to be set.', |
|
) |
|
parser.add_argument( |
|
'--min_duration', |
|
default=None, |
|
type=float, |
|
help='Minimum duration of audio clip in the dataset. By default, it is None and will not filter files.', |
|
) |
|
parser.add_argument( |
|
"--shuffle", |
|
action='store_true', |
|
help="Whether or not to randomly shuffle the samples in the manifest before tarring/sharding.", |
|
) |
|
|
|
parser.add_argument( |
|
"--keep_files_together", |
|
action='store_true', |
|
help="Whether or not to keep entries from the same file (but different offsets) together when sorting before tarring/sharding.", |
|
) |
|
|
|
parser.add_argument( |
|
"--sort_in_shards", |
|
action='store_true', |
|
help="Whether or not to sort samples inside the shards based on their duration.", |
|
) |
|
|
|
parser.add_argument( |
|
"--buckets_num", type=int, default=1, help="Number of buckets to create based on duration.", |
|
) |
|
|
|
parser.add_argument("--shuffle_seed", type=int, default=None, help="Random seed for use if shuffling is enabled.") |
|
parser.add_argument( |
|
'--write_metadata', |
|
action='store_true', |
|
help=( |
|
"Flag to write a blank metadata with the current call config. " |
|
"Note that the metadata will not contain the number of shards, " |
|
"and it must be filled out by the user." |
|
), |
|
) |
|
parser.add_argument('--workers', type=int, default=1, help='Number of worker processes') |
|
args = parser.parse_args() |
|
|
|
|
|
@dataclass |
|
class ASRTarredDatasetConfig: |
|
num_shards: int = -1 |
|
shuffle: bool = False |
|
max_duration: Optional[float] = None |
|
min_duration: Optional[float] = None |
|
shuffle_seed: Optional[int] = None |
|
sort_in_shards: bool = True |
|
keep_files_together: bool = False |
|
|
|
|
|
@dataclass |
|
class ASRTarredDatasetMetadata: |
|
created_datetime: Optional[str] = None |
|
version: int = 0 |
|
num_samples_per_shard: Optional[int] = None |
|
is_concatenated_manifest: bool = False |
|
|
|
dataset_config: Optional[ASRTarredDatasetConfig] = ASRTarredDatasetConfig() |
|
history: Optional[List[Any]] = field(default_factory=lambda: []) |
|
|
|
def __post_init__(self): |
|
self.created_datetime = self.get_current_datetime() |
|
|
|
def get_current_datetime(self): |
|
return datetime.now().strftime("%m-%d-%Y %H-%M-%S") |
|
|
|
@classmethod |
|
def from_config(cls, config: DictConfig): |
|
obj = cls() |
|
obj.__dict__.update(**config) |
|
return obj |
|
|
|
@classmethod |
|
def from_file(cls, filepath: str): |
|
config = OmegaConf.load(filepath) |
|
return ASRTarredDatasetMetadata.from_config(config=config) |
|
|
|
|
|
class ASRTarredDatasetBuilder: |
|
""" |
|
Helper class that constructs a tarred dataset from scratch, or concatenates tarred datasets |
|
together and constructs manifests for them. |
|
""" |
|
|
|
def __init__(self): |
|
self.config = None |
|
|
|
def configure(self, config: ASRTarredDatasetConfig): |
|
""" |
|
Sets the config generated from command line overrides. |
|
|
|
Args: |
|
config: ASRTarredDatasetConfig dataclass object. |
|
""" |
|
self.config = config |
|
|
|
if self.config.num_shards < 0: |
|
raise ValueError("`num_shards` must be > 0. Please fill in the metadata information correctly.") |
|
|
|
def create_new_dataset(self, manifest_path: str, target_dir: str = "./tarred/", num_workers: int = 0): |
|
""" |
|
Creates a new tarred dataset from a given manifest file. |
|
|
|
Args: |
|
manifest_path: Path to the original ASR manifest. |
|
target_dir: Output directory. |
|
num_workers: Integer denoting number of parallel worker processes which will write tarfiles. |
|
Defaults to 1 - which denotes sequential worker process. |
|
|
|
Output: |
|
Writes tarfiles, along with the tarred dataset compatible manifest file. |
|
Also preserves a record of the metadata used to construct this tarred dataset. |
|
""" |
|
if self.config is None: |
|
raise ValueError("Config has not been set. Please call `configure(config: ASRTarredDatasetConfig)`") |
|
|
|
if manifest_path is None: |
|
raise FileNotFoundError("Manifest filepath cannot be None !") |
|
|
|
config = self.config |
|
|
|
if not os.path.exists(target_dir): |
|
os.makedirs(target_dir) |
|
|
|
|
|
entries, total_duration, filtered_entries, filtered_duration = self._read_manifest(manifest_path, config) |
|
|
|
if len(filtered_entries) > 0: |
|
print(f"Filtered {len(filtered_entries)} files which amounts to {filtered_duration} seconds of audio.") |
|
print( |
|
f"After filtering, manifest has {len(entries)} files which amounts to {total_duration} seconds of audio." |
|
) |
|
|
|
if len(entries) == 0: |
|
print("No tarred dataset was created as there were 0 valid samples after filtering!") |
|
return |
|
if config.shuffle: |
|
random.seed(config.shuffle_seed) |
|
print("Shuffling...") |
|
if config.keep_files_together: |
|
filename_entries = defaultdict(list) |
|
for ent in entries: |
|
filename_entries[ent["audio_filepath"]].append(ent) |
|
filenames = list(filename_entries.keys()) |
|
random.shuffle(filenames) |
|
shuffled_entries = [] |
|
for filename in filenames: |
|
shuffled_entries += filename_entries[filename] |
|
entries = shuffled_entries |
|
else: |
|
random.shuffle(entries) |
|
|
|
|
|
print(f"Number of samples added : {len(entries)}") |
|
print(f"Remainder: {len(entries) % config.num_shards}") |
|
|
|
start_indices = [] |
|
end_indices = [] |
|
|
|
for i in range(config.num_shards): |
|
start_idx = (len(entries) // config.num_shards) * i |
|
end_idx = start_idx + (len(entries) // config.num_shards) |
|
print(f"Shard {i} has entries {start_idx} ~ {end_idx}") |
|
files = set() |
|
for ent_id in range(start_idx, end_idx): |
|
files.add(entries[ent_id]["audio_filepath"]) |
|
print(f"Shard {i} contains {len(files)} files") |
|
if i == config.num_shards - 1: |
|
|
|
print(f"Have {len(entries) - end_idx} entries left over that will be discarded.") |
|
|
|
start_indices.append(start_idx) |
|
end_indices.append(end_idx) |
|
|
|
manifest_folder, _ = os.path.split(manifest_path) |
|
|
|
with Parallel(n_jobs=num_workers, verbose=config.num_shards) as parallel: |
|
|
|
new_entries_list = parallel( |
|
delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, i, manifest_folder) |
|
for i, (start_idx, end_idx) in enumerate(zip(start_indices, end_indices)) |
|
) |
|
|
|
|
|
new_entries = [sample for manifest in new_entries_list for sample in manifest] |
|
del new_entries_list |
|
|
|
print("Total number of entries in manifest :", len(new_entries)) |
|
|
|
|
|
new_manifest_path = os.path.join(target_dir, 'tarred_audio_manifest.json') |
|
with open(new_manifest_path, 'w') as m2: |
|
for entry in new_entries: |
|
json.dump(entry, m2) |
|
m2.write('\n') |
|
|
|
|
|
new_metadata_path = os.path.join(target_dir, 'metadata.yaml') |
|
metadata = ASRTarredDatasetMetadata() |
|
|
|
|
|
metadata.dataset_config = config |
|
metadata.num_samples_per_shard = len(new_entries) // config.num_shards |
|
|
|
|
|
metadata_yaml = OmegaConf.structured(metadata) |
|
OmegaConf.save(metadata_yaml, new_metadata_path, resolve=True) |
|
|
|
def create_concatenated_dataset( |
|
self, |
|
base_manifest_path: str, |
|
manifest_paths: List[str], |
|
metadata: ASRTarredDatasetMetadata, |
|
target_dir: str = "./tarred_concatenated/", |
|
num_workers: int = 1, |
|
): |
|
""" |
|
Creates new tarfiles in order to create a concatenated dataset, whose manifest contains the data for |
|
both the original dataset as well as the new data submitted in manifest paths. |
|
|
|
Args: |
|
base_manifest_path: Path to the manifest file which contains the information for the original |
|
tarred dataset (with flattened paths). |
|
manifest_paths: List of one or more paths to manifest files that will be concatenated with above |
|
base tarred dataset. |
|
metadata: ASRTarredDatasetMetadata dataclass instance with overrides from command line. |
|
target_dir: Output directory |
|
|
|
Output: |
|
Writes tarfiles which with indices mapping to a "concatenated" tarred dataset, |
|
along with the tarred dataset compatible manifest file which includes information |
|
about all the datasets that comprise the concatenated dataset. |
|
|
|
Also preserves a record of the metadata used to construct this tarred dataset. |
|
""" |
|
if not os.path.exists(target_dir): |
|
os.makedirs(target_dir) |
|
|
|
if base_manifest_path is None: |
|
raise FileNotFoundError("Base manifest filepath cannot be None !") |
|
|
|
if manifest_paths is None or len(manifest_paths) == 0: |
|
raise FileNotFoundError("List of additional manifest filepaths cannot be None !") |
|
|
|
config = ASRTarredDatasetConfig(**(metadata.dataset_config)) |
|
|
|
|
|
base_entries, _, _, _ = self._read_manifest(base_manifest_path, config) |
|
print(f"Read base manifest containing {len(base_entries)} samples.") |
|
|
|
|
|
if metadata.num_samples_per_shard is None: |
|
num_samples_per_shard = len(base_entries) // config.num_shards |
|
else: |
|
num_samples_per_shard = metadata.num_samples_per_shard |
|
|
|
print("Number of samples per shard :", num_samples_per_shard) |
|
|
|
|
|
print(f"Selected max duration : {config.max_duration}") |
|
print(f"Selected min duration : {config.min_duration}") |
|
|
|
entries = [] |
|
for new_manifest_idx in range(len(manifest_paths)): |
|
new_entries, total_duration, filtered_new_entries, filtered_duration = self._read_manifest( |
|
manifest_paths[new_manifest_idx], config |
|
) |
|
|
|
if len(filtered_new_entries) > 0: |
|
print( |
|
f"Filtered {len(filtered_new_entries)} files which amounts to {filtered_duration:0.2f}" |
|
f" seconds of audio from manifest {manifest_paths[new_manifest_idx]}." |
|
) |
|
print( |
|
f"After filtering, manifest has {len(entries)} files which amounts to {total_duration} seconds of audio." |
|
) |
|
|
|
entries.extend(new_entries) |
|
|
|
if len(entries) == 0: |
|
print("No tarred dataset was created as there were 0 valid samples after filtering!") |
|
return |
|
|
|
if config.shuffle: |
|
random.seed(config.shuffle_seed) |
|
print("Shuffling...") |
|
random.shuffle(entries) |
|
|
|
|
|
drop_count = len(entries) % num_samples_per_shard |
|
total_new_entries = len(entries) |
|
entries = entries[:-drop_count] |
|
|
|
print( |
|
f"Dropping {drop_count} samples from total new samples {total_new_entries} since they cannot " |
|
f"be added into a uniformly sized chunk." |
|
) |
|
|
|
|
|
num_added_shards = len(entries) // num_samples_per_shard |
|
|
|
print(f"Number of samples in base dataset : {len(base_entries)}") |
|
print(f"Number of samples in additional datasets : {len(entries)}") |
|
print(f"Number of added shards : {num_added_shards}") |
|
print(f"Remainder: {len(entries) % num_samples_per_shard}") |
|
|
|
start_indices = [] |
|
end_indices = [] |
|
shard_indices = [] |
|
for i in range(num_added_shards): |
|
start_idx = (len(entries) // num_added_shards) * i |
|
end_idx = start_idx + (len(entries) // num_added_shards) |
|
shard_idx = i + config.num_shards |
|
print(f"Shard {shard_idx} has entries {start_idx + len(base_entries)} ~ {end_idx + len(base_entries)}") |
|
|
|
start_indices.append(start_idx) |
|
end_indices.append(end_idx) |
|
shard_indices.append(shard_idx) |
|
|
|
manifest_folder, _ = os.path.split(base_manifest_path) |
|
|
|
with Parallel(n_jobs=num_workers, verbose=num_added_shards) as parallel: |
|
|
|
new_entries_list = parallel( |
|
delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, shard_idx, manifest_folder) |
|
for i, (start_idx, end_idx, shard_idx) in enumerate(zip(start_indices, end_indices, shard_indices)) |
|
) |
|
|
|
|
|
new_entries = [sample for manifest in new_entries_list for sample in manifest] |
|
del new_entries_list |
|
|
|
|
|
if metadata is None: |
|
new_version = 1 |
|
else: |
|
new_version = metadata.version + 1 |
|
|
|
print("Total number of entries in manifest :", len(base_entries) + len(new_entries)) |
|
|
|
new_manifest_path = os.path.join(target_dir, f'tarred_audio_manifest_version_{new_version}.json') |
|
with open(new_manifest_path, 'w') as m2: |
|
|
|
for entry in base_entries: |
|
json.dump(entry, m2) |
|
m2.write('\n') |
|
|
|
|
|
for entry in new_entries: |
|
json.dump(entry, m2) |
|
m2.write('\n') |
|
|
|
|
|
base_metadata = metadata |
|
|
|
|
|
new_metadata_path = os.path.join(target_dir, f'metadata_version_{new_version}.yaml') |
|
metadata = ASRTarredDatasetMetadata() |
|
|
|
|
|
config.num_shards = config.num_shards + num_added_shards |
|
|
|
|
|
metadata.version = new_version |
|
metadata.dataset_config = config |
|
metadata.num_samples_per_shard = num_samples_per_shard |
|
metadata.is_concatenated_manifest = True |
|
metadata.created_datetime = metadata.get_current_datetime() |
|
|
|
|
|
current_metadata = OmegaConf.structured(base_metadata.history) |
|
metadata.history = current_metadata |
|
|
|
|
|
metadata_yaml = OmegaConf.structured(metadata) |
|
OmegaConf.save(metadata_yaml, new_metadata_path, resolve=True) |
|
|
|
def _read_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig): |
|
"""Read and filters data from the manifest""" |
|
|
|
entries = [] |
|
total_duration = 0.0 |
|
filtered_entries = [] |
|
filtered_duration = 0.0 |
|
with open(manifest_path, 'r') as m: |
|
for line in m: |
|
entry = json.loads(line) |
|
if (config.max_duration is None or entry['duration'] < config.max_duration) and ( |
|
config.min_duration is None or entry['duration'] >= config.min_duration |
|
): |
|
entries.append(entry) |
|
total_duration += entry["duration"] |
|
else: |
|
filtered_entries.append(entry) |
|
filtered_duration += entry['duration'] |
|
|
|
return entries, total_duration, filtered_entries, filtered_duration |
|
|
|
def _create_shard(self, entries, target_dir, shard_id, manifest_folder): |
|
"""Creates a tarball containing the audio files from `entries`. |
|
""" |
|
if self.config.sort_in_shards: |
|
entries.sort(key=lambda x: x["duration"], reverse=False) |
|
|
|
new_entries = [] |
|
tar = tarfile.open(os.path.join(target_dir, f'audio_{shard_id}.tar'), mode='w', dereference=True) |
|
|
|
count = dict() |
|
for entry in entries: |
|
|
|
if os.path.exists(entry["audio_filepath"]): |
|
audio_filepath = entry["audio_filepath"] |
|
else: |
|
audio_filepath = os.path.join(manifest_folder, entry["audio_filepath"]) |
|
if not os.path.exists(audio_filepath): |
|
raise FileNotFoundError(f"Could not find {entry['audio_filepath']}!") |
|
|
|
base, ext = os.path.splitext(audio_filepath) |
|
base = base.replace('/', '_') |
|
|
|
base = base.replace('.', '_') |
|
squashed_filename = f'{base}{ext}' |
|
if squashed_filename not in count: |
|
tar.add(audio_filepath, arcname=squashed_filename) |
|
to_write = squashed_filename |
|
count[squashed_filename] = 1 |
|
else: |
|
to_write = base + "-sub" + str(count[squashed_filename]) + ext |
|
count[squashed_filename] += 1 |
|
|
|
new_entry = { |
|
'audio_filepath': to_write, |
|
'duration': entry['duration'], |
|
'shard_id': shard_id, |
|
} |
|
|
|
if 'label' in entry: |
|
new_entry['label'] = entry['label'] |
|
|
|
if 'text' in entry: |
|
new_entry['text'] = entry['text'] |
|
|
|
if 'offset' in entry: |
|
new_entry['offset'] = entry['offset'] |
|
|
|
if 'lang' in entry: |
|
new_entry['lang'] = entry['lang'] |
|
|
|
new_entries.append(new_entry) |
|
|
|
tar.close() |
|
return new_entries |
|
|
|
@classmethod |
|
def setup_history(cls, base_metadata: ASRTarredDatasetMetadata, history: List[Any]): |
|
if 'history' in base_metadata.keys(): |
|
for history_val in base_metadata.history: |
|
cls.setup_history(history_val, history) |
|
|
|
if base_metadata is not None: |
|
metadata_copy = copy.deepcopy(base_metadata) |
|
with open_dict(metadata_copy): |
|
metadata_copy.pop('history', None) |
|
history.append(metadata_copy) |
|
|
|
|
|
def main(): |
|
if args.buckets_num > 1: |
|
bucket_length = (args.max_duration - args.min_duration) / float(args.buckets_num) |
|
for i in range(args.buckets_num): |
|
min_duration = args.min_duration + i * bucket_length |
|
max_duration = min_duration + bucket_length |
|
if i == args.buckets_num - 1: |
|
|
|
max_duration += 1e-5 |
|
target_dir = os.path.join(args.target_dir, f"bucket{i+1}") |
|
print(f"Creating bucket {i+1} with min_duration={min_duration} and max_duration={max_duration} ...") |
|
print(f"Results are being saved at: {target_dir}.") |
|
create_tar_datasets(min_duration=min_duration, max_duration=max_duration, target_dir=target_dir) |
|
print(f"Bucket {i+1} is created.") |
|
else: |
|
create_tar_datasets(min_duration=args.min_duration, max_duration=args.max_duration, target_dir=args.target_dir) |
|
|
|
|
|
def create_tar_datasets(min_duration: float, max_duration: float, target_dir: str): |
|
builder = ASRTarredDatasetBuilder() |
|
|
|
if args.write_metadata: |
|
metadata = ASRTarredDatasetMetadata() |
|
dataset_cfg = ASRTarredDatasetConfig( |
|
num_shards=args.num_shards, |
|
shuffle=args.shuffle, |
|
max_duration=max_duration, |
|
min_duration=min_duration, |
|
shuffle_seed=args.shuffle_seed, |
|
sort_in_shards=args.sort_in_shards, |
|
keep_files_together=args.keep_files_together, |
|
) |
|
metadata.dataset_config = dataset_cfg |
|
|
|
output_path = os.path.join(target_dir, 'default_metadata.yaml') |
|
OmegaConf.save(metadata, output_path, resolve=True) |
|
print(f"Default metadata written to {output_path}") |
|
exit(0) |
|
|
|
if args.concat_manifest_paths is None or len(args.concat_manifest_paths) == 0: |
|
print("Creating new tarred dataset ...") |
|
|
|
|
|
config = ASRTarredDatasetConfig( |
|
num_shards=args.num_shards, |
|
shuffle=args.shuffle, |
|
max_duration=max_duration, |
|
min_duration=min_duration, |
|
shuffle_seed=args.shuffle_seed, |
|
sort_in_shards=args.sort_in_shards, |
|
keep_files_together=args.keep_files_together, |
|
) |
|
builder.configure(config) |
|
builder.create_new_dataset(manifest_path=args.manifest_path, target_dir=target_dir, num_workers=args.workers) |
|
|
|
else: |
|
if args.buckets_num > 1: |
|
raise ValueError("Concatenation feature does not support buckets_num > 1.") |
|
print("Concatenating multiple tarred datasets ...") |
|
|
|
|
|
if args.metadata_path is not None: |
|
metadata = ASRTarredDatasetMetadata.from_file(args.metadata_path) |
|
else: |
|
raise ValueError("`metadata` yaml file path must be provided!") |
|
|
|
|
|
history = [] |
|
builder.setup_history(OmegaConf.structured(metadata), history) |
|
metadata.history = history |
|
|
|
|
|
metadata.dataset_config.max_duration = max_duration |
|
metadata.dataset_config.min_duration = min_duration |
|
metadata.dataset_config.shuffle = args.shuffle |
|
metadata.dataset_config.shuffle_seed = args.shuffle_seed |
|
metadata.dataset_config.sort_in_shards = args.sort_in_shards |
|
|
|
builder.configure(metadata.dataset_config) |
|
|
|
|
|
builder.create_concatenated_dataset( |
|
base_manifest_path=args.manifest_path, |
|
manifest_paths=args.concat_manifest_paths, |
|
metadata=metadata, |
|
target_dir=target_dir, |
|
num_workers=args.workers, |
|
) |
|
|
|
if DALI_INDEX_SCRIPT_AVAILABLE and dali_index.INDEX_CREATOR_AVAILABLE: |
|
print("Constructing DALI Tarfile Index - ", target_dir) |
|
index_config = dali_index.DALITarredIndexConfig(tar_dir=target_dir, workers=args.workers) |
|
dali_index.main(index_config) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|