Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (quant-channel-splitting): support channel splitting after quantization #912

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
365 changes: 302 additions & 63 deletions src/brevitas/graph/channel_splitting.py

Large diffs are not rendered by default.

65 changes: 57 additions & 8 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.utils import get_module
from brevitas.graph.utils import get_node
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
from brevitas.nn import QuantConvTranspose1d
from brevitas.nn import QuantConvTranspose2d
from brevitas.nn import QuantConvTranspose3d
from brevitas.nn import QuantHardTanh
from brevitas.nn import QuantIdentity
from brevitas.nn import QuantLinear
from brevitas.nn import QuantMultiheadAttention
from brevitas.nn import QuantReLU
from brevitas.nn import QuantSigmoid
from brevitas.nn import QuantTanh
from brevitas.nn.equalized_layer import EqualizedModule
from brevitas.nn.quant_scale_bias import ScaleBias
from brevitas.utils.torch_utils import KwargsForwardHook
Expand All @@ -35,11 +48,19 @@
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
QuantConvTranspose1d,
QuantConvTranspose2d,
QuantConvTranspose3d,
nn.MultiheadAttention,
QuantMultiheadAttention,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
QuantConv1d,
QuantConv2d,
QuantConv3d,
nn.Linear,
QuantLinear,
nn.LayerNorm,
nn.BatchNorm1d,
nn.BatchNorm2d,
Expand Down Expand Up @@ -85,6 +106,8 @@

_ignore_ops = (getattr, 'size')

_quant_modules = (QuantReLU, QuantIdentity, QuantHardTanh, QuantTanh, QuantSigmoid)


# Start and End identify the starting and ending channels of the weight matrix that need to be
# equalized.
Expand Down Expand Up @@ -644,9 +667,26 @@ def _is_supported_module(graph_model: GraphModule, node: Node) -> bool:
return False


def _is_quant_module(graph_model: GraphModule, node: Node):
module = get_module(graph_model, node.target)
return isinstance(module, _quant_modules)


def _get_act_impl(graph_module: GraphModule, node: Node):
module = get_module(graph_module, node.target)
# we know it is a quant module, so just access the proxies
# should be the act_quant.fused_activation_quant_proxy.activation_impl
return module.act_quant.fused_activation_quant_proxy.activation_impl


def _is_scale_invariant_module(graph_model: GraphModule, node: Node) -> bool:
return node.op == 'call_module' and isinstance(
get_module(graph_model, node.target), _scale_invariant_layers)
# if quant module, we need to check the activation impl
if node.op == 'call_module':
if _is_quant_module(graph_model, node):
# if its quant, check the call impl
act_impl = _get_act_impl(graph_model, node)
return isinstance(act_impl, _scale_invariant_layers)
return isinstance(get_module(graph_model, node.target), _scale_invariant_layers)


def _is_scale_varying_activation(graph_model, node):
Expand All @@ -667,20 +707,29 @@ def _is_reshaping_op(node: Node) -> bool:

def get_weight_source(module):
transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1)
if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'out_proj'):
if isinstance(
module,
(nn.MultiheadAttention, QuantMultiheadAttention)) and not hasattr(module, 'out_proj'):
raise RuntimeError("Configuration for Multiheadattention not supported")
weight = module.out_proj.weight if isinstance(module, nn.MultiheadAttention) else module.weight
weight = module.out_proj.weight if isinstance(
module, (nn.MultiheadAttention, QuantMultiheadAttention)) else module.weight
axis = _get_output_axis(module)
weight = transpose(weight, axis)
return weight


def get_weight_sink(module):
transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1)
if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'in_proj_weight'):
raise RuntimeError("Configuration for Multiheadattention not supported")
weight = WeightBiasWrapper(module.in_proj_weight).weight if isinstance(
module, nn.MultiheadAttention) else module.weight
if isinstance(module, nn.MultiheadAttention):
if not hasattr(module, 'in_proj_weight'):
raise RuntimeError("Configuration for Multiheadattention not supported")
weight = WeightBiasWrapper(module.in_proj_weight).weight
elif isinstance(module, QuantMultiheadAttention):
if not hasattr(module.in_proj, 'weight'):
raise RuntimeError("Configuration for Multiheadattention not supported")
weight = WeightBiasWrapper(module.in_proj.weight).weight
else:
weight = module.weight
axis = _get_input_axis(module)
weight = transpose(weight, axis)
return weight
Expand Down
14 changes: 8 additions & 6 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def update_batch(self, module, input, current_layer):
if len(inp.shape) > 2:
inp = inp.reshape((-1, sum(inp.shape[2:])))
# For QuantLinear layer, groups will be 1
if isinstance(inp, QuantTensor):
inp = inp.value
inp_processed = inp.unsqueeze(0)

if isinstance(self.layer, SUPPORTED_CONV_OP):
Expand Down Expand Up @@ -252,14 +254,14 @@ def single_layer_update(self):
self.float_input = self.float_input.to(dev)
self.quantized_input = self.quantized_input.to(dev)
# We don't need full Hessian, we just need the diagonal
self.H_diag = self.quantized_input.transpose(2, 1).square().sum(
2) # summing over Batch dimension
H_diag = self.quantized_input.transpose(2,
1).square().sum(2) # summing over Batch dimension
permutation_list = []
for group_index in range(self.groups):
if self.act_order:
# Re-order Hessian_diagonal so that weights associated to
# higher magnitude activations are quantized first
perm = torch.argsort(self.H_diag[group_index, :], descending=True)
perm = torch.argsort(H_diag[group_index, :], descending=True)
else:
# No permutation, permutation tensor is a ordered index
perm = torch.tensor(range(weight.shape[-1]), device=dev)
Expand Down Expand Up @@ -368,14 +370,14 @@ def single_layer_update(self):
z = torch.zeros(weight.shape[:-1], device=dev)

# We don't need full Hessian, we just need the diagonal
self.H_diag = self.quantized_input.transpose(2, 1).square().sum(
2) # summing over Batch dimension
H_diag = self.quantized_input.transpose(2,
1).square().sum(2) # summing over Batch dimension
permutation_list = []
for group_index in range(self.groups):
if self.act_order:
# Re-order Hessian_diagonal so that weights associated to
# higher magnitude activations are quantized first
perm = torch.argsort(self.H_diag[group_index, :], descending=True)
perm = torch.argsort(H_diag[group_index, :], descending=True)
else:
# No permutation, permutation tensor is a ordered index
perm = torch.tensor(range(weight.shape[-1]), device=dev)
Expand Down
23 changes: 17 additions & 6 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Callable, Optional

from torch import nn

from brevitas import config
Expand All @@ -9,6 +11,8 @@
from brevitas.fx.brevitas_tracer import symbolic_trace
from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph.channel_splitting import GraphChannelSplitting
from brevitas.graph.channel_splitting import split_evenly
from brevitas.graph.equalize import _channel_maxabs
from brevitas.graph.equalize import EqualizeGraph
from brevitas.graph.fixed_point import CollapseConsecutiveConcats
from brevitas.graph.fixed_point import MergeBatchNorm
Expand Down Expand Up @@ -291,9 +295,13 @@ def preprocess_for_quantize(
merge_bn=True,
equalize_bias_shrinkage: str = 'vaiq',
equalize_scale_computation: str = 'maxabs',
channel_splitting_ratio: float = 0.0,
apply_channel_splitting: bool = False,
channel_splitting_layer_split_perc_func: Callable = lambda x: 0.02,
channel_splitting_region_filter_func: Callable = lambda x,
y: True,
channel_splitting_split_input: bool = True,
channel_splitting_criterion: str = 'maxabs'):
channel_splitting_split_func: Callable = split_evenly,
channel_splitting_split_criterion_func: Callable = _channel_maxabs):

training_state = model.training
model.eval()
Expand All @@ -315,11 +323,14 @@ def preprocess_for_quantize(
merge_bias=equalize_merge_bias,
bias_shrinkage=equalize_bias_shrinkage,
scale_computation_type=equalize_scale_computation).apply(model)
if channel_splitting_ratio > 0:
if apply_channel_splitting:
# not setting quant_split_func since we're preprocessing for quantization
model = GraphChannelSplitting(
split_ratio=channel_splitting_ratio,
split_criterion=channel_splitting_criterion,
split_input=channel_splitting_split_input).apply(model)
layer_split_perc_func=channel_splitting_layer_split_perc_func,
region_filter_func=channel_splitting_region_filter_func,
split_criterion_func=channel_splitting_split_criterion_func,
split_input=channel_splitting_split_input,
split_func=channel_splitting_split_func).apply(model)
model.train(training_state)
return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from brevitas import __version__ as brevitas_version
from brevitas import config
from brevitas import torch_version
from brevitas.graph.channel_splitting import GraphChannelSplitting
from brevitas.graph.quantize import preprocess_for_quantize
from brevitas.graph.target.flexml import preprocess_for_flexml_quantize
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization
Expand Down Expand Up @@ -101,12 +102,13 @@ def unique(sequence):
'gpfq': [False], # Enable/Disable GPFQ
'gpfa2q': [False], # Enable/Disable GPFA2Q
'gpfq_p': [1.0], # GPFQ P
'gpxq_act_order': [False], # Use act_order euristics for GPxQ
'gpxq_act_order': [True], # Use act_order euristics for GPxQ
'accumulator_bit_width': [16], # Accumulator bit width, only in combination with GPFA2Q
'act_quant_percentile': [99.999], # Activation Quantization Percentile
'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible
'channel_splitting_ratio': [0.0], # Channel Splitting ratio, 0.0 means no splitting
'split_input': [True], # Whether to split the input channels when applying channel splitting
'quant_channel_splitting_ratio': [0.0], # channel splitting after quantizing
'split_input': [False], # Whether to split the input channels when applying channel splitting
'merge_bn': [True]} # Whether to merge BN layers

parser = argparse.ArgumentParser(description='PyTorch ImageNet PTQ Validation')
Expand Down Expand Up @@ -214,7 +216,9 @@ def ptq_torchvision_models(args):
equalize_iters=config_namespace.graph_eq_iterations,
equalize_merge_bias=config_namespace.graph_eq_merge_bias,
merge_bn=config_namespace.merge_bn,
channel_splitting_ratio=config_namespace.channel_splitting_ratio,
apply_channel_splitting=config_namespace.channel_splitting_ratio > 0,
channel_splitting_layer_split_perc_func=lambda x: config_namespace.
channel_splitting_ratio,
channel_splitting_split_input=config_namespace.split_input)
else:
raise RuntimeError(f"{config_namespace.target_backend} backend not supported.")
Expand Down Expand Up @@ -252,6 +256,14 @@ def ptq_torchvision_models(args):
quant_model = quant_model.cuda(args.gpu)
cudnn.benchmark = False

# apply channel splitting here after quantizing the model
if config_namespace.quant_channel_splitting_ratio > 0:
print("Applying Quant Channel Splitting...")
quant_model = GraphChannelSplitting(
region_filter_func=(lambda x, y: True),
layer_split_perc_func=(lambda x: config_namespace.quant_channel_splitting_ratio),
split_input=config_namespace.split_input).apply(quant_model)

# Calibrate the quant_model on the calibration dataloader
print("Starting calibration")
calibrate(calib_loader, quant_model)
Expand All @@ -271,7 +283,7 @@ def ptq_torchvision_models(args):
quant_model,
p=config_namespace.gpfq_p,
act_order=config_namespace.gpxq_act_order,
gpfa2q=config_namespace.gpfa2q,
use_gpfa2q=config_namespace.gpfa2q,
accumulator_bit_width=config_namespace.accumulator_bit_width)

if config_namespace.gptq:
Expand Down Expand Up @@ -372,8 +384,10 @@ def validate_config(config_namespace):
is_valid = False
if config_namespace.act_exponent_bit_width + config_namespace.act_mantissa_bit_width != config_namespace.act_bit_width - 1:
is_valid = False
if config_namespace.channel_splitting_ratio > 0 and config_namespace.quant_channel_splitting_ratio > 0:
is_valid = False
# if channel splitting is disabled, no need for split input
if not config_namespace.channel_splitting_ratio:
if config_namespace.channel_splitting_ratio == 0 and config_namespace.quant_channel_splitting_ratio == 0:
config_namespace.split_input = None

config_namespace.is_valid = is_valid
Expand Down
Loading
Loading