diff --git a/src/brevitas/graph/channel_splitting.py b/src/brevitas/graph/channel_splitting.py index 8978366eb..859289002 100644 --- a/src/brevitas/graph/channel_splitting.py +++ b/src/brevitas/graph/channel_splitting.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import math -from typing import Dict, List, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -15,6 +15,7 @@ from brevitas.graph.equalize import _get_output_axis from brevitas.graph.equalize import Region from brevitas.graph.equalize import transpose +from brevitas.nn.mixin.base import QuantLayerMixin __all__ = ['GraphChannelSplitting'] @@ -24,12 +25,74 @@ _unsupported_layers = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm) +# example criterion function +def compressibility_loss(inp: torch.Tensor, dim: int = 1) -> torch.Tensor: + out = torch.norm(inp, dim=dim, p=1) / torch.norm(inp, dim=dim, p=2) + return out + + +# example quant split function +def quant_split_evenly( + channel: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + bias: torch.Tensor, + module: nn.Module): + return channel / 2., channel / 2., scale, scale, zero_point, zero_point, bias / 2., bias / 2. + + +def quant_split_quant_error( + channel: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + bias: torch.Tensor, + module: nn.Module): + bit_width = module.weight_quant.bit_width() + int_threshold = module.weight_quant.tensor_quant.int_scaling_impl(bit_width) + # TODO: insert assertion about the int_quant + channel_0: torch.Tensor = module.weight_quant.tensor_quant.int_quant( + scale / int_threshold, zero_point, bit_width, channel) + channel_1 = channel - channel_0 + # leaving scales untouched and initializing bias 1:0 + device = bias.device + assert torch.allclose(channel_0 + channel_1, channel) + return ( + channel_0.clone(), + channel_1.clone(), + scale, + scale, + zero_point, + zero_point, + bias, + torch.tensor(0.0, device=device), + ) + + +def quant_duplicate_channel( + channel: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + bias: torch.Tensor, + module: nn.Module): + # no need to return anything else + return channel, channel, scale, scale, zero_point, zero_point, bias, bias + + +# unquantized split and duplicate functions +def split_evenly(channel: torch.Tensor, bias: torch.Tensor): + return channel / 2., channel / 2., bias / 2., bias / 2. + + +def duplicate_channel(channel: torch.Tensor, bias: torch.Tensor): + return channel, channel, bias, bias + + def _channels_to_split( sources: Dict[str, nn.Module], sinks: Dict[str, nn.Module], - split_criterion: str, - split_ratio: float, - split_input: bool) -> Dict[nn.Module, List[torch.Tensor]]: + layer_split_perc_func: Callable, + split_input: bool, + split_criterion_func: Callable) -> Dict[nn.Module, List[torch.Tensor]]: """ This method computes the channels that will be split based on `split_criterion`. """ @@ -38,19 +101,21 @@ def _channels_to_split( # the modules are all of the same shape so we can just take the first one single_module = next(iter(modules)) num_channels = single_module.weight.shape[_get_axis(single_module)] - splits_per_layer = int(math.ceil(split_ratio * num_channels)) + split_perc = layer_split_perc_func(single_module) + splits_per_layer = int(math.ceil(split_perc * num_channels)) all_channels = [] - if split_criterion == 'maxabs': - for module in modules: - # get input/output axis of module - axis = _get_axis(module) - # transpose to have axis as first dimension - weight_t = transpose(module.weight, axis) - # flatten all but first dimension and get max per channel - max_per_channel = _channel_maxabs(weight_t.reshape(weight_t.size(0), -1)) - channels_sorted = torch.argsort(max_per_channel, descending=True) - all_channels.append(channels_sorted[:splits_per_layer]) + for module in modules: + # get input/output axis of module + axis = _get_axis(module) + # transpose to have axis as first dimension + weight_t = transpose(module.weight, axis) + # flatten all but first dimension and get max per channel + weight_t = weight_t.reshape(weight_t.size(0), -1) + # order values based on criterion + values_sorted = split_criterion_func(weight_t) + channels_sorted = torch.argsort(values_sorted, descending=True) + all_channels.append(channels_sorted[:splits_per_layer]) # return tensor with the unique indices to split channels_to_split = torch.cat(all_channels) @@ -58,12 +123,122 @@ def _channels_to_split( # decorator is needed to modify the weights in-place using a view -@torch.no_grad() -def _split_channels( - module: nn.Module, - channels_to_split: torch.Tensor, - split_input: bool = False, - split_factor: float = 0.5) -> None: +# @torch.no_grad() +def _split_quantized_channels( + module: nn.Module, channels_to_split: torch.Tensor, split_input: bool, + split_func: Callable) -> None: + """ + Given a QuantModule, this method splits the weight channels and scales in case of per_channel + quantization. It differs from _split_quantized_channels as the actual splitting of channels and scales + might needs access to the quantization methods and parameters. + """ + weight = module.weight.data + bias = module.bias.data if module.bias is not None else None + num_added_channels = len(channels_to_split) + + ic_axis = _get_input_axis(module) + oc_axis = _get_output_axis(module) + axis = ic_axis if split_input else oc_axis + # save shape of the module weights + weight_shape = list(weight.shape) + + # check for per_channel quantization + is_per_channel_quant = False + is_asym_quant = False + # we can only split scales etc. for output channels, so check if we are splitting output + if module.weight_quant.scale().shape[oc_axis] == weight_shape[oc_axis]: + is_per_channel_quant = True + # if scales are in the log domain, then arithmetic manipulations need to take that into account + scales = module.weight_quant.tensor_quant.scaling_impl(weight) + try: + # get zero_points, for sym quantization the zero point is 1D + if module.weight_quant.tensor_quant.zero_point_impl.value.shape == scales.shape: + is_asym_quant = True + zero_points = module.weight_quant.tensor_quant.zero_point_impl.value.data + except AttributeError: + # nothing to do with the zero points, that's 0 anyways + pass + for id in channels_to_split: + # get channel to split + channel_to_split = weight.index_select(dim=axis, index=id) + # get scale for channel + scale_to_split = scales.index_select(dim=axis, index=id) if not split_input else scales + # get zero point + zero_point_to_split = zero_points.index_select( + dim=axis, index=id) if is_asym_quant else torch.tensor(0.0) + # get bias + bias_to_split = bias[id] if bias is not None and not split_input else torch.tensor(0.0) + # split channel/scale/zero_point/bias based on custom method, i.e. halfing/duplicating it + split_values = split_func( + channel_to_split, scale_to_split, zero_point_to_split, bias_to_split, module) + assert len(split_values) == 8, 'split method needs to return 8 values: 2x channel, 2x scale, 2x zero_point, 2x bias' + # unpack all split_values + split_channel_0, split_channel_1, split_scale_0, split_scale_1, zero_point_0, zero_point_1, split_bias_0, split_bias_1 = split_values + # set orig channel to split_channel_0 using fill & add since no counterpart to index_select + weight = weight.index_fill(dim=axis, index=id, value=0.0) + weight = weight.index_add(dim=axis, index=id, source=split_channel_0) + # stack the second channel + weight = torch.cat([weight, split_channel_1], dim=axis) + + # if per_channel quant, we need to create a new scale for the added channel + if is_per_channel_quant and not split_input: + scales = scales.index_fill(dim=oc_axis, index=id, value=0.0) + scales = scales.index_add(dim=oc_axis, index=id, source=split_scale_0) + # stacking the newly created scale, always per OC + scales = torch.cat([scales, split_scale_1], dim=oc_axis) + + # zero points in case of asym + if is_asym_quant: + zero_points = zero_points.index_fill(dim=oc_axis, index=id, value=0.0) + zero_points = zero_points.index_add(dim=oc_axis, index=id, source=zero_point_0) + zero_points = torch.cat([zero_points, zero_point_1], dim=oc_axis) + + if bias is not None and not split_input: + bias[id] = split_bias_0 + split_bias_1 = split_bias_1.unsqueeze(0) + bias = torch.cat([bias, split_bias_1], dim=0) + + # set weights to module's weights + module.weight.data = weight.clone().contiguous() + + if bias is not None: + module.bias.data = bias.clone().contiguous() + + if is_per_channel_quant and not split_input: + # set value for scaling_impl to scales + scaling_impl = module.weight_quant.tensor_quant.scaling_impl + try: + scales = scaling_impl.restrict_clamp_scaling.restrict_value_impl.restrict_init_tensor( + scales) + except AttributeError: + # no restrict_clamp_scaling, so pass + pass + finally: + # TODO: this is the wrong attribute to set, ask Giuseppe for the right place to set them to + module.weight_quant.tensor_quant.scaling_impl.value.data = scales.clone().contiguous() + + if is_asym_quant: + # set zero_points to module + module.weight_quant.tensor_quant.zero_point_impl.value.data = zero_points.clone( + ).contiguous() + + if isinstance(module, _conv): + if split_input: + module.in_channels += num_added_channels + else: + module.out_channels += num_added_channels + elif isinstance(module, nn.Linear): + if split_input: + module.in_features += num_added_channels + else: + module.out_features += num_added_channels + + +# decorator is needed to modify the weights in-place using a view +# @torch.no_grad() +def _split_unquantized_channels( + module: nn.Module, channels_to_split: torch.Tensor, split_input: bool, + split_func: Callable) -> None: """ Given a module, this method splits the weight channels as proposed in https://arxiv.org/abs/1901.09504. `split_factor` determines how to split the channels, `channels_to_split` is a list of channel indices. @@ -75,29 +250,28 @@ def _split_channels( _get_axis = _get_input_axis if split_input else _get_output_axis axis = _get_axis(module) - # save shape of the module weights - orig_shape = list(weight.shape) - weight_t = transpose(weight, axis) - # flatten to 2d - weight_t = weight_t.reshape(weight_t.size(0), -1) + for id in channels_to_split: - # split and get channel to stack - weight_t[id, :] *= split_factor - split_channel = weight_t[id, :] - # expand so we can stack - split_channel = split_channel.expand(1, split_channel.size(0)) - weight_t = torch.cat([weight_t, split_channel], dim=0) + # get channel and bias to split + channel_to_split = weight.index_select(dim=axis, index=id) + bias_to_split = bias[id] if bias is not None and not split_input else torch.tensor(0.0) + # split channel with user specified method + split_values = split_func(channel_to_split, bias_to_split) + assert len(split_values) == 4, 'split method needs to return 4 values: 2x channel, 2x bias' + split_channel_0, split_channel_1, split_bias_0, split_bias_1 = split_values + # set orig channel to split_channel_0 using fill & add since no counterpart to index_select + weight = weight.index_fill(dim=axis, index=id, value=0.0) + weight = weight.index_add(dim=axis, index=id, source=split_channel_0) + # stack the second channel + weight = torch.cat([weight, split_channel_1], dim=axis) if bias is not None and not split_input: - bias[id] *= split_factor - split_channel = bias[id:id + 1] - bias = torch.cat((bias, split_channel)) - - # reshape weight_t back to orig shape with the added channels - del orig_shape[axis] - weight_t = weight_t.reshape(weight_t.size(0), *orig_shape) - weight_t = transpose(weight_t, axis) - module.weight.data = weight_t + bias[id] = split_bias_0 + split_bias_1 = split_bias_1.unsqueeze(0) + bias = torch.cat([bias, split_bias_1], dim=0) + + # set weight to the modules weight + module.weight.data = weight if bias is not None: module.bias.data = bias @@ -113,26 +287,73 @@ def _split_channels( module.out_features += num_added_channels +def _duplicate_channels( + module: nn.Module, + channels_to_split: torch.Tensor, + duplicate_input: bool, + duplicating_func: Callable = duplicate_channel, + quant_duplicating_func: Callable = quant_duplicate_channel): + # wrapper to simply use duplicating functions + if isinstance(module, QuantLayerMixin): + # duplicate using duplicate func + _split_quantized_channels( + module, + channels_to_split, + split_input=duplicate_input, + split_func=quant_duplicating_func) + else: + # duplicate the channels as before + _split_unquantized_channels( + module, channels_to_split, split_input=duplicate_input, split_func=duplicating_func) + + +def _split_channels( + module: nn.Module, + channels_to_split: torch.Tensor, + split_input: bool, + split_func: Callable, + quant_split_func: Callable): + # wrapper for splitting channels in quant/unquant modules + if isinstance(module, QuantLayerMixin): + # split quantized channels using the specified splitting mechanism + _split_quantized_channels( + module, channels_to_split, split_input, split_func=quant_split_func) + else: + # split channels regularly + _split_unquantized_channels(module, channels_to_split, split_input, split_func=split_func) + + def _split_channels_region( sources: Dict[str, nn.Module], sinks: Dict[str, nn.Module], channels_to_split: torch.tensor, - split_input: bool) -> None: + split_input: bool, + split_func: Callable, + quant_split_func: Callable) -> None: if not split_input: # splitting output channels for module in sources: - _split_channels(module, channels_to_split, split_input=False) + _split_channels( + module, + channels_to_split, + split_input=False, + split_func=split_func, + quant_split_func=quant_split_func) for module in sinks: # duplicating input_channels for all modules in the sink - _split_channels(module, channels_to_split, split_factor=1, split_input=True) + _duplicate_channels(module, channels_to_split, duplicate_input=True) else: # input channels are split in half, output channels duplicated for module in sinks: - _split_channels(module, channels_to_split, split_input=True) - + _split_channels( + module, + channels_to_split, + split_input=True, + split_func=split_func, + quant_split_func=quant_split_func) for module in sources: # duplicating output_channels for all modules in the source - _split_channels(module, channels_to_split, split_factor=1, split_input=False) + _duplicate_channels(module, channels_to_split, duplicate_input=False) def _is_groupwise(module: nn.Module) -> bool: @@ -187,9 +408,11 @@ def _unwrap_mha(sources: List[nn.Module]) -> List[nn.Module]: def _split( model: GraphModule, regions: List[Region], - split_ratio: float, + layer_split_perc_func: Callable, split_input: bool, - split_criterion: str = 'maxabs') -> GraphModule: + split_func: Callable = split_evenly, + quant_split_func: Callable = quant_split_quant_error, + split_criterion_func: Callable = compressibility_loss) -> GraphModule: for i, region in enumerate(regions): sources = [region.get_module_from_name(src) for src in region.srcs_names] sinks = [region.get_module_from_name(sink) for sink in region.sinks_names] @@ -202,20 +425,22 @@ def _split( channels_to_split = _channels_to_split( sources=sources, sinks=sinks, - split_criterion=split_criterion, - split_ratio=split_ratio, + split_criterion_func=split_criterion_func, + layer_split_perc_func=layer_split_perc_func, split_input=split_input) # splitting/duplicating channels _split_channels_region( sources=sources, sinks=sinks, channels_to_split=channels_to_split, - split_input=split_input) + split_input=split_input, + split_func=split_func, + quant_split_func=quant_split_func) return model -def _clean_regions(regions: List[Region]) -> List[Region]: +def _clean_regions(regions: List[Region], region_filter_func: Callable) -> List[Region]: """ Given a list of regions, this method removes all regions that are not compatible with channel splitting. """ @@ -247,6 +472,10 @@ def _clean_regions(regions: List[Region]) -> List[Region]: if not _is_supported(srcs=sources, sinks=sinks): # add region to be deleted regions_to_del.add(i) + # check if user filters out this region + if not region_filter_func(sources, sinks): + # user doesn't want to split this region + regions_to_del.add(i) regions = [regions[i] for i, _ in enumerate(regions) if i not in regions_to_del] return regions @@ -255,15 +484,23 @@ def _clean_regions(regions: List[Region]) -> List[Region]: class GraphChannelSplitting(GraphTransform): def __init__( - self, - split_ratio: float = 0.02, - split_criterion: str = 'maxabs', - split_input: bool = True): + self, + split_input: bool = True, + split_criterion_func: Callable[[torch.Tensor, int], torch.Tensor] = _channel_maxabs, + split_func: Callable = split_evenly, + quant_split_func: Callable = quant_split_quant_error, + layer_split_perc_func: Optional[Callable[[nn.Module], float]] = lambda x: 0.02, + region_filter_func: Optional[Callable[[List[nn.Module], List[nn.Module]], + bool]] = lambda sources, + sinks: True): super(GraphChannelSplitting, self).__init__() - self.split_ratio = split_ratio - self.split_criterion = split_criterion self.split_input = split_input + self.layer_split_perc_func = layer_split_perc_func + self.split_criterion_func = split_criterion_func + self.split_func = split_func + self.quant_split_func = quant_split_func + self.region_filter_func = region_filter_func def apply( self, @@ -271,14 +508,16 @@ def apply( return_regions: bool = False ) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]: regions = _extract_regions(model) - regions = _clean_regions(regions) + regions = _clean_regions(regions, region_filter_func=self.region_filter_func) if len(regions) > 0: model = _split( model=model, regions=regions, - split_ratio=self.split_ratio, - split_criterion=self.split_criterion, - split_input=self.split_input) + layer_split_perc_func=self.layer_split_perc_func, + split_criterion_func=self.split_criterion_func, + split_input=self.split_input, + split_func=self.split_func, + quant_split_func=self.quant_split_func) if return_regions: return model, regions else: diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index e6538421e..a728df9c9 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -20,6 +20,19 @@ from brevitas.graph.base import ModuleInstanceToModuleInstance from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node +from brevitas.nn import QuantConv1d +from brevitas.nn import QuantConv2d +from brevitas.nn import QuantConv3d +from brevitas.nn import QuantConvTranspose1d +from brevitas.nn import QuantConvTranspose2d +from brevitas.nn import QuantConvTranspose3d +from brevitas.nn import QuantHardTanh +from brevitas.nn import QuantIdentity +from brevitas.nn import QuantLinear +from brevitas.nn import QuantMultiheadAttention +from brevitas.nn import QuantReLU +from brevitas.nn import QuantSigmoid +from brevitas.nn import QuantTanh from brevitas.nn.equalized_layer import EqualizedModule from brevitas.nn.quant_scale_bias import ScaleBias from brevitas.utils.torch_utils import KwargsForwardHook @@ -35,11 +48,19 @@ nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d, + QuantConvTranspose1d, + QuantConvTranspose2d, + QuantConvTranspose3d, nn.MultiheadAttention, + QuantMultiheadAttention, nn.Conv1d, nn.Conv2d, nn.Conv3d, + QuantConv1d, + QuantConv2d, + QuantConv3d, nn.Linear, + QuantLinear, nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, @@ -85,6 +106,8 @@ _ignore_ops = (getattr, 'size') +_quant_modules = (QuantReLU, QuantIdentity, QuantHardTanh, QuantTanh, QuantSigmoid) + # Start and End identify the starting and ending channels of the weight matrix that need to be # equalized. @@ -644,9 +667,26 @@ def _is_supported_module(graph_model: GraphModule, node: Node) -> bool: return False +def _is_quant_module(graph_model: GraphModule, node: Node): + module = get_module(graph_model, node.target) + return isinstance(module, _quant_modules) + + +def _get_act_impl(graph_module: GraphModule, node: Node): + module = get_module(graph_module, node.target) + # we know it is a quant module, so just access the proxies + # should be the act_quant.fused_activation_quant_proxy.activation_impl + return module.act_quant.fused_activation_quant_proxy.activation_impl + + def _is_scale_invariant_module(graph_model: GraphModule, node: Node) -> bool: - return node.op == 'call_module' and isinstance( - get_module(graph_model, node.target), _scale_invariant_layers) + # if quant module, we need to check the activation impl + if node.op == 'call_module': + if _is_quant_module(graph_model, node): + # if its quant, check the call impl + act_impl = _get_act_impl(graph_model, node) + return isinstance(act_impl, _scale_invariant_layers) + return isinstance(get_module(graph_model, node.target), _scale_invariant_layers) def _is_scale_varying_activation(graph_model, node): @@ -667,9 +707,12 @@ def _is_reshaping_op(node: Node) -> bool: def get_weight_source(module): transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1) - if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'out_proj'): + if isinstance( + module, + (nn.MultiheadAttention, QuantMultiheadAttention)) and not hasattr(module, 'out_proj'): raise RuntimeError("Configuration for Multiheadattention not supported") - weight = module.out_proj.weight if isinstance(module, nn.MultiheadAttention) else module.weight + weight = module.out_proj.weight if isinstance( + module, (nn.MultiheadAttention, QuantMultiheadAttention)) else module.weight axis = _get_output_axis(module) weight = transpose(weight, axis) return weight @@ -677,10 +720,16 @@ def get_weight_source(module): def get_weight_sink(module): transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1) - if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'in_proj_weight'): - raise RuntimeError("Configuration for Multiheadattention not supported") - weight = WeightBiasWrapper(module.in_proj_weight).weight if isinstance( - module, nn.MultiheadAttention) else module.weight + if isinstance(module, nn.MultiheadAttention): + if not hasattr(module, 'in_proj_weight'): + raise RuntimeError("Configuration for Multiheadattention not supported") + weight = WeightBiasWrapper(module.in_proj_weight).weight + elif isinstance(module, QuantMultiheadAttention): + if not hasattr(module.in_proj, 'weight'): + raise RuntimeError("Configuration for Multiheadattention not supported") + weight = WeightBiasWrapper(module.in_proj.weight).weight + else: + weight = module.weight axis = _get_input_axis(module) weight = transpose(weight, axis) return weight diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 73a08727b..b2d0bcf1f 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -170,6 +170,8 @@ def update_batch(self, module, input, current_layer): if len(inp.shape) > 2: inp = inp.reshape((-1, sum(inp.shape[2:]))) # For QuantLinear layer, groups will be 1 + if isinstance(inp, QuantTensor): + inp = inp.value inp_processed = inp.unsqueeze(0) if isinstance(self.layer, SUPPORTED_CONV_OP): @@ -252,14 +254,14 @@ def single_layer_update(self): self.float_input = self.float_input.to(dev) self.quantized_input = self.quantized_input.to(dev) # We don't need full Hessian, we just need the diagonal - self.H_diag = self.quantized_input.transpose(2, 1).square().sum( - 2) # summing over Batch dimension + H_diag = self.quantized_input.transpose(2, + 1).square().sum(2) # summing over Batch dimension permutation_list = [] for group_index in range(self.groups): if self.act_order: # Re-order Hessian_diagonal so that weights associated to # higher magnitude activations are quantized first - perm = torch.argsort(self.H_diag[group_index, :], descending=True) + perm = torch.argsort(H_diag[group_index, :], descending=True) else: # No permutation, permutation tensor is a ordered index perm = torch.tensor(range(weight.shape[-1]), device=dev) @@ -368,14 +370,14 @@ def single_layer_update(self): z = torch.zeros(weight.shape[:-1], device=dev) # We don't need full Hessian, we just need the diagonal - self.H_diag = self.quantized_input.transpose(2, 1).square().sum( - 2) # summing over Batch dimension + H_diag = self.quantized_input.transpose(2, + 1).square().sum(2) # summing over Batch dimension permutation_list = [] for group_index in range(self.groups): if self.act_order: # Re-order Hessian_diagonal so that weights associated to # higher magnitude activations are quantized first - perm = torch.argsort(self.H_diag[group_index, :], descending=True) + perm = torch.argsort(H_diag[group_index, :], descending=True) else: # No permutation, permutation tensor is a ordered index perm = torch.tensor(range(weight.shape[-1]), device=dev) diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index b1b94b5da..1373d5a4a 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Callable, Optional + from torch import nn from brevitas import config @@ -9,6 +11,8 @@ from brevitas.fx.brevitas_tracer import symbolic_trace from brevitas.graph.base import ModuleToModuleByClass from brevitas.graph.channel_splitting import GraphChannelSplitting +from brevitas.graph.channel_splitting import split_evenly +from brevitas.graph.equalize import _channel_maxabs from brevitas.graph.equalize import EqualizeGraph from brevitas.graph.fixed_point import CollapseConsecutiveConcats from brevitas.graph.fixed_point import MergeBatchNorm @@ -291,9 +295,13 @@ def preprocess_for_quantize( merge_bn=True, equalize_bias_shrinkage: str = 'vaiq', equalize_scale_computation: str = 'maxabs', - channel_splitting_ratio: float = 0.0, + apply_channel_splitting: bool = False, + channel_splitting_layer_split_perc_func: Callable = lambda x: 0.02, + channel_splitting_region_filter_func: Callable = lambda x, + y: True, channel_splitting_split_input: bool = True, - channel_splitting_criterion: str = 'maxabs'): + channel_splitting_split_func: Callable = split_evenly, + channel_splitting_split_criterion_func: Callable = _channel_maxabs): training_state = model.training model.eval() @@ -315,11 +323,14 @@ def preprocess_for_quantize( merge_bias=equalize_merge_bias, bias_shrinkage=equalize_bias_shrinkage, scale_computation_type=equalize_scale_computation).apply(model) - if channel_splitting_ratio > 0: + if apply_channel_splitting: + # not setting quant_split_func since we're preprocessing for quantization model = GraphChannelSplitting( - split_ratio=channel_splitting_ratio, - split_criterion=channel_splitting_criterion, - split_input=channel_splitting_split_input).apply(model) + layer_split_perc_func=channel_splitting_layer_split_perc_func, + region_filter_func=channel_splitting_region_filter_func, + split_criterion_func=channel_splitting_split_criterion_func, + split_input=channel_splitting_split_input, + split_func=channel_splitting_split_func).apply(model) model.train(training_state) return model diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 668eee22c..821635596 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -20,6 +20,7 @@ from brevitas import __version__ as brevitas_version from brevitas import config from brevitas import torch_version +from brevitas.graph.channel_splitting import GraphChannelSplitting from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization @@ -101,12 +102,13 @@ def unique(sequence): 'gpfq': [False], # Enable/Disable GPFQ 'gpfa2q': [False], # Enable/Disable GPFA2Q 'gpfq_p': [1.0], # GPFQ P - 'gpxq_act_order': [False], # Use act_order euristics for GPxQ + 'gpxq_act_order': [True], # Use act_order euristics for GPxQ 'accumulator_bit_width': [16], # Accumulator bit width, only in combination with GPFA2Q 'act_quant_percentile': [99.999], # Activation Quantization Percentile 'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible 'channel_splitting_ratio': [0.0], # Channel Splitting ratio, 0.0 means no splitting - 'split_input': [True], # Whether to split the input channels when applying channel splitting + 'quant_channel_splitting_ratio': [0.0], # channel splitting after quantizing + 'split_input': [False], # Whether to split the input channels when applying channel splitting 'merge_bn': [True]} # Whether to merge BN layers parser = argparse.ArgumentParser(description='PyTorch ImageNet PTQ Validation') @@ -214,7 +216,9 @@ def ptq_torchvision_models(args): equalize_iters=config_namespace.graph_eq_iterations, equalize_merge_bias=config_namespace.graph_eq_merge_bias, merge_bn=config_namespace.merge_bn, - channel_splitting_ratio=config_namespace.channel_splitting_ratio, + apply_channel_splitting=config_namespace.channel_splitting_ratio > 0, + channel_splitting_layer_split_perc_func=lambda x: config_namespace. + channel_splitting_ratio, channel_splitting_split_input=config_namespace.split_input) else: raise RuntimeError(f"{config_namespace.target_backend} backend not supported.") @@ -252,6 +256,14 @@ def ptq_torchvision_models(args): quant_model = quant_model.cuda(args.gpu) cudnn.benchmark = False + # apply channel splitting here after quantizing the model + if config_namespace.quant_channel_splitting_ratio > 0: + print("Applying Quant Channel Splitting...") + quant_model = GraphChannelSplitting( + region_filter_func=(lambda x, y: True), + layer_split_perc_func=(lambda x: config_namespace.quant_channel_splitting_ratio), + split_input=config_namespace.split_input).apply(quant_model) + # Calibrate the quant_model on the calibration dataloader print("Starting calibration") calibrate(calib_loader, quant_model) @@ -271,7 +283,7 @@ def ptq_torchvision_models(args): quant_model, p=config_namespace.gpfq_p, act_order=config_namespace.gpxq_act_order, - gpfa2q=config_namespace.gpfa2q, + use_gpfa2q=config_namespace.gpfa2q, accumulator_bit_width=config_namespace.accumulator_bit_width) if config_namespace.gptq: @@ -372,8 +384,10 @@ def validate_config(config_namespace): is_valid = False if config_namespace.act_exponent_bit_width + config_namespace.act_mantissa_bit_width != config_namespace.act_bit_width - 1: is_valid = False + if config_namespace.channel_splitting_ratio > 0 and config_namespace.quant_channel_splitting_ratio > 0: + is_valid = False # if channel splitting is disabled, no need for split input - if not config_namespace.channel_splitting_ratio: + if config_namespace.channel_splitting_ratio == 0 and config_namespace.quant_channel_splitting_ratio == 0: config_namespace.split_input = None config_namespace.is_valid = is_valid diff --git a/tests/brevitas/graph/test_channel_splitting.py b/tests/brevitas/graph/test_channel_splitting.py index 30a3dd8d3..a11a2225d 100644 --- a/tests/brevitas/graph/test_channel_splitting.py +++ b/tests/brevitas/graph/test_channel_splitting.py @@ -6,8 +6,11 @@ from brevitas.fx import symbolic_trace from brevitas.graph.channel_splitting import _clean_regions from brevitas.graph.channel_splitting import _split +from brevitas.graph.channel_splitting import GraphChannelSplitting from brevitas.graph.equalize import _extract_regions from brevitas.graph.fixed_point import MergeBatchNorm +from brevitas.graph.quantize import preprocess_for_quantize +from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model from .equalization_fixtures import * @@ -28,13 +31,13 @@ @pytest.mark.parametrize('split_input', [False, True]) def test_toymodels(toy_model, split_input, request): - test_id = request.node.callspec.id + model_name = request.node.callspec.id.split('-')[0] torch.manual_seed(SEED) model_class = toy_model model = model_class() - if 'mha' in test_id: + if 'mha' in model_name: inp = torch.randn(IN_SIZE_LINEAR) else: inp = torch.randn(IN_SIZE_CONV) @@ -50,11 +53,12 @@ def test_toymodels(toy_model, split_input, request): old_state_dict = model.state_dict() regions = _extract_regions(model) - regions = _clean_regions(regions) - if model_class in no_split_models: + regions = _clean_regions(regions, region_filter_func=lambda x, y: True) + if model_name in no_split_models: assert len(regions) == 0 else: - model = _split(model, regions, split_ratio=SPLIT_RATIO, split_input=split_input) + model = _split( + model, regions, split_input=split_input, layer_split_perc_func=lambda x: SPLIT_RATIO) out = model(inp) assert torch.allclose(expected_out, out, atol=ATOL) @@ -79,7 +83,7 @@ def test_toymodels(toy_model, split_input, request): @pytest.mark.parametrize('split_input', [False, True]) def test_torchvision_models(model_coverage: tuple, split_input: bool, request): - model_class = request.node.callspec.id.split('-')[0] + model_name = request.node.callspec.id.split('-')[0] model, coverage = model_coverage @@ -96,11 +100,12 @@ def test_torchvision_models(model_coverage: tuple, split_input: bool, request): old_state_dict = model.state_dict() regions = _extract_regions(model) - regions = _clean_regions(regions) - if model_class in no_split_models: + regions = _clean_regions(regions, region_filter_func=lambda x, y: True) + if model_name in no_split_models: assert len(regions) == 0 else: - model = _split(model, regions, split_ratio=SPLIT_RATIO, split_input=split_input) + model = _split( + model, regions, split_input=split_input, layer_split_perc_func=lambda x: SPLIT_RATIO) out = model(inp) assert torch.allclose(expected_out, out, atol=ATOL) @@ -119,3 +124,135 @@ def test_torchvision_models(model_coverage: tuple, split_input: bool, request): for module in modified_sinks: weight_name = module + '.weight' assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name]) + + +@pytest.mark.parametrize('split_input', [False, True]) +def test_quant_toymodels(toy_model, split_input, request): + model_name = request.node.callspec.id.split('-')[0] + + torch.manual_seed(SEED) + + model_class = toy_model + model = model_class() + if 'mha' in model_name: + pytest.skip('MHA not supported with this quantization method') + else: + inp = torch.randn(IN_SIZE_CONV) + + # preprocess model for quantization, like merge BN etc. + model = preprocess_for_quantize(model) + # save regions + regions = _extract_regions(model) + # quantize model pretty basic + quant_model = quantize_model( + model, + backend='layerwise', + weight_bit_width=8, + act_bit_width=8, + bias_bit_width=32, + scale_factor_type='float_scale', + weight_narrow_range=False, + weight_param_method='mse', + weight_quant_granularity='per_channel', + weight_quant_type='sym', + layerwise_first_last_bit_width=8, + act_param_method='stats', + act_quant_percentile=99.999, + act_quant_type='sym', + quant_format='int') + + expected_out = quant_model(inp) + + # save model's state dict to check if channel splitting was done or not + old_state_dict = quant_model.state_dict() + + # quant_regions should be the same + quant_regions = _extract_regions(quant_model) + quant_regions = _clean_regions(quant_regions, region_filter_func=lambda x, y: True) + + if model_name in no_split_models: + assert len(quant_regions) == 0 + else: + # check regions + assert len(quant_regions) == len(regions) + + # pass custom split function here + quant_model = _split( + quant_model, + quant_regions, + split_input=split_input, + layer_split_perc_func=lambda x: SPLIT_RATIO) + + out = quant_model(inp) + # checking if the outputs are all close, doesn't work for split_input = True + assert torch.allclose(expected_out, out, atol=0.1) + + modified_sources = {source for region in quant_regions for source in region.srcs_names} + # avoiding checking the same module multiple times + modified_sinks = { + sink for region in quant_regions for sink in region.sinks_names} - modified_sources + for module in modified_sources: + if 'mha' in module: + module += '.out_proj' + weight_name = module + '.weight' + assert not torch.equal( + old_state_dict[weight_name], quant_model.state_dict()[weight_name]) + bias_name = module + '.bias' + # not all modules have bias and they only differ when splitting output channels + if bias_name in old_state_dict.keys(): + assert not torch.equal( + old_state_dict[bias_name], quant_model.state_dict()[bias_name]) + for module in modified_sinks: + weight_name = module + '.weight' + assert not torch.equal( + old_state_dict[weight_name], quant_model.state_dict()[weight_name]) + + +@pytest.mark.parametrize('split_input', [False, True]) +def test_torchvision_models_preprocessing(model_coverage: tuple, split_input: bool, request): + model_name = request.node.callspec.id.split('-')[0] + + model, coverage = model_coverage + + torch.manual_seed(SEED) + inp = torch.randn(IN_SIZE_CONV) + + model.eval() + expected_out = model(inp) + + model = symbolic_trace(model) + # merge BN before applying channel splitting + model = MergeBatchNorm().apply(model) + + old_state_dict = model.state_dict() + regions = _extract_regions(model) + regions = _clean_regions(regions, region_filter_func=lambda x, y: True) + + # use default channel absmax for criterion and split evenly for split_func + model, split_regions = GraphChannelSplitting( + layer_split_perc_func=lambda x: SPLIT_RATIO, + region_filter_func=lambda x, y: True, + split_input=split_input).apply(model, return_regions=True) + if model_name in no_split_models: + assert len(regions) == 0 + else: + # check if regions are the same + assert len(regions) == len(split_regions) + + out = model(inp) + assert torch.allclose(expected_out, out, atol=ATOL) + + modified_sources = {source for region in split_regions for source in region.srcs_names} + # avoiding checking the same module multiple times + modified_sinks = { + sink for region in split_regions for sink in region.sinks_names} - modified_sources + for module in modified_sources: + weight_name = module + '.weight' + assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name]) + bias_name = module + '.bias' + # not all modules have bias and they only differ when splitting output channels + if bias_name in old_state_dict.keys() and not split_input: + assert not torch.equal(old_state_dict[bias_name], model.state_dict()[bias_name]) + for module in modified_sinks: + weight_name = module + '.weight' + assert not torch.equal(old_state_dict[weight_name], model.state_dict()[weight_name]) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 89759b41a..7d5bdc3d1 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -14,6 +14,7 @@ from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module +from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model from .equalization_fixtures import * @@ -225,3 +226,87 @@ def test_act_equalization_torchvision_models(model_dict: dict, layerwise: bool): # Check that at least one region performs "true" equalization # If all shapes are scalar, no equalization has been performed assert any([shape != () for shape in shape_scale_regions]) + + +@pytest_cases.parametrize("backend", ['layerwise', 'fx']) +def test_regions_quantized_models(toy_model, backend, request): + test_id = request.node.callspec.id + + # mha produces torch error when quantizing + if 'mha' in test_id: + pytest.skip('MHA not supported for this test.') + + model_class = toy_model + model = model_class() + + model = symbolic_trace(model) + regions = _extract_regions(model) + + # maybe think about other quantization params + quant_model = quantize_model( + model, + backend=backend, + weight_bit_width=8, + act_bit_width=8, + bias_bit_width=32, + scale_factor_type='float_scale', + weight_narrow_range=False, + weight_param_method='stats', + weight_quant_granularity='per_tensor', + weight_quant_type='sym', + layerwise_first_last_bit_width=8, + act_param_method='stats', + act_quant_percentile=99.999, + act_quant_type='sym', + quant_format='int') + quant_regions = _extract_regions(quant_model) + + # check that the same regions were extracted for the quant_model + assert len(regions) == len(quant_regions) + for region, quant_region in zip(regions, quant_regions): + # we need to check the names, the modules will be different as they're quantized + assert region.srcs_names == quant_region.srcs_names + assert region.sinks_names == quant_region.sinks_names + + +@pytest_cases.parametrize("backend", ['layerwise', 'fx']) +def test_regions_quantized_torchvision_models(model_coverage, backend): + model, coverage = model_coverage + + # mobilenet uses ReLU6, fx quantization replaces those modules with ReLU, yielding more regions + if model._get_name() == 'MobileNetV2' and backend == 'fx': + pytest.skip('Mobilenet_v2 quantized with fx not compatible with region extracting') + + torch.manual_seed(SEED) + model.eval() + # The isistance does not work after symbolic trace + model = symbolic_trace(model) + model = TorchFunctionalToModule().apply(model) + + regions = _extract_regions(model) + + # maybe think about other quantization params + quant_model = quantize_model( + model, + backend=backend, + weight_bit_width=8, + act_bit_width=8, + bias_bit_width=32, + scale_factor_type='float_scale', + weight_narrow_range=False, + weight_param_method='stats', + weight_quant_granularity='per_tensor', + weight_quant_type='sym', + layerwise_first_last_bit_width=8, + act_param_method='stats', + act_quant_percentile=99.999, + act_quant_type='sym', + quant_format='int') + quant_regions = _extract_regions(quant_model) + + # check that the same regions were extracted for the quant_model + assert len(regions) == len(quant_regions) + for region, quant_region in zip(regions, quant_regions): + # we need to check the names, the modules will be different as they're quantized + assert region.srcs_names == quant_region.srcs_names + assert region.sinks_names == quant_region.sinks_names