Source code for nntoolbox.callbacks.gradient
from .callbacks import Callback
from torch.nn.utils import clip_grad_value_, clip_grad_norm_
__all__ = ['GradientValueClipping', 'GradientNormClipping']
# UNTESTED
[docs]class GradientValueClipping(Callback):
def __init__(self, clip_value: float):
"""
:param clip_value: range of allowed gradient: (-clip, clip)
"""
self.clip_value = clip_value
[docs] def after_backward(self) -> bool:
clip_grad_value_(self.learner._model.parameters(), self.clip_value)
return True
# UNTESTED
[docs]class GradientNormClipping(Callback):
def __init__(self, max_norm: float, norm_type=2):
"""
:param clip_value: range of allowed gradient: (-clip, clip)
"""
self.max_norm = max_norm
self.norm_type = norm_type
[docs] def after_backward(self) -> bool:
clip_grad_norm_(self.learner._model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type)
return True