Source code for betty.optim.sgd
from betty.optim.optimizer import DifferentiableOptimizerBase
[docs]class DifferentiableSGD(DifferentiableOptimizerBase):
"""
Differentiable version of PyTorch's
`SGD <https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#sgd>`_ optimizer.
All in-place operations are replaced.
"""
def step(self, params):
for param_group, param_mapping in zip(self.param_groups, self.param_mappings):
weight_decay = param_group["weight_decay"]
momentum = param_group["momentum"]
dampening = param_group["dampening"]
nesterov = param_group["nesterov"]
for param_idx in param_mapping:
p = params[param_idx]
if p.grad is None:
continue
grad = p.grad
if weight_decay != 0:
grad = grad + weight_decay * p
param_state = self.state[param_idx]
if (
"momentum_buffer" not in param_state
or param_state["momentum_buffer"] is None
):
buf = param_state["momentum_buffer"] = grad
else:
buf = param_state["momentum_buffer"]
buf = momentum * buf + (1 - dampening) * grad
param_state["momentum_buffer"] = buf
if nesterov:
grad = grad + momentum * buf
else:
grad = buf
p.update = param_group["lr"] * grad
new_params = tuple(p - p.update for p in params if hasattr(p, "update"))
for p in params:
if hasattr(p, "update"):
del p.update
return new_params