nntoolbox.utils.sampler module¶
-
class
nntoolbox.utils.sampler.
BatchSampler
(sampler: torch.utils.data.sampler.Sampler[int], batch_size: int, drop_last: bool)[source]¶ Bases:
torch.utils.data.sampler.Sampler
[List
[int
]]Wraps another sampler to yield a mini-batch of indices.
- Args:
sampler (Sampler or Iterable): Base sampler. Can be any iterable object batch_size (int): Size of mini-batch. drop_last (bool): If
True
, the sampler will drop the last batch ifits size would be less than
batch_size
- Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
-
class
nntoolbox.utils.sampler.
BatchStratifiedSampler
(data_source: torch.utils.data.dataset.Dataset, n_sample_per_class: int, n_class_per_batch: int, drop_last: bool = False)[source]¶ Bases:
torch.utils.data.sampler.Sampler
[List
[int
]]Ensure that each class in the batch has the same number of examples
-
class
nntoolbox.utils.sampler.
MultiRandomSampler
(data_source, batch_size: int, replacement=False)[source]¶ Bases:
Generic
[torch.utils.data.sampler.T_co
]Samples elements randomly to form multiple batches
- Arguments:
data_source (Dataset): dataset to sample from replacement (bool): samples are drawn with replacement if
True
, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. This argumentis supposed to be specified only when replacement is
True
.