Source code for nntoolbox.utils.utils
import torch
import numpy as np
import copy
from torch.nn import Module
from torch import nn, Tensor
from typing import Optional, List, Tuple
__all__ = [
'compute_num_batch', 'copy_model', 'save_model',
'load_model', 'get_device', 'get_trainable_parameters',
'count_trainable_parameters', 'to_onehot',
'to_onehotv2', 'is_nan', 'is_valid',
'get_children', 'get_all_submodules', 'find_index',
'dropout_mask'
]
[docs]def compute_num_batch(data_size: int, batch_size: int):
"""
Compute number of batches per epoch
:param data_size: number of datapoints
:param batch_size: number of datapoints per batch
:return:
"""
return int(np.ceil(data_size / float(batch_size)))
[docs]def copy_model(model: Module) -> Module:
"""
Return an exact copy of the model (both architecture and initial weights, without tying the weights)
:param model: model to be copied
:return: a copy of the model
"""
return copy.deepcopy(model)
[docs]def save_model(model: Module, path: str):
"""
Save a model
:param model:
:param path: path to save model at
"""
torch.save(model.state_dict(), path)
print("Model saved")
[docs]def load_model(model: Module, path: str):
"""
Load the model from path
:param model
:param path: path of saved model
"""
model.load_state_dict(torch.load(path))
print("Model loaded")
[docs]def get_device():
"""
Convenient helper for getting device
:return: a torch device object (gpu if exists)
"""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
[docs]def get_trainable_parameters(model: Module) -> List[Tensor]:
return list(filter(lambda p: p.requires_grad, model.parameters()))
[docs]def count_trainable_parameters(model: Module) -> int:
"""
Based on https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/8
:param model:
:return:
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs]def to_onehot(label: Tensor, n_class: Optional[int]=None) -> Tensor:
"""
Return one hot encoding of label (assuming the label index is 1)
:param label:
:param n_class:
:return:
"""
if n_class is None:
n_class = torch.max(label) + 1
label_oh = torch.zeros([label.shape[0], n_class] + list(label.shape)[1:]).long().to(label.device)
label = label.unsqueeze(1)
label_oh.scatter_(dim=1, index=label, value=1)
return label_oh
[docs]def to_onehotv2(label: Tensor, n_class: Optional[int] = None) -> Tensor:
"""
Return one hot encoding of label (assuming the label index is -1)
:param label:
:param n_class:
:return:
"""
if n_class is None:
n_class = torch.max(label) + 1
# label_oh = torch.zeros([label.shape[0], n_class] + list(label.shape)[1:]).long().to(label.device)
label_oh = torch.zeros(list(label.shape) + [n_class]).long().to(label.device)
label = label.unsqueeze(-1)
label_oh.scatter_(dim=-1, index=label, value=1)
return label_oh
[docs]def is_nan(tensor: Tensor) -> bool:
"""
Check if any element of a tensor is NaN
:param tensor:
:return: whether any element of the tensor is NaN
"""
return torch.isnan(tensor).any()
[docs]def is_valid(tensor: Tensor) -> bool:
"""
Check if a tensor is valid (not inf + not nan)
:param tensor:
:return: whether a tensor is valid
"""
sum = float(tensor.sum().cpu().detach())
return sum != float('-inf') and sum != float('inf') and sum == sum
[docs]def get_children(model: Module) -> List[Module]:
"""
:param model:
:return: list of all children of a model
"""
return list(model.children())
[docs]def get_all_submodules(module: Module) -> List[Module]:
"""
Get all submodules of a module
:param model:
:return: list of all submodules of a model
"""
return [submodule for submodule in module.modules() if type(submodule) != nn.Sequential]
[docs]def find_index(array, value):
return np.where(array == value)[0][0]
[docs]def dropout_mask(t: Tensor, size: Tuple[int, ...], drop_p):
return t.new(*size).bernoulli_(1 - drop_p).div(1 - drop_p)