import torch
import numpy as np
from torch.nn import Module
from torch import Tensor
from torch.autograd import grad
from typing import List, Callable, Union, Iterable
__all__ = [
'compute_gradient', 'compute_jacobian', 'compute_jacobian_v2',
'update_gradient', 'accumulate_gradient', 'compute_gradient_norm',
'hessian_diagonal', 'gather_flat_grad'
]
[docs]def compute_gradient(output: Tensor, model: Module) -> List[Tensor]:
"""
Comput gradient of the output of a model
:param output:
:param model:
:return: list of gradients of model parameters
"""
ret = []
output.backward(retain_graph=True)
for parameter in model.parameters():
ret.append(parameter.grad)
parameter.grad = None # Reset gradient accumulation
return ret
[docs]def compute_jacobian(
input: Tensor, fn: Callable[[Tensor], Tensor], is_batch: bool=True, requires_grad: bool=True
) -> Tensor:
"""
Compute the jacobian of function(input) with respect to input. For most purpose, should use v2
:param output:
:param input: assume that input require_grad = True
:param fn:
:param batch: whether to compute gradient by batch
:return:
"""
if is_batch:
return torch.stack([compute_jacobian(input[ind], fn, False) for ind in range(len(input))], dim=0)
else:
output = fn(input)
output_shape = output.shape
input_shape = input.shape
output = output.view(-1)
grad = [
torch.autograd.grad(output[ind], [input], allow_unused=True, create_graph=requires_grad)[0]
for ind in range(len(output))
]
return torch.stack(grad, dim=0).reshape(output_shape + input_shape)
[docs]def compute_jacobian_v2(
output: Tensor, input: Union[Tensor, Iterable[Tensor]], requires_grad: bool=True
) -> Union[Tensor, Iterable[Tensor]]:
"""
Compute the jacobian of a vector with respect to an input tensor
:param output: a 1D vector of length L
:param input: either a tensor (parameter) or an iterable of paramters
:param requires_grad: whether output should be differentiable
:return: jacobian
"""
if isinstance(input, Tensor):
assert len(output.shape) == 1
grads = [grad(output[ind], input, create_graph=requires_grad)[0] for ind in range(len(output))]
return torch.stack(grads, dim=0)
else:
return [compute_jacobian_v2(output, param, requires_grad) for param in input]
[docs]def update_gradient(gradients: Tensor, model: Module, fn: Callable[[Tensor], Tensor]=lambda x:x):
for gradient, parameter in zip(gradients, model.parameters()):
parameter.grad = fn(gradient) # Reset gradient accumulation
[docs]def accumulate_gradient(gradients, model, fn=lambda x:x):
for gradient, parameter in zip(gradients, model.parameters()):
parameter.grad += fn(gradient) # Reset gradient accumulation
[docs]def compute_gradient_norm(output: Tensor, model: Module):
"""
Compute the norm of the gradient of an output (e.g a loss) with respect to a model parameters
:param output:
:param model:
:return:
"""
ret = 0
output.backward(retain_graph=True)
for parameter in model.parameters():
grad = parameter.grad
ret += grad.pow(2).sum().cpu().detach().numpy()
parameter.grad = None # Reset gradient accumulation
return ret
[docs]def hessian_diagonal(
output: Tensor, input: Union[Tensor, Iterable], requires_grad: bool=True
) -> Union[Tensor, List[Tensor]]:
"""
Compute the diagonal of the hessian
:param output: a scalar tensor
:param input: either a tensor (parameter), or a list/generator of parameters
:param requires_grad: whether output should be differentiable
:return: a tensor (parameter), or a list/generator of parameters, denoting the diagonal of hessian of output
with respect to input
"""
if isinstance(input, Tensor):
original_grad = input.grad
assert output.numel() == 1
grads = grad(output, input, create_graph=True)[0]
if not grads.requires_grad:
input.grad = original_grad
return torch.zeros(input.shape)
grads.view(-1).backward(torch.eye(grads.numel()), create_graph=requires_grad)
hess_diag = input.grad if input.grad is not None else torch.zeros(input.shape)
input.grad = original_grad
return hess_diag
else:
hess_diags = []
for param in input:
hess_diags.append(hessian_diagonal(output, param, requires_grad))
return hess_diags
[docs]def gather_flat_grad(params: Iterable[Tensor]) -> Tensor:
"""
Gather gradient of all the parameters and flatten into a vector. Adapted from pytorch's L-BFGS implementation.
:param params: List of parameters
:return: gradient vector of the parameters
"""
views = []
for p in params:
if p.grad is None:
view = p.new(p.numel()).zero_()
elif p.grad.is_sparse:
view = p.grad.to_dense().view(-1)
else:
view = p.grad.view(-1)
views.append(view)
return torch.cat(views, 0)