Shortcuts

Source code for betty.problems.iterative_problem

# Copyright Sang Keun Choe
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ast import Import
from copy import deepcopy

import torch

try:
    import functorch

    HAS_FUNCTORCH = True
except ImportError:
    HAS_FUNCTORCH = False

from betty.problems import Problem
import betty.optim as optim


# pylint: disable=W0223
[docs]class IterativeProblem(Problem): """ ``IterativeProblem`` is sublassed from ``Problem``. """ def __init__( self, name, config, module=None, optimizer=None, scheduler=None, train_data_loader=None, extra_config=None, ): super().__init__( name, config, module, optimizer, scheduler, train_data_loader, extra_config, ) # functorch installation check if not HAS_FUNCTORCH: raise ImportError( "IterativeProblem requires functorch and PyTorch 1.11. " "Run 'pip install functorch'. " "The functorch dependency will be deprecated in the future." ) # functional modules self.fmodule = None self.params = None self.buffers = None self.params_cache = None self.buffers_cache = None self.opitmizer_state_dict_cache = None def initialize(self, engine_config): super().initialize(engine_config=engine_config) # patch module to be functional so that gradient flows through param update # optimizer & scheduler should accordingly be patched as module gets patched self.initialize_optimizer_state() self.patch_modules() self.patch_optimizer() self.patch_scheduler() def optimizer_step(self, *args, **kwargs): assert ( not self._fp16 ), "[!] FP16 training is not supported for IterativeProblem." if self.is_implemented("custom_optimizer_step"): self.params = self.custom_optimizer_step(*args, **kwargs) else: self.params = self.optimizer.step(self.params) def initialize_optimizer_state(self): for param_group in self.optimizer.param_groups: for param in param_group["params"]: param.grad = torch.zeros_like(param.data) self.optimizer.step()
[docs] def patch_modules(self): """ Patch PyTorch's native stateful module into the stateless module so as to support functional forward that takes params as its input. """ fmodule, params, buffers = functorch.make_functional_with_buffers(self.module) self.fmodule = fmodule self.params = params self.buffers = buffers
[docs] def patch_optimizer(self): """ Patch PyTorch's native optimizer by replacing all involved in-place operations to allow gradient flow through the parameter update process. """ if self.optimizer is not None: self.optimizer = optim.patch_optimizer(self.optimizer, self.module)
[docs] def patch_scheduler(self): """ Patch scheduler to be compatible with the patched optimizer """ if self.scheduler is not None: self.scheduler = optim.patch_scheduler(self.scheduler, self.optimizer)
def cache_states(self): # TODO: replace deepcopy with state_dict self.params_cache = deepcopy(self.params) self.buffers_cache = deepcopy(self.buffers) if self.optimizer is not None: self.opitmizer_state_dict_cache = deepcopy(self.optimizer.state) def recover_states(self): # TODO: change loading mechanism based on state_dict self.params, self.params_cache = self.params_cache, None self.buffers, self.buffers_cache = self.buffers_cache, None if self.optimizer is not None: self.optimizer.state = self.opitmizer_state_dict_cache self.opitmizer_state_dict_cache = None def parameters(self): return self.params def trainable_parameters(self): return self.params