Source code for aihwkit_lightning.optim.analog_optimizer

# -*- coding: utf-8 -*-

# (C) Copyright 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Analog-aware inference optimizer."""

from types import new_class
from typing import Any, Dict, Type, Generator, Callable

from torch.optim import Optimizer

from aihwkit_lightning.nn import AnalogLayerBase


[docs]class AnalogOptimizer(Optimizer): """Generic optimizer that wraps an existing ``Optimizer`` for analog inference. This class wraps an existing ``Optimizer``, customizing the optimization step for triggering the analog update needed for analog tiles. All other (digital) parameters are governed by the given torch optimizer. In case of hardware-aware training (``InferenceTile``) the tile weight update is also governed by the given optimizer, otherwise it is using the internal analog update as defined in the ``rpu_config``. The ``AnalogOptimizer`` constructor expects the wrapped optimizer class as the first parameter, followed by any arguments required by the wrapped optimizer. Note: The instances returned are of a *new* type that is a subclass of: * the wrapped ``Optimizer`` (allowing access to all their methods and attributes). * this ``AnalogOptimizer``. Example: The following block illustrate how to create an optimizer that wraps standard SGD: >>> from torch.optim import SGD >>> from torch.nn import Linear >>> from aihwkit.simulator.configs.configs import InferenceRPUConfig >>> from aihwkit.optim import AnalogOptimizer >>> model = AnalogLinear(3, 4, rpu_config=InferenceRPUConfig) >>> optimizer = AnalogOptimizer(SGD, model.parameters(), lr=0.02) """ SUBCLASSES = {} # type: Dict[str, Type] """Registry of the created subclasses.""" def __new__(cls, optimizer_cls: Type, *_: Any, **__: Any) -> "AnalogOptimizer": subclass_name = "{}{}".format(cls.__name__, optimizer_cls.__name__) # Retrieve or create a new subclass, that inherits both from # `AnalogOptimizer` and for the specific torch optimizer # (`optimizer_cls`). if subclass_name not in cls.SUBCLASSES: cls.SUBCLASSES[subclass_name] = new_class(subclass_name, (cls, optimizer_cls), {}) return super().__new__(cls.SUBCLASSES[subclass_name]) # pylint: disable=unused-argument def __init__(self, _: Type, analog_layers: Callable[[], Generator], *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) def hook(*_: Any, **__: Any): for analog_layer in analog_layers(): analog_layer: AnalogLayerBase # type: ignore[no-redef] analog_layer.clip_weights() self.register_step_post_hook(hook)