Shortcuts

Data Reweighting

Introduction

Here we re-implement the data reweighting algorithm from Meta-Weight-Net: Learning an Explicit Mapping For Sample Weighting, which is a two level MLO program. The first level or lower level problem is the classification problem and the second level or upper level problem is the meta learning problem. These levels will be followed by a validation stage at the end.

Classification Problem: Here we train the weights \(\textbf{w}\) of the classifier by minimizing the loss calculated on the training data set while imposing some meta weight on each sample loss. Let the \(i^{th}\) sample loss be \(L_{i}^{train}(\textbf{w}) = l(y_i, y_{predicted})\) and the \(i^{th}\) meta weight be \(\mathcal{V}(L_{i}^{train}(\textbf{w}), \Theta)\) then the objective for this problem will be,

\[\textbf{w}^{*}(\Theta) = \mathrm{argmin} \mathcal{L}^{train}(\textbf{w} ; \Theta) = \frac{1}{N} \sum_{i=1}^{N}\mathcal{V}(L_{i}^{train}(\textbf{w}), \Theta)L_{i}^{train}(\textbf{w})\]

The model of the classifier is specified to be ResNet32 and we will be using the SGD algorithm for the optimization.

Meta Learning Problem: Here we train the parameters \(\Theta\) of the meta weight net by minimizing the loss calculated on the meta training data set. Let the \(i^{th}\) sample loss on the meta data be \(L_{i}^{meta}(\textbf{w}) = l(y^{(meta)}_i, y^{(meta)}_{predicted})\) then the objective of this problem will be,

\[\Theta^{*} = \mathrm{argmin} \mathcal{L}^{meta}(\textbf{w}^{*}(\Theta)) = \frac{1}{M} \sum_{i=1}^{M}L_{i}^{meta}(\textbf{w}^{*}(\Theta))\]

The model of the meta weight net is chosen to be an MLP and we will be using the Adam algorithm for the optimization. For complete and detailed formulation of the loss functions see here.

Note that for calculating the loss of first level we need the forward pass of the second level and for calculating the loss of second level we need the forward pass of the first level. Hence we define the following dependencies. The first level depends on the second level through a 'u2l' (upper to lower) dependency and the second level depends on the first level through a 'l2u' (lower to upper) dependency.

Course of Action

In order to implement the data reweighting algorithm we will go through the following pipeline,

  1. Preparing Data: Prepare the data that will be used for training.

  2. Designing Models: Design models that will be used in the two levels.

  3. Using betty: Finally use Betty to implement the two level MLO program.

Preparing Data

Here we prepare the data that will be used for training the models in the different levels of the algorithm. We will require three different data sets. The first is train_dataloader which will be used in the first level. The second is meta_dataloader which will be used in the second level. Finally we will have a test_dataloader which will be used in the validation stage. These data sets can be prepared as given here.

Designing Models

Here we design the models used in the levels. We will have to prepare one model each for our two levels. The first level has the ResNet32 model and the second level has the MLP model. Both of these models can be designed as given here.

Using Betty

Now we will train our models using the data reweighting algorithm with the help of Betty. We first import the required libraries. The code blocks used below can be found here.

import torch
import torch.nn.functional as F
import torch.optim as optim

from model import *
from data import *
from utils import *

from betty.engine import Engine
from betty.problems import ImplicitProblem
from betty.configs import Config, EngineConfig

Now we simply need to do two things to implement our algorithm:

  1. Define each level’s optimization problem using the Problem class.

  2. Define the hierarchical problem structure using the Engine class.

Defining Problem

Each level problem can be defined with seven components: (1) module, (2) optimizer, (3) data loader, (4) loss function, (5) problem configuration, (6) name, and (7) other optional components (e.g. learning rate scheduler). The loss function (4) can be defined via the training_step method, while all other components can be provided through the class constructor.

First Level: The first level is characterized by the follwing code. The comments along with the code assist the understanding.

#all problem classes are supposed to be a subclass of ImplicitProblem
#the Inner problem class specifies the classifier problem
class Inner(ImplicitProblem):

   #this method defines the forward pass of the classifier with x as an input
   def forward(self, x):
       #the module attribute of a problem class contains its model
       return self.module(x)

   #this method defines the loss function of our problem
   #it takes a batch (subset) of (inputs, labels) from the training data set of the problem as input
   def training_step(self, batch):
       inputs, labels = batch

       #we calculate the predicted labels from the forward pass of the classifier
       outputs = self.forward(inputs)

       #we calculate the cross entropy loss of our classifier probelem and reshape it as required
       loss_vector = F.cross_entropy(outputs, labels.long(), reduction="none")
       loss_vector_reshape = torch.reshape(loss_vector, (-1, 1))

       #we calculate the weight that is supposed to be imposed on every sample loss
       #we do so by using the forward pass of the second level problem
       #we can access the forward pass of other problems by using the 'name' attribute
       weight = self.outer(loss_vector_reshape.detach())

       #we calculte the final loss as the mean of the product of the weights and indvidual
       #sample losses
       loss = torch.mean(weight * loss_vector_reshape)

       return loss

   #this method sets the training data of the problem
   def configure_train_data_loader(self):
       return train_dataloader

   #this method sets the module of the problem to the required model
   def configure_module(self):
       return ResNet32(args.dataset == "cifar10" and 10 or 100).to(device=args.device)

   #this method sets the optimizer of the problem
   #we have used the SGD algorithm for optimization here
   def configure_optimizer(self):
       optimizer = optim.SGD(
           self.module.parameters(),
           lr=args.lr,
           momentum=args.momentum,
           dampening=args.dampening,
           weight_decay=args.weight_decay,
           nesterov=args.nesterov,
       )
       return optimizer

   #this method sets the scheduler sepecifications of the problem (optional)
   def configure_scheduler(self):
       scheduler = optim.lr_scheduler.MultiStepLR(
           self.optimizer, milestones=[5000, 7500, 9000], gamma=0.1
       )
       return scheduler

Second Level: The first level is characterized by the follwing code. The comments along with the code assist the understanding.

#all problem classes are supposed to be a subclass of ImplicitProblem
#the Outer problem class specifies the meta learning problem
class Outer(ImplicitProblem):

   #this method defines the forward pass of the meta learning problem with x as an input
   def forward(self, x):
       #the module attribute of a problem class contains its model
       return self.module(x)

   #this method defines the loss function of our problem
   #it takes a batch (subset) of (inputs, labels) from the meta data set of the problem as input
   def training_step(self, batch):
       inputs, labels = batch

       #we calculate the predicted labels from the forward pass of the classifier
       #we do so by using the forward pass of the second level problem
       #we can access the forward pass of other problems by using the 'name' attribute
       outputs = self.inner(inputs)

       #we calculte the final loss as the mean of the product of the weights and
       #indvidual sample losses
       loss = F.cross_entropy(outputs, labels.long())

       #we calculate the accuracy of the predictions made
       acc = (outputs.argmax(dim=1) == labels.long()).float().mean().item() * 100

       #we return the loss and the accuracy in form of a dictionary
       return {"loss": loss, "acc": acc}

   #this method sets the training data of the problem
   def configure_train_data_loader(self):
       return meta_dataloader

   #this method sets the module of the problem to the required model
   def configure_module(self):
       meta_net = MLP(
           hidden_size=args.meta_net_hidden_size, num_layers=args.meta_net_num_layers
       ).to(device=args.device)
       return meta_net

   #this method sets the optimizer of the problem
   #we have used the Adam algorithm for optimization here
   def configure_optimizer(self):
       meta_optimizer = optim.Adam(
           self.module.parameters(), lr=args.meta_lr, weight_decay=args.meta_weight_decay
       )
       return meta_optimizer

Instantiation: here we instantiate our porblem classes and make their respective objects which call their constructors.

#we difine the configurations of both the problems using the Config library
#configuration of a prooblem contains important specifications related to the problem
outer_config = Config(type="darts", fp16=args.fp16, log_step=100)
inner_config = Config(type="darts", fp16=args.fp16, unroll_steps=1)

#we instantiate the Inner and Outer problems and set their 'name', 'config',
#'device' attributes
outer = Outer(name="outer", config=outer_config, device=args.device)
inner = Inner(name="inner", config=inner_config, device=args.device)

With this our problems are characterized and instansiated. Now we move on to set our Engine class.

Defining Engine

The Engine class handles the hierarchical dependencies between problems. In MLO, there are two types of dependencies: upper-to-lower 'u2l' and lower-to-upper 'l2u'. Both types of dependencies can be defined with Python dictionary, where the key is the starting node and the value is the list of destination nodes.

Since Engine manages the whole MLO program, you can also perform a global validation stage within it. All involved problems of the MLO program can again be accessed with their ‘name’ attribute.

#initiate best accuracy
best_acc = -1

#when we have to define a validation level then we make a subclass of Engine to do so
#if a validation level is not required we do not need this class
class ReweightingEngine(Engine):
    @torch.no_grad()

    #defines the validation level
    def validation(self):

        #initiate correct number of predictions and total predictions
        correct = 0
        total = 0
        global best_acc

        #go thorugh the testing data set for validation
        for x, target in test_dataloader:

            #move the inputs and labels to the desired device
            x, target = x.to(args.device), target.to(args.device)

            #calculate the predicted labels without gradient tracking
            with torch.no_grad():
                out = self.inner(x)

            #update correct if the prediction is correct
            correct += (out.argmax(dim=1) == target).sum().item()

            #update total
            total += x.size(0)

        #calculate accuracy
        acc = correct / total * 100

        #update best accuracy if the new accuracy is greater than the previous accuracy
        if best_acc < acc:
            best_acc = acc

        #return accuracy and best accuracy as a dictionary
        return {"acc": acc, "best_acc": best_acc}

#setup engine configuration using EngineConfig Library
engine_config = EngineConfig(train_iters=10000, valid_step=100, distributed=args.distributed, roll_back=args.rollback)

#specify all the problems in a list
problems = [outer, inner]

#set dependencies as dictionaries
#level 1(inner) accesses level 2(outer)
u2l = {outer: [inner]}

#level 2(outer) accesses level 1(inner)
l2u = {inner: [outer]}

#set up a dictiontionary to list out dependencies
dependencies = {"l2u": l2u, "u2l": u2l}

#instantiate engine and set the 'config', 'problems', 'dependencies' attributes
engine = ReweightingEngine(config=engine_config, problems=problems, dependencies=dependencies)

#run the engine
engine.run()
print(f"IF {args.imbalanced_factor} || Best Acc.: {best_acc}")

With this the dependencies are defined and .run() method of Eninge class will start the program.

Conclusion

Once we define all optimization problems and the hierarchical dependencies between the problems with, respectively, the Problem class and the Engine class, all complicated internal mechanism of MLO such as gradient calculation and optimization execution order are handled by Betty.