Source code for secmlt.optimization.scheduler_factory

"""Learning Rate Schedulers creation tools."""

from __future__ import annotations  # noqa: I001

import functools
from typing import ClassVar, TYPE_CHECKING

from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR

if TYPE_CHECKING:
    from torch.optim import Optimizer

COSINE_ANNEALING = "cosine_annealing"
NO_SCHEDULER = "no_scheduler"


[docs] class NoScheduler(_LRScheduler): """No learning rate scheduler, does nothing.""" def __init__(self, optimizer: Optimizer, last_epoch: int | None = -1) -> None: """Create a NoScheduler instance.""" super().__init__(optimizer, last_epoch)
[docs] def step(self, epoch: int | None = None) -> None: """No operation."""
[docs] class LRSchedulerFactory: """Creator class for learning rate schedulers.""" SCHEDULERS: ClassVar[dict[str, _LRScheduler]] = { NO_SCHEDULER: NoScheduler, COSINE_ANNEALING: CosineAnnealingLR, }
[docs] @staticmethod def create_from_name( scheduler_name: str, **kwargs, ) -> functools.partial[_LRScheduler]: """ Create a learning rate scheduler. Parameters ---------- scheduler_name : str One of the available scheduler names. Available: `cosine`. Returns ------- functools.partial[LRScheduler] The created scheduler. Raises ------ ValueError Raises ValueError when the requested scheduler is not in the list of implemented schedulers. """ if scheduler_name == COSINE_ANNEALING: return LRSchedulerFactory.create_cosine_annealing() if scheduler_name == NO_SCHEDULER: return LRSchedulerFactory.create_no_scheduler() msg = f"Scheduler not found. Use one of: \ {list(LRSchedulerFactory.SCHEDULERS.keys())}" raise ValueError(msg)
[docs] @staticmethod def create_no_scheduler() -> functools.partial[_LRScheduler]: """ Create a NoScheduler instance. Returns ------- functools.partial[LRScheduler] NoScheduler instance. """ return functools.partial(NoScheduler)
[docs] @staticmethod def create_cosine_annealing() -> functools.partial[_LRScheduler]: """ Create the Cosine Annealing scheduler. Returns ------- functools.partial[LRScheduler] Cosine Annealing scheduler. """ return functools.partial( CosineAnnealingLR, )