Shortcuts

Quick Start

Throughout our tutorials, we will use Data Reweighting for Long-Tailed Image Classification as our running example. The basic context is that we aim to mitigate a class imbalance problem (or long-tailed distribution problem) by re-assigning higher/lower weights to data from rare/common classes. In particular, Meta-Weight-Net (MWN) formulates data reweighting as bilevel optimization as follows:

\[\begin{split}w^*=\underset{w}{\mathrm{argmin}}\;\mathcal{L}_{reweight}(\theta^*(w))\quad\quad\quad\quad\quad\quad\quad\quad\quad\quad\quad\quad\;\;\,\text{Reweighting}\\ \text{s.t. }\theta^*(w)=\underset{\theta}{\mathrm{argmin}}\;\frac{1}{N}\sum_{i=1}^n\mathcal{R}(L^i_{class}(\theta);w)\cdot L^i_{class}(\theta)\quad\quad\quad\text{Classification}\end{split}\]

where \(\theta\) denotes the classifier network’s parameters, \(L_{class}^i\) is the classification loss (cross-entropy) for the \(i\)-th training sample, \(\mathcal{L}_{reweight}\) is the loss for the reweighting level (cross-entropy), and \(w\) denotes the parameters for MWN \(\mathcal{R}\), which reweights each training sample given training loss \(L^i_{class}\).

Now that we have a problem formulation, we need to (1) define each level problem with the Problem class, and (2) define dependencies between problems with the Engine class.


Basic setup

Before diving into MLO, we do basic setup such as importing dependencies and constructing the imbalanced (or long-tailed) dataset. Here, we set the data imbalance factor to 100, meaning that the most common class has 50 times more data than the least common class. This part is not directly relevant to MLO, so users can simply copy and paste the following code.

Preparation code
# import dependencies
import copy
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST


device = "cuda" if torch.cuda.is_available() else "cpu"

# Construct imbalanced (or long-tailed) dataset
def build_dataset(reweight_size=1000, imbalanced_factor=100):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    dataset = MNIST(root="./data", train=True, download=True, transform=transform)

    num_classes = len(dataset.classes)
    num_meta = int(reweight_size / num_classes)

    index_to_meta = []
    index_to_train = []

    imbalanced_num_list = []
    sample_num = int((len(dataset.targets) - reweight_size) / num_classes)
    for class_index in range(num_classes):
        imbalanced_num = sample_num / (imbalanced_factor ** (class_index / (num_classes - 1)))
        imbalanced_num_list.append(int(imbalanced_num))
    np.random.shuffle(imbalanced_num_list)

    for class_index in range(num_classes):
        index_to_class = [
            index for index, label in enumerate(dataset.targets) if label == class_index
        ]
        np.random.shuffle(index_to_class)
        index_to_meta.extend(index_to_class[:num_meta])
        index_to_class_for_train = index_to_class[num_meta:]

        index_to_class_for_train = index_to_class_for_train[: imbalanced_num_list[class_index]]

        index_to_train.extend(index_to_class_for_train)

    reweight_dataset = copy.deepcopy(dataset)
    dataset.data = dataset.data[index_to_train]
    dataset.targets = list(np.array(dataset.targets)[index_to_train])
    reweight_dataset.data = reweight_dataset.data[index_to_meta]
    reweight_dataset.targets = list(np.array(reweight_dataset.targets)[index_to_meta])

    return dataset, reweight_dataset

classifier_dataset, reweight_dataset = build_dataset(imbalanced_factor=100)

Problem

In this example, we have a MLO program consisting of two problem levels: upper and lower. We respectively refer to these two problems as Reweight and Classifier, and create Problem classes for each of them. As introduced in the Software Design chapter, each problem is defined by (1) module, (2) optimizer, (3) data loader, (4) loss function, (5) training configuration, and (6) other optional components (e.g. learning rate scheduler). Everything except for (4) loss function can be provided through the class constructor, and (4) can be provided via the training_step method. In the following subsections, we provide a step-by-step guide for implementing each of these components in the Problem class, for both the lower-level and upper-level problems.

Lower-level Problem (Classifier)

In our data reweighting example, the lower-level problem corresponds to the long-tailed MNIST image classification task. The data loader code is adopted from here. We can respectively define the module, optimizer, data loader, loss function, and training configuration as follows.

Module, Optimizer, Data Loader, (optional) Scheduler

# Module
classifier_module = nn.Sequential(
    nn.Flatten(), nn.Linear(784, 200), nn.ReLU(), nn.Linear(200, 10)
)

# Optimizer
classifier_optimizer = optim.SGD(classifier_module.parameters(), lr=0.1, momentum=0.9)

# Data Loader
classifier_dataloader = DataLoader(
    classifier_dataset, batch_size=100, shuffle=True, pin_memory=True
)

# LR Scheduler
classifier_scheduler = optim.lr_scheduler.MultiStepLR(
    classifier_optimizer, milestones=[1500, 2500], gamma=0.1
)

Loss Function

Unlike other components, the loss function should be directly implemented in the Problem class via the training_step method.

from betty.problems import ImplicitProblem

class Classifier(ImplicitProblem):
    def training_step(self, batch):
        inputs, labels = batch
        outputs = self.forward(inputs)
        loss_vector = F.cross_entropy(outputs, labels.long(), reduction="none")

        # Reweight
        loss_vector_reshape = torch.reshape(loss_vector, (-1, 1))
        weight = self.reweight(loss_vector_reshape.detach())
        loss = torch.mean(weight * loss_vector_reshape)

        return loss

In this example, we aim to overcome a long-tailed distribution by reweighting each data sample (e.g. increasing weights for data from rare classes while decreasing weights for data from common classes). This is achieved by interacting with the upper-level Reweight problem. The Engine class will provide an access to the Reweight problem via its name for the Classifier problem (i.e. in the line weight = self.reweight(loss_vector_reshape.detach())). Thus, users should be aware of names of other problems, with which the current problem interacts, when writing the loss function.

Training Configuration

The Reweight parameter affects optimization of the Classifier parameter, which will again affect the Reweight loss function. Thus, best-response Jacobian for the optimization process of Classifier problem should be calculated. In this tutorial, we adopt implicit differentiation with finite difference (a.k.a. DARTS) as a best-response Jacobian calculation algorithm. Furthermore, since Classifier is the lower-level problem, we need to specify how many steps we want to unroll before updating the upper-level Reweight problem. We choose the simplest one-step unrolling for our example. All of these can be easily specified with Config.

from betty.configs import Config

classifier_config = Config(type='darts', unroll_steps=1)

Problem Instantiation

Now that we have all the components to define the Classifier problem, we can instantiate the Problem class. We use ‘classifier’ as the name for this problem.

classifier = Classifier(
    name='classifier',
    module=classifier_module,
    optimizer=classifier_optimizer,
    scheduler=classifier_scheduler,
    train_data_loader=classifier_dataloader,
    config=classifier_config,
    device=device
)

Upper-level Problem (Reweight)

While the lower-level problem is a classification problem, the upper-level problem is a reweighting problem. Specifically, Meta-Weight-Net (MWN) proposes to reweight each data sample using an MLP with a single hidden layer, which takes a loss value as an input and outputs an importance weight.

Module, Optimizer, Data Loader

# Module
reweight_module = nn.Sequential(
    nn.Linear(1, 100), nn.ReLU(), nn.Linear(100, 1), nn.Sigmoid()
)

# Optimizer
reweight_optimizer = optim.Adam(reweight_module.parameters(), lr=1e-5)

# Data Loader
reweight_dataloader = DataLoader(
    reweight_dataset, batch_size=100, shuffle=True, pin_memory=True
)

Loss Function

The upper-level reweight problem aims to optimize the loss value on the balanced validation dataset (i.e. reweight_dataloader) with respect to the optimal parameters of the Classifier problem. As before, users can access the inner-level classifier problem via its name (i.e. self.classifier).

class Reweight(ImplicitProblem):
    def training_step(self, batch):
        inputs, labels = batch
        outputs = self.classifier(inputs)
        loss = F.cross_entropy(outputs, labels.long())
        print('Reweight Loss:', loss.item())

        return loss

Training Configuration

Since the Reweight problem is the uppermost problem, there is no need for calculating best-response Jacobian. Thus, we don’t need to specify any training configurations for the Reweight problem.

reweight_config = Config()

Problem Instantiation

We can now instantiate the Problem class for the Reweight problem. We use ‘reweight’ as the name for this problem.

reweight = Reweight(
    name='reweight',
    module=reweight_module,
    optimizer=reweight_optimizer,
    train_data_loader=reweight_dataloader,
    config=reweight_config,
    device=device
)

Engine

Recalling the Software Design chapter, the Engine class handles problem dependencies and execution of multilevel optimization. Let’s again take a step-by-step dive into each of these components.

Problem Dependencies

The dependency between problems are split into two categories — upper-to-lower (u2l) and lower-to-upper(l2u) — both of which are defined using a Python dictionary. In our example, reweight is the upper-level problem and classifier is the lower-level problem.

u2l = {reweight: [classifier]}
l2u = {classifier: [reweight]}
dependencies = {'l2u': l2u, 'u2l': u2l}

Engine Instantiation

To instantiate the Engine class, we need to provide all involved problems as well as the Engine configuration. Since we already defined all problems, we can simply combine them in a Python list. In addition, we perform our multilevel optimization for 3,000 iterations, which can be specified in EngineConfig.

from betty.configs import EngineConfig
from betty.engine import Engine

problems = [reweight, classifier]
engine_config = EngineConfig(train_iters=3000)
engine = Engine(config=engine_config, problems=problems, dependencies=dependencies)

Execution of Multilevel Optimization

Finally, multilevel optimization can be excuted by running engine.run(), which calls the step method of the lowermost problem (i.e. Classifier), which corresponds to a single step of gradient descent. After unrolling gradient descent for the lower-most problem for a pre-determined number of steps (unroll_steps attribute in classifier_config), the step method of Classifier will automatically call the step method of Reweight according to the provided dependencies.

engine.run()

Results

Once the training is done, we perform the validation procedure manually as below:

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
valid_dataset = MNIST(root="./data", train=False, transform=transform)
valid_dataloader = DataLoader(valid_dataset, batch_size=100, pin_memory=True)

correct = 0
total = 0
for x, target in valid_dataloader:
    x, target = x.to(device), target.to(device)
    out = classifier(x)
    correct += (out.argmax(dim=1) == target).sum().item()
    total += x.size(0)
acc = correct / total * 100
print("Imbalanced Classification Accuracy:", acc)

The full code of the above example can be found in this link. If everything runs correctly, you should see something like below on your screen:

[2022-06-20 13:01:48] [INFO] Initializing Multilevel Optimization...

[2022-06-20 13:01:51] [INFO] *** Problem Information ***
[2022-06-20 13:01:51] [INFO] Name: reweight
[2022-06-20 13:01:51] [INFO] Uppers: []
[2022-06-20 13:01:51] [INFO] Lowers: ['classifier']
[2022-06-20 13:01:51] [INFO] Paths: [['reweight', 'classifier', 'reweight']]

[2022-06-20 13:01:51] [INFO] *** Problem Information ***
[2022-06-20 13:01:51] [INFO] Name: classifier
[2022-06-20 13:01:51] [INFO] Uppers: ['reweight']
[2022-06-20 13:01:51] [INFO] Lowers: []
[2022-06-20 13:01:51] [INFO] Paths: []

[2022-06-20 13:01:51] [INFO] Time spent on initialization: 3.124 (s)

Classification Accuracy: 95.41

Finally, we compare our data reweighting result with the baseline without reweighting in the below table:

Test Accuracy

Baseline

91.82%

Reweighting

95.41%

The above result shows that long-tailed image classification can clearly benefit from data reweighting!

Happy Multilevel Optimization!