Spaces:
Runtime error
Runtime error
File size: 1,711 Bytes
8918ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import random
from torch.utils.data import Sampler
class BatchSampler(Sampler):
'''
A `torch.utils.data.Sampler` which samples batches according to a
maximum number of graph nodes.
:param node_counts: array of node counts in the dataset to sample from
:param max_batch_nodes: the maximum number of nodes in any batch,
including batches of a single element
:param shuffle: if `True`, batches in shuffled order
'''
def __init__(self, node_counts, max_batch_nodes=10000, shuffle=True):
self.node_counts = node_counts
self.idx = [i for i in range(len(node_counts)) if node_counts[i] <= max_batch_nodes]
self.shuffle = shuffle
self.max_batch_nodes = max_batch_nodes
self._form_batches()
def _form_batches(self):
self.batches = []
if self.shuffle: random.shuffle(self.idx)
idx = self.idx
while idx:
batch = []
max_n_node = 0
while idx:
if max(self.node_counts[idx[0]], max_n_node) * (len(batch) + 1) > self.max_batch_nodes:
break
next_idx, idx = idx[0], idx[1:]
current_n_node = self.node_counts[next_idx]
if current_n_node > max_n_node:
max_n_node = current_n_node
batch.append(next_idx)
self.batches.append(batch)
def __len__(self):
if not self.batches: self._form_batches()
return len(self.batches)
def __iter__(self):
if not self.batches: self._form_batches()
for batch in self.batches: yield batch
|