Source code for betty.hypergradient.cg
import warnings
import torch
from betty.utils import neg_with_none, to_vec
[docs]def cg(vector, curr, prev, sync):
"""
Approximate the matrix-vector multiplication with the best response Jacobian by the
(PyTorch's) default autograd method. Users may need to specify learning rate (``cg_alpha``) and
conjugate gradient descent iterations (``cg_iterations``) in ``Config``.
:param vector:
Vector with which matrix-vector multiplication with best-response Jacobian (matrix) would
be performed.
:type vector: Sequence of Tensor
:param curr: A current level problem
:type curr: Problem
:param prev: A directly lower-level problem to the current problem
:type prev: Problem
:return: (Intermediate) gradient
:rtype: Sequence of Tensor
"""
assert len(curr.paths) == 0, "cg method is not supported for higher order MLO!"
config = curr.config
in_loss = curr.training_step_exec(curr.cur_batch)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
in_grad = torch.autograd.grad(
in_loss, curr.trainable_parameters(), create_graph=True
)
x = [torch.zeros_like(vi) for vi in vector]
r = [torch.zeros_like(vi).copy_(vi) for vi in vector]
p = [torch.zeros_like(rr).copy_(rr) for rr in r]
for _ in range(config.cg_iterations):
hvp = torch.autograd.grad(
in_grad, curr.parameters(), grad_outputs=p, retain_graph=True
)
hvp_vec = to_vec(hvp, alpha=config.cg_alpha)
r_vec = to_vec(r)
p_vec = to_vec(p)
numerator = torch.dot(r_vec, r_vec)
denominator = torch.dot(hvp_vec, p_vec)
alpha = numerator / denominator
x_new = [xx + alpha * pp for xx, pp in zip(x, p)]
r_new = [rr - alpha * pp for rr, pp in zip(r, hvp)]
r_new_vec = to_vec(r_new)
beta = torch.dot(r_new_vec, r_new_vec) / numerator
p_new = [rr + beta * pp for rr, pp in zip(r, p)]
x, p, r = x_new, p_new, r_new
x = [config.cg_alpha * xx for xx in x]
if sync:
x = [neg_with_none(x_i) for x_i in x]
torch.autograd.backward(
in_grad, inputs=prev.trainable_parameters(), grad_tensors=x
)
implicit_grad = None
else:
implicit_grad = torch.autograd.grad(
in_grad, prev.trainable_parameters(), grad_outputs=x
)
implicit_grad = [neg_with_none(ig) for ig in implicit_grad]
return implicit_grad