betty.optim¶
Iterative differentiation (ITD) differentiates through the optimization path, requiring
tracking of the intermediate states of the parameter during optimization. However, native
PyTorch optimizer overrides such intermediate states with the use of in-place operations for the
good purpose of saving memory. We here provide the functionality of patching PyTorch’s native
optimizers by substituting all involved in-place operations to allow ITD. When users pass
their optimizer to IterativeProblem
class through the constructor, initialize
method of
IterativeProblem
will automatically call patch_optimizer
and patch_scheduler
to patch
user-proviced PyTorch native optimizer/scheduler along with stateful modules.
- betty.optim.patch_optimizer(optimizer, module)[source]¶
Patch PyTorch’s native optimizer by replacing all in-place operations in its
step
method- Parameters:
optimizer (torch.nn.optim.Optimizer) – User-provided PyTorch’s native optimizer
module (torch.nn.Module) – User-provided PyTorch module that is being optimized by the
optimizer
- Returns:
Corresponding differentiable optimizer
- Return type:
DifferentiableOptimizer
- betty.optim.patch_scheduler(old_scheduler, new_optimizer)[source]¶
Patch the original learning rate scheduler to work with a patched differentiable optimizer.
- Parameters:
old_scheduler (torch.optim.lr_scheduler) – User-provided PyTorch’s learning rate scheduler
new_optimizer (DifferentiableOptimizer) – Patched differentiable optimizer
- Returns:
Patched learning rate scheduler
- Return type:
Supported optimizers¶
Below is the list of differentiable optimizers supported by Betty.
- class betty.optim.sgd.DifferentiableSGD(optimizer, module)[source]¶
Differentiable version of PyTorch’s SGD optimizer. All in-place operations are replaced.