Source code for aihwkit_lightning.simulator.configs.configs

# -*- coding: utf-8 -*-

# (C) Copyright 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Configurations for resistive processing units."""

# pylint: disable=too-few-public-methods

from dataclasses import dataclass, field
from aihwkit_lightning.simulator.parameters import (
    IOParameters,
    WeightClipParameter,
    WeightModifierParameter,
    MappingParameter,
    PrePostProcessingParameter,
)


[docs]@dataclass class TorchInferenceRPUConfig: """Configuration for an analog tile that is used only for inference. Training is done in *hardware-aware* manner, thus using only the non-idealities of the forward-pass, but backward and update passes are ideal. During inference, statistical models of programming, drift and read noise can be used. """ # pylint: disable=too-many-instance-attributes forward: IOParameters = field( default_factory=IOParameters, metadata=dict(bindings_include=True) ) """Input-output parameter setting for the forward direction. This parameters govern the hardware definitions specifying analog MVM non-idealities. Note: This forward pass is applied equally in training and inference. In addition, materials effects such as drift and programming noise can be enabled during inference by specifying the ``noise_model`` """ clip: WeightClipParameter = field(default_factory=WeightClipParameter) """Parameter for weight clip. If a clipping type is set, the weights are clipped according to the type specified. Caution: The clipping type is set to ``None`` by default, setting parameters of the clipping will not be taken into account, if the clipping type is not specified. """ modifier: WeightModifierParameter = field(default_factory=WeightModifierParameter) """Parameter for weight modifier. If a modifier type is set, it is called once per mini-match in the ``post_update_step`` and modifies the weight in forward and backward direction for the next mini-batch during training, but updates hidden reference weights. In eval mode, the reference weights are used instead for forward. The modifier is used to do hardware-aware training, so that the model becomes more noise robust during inference (e.g. when the ``noise_model`` is employed). """ mapping: MappingParameter = field(default_factory=MappingParameter) """Parameter related to mapping weights to tiles for supporting modules.""" pre_post: PrePostProcessingParameter = field(default_factory=PrePostProcessingParameter) """Parameter related digital pre and post processing."""