Source code for nntoolbox.utils.sampler

from torch.utils.data import Sampler, WeightedRandomSampler, BatchSampler, Dataset
import torch
import numpy as np
from .utils import compute_num_batch


__all__ = ['MultiRandomSampler', 'BatchSampler', 'BatchStratifiedSampler']


[docs]class MultiRandomSampler(Sampler): """Samples elements randomly to form multiple batches Arguments: data_source (Dataset): dataset to sample from replacement (bool): samples are drawn with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. This argument is supposed to be specified only when `replacement` is ``True``. """ def __init__(self, data_source, batch_size:int, replacement=False): self.data_source = data_source self.replacement = replacement self.batch_size = batch_size assert self.batch_size <= len(data_source) if not isinstance(self.replacement, bool): raise ValueError("replacement should be a boolean value, but got " "replacement={}".format(self.replacement)) def __iter__(self): n = len(self.data_source) indices = [] # for _ in range(compute_num_batch(n, self.batch_size)): if self.replacement: indices = torch.randint(high=n, size=(len(self.data_source),), dtype=torch.int64).tolist() else: indices = torch.randperm(n).tolist() return iter(indices) def __len__(self): return len(self.data_source)
class BalancedSampler(WeightedRandomSampler): """ For each data point, sample it with weight inversely proportional to the number of points of its class """ def __init__(self, data_source: Dataset, num_samples: int, replacement: bool=True): """ :param data_source: dataset :param num_samples: number of samples to draw :param replacement: if ``True``, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row """ all_labels = [] for inputs, labels in data_source: all_labels.append(labels) class_weights = 1.0 / np.bincount(all_labels) weights = list(map(lambda label: class_weights[label], all_labels)) # weights = np.array(weights) / np.sum(weights) super(BalancedSampler, self).__init__(weights, num_samples, replacement)
[docs]class BatchStratifiedSampler(BatchSampler): """Ensure that each class in the batch has the same number of examples""" def __init__(self, data_source: Dataset, n_sample_per_class: int, n_class_per_batch: int, drop_last: bool=False): """ :param data_source: dataset :param batch_size: size of each batch :param n_class_per_batch: number of class to sample each batch :param drop_last: if ``True``, drop the last batch. """ all_labels = [] for inputs, labels in data_source: all_labels.append(labels) self.labels = np.array(all_labels) self.label_counts = np.bincount(all_labels) assert len(self.label_counts) >= n_class_per_batch self.n_data = int(np.sum(all_labels)) self.batch_size = n_class_per_batch * n_sample_per_class self.n_sample_per_class = n_sample_per_class self.n_class_per_batch, self.drop_last = n_class_per_batch, drop_last def __iter__(self): for _ in range(self.__len__()): classes = np.random.choice(len(self.label_counts), self.n_class_per_batch, replace=False) batch_idx = [] for c in classes: c_idx = np.where(self.labels == c)[0] batch_idx.append(np.random.choice(c_idx, size=self.n_sample_per_class, replace=True)) yield np.concatenate(batch_idx, axis=0) def __len__(self): if self.drop_last: return self.n_data // self.batch_size return compute_num_batch(self.n_data, self.batch_size)