betty.problems¶
Problem¶
- class betty.problems.problem.Problem(name, config=None, module=None, optimizer=None, scheduler=None, train_data_loader=None, extra_config=None)[source]¶
This is the base class for an optimization problem in multilevel optimization. Specifically, each problem is defined by the parameter (or module), the sets of the upper and lower constraining problems, the dataset, the loss function, the optimizer, and other optimization configurations (e.g. best-response Jacobian calculation algorithm, number of unrolling steps, etc.).
- add_child(problem)[source]¶
Add
problem
to the lower-level problem list.- Parameters:
problem (Problem) – lower-level problem in the dependency graph
- add_logger(logger)[source]¶
Add logger to the current problem.
- Parameters:
logger – logger defined by users in
Engine
.
- add_parent(problem)[source]¶
Add
problem
to the upper-level problem list.- Parameters:
problem (Problem) – upper-level problem in the dependency graph
- backward(loss, params, paths, create_graph=False, retain_graph=True, allow_unused=True)[source]¶
Calculate the gradient of
loss
with respect toparams
based on a user-definedconfig
.- Parameters:
loss (Tensor) – Outputs of the differentiated function.
params (Sequence of Tensor) – Inputs with respect to which the gradient will be returned.
paths (List of list of Problem) – Paths on which the gradient will be calculated.
create_graph (bool, optional) – If
True
, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default:True
.retain_graph (bool, optional) – If
False
, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option toTrue
is not needed and often can be worked around in a much more efficient way. Defaults to the value ofcreate_graph
.allow_unused (bool, optional) – If
False
, specifying inputs that were not used when computing outputs (and therefore their grad is always zero) is an error. Defaults toFalse
.
- abstract cache_states()[source]¶
Cache params, buffers, optimizer states when
config.roll_back
is set toTrue
instep
.
- check_ready()[source]¶
Check if unrolling processes of lower level problems in the hierarchical dependency graph are all ready/done.
step
function is only excuted when this method returnsTrue
.- Return type:
bool
- property children¶
Return lower-level problems for the current problem.
- property config¶
Return the configuration for the current problem.
- configure_distributed_training(dictionary)[source]¶
Set the configuration for distributed training.
- Parameters:
dictionary (dict) – Python dictionary of distributed training provided by Engine.
- configure_roll_back(roll_back)[source]¶
Set the roll-back (warm- start) option from Engine
- Parameters:
roll_back (bool) – roll-back (warm-start) on/off
- property count¶
Return the local step for the current problem.
- Returns:
local step
- Return type:
int
- forward(*args, **kwargs)[source]¶
Users define how forward (or call) function is defined for the problem here.
- get_batch()[source]¶
Load training batch from the user-provided data loader
- Returns:
New training batch
- Return type:
Any
- get_batch_single_loader(idx)[source]¶
Load training batch from one of the user-provided data loader(s)
- Returns:
New training batch
- Return type:
Any
- get_loss(batch)[source]¶
Calculate loss and log metrics for the current batch based on the user-defined loss function.
- Returns:
loss and log metrics (e.g. classification accuracy)
- Return type:
dict
- get_opt_param_group_for_param(param)[source]¶
Get optimizer param_group for specific parameter
- Parameters:
param (torch.nn.Parameter) – Parameter for which optimizer param_group is inquired
- Returns:
param_group for the given parameter
- Return type:
dict
- get_opt_state_for_param(param)[source]¶
Get optimizer state for specific parameter
- Parameters:
param (torch.nn.Parameter) – Parameter for which optimizer state is inquired
- Returns:
optimizer state for the given parameter
- Return type:
dict
- gradient_accumulation_boundary()[source]¶
Check whether the current step is on the gradient accumulation boundary
- initialize()[source]¶
initialize
patches/sets up module, optimizer, data loader, etc. after compiling a user-provided configuration (e.g., fp16 training, iterative differentiation)
- is_implemented(fn_name)[source]¶
Check if
fn_name
method is implemented in the class- Return type:
bool
- property leaf¶
Return whether the current problem is leaf or not.
- Returns:
leaf
- Return type:
bool
- load_state_dict(state_dict)[source]¶
Load the state for the
Problem
- Args:
state_dict (dict): Python dictionary of Problem states.
- log(stats, global_step)[source]¶
Log (training) stats to the
self.logger
- Parameters:
stats (Any) – log metrics such as loss and classification accuracy.
step (int) – global/local step associated with the
stats
.
- property name¶
[summary] Return the user-defined name of the module.
- abstract optimizer_step(*args, **kwargs)[source]¶
Update weights as in PyTorch’s native
optim.step()
- property parents¶
Return upper-level problems for the current problem.
- patch_data_loader(loader)[source]¶
Patch data loader given the systems configuration (e.g., DDP, FSDP)
- patch_everything()[source]¶
We patch module, optimizer, data loader, and lr scheduler for device placement, distributed training, zero optimizer, fsdp, etc.
- property paths¶
Return hypergradient calculation paths for the current problem.
- abstract recover_states()[source]¶
Recover params, buffers, optimizer states when
config.roll_back
is set toTrue
instep
.
- set_grads(params, grads)[source]¶
Set gradients for trainable parameters.
params.grad = grads
- Parameters:
params (Sequence of Tensor) – Trainable parameters
grads (Sequence of Tensor) – Calculated gradient
- state_dict()[source]¶
Return all states involved in
Problem
with a Python dictionary. By default, it includesself.module.state_dict
andself.optimizer.state_dict
. Depending on users’ configurations, it may includeself.scheuler.state_dict
(lr scheduler) andself.scaler.state_dict
(fp16 training)
- step(global_step=None)[source]¶
step
method abstracts a one-step gradient descent update with four sub-steps: 1) data loading, 2) cost calculation, 3) gradient calculation, and 4) parameter update. It also calls upper-level problems’ step methods after unrolling gradient steps based on the hierarchical dependency graph.- Parameters:
global_step (int, optional) – global step of the whole multilevel optimization. Defaults to None.
Implicit Problem¶
Implicit problem is used when best-response Jacobian for the current problem is calculated with
(approximate) implicit differentiation (AID). AID doesn’t require patching module/optimizer, and
usually achieve memory & compute efficiency especially when the large unrolling step is used. We
recommend users to use ImplicitProblem
as a default class to define problems in MLO.
Iterative Problem¶
Iterative Problem is used when best-response Jacobian for the current problem is calculated with iterative differentiation (ITD). ITD requires patching module/optimizer to track intermediate parameter states for the gradient flow. We discourage users to use this class, because memory/computation efficiency of ITD is oftentimes worse than AID. In addition, users may need to be familiar with functional programming style due to the use of stateless modules.
- class betty.problems.iterative_problem.IterativeProblem(name, config, module=None, optimizer=None, scheduler=None, train_data_loader=None, extra_config=None)[source]¶
IterativeProblem
is sublassed fromProblem
.- patch_modules()[source]¶
Patch PyTorch’s native stateful module into the stateless module so as to support functional forward that takes params as its input.