EdgeTA / data /build_cl /scenario.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame contribute delete
5.73 kB
import enum
from functools import reduce
from typing import Dict, List, Tuple
import numpy as np
import copy
from utils.common.log import logger
from ..datasets.ab_dataset import ABDataset
from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader
from data import get_dataset, MergedDataset, Scenario as DAScenario
class _ABDatasetMetaInfo:
def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type, ignore_classes, idx_map):
self.name = name
self.classes = classes
self.class_aliases = class_aliases
self.shift_type = shift_type
self.task_type = task_type
self.object_type = object_type
self.ignore_classes = ignore_classes
self.idx_map = idx_map
def __repr__(self) -> str:
return f'({self.name}, {self.classes}, {self.idx_map})'
class Scenario:
def __init__(self, config, target_datasets_info: List[_ABDatasetMetaInfo], num_classes: int, num_source_classes: int, data_dirs):
self.config = config
self.target_datasets_info = target_datasets_info
self.num_classes = num_classes
self.cur_task_index = 0
self.num_source_classes = num_source_classes
self.cur_class_offset = num_source_classes
self.data_dirs = data_dirs
self.target_tasks_order = [i.name for i in self.target_datasets_info]
self.num_tasks_to_be_learn = sum([len(i.classes) for i in target_datasets_info])
logger.info(f'[scenario build] # classes: {num_classes}, # tasks to be learnt: {len(target_datasets_info)}, '
f'# classes per task: {config["num_classes_per_task"]}')
def to_json(self):
config = copy.deepcopy(self.config)
config['da_scenario'] = config['da_scenario'].to_json()
target_datasets_info = [str(i) for i in self.target_datasets_info]
return dict(
config=config, target_datasets_info=target_datasets_info,
num_classes=self.num_classes
)
def __str__(self):
return f'Scenario({self.to_json()})'
def get_cur_class_offset(self):
return self.cur_class_offset
def get_cur_num_class(self):
return len(self.target_datasets_info[self.cur_task_index].classes)
def get_nc_per_task(self):
return len(self.target_datasets_info[0].classes)
def next_task(self):
self.cur_class_offset += len(self.target_datasets_info[self.cur_task_index].classes)
self.cur_task_index += 1
print(f'now, cur task: {self.cur_task_index}, cur_class_offset: {self.cur_class_offset}')
def get_cur_task_datasets(self):
dataset_info = self.target_datasets_info[self.cur_task_index]
dataset_name = dataset_info.name.split('|')[0]
# print()
# source_datasets_info = []
res ={ **{split: get_dataset(dataset_name=dataset_name,
root_dir=self.data_dirs[dataset_name],
split=split,
transform=None,
ignore_classes=dataset_info.ignore_classes,
idx_map=dataset_info.idx_map) for split in ['train']},
**{split: MergedDataset([get_dataset(dataset_name=dataset_name,
root_dir=self.data_dirs[dataset_name],
split=split,
transform=None,
ignore_classes=di.ignore_classes,
idx_map=di.idx_map) for di in self.target_datasets_info[0: self.cur_task_index + 1]])
for split in ['val', 'test']}
}
# if len(res['train']) < 200 or len(res['val']) < 200 or len(res['test']) < 200:
# return None
if len(res['train']) < 1000:
res['train'] = MergedDataset([res['train']] * 5)
logger.info('aug train dataset')
if len(res['val']) < 1000:
res['val'] = MergedDataset(res['val'].datasets * 5)
logger.info('aug val dataset')
if len(res['test']) < 1000:
res['test'] = MergedDataset(res['test'].datasets * 5)
logger.info('aug test dataset')
# da_scenario: DAScenario = self.config['da_scenario']
# offline_datasets = da_scenario.get_offline_datasets()
for k, v in res.items():
logger.info(f'{k} dataset: {len(v)}')
# new_val_datasets = [
# *[d['val'] for d in offline_datasets.values()],
# res['val']
# ]
# res['val'] = MergedDataset(new_val_datasets)
# new_test_datasets = [
# *[d['test'] for d in offline_datasets.values()],
# res['test']
# ]
# res['test'] = MergedDataset(new_test_datasets)
return res
def get_cur_task_train_datasets(self):
dataset_info = self.target_datasets_info[self.cur_task_index]
dataset_name = dataset_info.name.split('|')[0]
# print()
# source_datasets_info = []
res = get_dataset(dataset_name=dataset_name,
root_dir=self.data_dirs[dataset_name],
split='train',
transform=None,
ignore_classes=dataset_info.ignore_classes,
idx_map=dataset_info.idx_map)
return res
def get_online_cur_task_samples_for_training(self, num_samples):
dataset = self.get_cur_task_datasets()
dataset = dataset['train']
return next(iter(build_dataloader(dataset, num_samples, 0, True, None)))[0]