Spaces:
Sleeping
Sleeping
import numpy as np | |
from tqdm import tqdm | |
from multiprocessing import Pool, Manager | |
from utils.utils import load_dataset | |
import os | |
import torch | |
import datetime | |
class DatasetBase(): | |
def __init__(self, coord_dim, num_samples, num_nodes, annotation, parallel, random_seed, num_cpus): | |
self.coord_dim = coord_dim | |
self.num_samples = num_samples | |
self.num_nodes = num_nodes | |
self.annotation = annotation | |
self.parallel = parallel | |
self.num_cpus = num_cpus | |
self.seed = random_seed | |
def generate_instance(self, seed): | |
raise NotImplementedError | |
def generate_dataset(self): | |
dataset = [] | |
num_required_samples = self.num_samples | |
seed = self.seed | |
end = False | |
print("Data generation started.", flush=True) | |
while(not end): | |
seeds = seed + np.arange(num_required_samples) | |
instances = [ | |
self.generate_instance(seed=s) | |
for s in tqdm(seeds, desc="Generating instances") | |
] | |
if self.annotation: | |
if self.parallel: | |
instances = self.generate_labeldata_para(instances, self.num_cpus) | |
else: | |
instances = self.generate_labeldata(instances) | |
dataset.extend(filter(None, instances)) | |
seed += num_required_samples | |
num_required_samples = self.num_samples - len(dataset) | |
if len(dataset) == self.num_samples: | |
end = True | |
else: | |
print(f"No feasible tour was not found in {num_required_samples} instances. Trying other {num_required_samples} instances.", flush=True) | |
print("Data generation completed.", flush=True) | |
return dataset | |
def annotate(self, instance): | |
raise NotImplementedError | |
def generate_labeldata(self, dataset): | |
""" | |
Parameters | |
---------- | |
dataset_path: str | |
path to the tsptw dataset | |
Returns | |
------- | |
dataset: | |
""" | |
return [self.annotate(instance) for instance in tqdm(dataset, desc="Annotating instances")] | |
def generate_labeldata_para(self, dataset, num_cpus): | |
with Pool(num_cpus) as pool: | |
annotation_data = list(tqdm(pool.imap(self.annotate, [instance for instance in dataset]), total=len(dataset), desc="Annotating instances")) | |
return annotation_data | |
import multiprocessing | |
import torch.multiprocessing | |
torch.multiprocessing.set_sharing_strategy("file_system") | |
class DataLoaderBase(torch.utils.data.Dataset): | |
def __init__(self, fpath, sequential=False, parallel=False, num_cpus=1): | |
now = datetime.datetime.now() | |
dir_name = f"test/data_load_{now.strftime('%Y%m%d_%H%M%S%f')}" | |
os.makedirs(dir_name) | |
annotation_data = load_dataset(fpath) | |
load = self.load_sequentially if sequential else self.load_randomly | |
if parallel: | |
data = [] | |
chunk_size = 1000 | |
num_process = multiprocessing.cpu_count() | |
pool = torch.multiprocessing.Pool(num_process) | |
for i in tqdm(range(0, len(annotation_data), chunk_size)): | |
chunk_data = annotation_data[i:i+chunk_size] | |
for fname in pool.starmap(load, [(instance, f"{dir_name}/chunk{i}_{j}.pkl") for j, instance in enumerate(chunk_data)]): | |
data.extend(load_dataset(fname)) | |
os.remove(fname) | |
pool.close() | |
self.data = data | |
else: | |
self.data = [elem for instance in tqdm(annotation_data) for elem in load(instance)] | |
self.size = len(self.data) | |
def __len__(self): | |
return self.size | |
def __getitem__(self, idx): | |
return self.data[idx] | |
def load_sequentially(self, instance, fname=None): | |
NotImplementedError | |
def load_randomly(self, instance, fname=None): | |
NotImplementedError |