import os from typing import TYPE_CHECKING, List, Union from datasets import concatenate_datasets, interleave_datasets, load_dataset from llmtuner.dsets.utils import checksum, EXT2TYPE from llmtuner.extras.logging import get_logger if TYPE_CHECKING: from datasets import Dataset, IterableDataset from llmtuner.hparams import ModelArguments, DataArguments logger = get_logger(__name__) def get_dataset( model_args: "ModelArguments", data_args: "DataArguments" ) -> Union["Dataset", "IterableDataset"]: max_samples = data_args.max_samples all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets for dataset_attr in data_args.dataset_list: logger.info("Loading dataset {}...".format(dataset_attr)) if dataset_attr.load_from == "hf_hub": data_path = dataset_attr.dataset_name data_files = None elif dataset_attr.load_from == "script": data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) data_files = None elif dataset_attr.load_from == "file": data_path = None data_files: List[str] = [] if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) if data_path is None: data_path = EXT2TYPE.get(file_name.split(".")[-1], None) else: assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match." elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None) else: raise ValueError("File not found.") assert data_path, "File extension must be txt, csv, json or jsonl." checksum(data_files, dataset_attr.dataset_sha1) else: raise NotImplementedError dataset = load_dataset( data_path, data_files=data_files, split=data_args.split, cache_dir=model_args.cache_dir, streaming=data_args.streaming, use_auth_token=True if model_args.use_auth_token else None ) if max_samples is not None: max_samples_temp = min(len(dataset), max_samples) dataset = dataset.select(range(max_samples_temp)) for column_name in ["prompt", "query", "response", "history"]: # align datasets if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) if dataset_attr.system_prompt: # add system prompt if data_args.streaming: dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}) else: dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset)) all_datasets.append(dataset) if len(data_args.dataset_list) == 1: return all_datasets[0] elif data_args.mix_strategy == "concat": if data_args.streaming: logger.warning("The samples between different datasets will not be mixed in streaming mode.") return concatenate_datasets(all_datasets) elif data_args.mix_strategy.startswith("interleave"): if not data_args.streaming: logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy) else: raise ValueError("Unknown mixing strategy.")