Spaces:
Runtime error
Runtime error
import random | |
from torch.utils.data import Sampler | |
class BatchSampler(Sampler): | |
''' | |
A `torch.utils.data.Sampler` which samples batches according to a | |
maximum number of graph nodes. | |
:param node_counts: array of node counts in the dataset to sample from | |
:param max_batch_nodes: the maximum number of nodes in any batch, | |
including batches of a single element | |
:param shuffle: if `True`, batches in shuffled order | |
''' | |
def __init__(self, node_counts, max_batch_nodes=10000, shuffle=True): | |
self.node_counts = node_counts | |
self.idx = [i for i in range(len(node_counts)) if node_counts[i] <= max_batch_nodes] | |
self.shuffle = shuffle | |
self.max_batch_nodes = max_batch_nodes | |
self._form_batches() | |
def _form_batches(self): | |
self.batches = [] | |
if self.shuffle: random.shuffle(self.idx) | |
idx = self.idx | |
while idx: | |
batch = [] | |
max_n_node = 0 | |
while idx: | |
if max(self.node_counts[idx[0]], max_n_node) * (len(batch) + 1) > self.max_batch_nodes: | |
break | |
next_idx, idx = idx[0], idx[1:] | |
current_n_node = self.node_counts[next_idx] | |
if current_n_node > max_n_node: | |
max_n_node = current_n_node | |
batch.append(next_idx) | |
self.batches.append(batch) | |
def __len__(self): | |
if not self.batches: self._form_batches() | |
return len(self.batches) | |
def __iter__(self): | |
if not self.batches: self._form_batches() | |
for batch in self.batches: yield batch | |