Source code for aihwkit_lightning.inference.calibration.calibration

# -*- coding: utf-8 -*-

# (C) Copyright 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Calibration for inference."""

from typing import Optional, Dict, Tuple, List, Union
from copy import deepcopy
from collections.abc import Iterator
from functools import partial
from enum import Enum

from tqdm import tqdm

from torch import tensor, Tensor, cat, randperm, no_grad, empty, zeros, full, int32
from torch.nn.functional import unfold
from torch.nn import Parameter
from torch.nn.modules.module import RemovableHandle

from aihwkit_lightning.exceptions import ConfigError
from aihwkit_lightning.simulator.configs import WeightNoiseInjectionType
from aihwkit_lightning.nn import AnalogLinear, AnalogConv2d
from aihwkit_lightning.simulator.configs import TorchInferenceRPUConfig
from aihwkit_lightning.nn.modules.container import AnalogWrapper

# mypy: disable-error-code="attr-defined"


[docs]class InputRangeCalibrationType(Enum): """Input range post-training calibration type. Different styles of calibrating the DAC ranges post-training. """ NONE = "None" """No Calibration.""" MOVING_STD = "MovingStd" """Computes a moving average of x*standard deviation of the inputs.""" MOVING_QUANTILE = "MovingQuantile" """Computes the moving average of the quantiles. Saves memory.""" CACHE_QUANTILE = "CacheQuantile" """Caches inputs that are then used to compute the Xth quantile for the input range.""" MAX = "Max" """Takes the abs().max() over the inputs."""
def _calibration_pre_forward( mod: Union[AnalogLinear, AnalogConv2d], input_args: Tuple, input_kwargs: Dict, calibration_type: InputRangeCalibrationType, cache_key: str, global_cache: Dict[str, Tensor], max_samples: int = 1000, ir_quantile: float = 0.99, ) -> None: """Caches inputs for calibrating the input ranges. Args: input_args: Forward inputs. calibration_type: type used for calibration cache_key: key of global cache max_samples: Maximal number of cache samples """ # pylint: disable=too-many-locals # get rid of entries that are all-zeros x_input: Tensor x_input = input_args[0] if len(input_args) > 0 else input_kwargs["inp"] if isinstance(mod, AnalogConv2d): assert isinstance(mod.padding, tuple), "Padding must be a tuple" x_input = unfold( x_input, kernel_size=mod.kernel_size, dilation=mod.dilation, padding=mod.padding, stride=mod.stride, ).transpose(-1, -2) x_input = x_input.reshape(-1, x_input.size(-1)) x_input = x_input[~(x_input == 0.0).all(-1)] ir_params = mod.rpu_config.pre_post.input_range # type: ignore cache = global_cache[cache_key] if calibration_type in [ InputRangeCalibrationType.CACHE_QUANTILE, InputRangeCalibrationType.MAX, ]: # We need to cache the inputs # Add new samples to the cache if calibration_type in [InputRangeCalibrationType.CACHE_QUANTILE]: cache = cat( [cache, x_input.float().reshape(-1, x_input.size(-1)).clone().detach().cpu()] ) # Shuffle and limit the number cache = cache[randperm(cache.size(0))[:max_samples]] else: if cache.numel() == 0: cache = full( (len(mod.in_sizes),), fill_value=float("-Inf"), dtype=x_input.dtype, device="cpu", ) current_upper = 0 for slice_idx, inp_size in enumerate(mod.in_sizes): inp_slice = x_input[..., current_upper : current_upper + inp_size] # noqa: E203 cache[slice_idx] = max( cache[slice_idx], inp_slice.abs().max().detach() ) # type: ignore[call-overload] current_upper += inp_size elif calibration_type in [ InputRangeCalibrationType.MOVING_QUANTILE, InputRangeCalibrationType.MOVING_STD, ]: current_upper = 0 for slice_idx, inp_size in enumerate(mod.in_sizes): inp_slice = x_input[..., current_upper : current_upper + inp_size] # noqa: E203 assert mod.input_range_update_idx is not None, "Input range update idx is None" idx = mod.input_range_update_idx[slice_idx] if calibration_type == InputRangeCalibrationType.MOVING_QUANTILE: val = ( inp_slice.abs().max() if ir_quantile == 1.0 else inp_slice.float().flatten().quantile(ir_quantile) ).item() else: std = inp_slice.std().item() val = ir_params.init_std_alpha * std old_val = mod.input_range[slice_idx].item() new_val = (old_val * idx + val) / (idx + 1) mod.input_range.data[slice_idx] = new_val.type_as(mod.input_range) mod.input_range_update_idx[slice_idx] += 1 current_upper += inp_size else: raise ConfigError(f"Unknown InputRangeCalibrationType {calibration_type}") global_cache[cache_key] = cache
[docs]@no_grad() def calibrate_input_ranges( model: AnalogWrapper, calibration_type: InputRangeCalibrationType, dataloader: Iterator, quantile: float = 0.99995, max_samples: int = 1000, std_alpha: Optional[float] = None, verbose: bool = True, ) -> None: """Calibrate the input ranges according to the defined strategy. Only tiles that support and have enabled input range learning will be calibrated. If noise management is turned on an error is raised. Note: This implementation transiently registers a new `forward_pre_hook` on the analog tile level. It assumes that the user has not defined any other forward prehooks. Args: model: The analog model for which to calibrate the input ranges. calibration_type: Strategy of the calibration. See :class:`~InputRangeCalibrationType` dataloader: Iterator that yields the next inputs. Is used like this ``x = next(dataloader); model(x)`` quantile: Quantile used for hard-coded quantile setting. Defaults to 0.99995. max_samples: Max batch samples to cache in each tile. Defaults to 1000. std_alpha: Number of standard deviations for moving standard deviation strategy. Defaults to ``init_std_alpha`` from RPUConfig verbose: Whether to print verbose output. Raises: ConfigError: If RPUConfig does not support input range learning """ # pylint: disable=too-many-statements, too-many-locals, too-many-branches sample_layer: Union[AnalogLinear, AnalogConv2d] sample_layer = next(model.analog_layers()) # type: ignore[assignment] is_training = sample_layer.training rpu_config = sample_layer.rpu_config if rpu_config.pre_post.input_range.enable: raise ConfigError( "You can only calibrate input ranges for models that don't have input ranges." ) if rpu_config.forward.inp_res > 0: raise ConfigError( "When calibrating the input ranges, the input res must be infinite (-1 or 0)" ) if is_training: raise ConfigError("Calibration can only be done in test mode.") cache: Dict[str, Tensor] cache = {} old_rpu_config: Dict[str, TorchInferenceRPUConfig] old_rpu_config = {} handles: List[RemovableHandle] handles = [] for layer_name, layer in model.named_analog_layers(): layer: Union[AnalogLinear, AnalogConv2d] # type: ignore[no-redef] if calibration_type in [ InputRangeCalibrationType.MOVING_QUANTILE, InputRangeCalibrationType.MOVING_STD, ]: # we actually first change the rpu_config to enable the input range layer.rpu_config.pre_post.input_range.enable = True layer.rpu_config.pre_post.input_range.dynamic = False layer.rpu_config.pre_post.input_range.learn_input_range = True layer.rpu_config.pre_post.input_range.init_value = 3.0 if std_alpha is not None: layer.rpu_config.pre_post.input_range.init_std_alpha = std_alpha layer.input_range = Parameter( # type: ignore[assignment] data=full( (len(layer.in_sizes),), fill_value=rpu_config.pre_post.input_range.init_value, dtype=layer.weight.dtype, device=layer.weight.device, ), requires_grad=rpu_config.pre_post.input_range.learn_input_range, ) layer.input_range_update_idx = Parameter( # type: ignore[assignment] data=zeros((len(layer.in_sizes),), dtype=int32, device=layer.weight.device), requires_grad=False, ) # turn off output noise and turn off the weight modifier # generate hook old_rpu_config[layer_name] = deepcopy(layer.rpu_config) layer.rpu_config.forward.out_noise = 0.0 layer.rpu_config.modifier.noise_type = WeightNoiseInjectionType.NONE cache[layer_name] = tensor([]) hook = partial( _calibration_pre_forward, ir_quantile=quantile, calibration_type=calibration_type, cache_key=layer_name, global_cache=cache, max_samples=max_samples, ) handles.append(layer.register_forward_pre_hook(hook, with_kwargs=True)) # Pass through the samples progress_bar = tqdm if verbose else lambda x: x for args, kwargs in progress_bar(dataloader): # type: ignore[operator] model(*args, **kwargs) # type: ignore[operator] # Remove hooks for handle in handles: handle.remove() # Re-assign the rpu-configs but also enable the IR range # and create the according Parameters for layer_name, layer in model.named_analog_layers(): layer: Union[AnalogLinear, AnalogConv2d] # type: ignore[no-redef] layer.input_range_update_idx = Parameter( # type: ignore[assignment] data=full((len(layer.in_sizes),), fill_value=float("-Inf"), device=layer.weight.device), requires_grad=False, ) rpu_config: TorchInferenceRPUConfig # type: ignore[no-redef] rpu_config = old_rpu_config[layer_name] if calibration_type in [ InputRangeCalibrationType.CACHE_QUANTILE, InputRangeCalibrationType.MAX, ]: rpu_config.pre_post.input_range.enable = True layer.input_range = Parameter( data=empty( (len(layer.in_sizes),), dtype=layer.weight.dtype, device=layer.weight.device ).fill_(rpu_config.pre_post.input_range.init_value), requires_grad=rpu_config.pre_post.input_range.learn_input_range, ) cached_inputs = cache[layer_name] if calibration_type == InputRangeCalibrationType.CACHE_QUANTILE: current_upper = 0 for slice_idx, inp_size in enumerate(layer.in_sizes): inp_slice = cached_inputs[ ..., current_upper : current_upper + inp_size ] # noqa: E203 layer.input_range.data[slice_idx] = ( inp_slice.flatten().quantile(quantile).item() ) current_upper += inp_size elif calibration_type == InputRangeCalibrationType.MAX: layer.input_range.data = cached_inputs layer.rpu_config = rpu_config