Shortcuts

Source code for betty.optim

import inspect

import torch

from .sgd import DifferentiableSGD
from .adam import DifferentiableAdam
from .adamw import DifferentiableAdamW

optimizer_mapping = {
    torch.optim.SGD: DifferentiableSGD,
    torch.optim.Adam: DifferentiableAdam,
    torch.optim.AdamW: DifferentiableAdamW,
}


[docs]def patch_optimizer(optimizer, module): """Patch PyTorch's native optimizer by replacing all in-place operations in its ``step`` method :param optimizer: User-provided PyTorch's native optimizer :type optimizer: `torch.nn.optim.Optimizer <https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer>`_ :param module: User-provided PyTorch module that is being optimized by the ``optimizer`` :type module: `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=module#torch.nn.Module>`_ :return: Corresponding differentiable optimizer :rtype: DifferentiableOptimizer """ assert type(optimizer) in optimizer_mapping return optimizer_mapping[type(optimizer)](optimizer, module)
[docs]def patch_scheduler(old_scheduler, new_optimizer): """Patch the original learning rate scheduler to work with a patched differentiable optimizer. :param old_scheduler: User-provided PyTorch's learning rate scheduler :type old_scheduler: `torch.optim.lr_scheduler <https://pytorch.org/docs/stable/optim.html?highlight=lr_scheduler>`_ :param new_optimizer: Patched differentiable optimizer :type new_optimizer: DifferentiableOptimizer :return: Patched learning rate scheduler :rtype: `torch.optim.lr_scheduler <https://pytorch.org/docs/stable/optim.html?highlight=lr_scheduler>`_ """ kwargs = {} sig = inspect.signature(old_scheduler.__class__.__init__) for param in sig.parameters: key = param if key == "self": continue elif key == "optimizer": kwargs[key] = new_optimizer elif key == "last_epoch": kwargs[key] = getattr(old_scheduler, key) - 1 else: value = getattr(old_scheduler, key) kwargs[key] = value new_scheduler = old_scheduler.__class__(**kwargs) return new_scheduler