|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import glob |
|
import logging |
|
import os |
|
from dataclasses import dataclass |
|
|
|
import hydra |
|
from hydra.core.config_store import ConfigStore |
|
from joblib import Parallel, delayed |
|
from omegaconf import MISSING |
|
|
|
try: |
|
from wds2idx import IndexCreator |
|
|
|
INDEX_CREATOR_AVAILABLE = True |
|
except (ImportError, ModuleNotFoundError): |
|
INDEX_CREATOR_AVAILABLE = False |
|
|
|
""" |
|
python create_dali_tarred_dataset_index.py \ |
|
tar_dir=<path to the directory which contains tarred dataset> \ |
|
workers=-1 |
|
|
|
""" |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
@dataclass |
|
class DALITarredIndexConfig: |
|
tar_dir: str = MISSING |
|
workers: int = -1 |
|
|
|
|
|
def process_index_path(tar_paths, index_dir): |
|
""" |
|
Appends the folder `{index_dir}` to the filepath of all tarfiles. |
|
Example: |
|
/X/Y/Z/audio_0.tar -> /X/Y/Z/{index_dir}/audio_0.index |
|
""" |
|
index_paths = [] |
|
for path in tar_paths: |
|
basepath, filename = os.path.split(path) |
|
path = filename.replace('.tar', '.index') |
|
path = os.path.join(basepath, path) |
|
base, name = os.path.split(path) |
|
index_path = os.path.join(index_dir, name) |
|
index_paths.append(index_path) |
|
|
|
return index_paths |
|
|
|
|
|
def build_index(tarpath, indexfile): |
|
with IndexCreator(tarpath, indexfile) as index: |
|
index.create_index() |
|
|
|
|
|
@hydra.main(config_path=None, config_name='index_config') |
|
def main(cfg: DALITarredIndexConfig): |
|
if not INDEX_CREATOR_AVAILABLE: |
|
logging.error("`wds2idx` is not installed. Please install NVIDIA DALI >= 1.11") |
|
exit(1) |
|
|
|
tar_files = list(glob.glob(os.path.join(cfg.tar_dir, "*.tar"))) |
|
|
|
index_dir = os.path.join(cfg.tar_dir, "dali_index") |
|
if not os.path.exists(index_dir): |
|
os.makedirs(index_dir, exist_ok=True) |
|
|
|
index_paths = process_index_path(tar_files, index_dir) |
|
|
|
with Parallel(n_jobs=cfg.workers, verbose=len(tar_files)) as parallel: |
|
_ = parallel(delayed(build_index)(tarpath, indexfile) for tarpath, indexfile in zip(tar_files, index_paths)) |
|
|
|
logging.info("Finished constructing index files !") |
|
|
|
|
|
ConfigStore.instance().store(name='index_config', node=DALITarredIndexConfig) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|