nntoolbox.utils.sampler module

class nntoolbox.utils.sampler.BatchSampler(sampler: torch.utils.data.sampler.Sampler[int], batch_size: int, drop_last: bool)[source]

Bases: torch.utils.data.sampler.Sampler[List[int]]

Wraps another sampler to yield a mini-batch of indices.

Args:

sampler (Sampler or Iterable): Base sampler. Can be any iterable object batch_size (int): Size of mini-batch. drop_last (bool): If True, the sampler will drop the last batch if

its size would be less than batch_size

Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class nntoolbox.utils.sampler.BatchStratifiedSampler(data_source: torch.utils.data.dataset.Dataset, n_sample_per_class: int, n_class_per_batch: int, drop_last: bool = False)[source]

Bases: torch.utils.data.sampler.Sampler[List[int]]

Ensure that each class in the batch has the same number of examples

class nntoolbox.utils.sampler.MultiRandomSampler(data_source, batch_size: int, replacement=False)[source]

Bases: Generic[torch.utils.data.sampler.T_co]

Samples elements randomly to form multiple batches

Arguments:

data_source (Dataset): dataset to sample from replacement (bool): samples are drawn with replacement if True, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. This argument

is supposed to be specified only when replacement is True.