|
import argparse |
|
import os |
|
import shutil |
|
from tqdm import tqdm |
|
import logging |
|
from src.utils.common import read_yaml, create_directories |
|
from datasets import load_dataset |
|
|
|
STAGE = "stage 01" |
|
|
|
logging.basicConfig( |
|
filename=os.path.join("logs", 'running_logs.log'), |
|
level=logging.INFO, |
|
format="[%(asctime)s: %(levelname)s: %(module)s]: %(message)s", |
|
filemode="a" |
|
) |
|
|
|
|
|
def main(config_path, params_path): |
|
|
|
config = read_yaml(config_path) |
|
params = read_yaml(params_path) |
|
artifacts = config["artifacts"] |
|
|
|
dataset = params["train"]["dataset_name"] |
|
cache_dir_ = artifacts["cache_dir"] |
|
create_directories([cache_dir_]) |
|
|
|
logging.info(f"load dataset") |
|
datasets = load_dataset(dataset, cache_dir=cache_dir_) |
|
logging.info(f"dataset saved in : {cache_dir_}") |
|
|
|
if __name__ == '__main__': |
|
args = argparse.ArgumentParser() |
|
args.add_argument("--config", "-c", default="configs/config.yaml") |
|
args.add_argument("--params", "-p", default="params.yaml") |
|
parsed_args = args.parse_args() |
|
|
|
try: |
|
logging.info("\n********************") |
|
logging.info(f">>>>> stage {STAGE} started <<<<<") |
|
main(config_path=parsed_args.config, params_path=parsed_args.params) |
|
logging.info(f">>>>> stage {STAGE} completed!<<<<<\n") |
|
except Exception as e: |
|
logging.exception(e) |
|
raise e |