Quick Start¶
Throughout our tutorials, we will use Data Reweighting for Long-Tailed Image Classification as our running example. The basic context is that we aim to mitigate a class imbalance problem (or long-tailed distribution problem) by re-assigning higher/lower weights to data from rare/common classes. In particular, Meta-Weight-Net (MWN) formulates data reweighting as bilevel optimization as follows:
where \(\theta\) denotes the classifier network’s parameters, \(L_{class}^i\) is the classification loss (cross-entropy) for the \(i\)-th training sample, \(\mathcal{L}_{reweight}\) is the loss for the reweighting level (cross-entropy), and \(w\) denotes the parameters for MWN \(\mathcal{R}\), which reweights each training sample given training loss \(L^i_{class}\).
Now that we have a problem formulation, we need to (1) define each level problem with
the Problem
class, and (2) define dependencies between problems with the Engine
class.
Basic setup¶
Before diving into MLO, we do basic setup such as importing dependencies and constructing the imbalanced (or long-tailed) dataset. Here, we set the data imbalance factor to 100, meaning that the most common class has 50 times more data than the least common class. This part is not directly relevant to MLO, so users can simply copy and paste the following code.
Preparation code
# import dependencies
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
device = "cuda" if torch.cuda.is_available() else "cpu"
# Construct imbalanced (or long-tailed) dataset
def build_dataset(reweight_size=1000, imbalanced_factor=100):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset = MNIST(root="./data", train=True, download=True, transform=transform)
num_classes = len(dataset.classes)
num_meta = int(reweight_size / num_classes)
index_to_meta = []
index_to_train = []
imbalanced_num_list = []
sample_num = int((len(dataset.targets) - reweight_size) / num_classes)
for class_index in range(num_classes):
imbalanced_num = sample_num / (imbalanced_factor ** (class_index / (num_classes - 1)))
imbalanced_num_list.append(int(imbalanced_num))
np.random.shuffle(imbalanced_num_list)
for class_index in range(num_classes):
index_to_class = [
index for index, label in enumerate(dataset.targets) if label == class_index
]
np.random.shuffle(index_to_class)
index_to_meta.extend(index_to_class[:num_meta])
index_to_class_for_train = index_to_class[num_meta:]
index_to_class_for_train = index_to_class_for_train[: imbalanced_num_list[class_index]]
index_to_train.extend(index_to_class_for_train)
reweight_dataset = copy.deepcopy(dataset)
dataset.data = dataset.data[index_to_train]
dataset.targets = list(np.array(dataset.targets)[index_to_train])
reweight_dataset.data = reweight_dataset.data[index_to_meta]
reweight_dataset.targets = list(np.array(reweight_dataset.targets)[index_to_meta])
return dataset, reweight_dataset
classifier_dataset, reweight_dataset = build_dataset(imbalanced_factor=100)
Problem¶
In this example, we have a MLO program consisting of two problem levels: upper and
lower. We respectively refer to these two problems as Reweight and Classifier,
and create Problem
classes for each of them. As introduced in the Software
Design chapter, each problem is defined by (1)
module, (2) optimizer, (3) data loader, (4) loss function, (5) training configuration,
and (6) other optional components (e.g. learning rate scheduler). Everything except for
(4) loss function can be provided through the class constructor, and (4) can be provided
via the training_step
method. In the following subsections, we provide a
step-by-step guide for implementing each of these components in the Problem
class,
for both the lower-level and upper-level problems.
Lower-level Problem (Classifier)¶
In our data reweighting example, the lower-level problem corresponds to the long-tailed MNIST image classification task. The data loader code is adopted from here. We can respectively define the module, optimizer, data loader, loss function, and training configuration as follows.
Module, Optimizer, Data Loader, (optional) Scheduler
# Module
classifier_module = nn.Sequential(
nn.Flatten(), nn.Linear(784, 200), nn.ReLU(), nn.Linear(200, 10)
)
# Optimizer
classifier_optimizer = optim.SGD(classifier_module.parameters(), lr=0.1, momentum=0.9)
# Data Loader
classifier_dataloader = DataLoader(
classifier_dataset, batch_size=100, shuffle=True, pin_memory=True
)
# LR Scheduler
classifier_scheduler = optim.lr_scheduler.MultiStepLR(
classifier_optimizer, milestones=[1500, 2500], gamma=0.1
)
Loss Function
Unlike other components, the loss function should be directly implemented in the
Problem
class via the training_step
method.
from betty.problems import ImplicitProblem
class Classifier(ImplicitProblem):
def training_step(self, batch):
inputs, labels = batch
outputs = self.forward(inputs)
loss_vector = F.cross_entropy(outputs, labels.long(), reduction="none")
# Reweight
loss_vector_reshape = torch.reshape(loss_vector, (-1, 1))
weight = self.reweight(loss_vector_reshape.detach())
loss = torch.mean(weight * loss_vector_reshape)
return loss
In this example, we aim to overcome a long-tailed distribution by reweighting each data
sample (e.g. increasing weights for data from rare classes while decreasing weights for
data from common classes). This is achieved by interacting with the upper-level
Reweight problem. The Engine class will provide an access to the Reweight
problem via its name for the Classifier problem (i.e. in the line weight =
self.reweight(loss_vector_reshape.detach())
). Thus, users should be aware of names of
other problems, with which the current problem interacts, when writing the loss
function.
Training Configuration
The Reweight parameter affects optimization of the Classifier parameter,
which will again affect the Reweight loss function. Thus, best-response Jacobian
for the optimization process of Classifier problem should be calculated. In this
tutorial, we adopt implicit differentiation with finite difference (a.k.a. DARTS)
as a best-response Jacobian calculation algorithm. Furthermore, since Classifier
is the lower-level problem, we need to specify how many steps we want to unroll
before updating the upper-level Reweight problem. We choose the simplest
one-step unrolling for our example. All of these can be easily specified with
Config
.
from betty.configs import Config
classifier_config = Config(type='darts', unroll_steps=1)
Problem Instantiation
Now that we have all the components to define the Classifier problem, we can
instantiate the Problem
class. We use ‘classifier’ as the name
for this
problem.
classifier = Classifier(
name='classifier',
module=classifier_module,
optimizer=classifier_optimizer,
scheduler=classifier_scheduler,
train_data_loader=classifier_dataloader,
config=classifier_config,
device=device
)
Upper-level Problem (Reweight)¶
While the lower-level problem is a classification problem, the upper-level problem is a reweighting problem. Specifically, Meta-Weight-Net (MWN) proposes to reweight each data sample using an MLP with a single hidden layer, which takes a loss value as an input and outputs an importance weight.
Module, Optimizer, Data Loader
# Module
reweight_module = nn.Sequential(
nn.Linear(1, 100), nn.ReLU(), nn.Linear(100, 1), nn.Sigmoid()
)
# Optimizer
reweight_optimizer = optim.Adam(reweight_module.parameters(), lr=1e-5)
# Data Loader
reweight_dataloader = DataLoader(
reweight_dataset, batch_size=100, shuffle=True, pin_memory=True
)
Loss Function
The upper-level reweight problem aims to optimize the loss value on the balanced
validation dataset (i.e. reweight_dataloader
) with respect to the optimal
parameters of the Classifier problem. As before, users can access the inner-level
classifier problem via its name (i.e. self.classifier
).
class Reweight(ImplicitProblem):
def training_step(self, batch):
inputs, labels = batch
outputs = self.classifier(inputs)
loss = F.cross_entropy(outputs, labels.long())
print('Reweight Loss:', loss.item())
return loss
Training Configuration
Since the Reweight problem is the uppermost problem, there is no need for calculating best-response Jacobian. Thus, we don’t need to specify any training configurations for the Reweight problem.
reweight_config = Config()
Problem Instantiation
We can now instantiate the Problem
class for the Reweight problem. We use
‘reweight’ as the name
for this problem.
reweight = Reweight(
name='reweight',
module=reweight_module,
optimizer=reweight_optimizer,
train_data_loader=reweight_dataloader,
config=reweight_config,
device=device
)
Engine¶
Recalling the Software Design chapter,
the Engine
class handles problem dependencies and execution of multilevel
optimization. Let’s again take a step-by-step dive into each of these components.
Problem Dependencies
The dependency between problems are split into two categories — upper-to-lower (u2l
)
and lower-to-upper(l2u
) — both of which are defined using a Python dictionary. In
our example, reweight
is the upper-level problem and classifier
is the
lower-level problem.
u2l = {reweight: [classifier]}
l2u = {classifier: [reweight]}
dependencies = {'l2u': l2u, 'u2l': u2l}
Engine Instantiation
To instantiate the Engine
class, we need to provide all involved problems as well as
the Engine configuration. Since we already defined all problems, we can simply combine
them in a Python list. In addition, we perform our multilevel optimization for 3,000
iterations, which can be specified in EngineConfig
.
from betty.configs import EngineConfig
from betty.engine import Engine
problems = [reweight, classifier]
engine_config = EngineConfig(train_iters=3000)
engine = Engine(config=engine_config, problems=problems, dependencies=dependencies)
Execution of Multilevel Optimization
Finally, multilevel optimization can be excuted by running engine.run()
, which calls
the step
method of the lowermost problem (i.e. Classifier), which corresponds to a
single step of gradient descent. After unrolling gradient descent for the lower-most
problem for a pre-determined number of steps (unroll_steps
attribute in
classifier_config
), the step
method of Classifier will automatically call
the step
method of Reweight according to the provided dependencies.
engine.run()
Results¶
Once the training is done, we perform the validation procedure manually as below:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
valid_dataset = MNIST(root="./data", train=False, transform=transform)
valid_dataloader = DataLoader(valid_dataset, batch_size=100, pin_memory=True)
correct = 0
total = 0
for x, target in valid_dataloader:
x, target = x.to(device), target.to(device)
out = classifier(x)
correct += (out.argmax(dim=1) == target).sum().item()
total += x.size(0)
acc = correct / total * 100
print("Imbalanced Classification Accuracy:", acc)
The full code of the above example can be found in this link. If everything runs correctly, you should see something like below on your screen:
[2022-06-20 13:01:48] [INFO] Initializing Multilevel Optimization...
[2022-06-20 13:01:51] [INFO] *** Problem Information ***
[2022-06-20 13:01:51] [INFO] Name: reweight
[2022-06-20 13:01:51] [INFO] Uppers: []
[2022-06-20 13:01:51] [INFO] Lowers: ['classifier']
[2022-06-20 13:01:51] [INFO] Paths: [['reweight', 'classifier', 'reweight']]
[2022-06-20 13:01:51] [INFO] *** Problem Information ***
[2022-06-20 13:01:51] [INFO] Name: classifier
[2022-06-20 13:01:51] [INFO] Uppers: ['reweight']
[2022-06-20 13:01:51] [INFO] Lowers: []
[2022-06-20 13:01:51] [INFO] Paths: []
[2022-06-20 13:01:51] [INFO] Time spent on initialization: 3.124 (s)
Classification Accuracy: 95.41
Finally, we compare our data reweighting result with the baseline without reweighting in the below table:
Test Accuracy |
|
---|---|
Baseline |
91.82% |
Reweighting |
95.41% |
The above result shows that long-tailed image classification can clearly benefit from data reweighting!
Happy Multilevel Optimization!