Source code for nntoolbox.callbacks.resizing
from .callbacks import Callback
from typing import Dict, Any
from torch import Tensor
import torch.nn.functional as F
__all__ = ['InputProgressiveResizing']
# UNTESTED
[docs]class InputProgressiveResizing(Callback):
"""
Implement a callback for progressive resizing (input only)
"""
def __init__(self, initial_size: int, max_size: int, upscale_every: int, upscale_factor: float, mode='bilinear'):
self.size, self.max_size = initial_size, max_size
self.initial_size = initial_size
self.upscale_every, self.upscale_factor = upscale_every, upscale_factor
self.mode = mode
[docs] def on_batch_begin(self, data: Dict[str, Tensor], train) -> Dict[str, Tensor]:
data["inputs"] = F.interpolate(data["inputs"], size=(self.size, self.size), mode=self.mode)
return data
[docs] def on_epoch_end(self, logs: Dict[str, Any]) -> bool:
if logs["epoch"] % self.upscale_every == 0 and self.size * self.upscale_factor <= self.max_size:
self.size = int(self.initial_size * (self.upscale_factor ** (logs["epoch"] // self.upscale_every)))
print("Increasing the scale of input to " + str(self.size))
return False