EdgeTA / data /dataset.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame contribute delete
1.28 kB
import importlib
from typing import Type
import torch
from torch.utils.data import TensorDataset
from torch.utils.data.dataloader import DataLoader
from .datasets.ab_dataset import ABDataset
from .datasets import * # import all datasets
from .datasets.registery import static_dataset_registery
def get_dataset(dataset_name, root_dir, split, transform=None, ignore_classes=[], idx_map=None) -> ABDataset:
dataset_cls = static_dataset_registery[dataset_name][0]
dataset = dataset_cls(root_dir, split, transform, ignore_classes, idx_map)
return dataset
def get_num_limited_dataset(dataset: ABDataset, num_samples: int, discard_label=True):
dataloader = iter(DataLoader(dataset, num_samples // 2, shuffle=True))
x, y = [], []
cur_num_samples = 0
while True:
batch = next(dataloader)
cur_x, cur_y = batch[0], batch[1]
x += [cur_x]
y += [cur_y]
cur_num_samples += cur_x.size(0)
if cur_num_samples >= num_samples:
break
x, y = torch.cat(x)[0: num_samples], torch.cat(y)[0: num_samples]
if discard_label:
new_dataset = TensorDataset(x)
else:
new_dataset = TensorDataset(x, y)
dataset.dataset = new_dataset
return dataset