Implicit Model Agnostic Meta-Learning¶
Here we re-implement the model-agnostic meta-learning algorithm from
Meta-Learning with Implicit Gradients,
where we learn the initialization weight for convolutional neural networks (CNNs) that allows
quick adaptations to various tasks given only few samples (i.e. few-shot learning).
In this post, we assume that the potential readers are already familiar with the algorithm,
and therefore mainly focus on how to implement MAML with Betty. Also, we note that,
while we focus on the implicit MAML instead of MAML in this blog post, the
modification for MAML is highly straightforward just by replacing ImplicitProblem
with IterativeProblem
. Finally, the full version of our code can be found
here.
Basics¶
MAML can be interpreted as a bilevel optimization problem, where the upper level learns
the initiazliation weight that allows quick adaptation to new tasks and
the lower level learns to adapt to the task given few training examples. Therefore,
users need to define two Problem
classes for both levels. Following Betty’s design
principle, each level can be defined by providing:
Module
Optimizer
Data loader (or data loading function)
Loss function
Problem
configurationName
(Optional) learning rate scheduler, …
While data loading is decoupled in most other bi-level optimization problems, in MAML, it’s coupled as the adaptation result from the lower level should be the tested on the same task in the upper level. We will dive into this in this post.
Environment¶
As stated above, data loading is coupled for upper and lower levels in MAML. This requires
the user to implement a unified data loading mechanism. Unfortunately, this is hard to achieve
(or the code will get ugly) with the Betty’s Problem
class design where users provide
data loader separately for each Problem
class through the class constructor.
To enable the clean implementation for such cases (where data loading is entangled for
multiple Problem``s), Betty provides the ``Env
class where users can specify a unified
data loading mechanism.
More specifically, we use learn2learn
to load MAML dataset loading, and define the data loading mechanism in the step
method.
The code is shown below.
import learn2learn as l2l
from betty.envs import Env
tasksets = l2l.vision.benchmarks.get_tasksets(
args.task,
train_ways=args.ways,
train_samples=2 * args.shots,
test_ways=args.ways,
test_samples=2 * args.shots,
num_tasks=args.task_num,
root="./data",
)
def split_data(data, labels, shots, ways):
out = {"train": None, "test": None}
adapt_indices = np.zeros(data.size(0), dtype=bool)
adapt_indices[np.arange(shots * ways) * 2] = True
eval_indices = torch.from_numpy(~adapt_indices)
adapt_indices = torch.from_numpy(adapt_indices)
out["train"] = (data[adapt_indices], labels[adapt_indices])
out["test"] = (data[eval_indices], labels[eval_indices])
return out
class MAMLEnv(Env):
def __init__(self):
super().__init__()
self.tasks = tasksets
self.batch = {"train": None, "test": None}
def step(self):
data, labels = self.tasks.train.sample()
data, labels = data.to(self.device), labels.to(self.device)
out = split_data(data, labels, args.shots, args.ways)
self.batch["train"] = out["train"]
self.batch["test"] = out["test"]
env = MAMLEnv()
Upper-Level¶
Betty allows users to define a custom data loading mechanism through get_batch
method.
Access to Env
will be granted by Engine
, and users can simply use self.env
.
Also, since the upper problem is updated only after the meta_batch_size
amount of
gradient accumulation, users need to specify this thorugh gradient_accumulation
attribute in Config
. The rest part (e.g., defining loss function, module, optimizer,
etc.) is relatively straightforward, and we direct readers who are unfamiliar with this
to our Tutorial.
class Upper(ImplicitProblem):
def training_step(self, batch):
inputs, labels = batch
out = self.lower(inputs)
loss = F.cross_entropy(out, labels)
acc = 100.0 * (out.argmax(dim=1) == labels).float().mean().item()
return {"loss": loss, "acc": acc}
def get_batch(self):
inputs, labels = self.env.batch["test"]
return inputs, labels
parent_module = ConvNet(args.ways, args.hidden_size)
parent_optimizer = optim.AdamW(parent_module.parameters(), lr=3e-4)
parent_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
parent_optimizer,
T_max=int(args.meta_batch_size * 7500),
)
parent_config = Config(
log_step=int(args.inner_steps * args.meta_batch_size * 10),
retain_graph=True,
gradient_accumulation=args.meta_batch_size,
)
parent = Upper(
name="upper",
module=parent_module,
optimizer=parent_optimizer,
scheduler=parent_scheduler,
config=parent_config,
)
Lower-Level¶
As in the upper level, we also define the custom data loading mechanism through
the get_batch
method. In addition, we need to initialize the weight of inner CNNs
to that of outer CNNs in the beginning of the inner loop. We offer such functionality
via on_inner_loop_start
method.
def reg_loss(parameters, reference_parameters, reg_lambda=0.25):
loss = 0
for p1, p2 in zip(parameters, reference_parameters):
loss += torch.sum(torch.pow((p1 - p2), 2))
return reg_lambda * loss
class Lower(ImplicitProblem):
def training_step(self, batch):
inputs, labels = batch
out = self.module(inputs)
loss = F.cross_entropy(out, labels)
reg = reg_loss(self.parameters(), self.upper.parameters(), args.reg)
return loss + reg
def get_batch(self):
inputs, labels = self.env.batch["train"]
return inputs, labels
def on_inner_loop_start(self):
self.module.load_state_dict(self.upper.module.state_dict())
child_module = model_cls(args.ways, args.hidden_size)
child_optimizer = optim.SGD(child_module.parameters(), lr=1e-1)
child_config = Config(type="darts", unroll_steps=args.inner_steps)
lower = Lower(
name="lower",
module=child_module,
optimizer=child_optimizer,
config=child_config
)
Engine¶
As illustruated in our
Tutorial,
the overall execution of MLO is handled by Engine
. Since we handle data loading in
Env
’s step
method, we have to (1) provide MAMLEnv
to the Engine
and (2)
coordinate the execution order of env.step
with other problem.step
. For the
first part, users can simply provide users’ custom Env
via the Engine
class
constructor. Coordinating execution order of Env
and Problem
can be achieved in
the train_step
method in Engine
. The below code shows how to do these.
class MAMLEngine(Engine):
def train_step(self):
if self.global_step % args.inner_steps == 1 or args.inner_steps == 1:
self.env.step()
for leaf in self.leaves:
leaf.step(global_step=self.global_step)
def validation(self):
self.upper.module.train()
if not hasattr(self, "best_acc"):
self.best_acc = -1
test_net = ConvNet(args.ways, args.hidden_size).to(self.device)
test_optim = optim.SGD(test_net.parameters(), lr=0.1)
accs = []
for i in range(500):
inputs, labels = tasksets.test.sample()
inputs, labels = inputs.to(self.device), labels.to(self.device)
out = split_data(inputs, labels, args.shots, args.ways)
train_inputs, train_labels = out["train"]
test_inputs, test_labels = out["test"]
test_net.load_state_dict(self.upper.module.state_dict())
for _ in range(args.inner_steps):
out = test_net(train_inputs)
loss = F.cross_entropy(out, train_labels)
test_optim.zero_grad()
loss.backward()
test_optim.step()
out = test_net(test_inputs)
accs.append((out.argmax(dim=1) == test_labels).detach())
acc = 100.0 * torch.cat(accs).float().mean().item()
if acc > self.best_acc:
self.best_acc = acc
return {"acc": acc, "best_acc": self.best_acc}
u2l = {outer: [inner]}
l2u = {inner: [outer]}
dependencies = {"u2l": u2l, "l2u": l2u}
engine = MAMLEngine(
config=engine_config, problems=problems, dependencies=dependencies, env=env
)
engine.run()
Finally, users can also define the validation mechanism via the validation
method,
and execute MAML training with engine.run()
.
Overall, throughout this tutorial, we tried to describe how to handle a unified/entangled
data loading mechanism for multiple Problem
classes via Env
. Such use of Env
can also be useful for implementing reinforcement learning related algorithms.