Shortcuts

betty.problems

Problem

class betty.problems.problem.Problem(name, config=None, module=None, optimizer=None, scheduler=None, train_data_loader=None, extra_config=None)[source]

This is the base class for an optimization problem in multilevel optimization. Specifically, each problem is defined by the parameter (or module), the sets of the upper and lower constraining problems, the dataset, the loss function, the optimizer, and other optimization configurations (e.g. best-response Jacobian calculation algorithm, number of unrolling steps, etc.).

add_child(problem)[source]

Add problem to the lower-level problem list.

Parameters:

problem (Problem) – lower-level problem in the dependency graph

add_env(env)[source]

Add environment to the current problem.

Parameters:

env – Environment.

add_logger(logger)[source]

Add logger to the current problem.

Parameters:

logger – logger defined by users in Engine.

add_parent(problem)[source]

Add problem to the upper-level problem list.

Parameters:

problem (Problem) – upper-level problem in the dependency graph

add_paths(paths)[source]

Add new hypergradient backpropagation paths.

backward(loss, params, paths, create_graph=False, retain_graph=True, allow_unused=True)[source]

Calculate the gradient of loss with respect to params based on a user-defined config.

Parameters:
  • loss (Tensor) – Outputs of the differentiated function.

  • params (Sequence of Tensor) – Inputs with respect to which the gradient will be returned.

  • paths (List of list of Problem) – Paths on which the gradient will be calculated.

  • create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default: True.

  • retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

  • allow_unused (bool, optional) – If False, specifying inputs that were not used when computing outputs (and therefore their grad is always zero) is an error. Defaults to False.

abstract cache_states()[source]

Cache params, buffers, optimizer states when config.roll_back is set to True in step.

check_ready()[source]

Check if unrolling processes of lower level problems in the hierarchical dependency graph are all ready/done. step function is only excuted when this method returns True.

Return type:

bool

property children

Return lower-level problems for the current problem.

clear_dependencies()[source]

Clear the dependencies of the current problem.

clip_grad()[source]

Perform gradient clipping based on the norm provided by Config

property config

Return the configuration for the current problem.

configure_device(device)[source]

Set the device for the current problem.

configure_distributed_training(dictionary)[source]

Set the configuration for distributed training.

Parameters:

dictionary (dict) – Python dictionary of distributed training provided by Engine.

configure_roll_back(roll_back)[source]

Set the roll-back (warm- start) option from Engine

Parameters:

roll_back (bool) – roll-back (warm-start) on/off

property count

Return the local step for the current problem.

Returns:

local step

Return type:

int

epoch_callback_exec()[source]
eval()[source]

Set the current problem to the evaluation mode.

forward(*args, **kwargs)[source]

Users define how forward (or call) function is defined for the problem here.

get_batch()[source]

Load training batch from the user-provided data loader

Returns:

New training batch

Return type:

Any

get_batch_single_loader(idx)[source]

Load training batch from one of the user-provided data loader(s)

Returns:

New training batch

Return type:

Any

get_loss(batch)[source]

Calculate loss and log metrics for the current batch based on the user-defined loss function.

Returns:

loss and log metrics (e.g. classification accuracy)

Return type:

dict

get_opt_param_group_for_param(param)[source]

Get optimizer param_group for specific parameter

Parameters:

param (torch.nn.Parameter) – Parameter for which optimizer param_group is inquired

Returns:

param_group for the given parameter

Return type:

dict

get_opt_state_for_param(param)[source]

Get optimizer state for specific parameter

Parameters:

param (torch.nn.Parameter) – Parameter for which optimizer state is inquired

Returns:

optimizer state for the given parameter

Return type:

dict

gradient_accumulation_boundary()[source]

Check whether the current step is on the gradient accumulation boundary

initialize()[source]

initialize patches/sets up module, optimizer, data loader, etc. after compiling a user-provided configuration (e.g., fp16 training, iterative differentiation)

is_implemented(fn_name)[source]

Check if fn_name method is implemented in the class

Return type:

bool

is_rank_zero()[source]

Check whether the current device is rank 0.

property leaf

Return whether the current problem is leaf or not.

Returns:

leaf

Return type:

bool

load_state_dict(state_dict)[source]

Load the state for the Problem

Args:

state_dict (dict): Python dictionary of Problem states.

log(stats, global_step)[source]

Log (training) stats to the self.logger

Parameters:
  • stats (Any) – log metrics such as loss and classification accuracy.

  • step (int) – global/local step associated with the stats.

property name

[summary] Return the user-defined name of the module.

one_step_descent(batch=None)[source]
abstract optimizer_step(*args, **kwargs)[source]

Update weights as in PyTorch’s native optim.step()

abstract parameters()[source]

Return all parameters for the current problem.

property parents

Return upper-level problems for the current problem.

patch_data_loader(loader)[source]

Patch data loader given the systems configuration (e.g., DDP, FSDP)

patch_everything()[source]

We patch module, optimizer, data loader, and lr scheduler for device placement, distributed training, zero optimizer, fsdp, etc.

patch_module()[source]

Patch module given the systems configuration (e.g., DDP, FSDP)

patch_optimizer()[source]

Patch optimizer given the systems configuration (e.g., DDP, FSDP)

patch_scheduler()[source]

Patch scheduler given the systems configuration (e.g., DDP, FSDP)

property paths

Return hypergradient calculation paths for the current problem.

abstract recover_states()[source]

Recover params, buffers, optimizer states when config.roll_back is set to True in step.

set_grads(params, grads)[source]

Set gradients for trainable parameters. params.grad = grads

Parameters:
  • params (Sequence of Tensor) – Trainable parameters

  • grads (Sequence of Tensor) – Calculated gradient

set_module(module)[source]

Set new module for the current Problem class.

set_optimizer(optimizer)[source]

Set new optimizer for the current Problem class.

set_scheduler(scheduler)[source]

Set new scheduler for the current Problem class.

set_train_data_loader(loader, idx=0)[source]

Set new data loader for the current Problem class.

state_dict()[source]

Return all states involved in Problem with a Python dictionary. By default, it includes self.module.state_dict and self.optimizer.state_dict. Depending on users’ configurations, it may include self.scheuler.state_dict (lr scheduler) and self.scaler.state_dict (fp16 training)

step(global_step=None)[source]

step method abstracts a one-step gradient descent update with four sub-steps: 1) data loading, 2) cost calculation, 3) gradient calculation, and 4) parameter update. It also calls upper-level problems’ step methods after unrolling gradient steps based on the hierarchical dependency graph.

Parameters:

global_step (int, optional) – global step of the whole multilevel optimization. Defaults to None.

step_after_roll_back()[source]
step_normal(global_step=None)[source]
synchronize_params(params)[source]

synchronize parameters across distributed data-parallel processes

train()[source]

Set the current problem to the training mode.

abstract trainable_parameters()[source]

Define all trainable parameters for the current problem.

abstract training_step(batch)[source]

Users define the loss function of the problem here.

training_step_exec(batch)[source]
zero_grad()[source]

Set gradients for trainable parameters for the current problem to 0. Similar with PyTorch’s optim.zero_grad() or module.zero_grad().

Implicit Problem

Implicit problem is used when best-response Jacobian for the current problem is calculated with (approximate) implicit differentiation (AID). AID doesn’t require patching module/optimizer, and usually achieve memory & compute efficiency especially when the large unrolling step is used. We recommend users to use ImplicitProblem as a default class to define problems in MLO.

class betty.problems.implicit_problem.ImplicitProblem(name, config, module=None, optimizer=None, scheduler=None, train_data_loader=None, extra_config=None)[source]

ImplicitProblem is sublassed from Problem.

Iterative Problem

Iterative Problem is used when best-response Jacobian for the current problem is calculated with iterative differentiation (ITD). ITD requires patching module/optimizer to track intermediate parameter states for the gradient flow. We discourage users to use this class, because memory/computation efficiency of ITD is oftentimes worse than AID. In addition, users may need to be familiar with functional programming style due to the use of stateless modules.

class betty.problems.iterative_problem.IterativeProblem(name, config, module=None, optimizer=None, scheduler=None, train_data_loader=None, extra_config=None)[source]

IterativeProblem is sublassed from Problem.

patch_modules()[source]

Patch PyTorch’s native stateful module into the stateless module so as to support functional forward that takes params as its input.

patch_optimizer()[source]

Patch PyTorch’s native optimizer by replacing all involved in-place operations to allow gradient flow through the parameter update process.

patch_scheduler()[source]

Patch scheduler to be compatible with the patched optimizer