Source code for nntoolbox.optim.utils

import torch
from torch.optim import Optimizer
from typing import Callable, Union, List
import numpy as np
import matplotlib.pyplot as plt


__all__ = ['get_lr', 'change_lr', 'plot_schedule', 'save_optimizer', 'load_optimizer']


[docs]def get_lr(optim: Optimizer) -> List[float]: return [param_group['lr'] for param_group in optim.param_groups]
# UNTESTED
[docs]def change_lr(optim: Optimizer, lrs: Union[float, List[float]]): """ Change the learning rate of an optimizer :param optim: optimizer :param lrs: target learning rate """ if isinstance(lrs, list): assert len(lrs) == len(optim.param_groups) else: lrs = [lrs for _ in range(len(optim.param_groups))] for param_group, lr in zip(optim.param_groups, lrs): param_group['lr'] = lr
[docs]def plot_schedule(schedule_fn: Callable[[int], float], iterations: int=30): """ Plot the learning rate schedule function :param schedule_fn: a function that returns a learning rate given an iteration :param iterations: maximum number of iterations (or epochs) :return: """ iterations = np.arange(iterations) lrs = np.array(list(map(schedule_fn, iterations))) plt.plot(iterations, lrs) plt.xlabel("Iterations") plt.ylabel("Learning Rate") plt.show()
# UNTESTED
[docs]def save_optimizer(optimizer: Optimizer, path: str): """ Save optimizer state for resuming training :param optimizer: :param path: """ torch.save(optimizer.state_dict(), path) print("Optimizer state saved.")
# UNTESTED
[docs]def load_optimizer(optimizer: Optimizer, path: str): """ Load optimizer state for resuming training :param optimizer: :param path: """ optimizer.load_state_dict(torch.load(path)) print("Optimizer state loaded.")