|
"""This is a slightly modified version of https://github.com/ufoym/imbalanced-dataset-sampler.""" |
|
|
|
from typing import Callable, List, Optional |
|
|
|
import pandas as pd |
|
import torch |
|
import torch.utils.data |
|
|
|
|
|
class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): |
|
"""Samples elements randomly from a given list of indices for imbalanced dataset. |
|
|
|
Arguments: |
|
indices: a list of indices |
|
num_samples: number of samples to draw |
|
callback_get_label: a callback-like function which takes one argument - the dataset |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset, |
|
labels: Optional[List] = None, |
|
indices: Optional[List] = None, |
|
num_samples: Optional[int] = None, |
|
callback_get_label: Optional[Callable] = None, |
|
): |
|
|
|
self.indices = list(range(len(dataset))) if indices is None else indices |
|
|
|
|
|
self.callback_get_label = callback_get_label |
|
|
|
|
|
self.num_samples = len(self.indices) if num_samples is None else num_samples |
|
|
|
|
|
df = pd.DataFrame() |
|
df["label"] = self._get_labels(dataset) if labels is None else labels |
|
df.index = self.indices |
|
df = df.sort_index() |
|
|
|
label_to_count = df["label"].value_counts() |
|
|
|
weights = 1.0 / label_to_count[df["label"]] |
|
|
|
self.weights = torch.DoubleTensor(weights.to_list()) |
|
|
|
def _get_labels(self, dataset): |
|
if self.callback_get_label: |
|
return self.callback_get_label(dataset) |
|
elif isinstance(dataset, torch.utils.data.TensorDataset): |
|
return dataset.tensors[1] |
|
elif isinstance(dataset, torch.utils.data.Subset): |
|
return dataset.dataset.imgs[:][1] |
|
elif isinstance(dataset, torch.utils.data.Dataset): |
|
return dataset.get_labels() |
|
else: |
|
raise NotImplementedError |
|
|
|
def __iter__(self): |
|
return ( |
|
self.indices[i] |
|
for i in torch.multinomial(self.weights, self.num_samples, replacement=True) |
|
) |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|