Source code for nntoolbox.optim.coord_desc
"""
Coordinate Descent Optimizer
"""
from torch.optim import Optimizer
from typing import List
__all__ = ['CoordDescOptimizer']
[docs]class CoordDescOptimizer(Optimizer):
def __init__(self, optimizers: List[Optimizer]):
self.optimizers = optimizers
self.ind = 0
[docs] def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()
[docs] def step(self, closure=None):
self.optimizers[self.ind % len(self.optimizers)].step(closure)
self.ind += 1