Source code for nntoolbox.callbacks.bs_scheduler
from .callbacks import Callback
from typing import Dict, Any, Callable
from torch.utils.data import DataLoader
__all__ = ['BatchSizeScheduler']
[docs]class BatchSizeScheduler(Callback):
def __init__(self, train_data: DataLoader, bs_schedule_fn: Callable[[int], int], timescale: str="iter"):
assert timescale == "iter" or timescale == "epoch"
self.timescale = timescale
self._train_data = train_data
self._bs_schedule_fn = bs_schedule_fn
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool:
if self.timescale == "epoch":
new_bs = self._bs_schedule_fn(logs["epoch"])
self._train_data.batch_size = new_bs
self._train_data.batch_sampler.batch_size = new_bs
return False
[docs] def on_batch_end(self, logs: Dict[str, Any]):
if self.timescale == "iter":
new_bs = self._bs_schedule_fn(logs["iter_cnt"])
self._train_data.batch_size = new_bs
self._train_data.batch_sampler.batch_size = new_bs
# UNTESTED
class BatchSizeIncreaser(Callback):
"""
Implement a callback to increase batch size during training
https://arxiv.org/pdf/1711.00489.pdf
"""
def __init__(
self, train_data: DataLoader, update_after: int, update_every: int,
bs_init: int, lr_scheduler, bs_max: int, bs_inc_rate: float=5.0
):
self._update_after = update_after
self._update_every = update_every
self._train_data = train_data
self.bs_inc_rate = bs_inc_rate
self.bs_max = bs_max
self.cur_bs = bs_init
self._lr_scheduler = lr_scheduler
def on_epoch_end(self, logs: Dict[str, Any]) -> bool:
if logs['epoch'] >= self._update_after and (logs['epoch'] - self._update_after) % self._update_every == 0:
new_bz = int(self.cur_bs * self.bs_inc_rate)
if new_bz < self.bs_max:
self.cur_bs = new_bz
self._train_data.batch_size = new_bz
self._train_data.batch_sampler.batch_size = new_bz
else:
self._lr_scheduler.step()
print("Increase batch size to " + str(new_bz))
return False