Shortcuts

Source code for betty.optim.adamw

import math

import torch

from betty.optim.optimizer import DifferentiableOptimizerBase


[docs]class DifferentiableAdamW(DifferentiableOptimizerBase): """ Differentiable version of PyTorch's `AdamW <https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html#adamw>`_ optimizer. All in-place operations are replaced. """ def step(self, params): for param_group, param_mapping in zip(self.param_groups, self.param_mappings): amsgrad = param_group["amsgrad"] beta1, beta2 = param_group["betas"] for param_idx in param_mapping: p = params[param_idx] if p.grad is None: continue grad = p.grad p = p * (1 - param_group["lr"] * param_group["weight_decay"]) state = self.state[param_idx] state["step"] += 1 bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] state["exp_avg"] = state["exp_avg"] * beta1 + (1 - beta1) * grad state["exp_avg_sq"] = ( state["exp_avg_sq"] * beta2 + (1 - beta2) * grad * grad ) if amsgrad: state["max_exp_avg_sq"] = torch.max( state["max_exp_avg_sq"], state["exp_avg_sq"] ) denom = ( state["max_exp_avg_sq"] / math.sqrt(bias_correction2) + param_group["eps"] ) else: denom = ( state["exp_avg_sq"] / math.sqrt(bias_correction2) + param_group["eps"] ) step_size = param_group["lr"] / bias_correction1 p.update = step_size * (state["exp_avg"] / denom) new_params = tuple(p - p.update for p in params if hasattr(p, "update")) for p in params: if hasattr(p, "update"): del p.update return new_params