nntoolbox.utils.utils module¶
-
nntoolbox.utils.utils.
compute_num_batch
(data_size: int, batch_size: int)[source]¶ Compute number of batches per epoch
- Parameters
data_size – number of datapoints
batch_size – number of datapoints per batch
- Returns
-
nntoolbox.utils.utils.
copy_model
(model: torch.nn.modules.module.Module) → torch.nn.modules.module.Module[source]¶ Return an exact copy of the model (both architecture and initial weights, without tying the weights)
- Parameters
model – model to be copied
- Returns
a copy of the model
-
nntoolbox.utils.utils.
count_trainable_parameters
(model: torch.nn.modules.module.Module) → int[source]¶ Based on https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/8
- Parameters
model –
- Returns
-
nntoolbox.utils.utils.
get_all_submodules
(module: torch.nn.modules.module.Module) → List[torch.nn.modules.module.Module][source]¶ Get all submodules of a module
- Parameters
model –
- Returns
list of all submodules of a model
-
nntoolbox.utils.utils.
get_children
(model: torch.nn.modules.module.Module) → List[torch.nn.modules.module.Module][source]¶ - Parameters
model –
- Returns
list of all children of a model
-
nntoolbox.utils.utils.
get_device
()[source]¶ Convenient helper for getting device
- Returns
a torch device object (gpu if exists)
-
nntoolbox.utils.utils.
get_trainable_parameters
(model: torch.nn.modules.module.Module) → List[torch.Tensor][source]¶
-
nntoolbox.utils.utils.
is_nan
(tensor: torch.Tensor) → bool[source]¶ Check if any element of a tensor is NaN
- Parameters
tensor –
- Returns
whether any element of the tensor is NaN
-
nntoolbox.utils.utils.
is_valid
(tensor: torch.Tensor) → bool[source]¶ Check if a tensor is valid (not inf + not nan)
- Parameters
tensor –
- Returns
whether a tensor is valid
-
nntoolbox.utils.utils.
load_model
(model: torch.nn.modules.module.Module, path: str)[source]¶ Load the model from path
:param model :param path: path of saved model
-
nntoolbox.utils.utils.
save_model
(model: torch.nn.modules.module.Module, path: str)[source]¶ Save a model
- Parameters
model –
path – path to save model at