Shortcuts

Source code for betty.hypergradient.darts

import torch

from betty.utils import to_vec, replace_none_with_zero
from betty.hypergradient.utils import precondition


[docs]def darts(vector, curr, prev, sync): """ Approximate the matrix-vector multiplication with the best response Jacobian by the finite difference method. More specifically, we modified the finite difference method proposed in `DARTS: Differentiable Architecture Search <https://arxiv.org/pdf/1806.09055.pdf>`_ by re-interpreting it from the implicit differentiation perspective. Empirically, this method achieves better memory efficiency, training wall time, and test accuracy that other methods. :param vector: Vector with which matrix-vector multiplication with best-response Jacobian (matrix) would be performed. :type vector: Sequence of Tensor :param curr: A current level problem :type curr: Problem :param prev: A directly lower-level problem to the current problem :type prev: Problem :return: (Intermediate) gradient :rtype: Sequence of Tensor """ config = curr.config R = config.darts_alpha if curr._strategy == "fsdp": curr_flat_param = curr.module._fsdp_wrapped_module.flat_param param_len = curr_flat_param.numel() - curr_flat_param._shard_numel_padded offset = curr._rank * param_len vector = [vector[0][offset : offset + param_len]] if config.darts_preconditioned and curr._strategy not in ["zero", "fsdp"]: vector = precondition(vector, curr) eps = R / to_vec(vector).norm().add_(1e-12).item() for p, v in zip(curr.trainable_parameters(), vector): p.data.add_(v.data, alpha=eps) loss_p = curr.training_step_exec(curr.cur_batch) grad_p = torch.autograd.grad(loss_p, prev.trainable_parameters(), allow_unused=True) grad_p = replace_none_with_zero(grad_p, prev.trainable_parameters()) if sync: grad_p = [-g_p.div_(2 * eps) for g_p in grad_p] if prev._strategy == "fsdp": prev_flat_param = prev.module._fsdp_wrapped_module.flat_param offset = prev._rank * prev_flat_param.numel() valid_len = prev_flat_param.numel() - prev_flat_param._shard_numel_padded prev_grad_shard = grad_p[0].narrow(0, offset, valid_len) new_grad_p = torch.zeros_like(prev.trainable_parameters()[0]) new_grad_p[:valid_len] = prev_grad_shard grad_p = [new_grad_p] prev.set_grads(prev.trainable_parameters(), grad_p) # negative for p, v in zip(curr.trainable_parameters(), vector): p.data.add_(v.data, alpha=-2 * eps) loss_n = curr.training_step_exec(curr.cur_batch) if sync: torch.autograd.backward(loss_n / (2 * eps), inputs=prev.trainable_parameters()) else: grad_n = torch.autograd.grad( loss_n, prev.trainable_parameters(), allow_unused=True ) grad_n = replace_none_with_zero(grad_n, prev.trainable_parameters()) # reverse weight change for p, v in zip(curr.trainable_parameters(), vector): p.data.add(v.data, alpha=eps) implicit_grad = None if not sync: implicit_grad = [(x - y).div_(2 * eps) for x, y in zip(grad_n, grad_p)] return implicit_grad