File size: 1,481 Bytes
9e426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os.path
import random

import torch
from torch.utils.data import Dataset



class RandomNDataset(Dataset):
    def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, selected_classes:list=None, seeds=None, max_num_instances=50000, ):
        self.selected_classes = selected_classes
        if selected_classes is not None:
            num_classes = len(selected_classes)
            max_num_instances = 10*num_classes
        self.num_classes = num_classes
        self.seeds = seeds
        if seeds is not None:
            self.max_num_instances = len(seeds)*num_classes
            self.num_seeds = len(seeds)
        else:
            self.num_seeds = (max_num_instances + num_classes - 1)  // num_classes
            self.max_num_instances = self.num_seeds*num_classes

        self.latent_shape = latent_shape


    def __getitem__(self, idx):
        label = idx // self.num_seeds
        if self.selected_classes:
            label = self.selected_classes[label]
        seed = random.randint(0, 1<<31) #idx % self.num_seeds
        if self.seeds is not None:
            seed = self.seeds[idx % self.num_seeds]

        # cls_dir = os.path.join(self.root, f"{label}")
        filename = f"{label}_{seed}.png",
        generator = torch.Generator().manual_seed(seed)
        latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32)
        return latent, label, filename
    def __len__(self):
        return self.max_num_instances