File size: 1,711 Bytes
8918ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import random
from torch.utils.data import Sampler

class BatchSampler(Sampler):
    '''

    A `torch.utils.data.Sampler` which samples batches according to a

    maximum number of graph nodes.

    

    :param node_counts: array of node counts in the dataset to sample from

    :param max_batch_nodes: the maximum number of nodes in any batch,

                      including batches of a single element

    :param shuffle: if `True`, batches in shuffled order

    '''
    def __init__(self, node_counts, max_batch_nodes=10000, shuffle=True):
        
        self.node_counts = node_counts
        self.idx = [i for i in range(len(node_counts)) if node_counts[i] <= max_batch_nodes]
        self.shuffle = shuffle
        self.max_batch_nodes = max_batch_nodes
        self._form_batches()
    
    def _form_batches(self):
        self.batches = []
        if self.shuffle: random.shuffle(self.idx)
        idx = self.idx
        while idx:
            batch = []
            max_n_node = 0
            while idx:
                if max(self.node_counts[idx[0]], max_n_node) * (len(batch) + 1) > self.max_batch_nodes:
                    break
                next_idx, idx = idx[0], idx[1:]
                current_n_node = self.node_counts[next_idx]
                if current_n_node > max_n_node:
                    max_n_node = current_n_node
                batch.append(next_idx)
            self.batches.append(batch)
    
    def __len__(self): 
        if not self.batches: self._form_batches()
        return len(self.batches)
    
    def __iter__(self):
        if not self.batches: self._form_batches()
        for batch in self.batches: yield batch