File size: 285 Bytes
9e426da
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
from typing import Callable
from torchvision.datasets import CelebA


class LocalDataset(CelebA):
    def __init__(self, root:str,  ):
        super(LocalDataset, self).__init__(root, "train")

    def __getitem__(self, idx):
        data = super().__getitem__(idx)
        return data