import numpy as np from tqdm import tqdm from multiprocessing import Pool, Manager from utils.utils import load_dataset import os import torch import datetime class DatasetBase(): def __init__(self, coord_dim, num_samples, num_nodes, annotation, parallel, random_seed, num_cpus): self.coord_dim = coord_dim self.num_samples = num_samples self.num_nodes = num_nodes self.annotation = annotation self.parallel = parallel self.num_cpus = num_cpus self.seed = random_seed def generate_instance(self, seed): raise NotImplementedError def generate_dataset(self): dataset = [] num_required_samples = self.num_samples seed = self.seed end = False print("Data generation started.", flush=True) while(not end): seeds = seed + np.arange(num_required_samples) instances = [ self.generate_instance(seed=s) for s in tqdm(seeds, desc="Generating instances") ] if self.annotation: if self.parallel: instances = self.generate_labeldata_para(instances, self.num_cpus) else: instances = self.generate_labeldata(instances) dataset.extend(filter(None, instances)) seed += num_required_samples num_required_samples = self.num_samples - len(dataset) if len(dataset) == self.num_samples: end = True else: print(f"No feasible tour was not found in {num_required_samples} instances. Trying other {num_required_samples} instances.", flush=True) print("Data generation completed.", flush=True) return dataset def annotate(self, instance): raise NotImplementedError def generate_labeldata(self, dataset): """ Parameters ---------- dataset_path: str path to the tsptw dataset Returns ------- dataset: """ return [self.annotate(instance) for instance in tqdm(dataset, desc="Annotating instances")] def generate_labeldata_para(self, dataset, num_cpus): with Pool(num_cpus) as pool: annotation_data = list(tqdm(pool.imap(self.annotate, [instance for instance in dataset]), total=len(dataset), desc="Annotating instances")) return annotation_data import multiprocessing import torch.multiprocessing torch.multiprocessing.set_sharing_strategy("file_system") class DataLoaderBase(torch.utils.data.Dataset): def __init__(self, fpath, sequential=False, parallel=False, num_cpus=1): now = datetime.datetime.now() dir_name = f"test/data_load_{now.strftime('%Y%m%d_%H%M%S%f')}" os.makedirs(dir_name) annotation_data = load_dataset(fpath) load = self.load_sequentially if sequential else self.load_randomly if parallel: data = [] chunk_size = 1000 num_process = multiprocessing.cpu_count() pool = torch.multiprocessing.Pool(num_process) for i in tqdm(range(0, len(annotation_data), chunk_size)): chunk_data = annotation_data[i:i+chunk_size] for fname in pool.starmap(load, [(instance, f"{dir_name}/chunk{i}_{j}.pkl") for j, instance in enumerate(chunk_data)]): data.extend(load_dataset(fname)) os.remove(fname) pool.close() self.data = data else: self.data = [elem for instance in tqdm(annotation_data) for elem in load(instance)] self.size = len(self.data) def __len__(self): return self.size def __getitem__(self, idx): return self.data[idx] def load_sequentially(self, instance, fname=None): NotImplementedError def load_randomly(self, instance, fname=None): NotImplementedError