nntoolbox.utils.sampler module¶
-
class
nntoolbox.utils.sampler.BatchSampler(sampler: torch.utils.data.sampler.Sampler[int], batch_size: int, drop_last: bool)[source]¶ Bases:
torch.utils.data.sampler.Sampler[List[int]]Wraps another sampler to yield a mini-batch of indices.
- Args:
sampler (Sampler or Iterable): Base sampler. Can be any iterable object batch_size (int): Size of mini-batch. drop_last (bool): If
True, the sampler will drop the last batch ifits size would be less than
batch_size- Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
-
class
nntoolbox.utils.sampler.BatchStratifiedSampler(data_source: torch.utils.data.dataset.Dataset, n_sample_per_class: int, n_class_per_batch: int, drop_last: bool = False)[source]¶ Bases:
torch.utils.data.sampler.Sampler[List[int]]Ensure that each class in the batch has the same number of examples
-
class
nntoolbox.utils.sampler.MultiRandomSampler(data_source, batch_size: int, replacement=False)[source]¶ Bases:
Generic[torch.utils.data.sampler.T_co]Samples elements randomly to form multiple batches
- Arguments:
data_source (Dataset): dataset to sample from replacement (bool): samples are drawn with replacement if
True, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. This argumentis supposed to be specified only when replacement is
True.