File size: 615 Bytes
9ff4511
 
 
 
 
 
 
 
 
 
 
 
 
e2b0b28
9ff4511
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from typing import Tuple
from torch.utils.data import Dataset

from datasets.valentini import Valentini
from transforms import Transform

DATASETS_POOL = {
    'valentini': Valentini
}


def get_datasets(cfg) -> Tuple[Dataset, Dataset]:
    name, dataset_params = list(cfg['dataset'].items())[0]
    transform = Transform(input_sample_rate=dataset_params['sample_rate'], **cfg['dataloader'])
    train_dataset = DATASETS_POOL[name](valid=False, transform=transform, **dataset_params)
    valid_dataset = DATASETS_POOL[name](valid=True, transform=transform, **dataset_params)
    return train_dataset, valid_dataset