Skip to content

Commit

Permalink
Feat (quant): decoupled PerChannel/PerTensor quantization (#1025)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Ian Colbert <[email protected]>
  • Loading branch information
Giuseppe5 and i-colbert authored Oct 8, 2024
1 parent 0b18761 commit 9048ecb
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 22 deletions.
93 changes: 76 additions & 17 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,62 @@ class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
scaling_per_output_type = ScalingPerOutputType.CHANNEL


class WeightNormPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
class PerChannelL2Norm(ExtendedInjector):
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
normalize_stats_impl = L2Norm


class PerChannelL1Norm(ExtendedInjector):
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
normalize_stats_impl = L1Norm


class PerChannelPreNorm(ExtendedInjector):
pre_scaling_impl = ParameterPreScalingWeightNorm
scaling_stats_input_view_shape_impl = OverOutputChannelView
scaling_impl = (this << 1).scaling_impl
normalize_stats_impl = (this << 1).normalize_stats_impl
tracked_parameter_list = (this << 1).tracked_parameter_list
pre_scaling_shape = (this << 1).pre_scaling_shape
permute_dims = (this << 1).permute_dims


class AccumulatorAwarePerChannelPreNorm(PerChannelPreNorm):

pre_scaling_impl = AccumulatorAwareParameterPreScaling
accumulator_bit_width = (this << 1).accumulator_bit_width
accumulator_bit_width_impl = (this << 1).accumulator_bit_width_impl


class AccumulatorAwareZeroCenterPerChannelPreNorm(AccumulatorAwarePerChannelPreNorm):

pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
pre_zero_point_impl = PreZeroCenterZeroPoint
pre_zero_point_shape = this.pre_scaling_shape # TODO: decouple zero_point from scaling
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_shape = (this << 1).scaling_shape


class SolvePostScaleGranularity(ExtendedInjector):

@value
def scaling_stats_input_view_shape_impl(scaling_per_output_type):
if scaling_per_output_type == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output_type == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS

@value
def stats_reduce_dim(scaling_per_output_type):
if scaling_per_output_type == ScalingPerOutputType.TENSOR:
return None
elif scaling_per_output_type == ScalingPerOutputType.CHANNEL:
return SCALING_STATS_REDUCE_DIM


class WeightNormPerChannelFloatDecoupled(SolvePostScaleGranularity,
SolveStatsReduceDimFromEnum,
SolveWeightScalingStatsInputDimsFromModule,
SolveWeightScalingPerOutputChannelShapeFromModule,
SolveParameterScalingShape,
Expand All @@ -361,6 +416,8 @@ def scaling_init(scaling_init_impl, bit_width):
scales = scaling_init_impl.parameter_list_stats() / (pow(2., bit_width - 1.) - 1.)
return scales

per_channel_pre_norm = PerChannelPreNorm

proxy_class = DecoupledWeightQuantProxyFromInjector
tensor_quant = DecoupledRescalingIntQuant
decoupled_int_quant = DecoupledIntQuant
Expand All @@ -369,22 +426,23 @@ def scaling_init(scaling_init_impl, bit_width):
scaling_init_impl = StatsFromParameterScaling
restrict_scaling_impl = LogFloatRestrictValue
scaling_stats_impl = AbsMax
pre_scaling_impl = ParameterPreScalingWeightNorm
restrict_pre_scaling_impl = LogFloatRestrictValue
normalize_stats_impl = L2Norm
normalize_stats_impl = PerChannelL2Norm.normalize_stats_impl
scaling_per_output_type = ScalingPerOutputType.CHANNEL
pre_scaling_shape = this.scaling_shape # TODO: decouple pre_scaling_shape from scaling_shape
pre_scaling_shape = this.scaling_per_output_channel_shape
int_scaling_impl = SingleArgStatelessBuffer(1.)
zero_point_impl = ZeroZeroPoint
pre_zero_point_impl = ZeroZeroPoint
bit_width_impl = BitWidthConst
narrow_range = True
signed = True
scaling_stats_input_view_shape_impl = OverOutputChannelView
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_min_val = 1e-10
pre_scaling_min_val = 1e-10

@value
def pre_scaling_impl():
return this.per_channel_pre_norm.pre_scaling_impl


class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled):
"""Experimental accumulator-aware weight quantizer based on `Quantized Neural Networks
Expand All @@ -403,16 +461,16 @@ class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled):
details on the arithmetic, see `AccumulatorAwareParameterPreScalingWeightNorm`. For further
details on accumulator-aware quantization (A2Q) technique, see the referenced paper."""

@value
def accumulator_bit_width_impl(accumulator_bit_width):
return BitWidthStatefulConst(accumulator_bit_width)

proxy_class = DecoupledWeightQuantWithInputProxyFromInjector
tensor_quant = DecoupledRescalingIntQuantWithInput
pre_scaling_impl = AccumulatorAwareParameterPreScaling
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits
normalize_stats_impl = L1Norm # required to align with derivations in paper
per_channel_pre_norm = AccumulatorAwarePerChannelPreNorm
normalize_stats_impl = PerChannelL1Norm.normalize_stats_impl # required to align with derivations in paper
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits

@value
def accumulator_bit_width_impl(accumulator_bit_width):
return BitWidthStatefulConst(accumulator_bit_width)


class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
Expand All @@ -423,10 +481,11 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
(1) added zero-centering constraint on the weights (i.e., `PreZeroCenterZeroPoint`)
(2) a more relaxed l1-norm bound that is derived in the referenced paper
"""
pre_scaling_impl = AccumulatorAwareZeroCenterParameterPreScaling
pre_zero_point_impl = PreZeroCenterZeroPoint
pre_zero_point_shape = this.scaling_shape # TODO: decouple zero_point from scaling
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
per_channel_pre_norm = AccumulatorAwareZeroCenterPerChannelPreNorm

@value
def pre_zero_point_impl():
return this.per_channel_pre_norm.pre_zero_point_impl


class MSESymmetricScaleSubInjector(ExtendedInjector):
Expand Down
25 changes: 23 additions & 2 deletions tests/brevitas/export/quant_module_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from torch import nn

from brevitas.inject.enum import ScalingPerOutputType
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
Expand All @@ -20,6 +21,7 @@
from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint
from brevitas.quant.fixed_point import Int8WeightPerTensorFixedPoint
from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant
from brevitas.quant.scaled_int import Int8AccumulatorAwareZeroCenterWeightQuant
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling
from brevitas.quant.scaled_int import Int8WeightPerChannelFloat
Expand All @@ -39,21 +41,33 @@
KERNEL_SIZE = 3
TOLERANCE = 1


class Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat(
Int8AccumulatorAwareZeroCenterWeightQuant):
scaling_per_output_type = ScalingPerOutputType.TENSOR


A2Q_QUANTIZERS = {
'a2q_per_channel_float': (Int8AccumulatorAwareWeightQuant, Int8ActPerTensorFloat),
'a2q_plus_per_tensor_float':
(Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat, Int8ActPerTensorFloat)}

QUANTIZERS = {
'asymmetric_per_tensor_float':
(ShiftedUint8WeightPerTensorFloat, ShiftedUint8ActPerTensorFloat),
'symmetric_per_tensor_float': (Int8WeightPerTensorFloat, Int8ActPerTensorFloat),
'asymmetric_per_channel_float':
(ShiftedUint8WeightPerChannelFloat, ShiftedUint8ActPerTensorFloat),
'symmetric_per_channel_float': (Int8WeightPerChannelFloat, Int8ActPerTensorFloat),
'a2q': (Int8AccumulatorAwareWeightQuant, Int8ActPerTensorFloat),
'symmetric_per_tensor_fixed_point': (Int8WeightPerTensorFixedPoint, Int8ActPerTensorFixedPoint),
'symmetric_per_channel_fixed_point':
(Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint)}
(Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint),
**A2Q_QUANTIZERS}

BIAS_QUANTIZERS = {
'bias_external_scale': (Int32Bias,),
'bias_internal_scale': (Int8BiasPerTensorFloatInternalScaling,)}

QUANT_WBIOL_IMPL = [
QuantLinear,
QuantConv1d,
Expand All @@ -62,6 +76,7 @@
QuantConvTranspose1d,
QuantConvTranspose2d,
QuantConvTranspose3d,]

BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8
BIAS_BIT_WIDTHS = [8, 16, 32]

Expand Down Expand Up @@ -102,6 +117,12 @@ def weight_act_quantizers(quantizers):
return quantizers


@fixture
@parametrize('quantizers', A2Q_QUANTIZERS.items(), ids=list(A2Q_QUANTIZERS.keys()))
def a2q_weight_act_quantizers(quantizers):
return quantizers


@fixture
@parametrize('quantizer', BIAS_QUANTIZERS.items(), ids=list(BIAS_QUANTIZERS.keys()))
def bias_quantizer(quantizer):
Expand Down
37 changes: 36 additions & 1 deletion tests/brevitas/export/test_qonnx_export.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import os

import torch

from brevitas.export import enable_debug
Expand All @@ -9,14 +11,15 @@
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantIdentity
from brevitas.nn import QuantLinear
from brevitas.nn import QuantReLU
from brevitas.nn import TruncAvgPool2d
from brevitas.quant.scaled_int import Int4WeightPerTensorFloatDecoupled
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int16Bias
from brevitas_examples import imagenet_classification
from tests.marker import jit_disabled_for_export

from .quant_module_fixture import *

OUT_CH = 50
IN_CH = 40
TOLERANCE = 1.1
Expand Down Expand Up @@ -48,6 +51,7 @@ def forward(self, x):
model(inp) # collect scale factors
model.eval()
export_qonnx(model, inp, export_path='generic_quant_linear.onnx')
os.remove('generic_quant_linear.onnx')


@jit_disabled_for_export()
Expand Down Expand Up @@ -79,6 +83,37 @@ def forward(self, x):
export_qonnx(model, inp, export_path='generic_decoupled_quant_linear.onnx')


@jit_disabled_for_export()
def test_a2q_quant_linear_export(a2q_weight_act_quantizers):
IN_SIZE = (2, IN_CH)

_, (weight_quant, io_quant) = a2q_weight_act_quantizers

class Model(torch.nn.Module):

def __init__(self):
super().__init__()
self.linear = QuantLinear(
out_features=OUT_CH,
in_features=IN_CH,
bias=True,
input_quant=io_quant,
output_quant=io_quant,
weight_quant=weight_quant,
bias_quant=Int16Bias,
return_quant_tensor=False)
self.linear.weight.data.uniform_(-0.1, 0.1)

def forward(self, x):
return self.linear(x)

inp = torch.randn(IN_SIZE)
model = Model()
model(inp) # collect scale factors
model.eval()
export_qonnx(model, inp, export_path='a2q_quant_linear.onnx')


@jit_disabled_for_export()
def test_generic_quant_conv_export():
IN_SIZE = (2, IN_CH, IN_CH, IN_CH)
Expand Down
20 changes: 19 additions & 1 deletion tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from brevitas import torch_version
import brevitas.config as config
from brevitas.inject.enum import ScalingPerOutputType
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
Expand Down Expand Up @@ -48,20 +49,37 @@
EMBED_DIM = 9
NUM_HEADS = 3


class Int8WeightNormL2PerChannelPerTensorFixedPoint(Int8WeightNormL2PerChannelFixedPoint):
scaling_per_output_type = ScalingPerOutputType.TENSOR


class Int8AccumulatorAwareWeightQuantPerTensorFloat(Int8AccumulatorAwareWeightQuant):
scaling_per_output_type = ScalingPerOutputType.TENSOR


class Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat(
Int8AccumulatorAwareZeroCenterWeightQuant):
scaling_per_output_type = ScalingPerOutputType.TENSOR


LSTM_WEIGHT_QUANTIZER = {
'None': None,
'quant_sym': Int8WeightPerTensorFloat,
'quant_asym': ShiftedUint8WeightPerTensorFloat}

A2Q_WBIOL_WEIGHT_QUANTIZER = {
'quant_a2q': Int8AccumulatorAwareWeightQuant,
'quant_a2q_plus': Int8AccumulatorAwareZeroCenterWeightQuant}
'quant_a2q_per_tensor': Int8AccumulatorAwareWeightQuantPerTensorFloat,
'quant_a2q_plus': Int8AccumulatorAwareZeroCenterWeightQuant,
'quant_a2q_plus_per_tensor': Int8AccumulatorawareZeroCenterWeightQuantPerTensorFloat}

WBIOL_WEIGHT_QUANTIZER = {
'None': None,
'quant_sym': Int8WeightPerTensorFloat,
'quant_asym': ShiftedUint8WeightPerTensorFloat,
'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint,
'quant_decoupled_per_tensor': Int8WeightNormL2PerChannelPerTensorFixedPoint,
'quant_mx': MXInt8Weight,
'quant_float': Fp8e4m3WeightPerTensorFloat,
**A2Q_WBIOL_WEIGHT_QUANTIZER}
Expand Down
13 changes: 12 additions & 1 deletion tests/brevitas/nn/test_a2q.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def calc_a2q_plus_acc_bit_width(
return min_bit_width


calc_fnc = {"quant_a2q": calc_a2q_acc_bit_width, "quant_a2q_plus": calc_a2q_plus_acc_bit_width}
calc_fnc = {
"quant_a2q": calc_a2q_acc_bit_width,
"quant_a2q_per_tensor": calc_a2q_acc_bit_width,
"quant_a2q_plus": calc_a2q_plus_acc_bit_width,
"quant_a2q_plus_per_tensor": calc_a2q_plus_acc_bit_width}


@pytest_cases.parametrize_with_cases('model_input', cases=case_model_a2q)
Expand Down Expand Up @@ -94,6 +98,13 @@ def test_quant_wbiol_a2q(model_input, current_cases):

# the tensor quantizer requires a QuantTensor with specified bit-width and sign
quant_weight = model.conv.quant_weight(quant_input)

# test that the scaling factor is per-tensor or per-channel
if kwargs['weight_quant'].endswith('per_tensor'):
assert quant_weight.scale.numel() == 1
else:
assert quant_weight.scale.numel() == model.conv.out_channels

quant_weight = quant_weight.int().float()
if kwargs['model_type'] == 'QuantLinear': # shape = (out_features, in_features)
quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=1)
Expand Down

0 comments on commit 9048ecb

Please sign in to comment.