Source code for nntoolbox.callbacks.device
from .callbacks import Callback
from typing import Dict, List, Union, Optional, Callable
from torch import Tensor, device
from torch.nn import DataParallel, Module, AdaptiveAvgPool2d, Sequential
from ..utils import get_device, cut_model
__all__ = ['ToDeviceCallback']
[docs]class ToDeviceCallback(Callback):
def __init__(self, device=get_device()):
self._device = device
self.learner = None
[docs] def on_train_begin(self):
if hasattr(self.learner, '_model'):
self.learner._model = self.learner._model.to(self._device)
elif hasattr(self.learner, '_models'):
for i in range(len(self.learner._models)):
self.learner._models[i] = self.learner._models[i].to(self._device)
[docs] def on_batch_begin(self, data: Dict[str, Tensor], train: bool) -> Dict[str, Tensor]:
for key in data:
data[key] = data[key].to(self._device)
return data
# UNTESTED
class DataParallelismCallback(Callback):
"""
Callback for naive data parallelism: copy model to each device, then divide each batch into micro-batch for
independent processing
"""
def __init__(
self, device_ids: Optional[List[Union[int, device]]]=None,
output_device: Optional[Union[int, device]]=None, dim: int=0
):
self.device_ids = device_ids
self.ouput_device = output_device
self.dim = dim
def on_train_begin(self):
self.learner._model = DataParallel(self.learner._model, self.device_ids, self.ouput_device, self.dim)
class MixedParallelismCB(Callback):
"""
Callback for mixed parallelism: data parallelism for convolution/feature layers, and model parallelism for head
(UNTESTED)
References:
https://discuss.pytorch.org/t/why-not-giving-the-whole-model-to-dataparallel-in-the-imagenet-example/4092
https://arxiv.org/pdf/1404.5997.pdf
"""
def __init__(
self, device_ids: Optional[List[Union[int, device]]]=None,
output_device: Optional[Union[int, device]]=None, dim: int=0,
sep: Callable[..., Module]=AdaptiveAvgPool2d
):
self.device_ids = device_ids
self.ouput_device = output_device
self.dim = dim
self.sep = sep
def on_train_begin(self):
features, head = cut_model(self.learner._model, sep=self.sep)
features = DataParallel(features, self.device_ids, self.ouput_device, self.dim)
self.learner._model = Sequential(features, head).to(self.ouput_device)