Skip to content

Commit

Permalink
update SmoothQuant algorithm with folding choice (#799)
Browse files Browse the repository at this point in the history
Signed-off-by: Xin He <[email protected]>
Co-authored-by: wenhuach21 <[email protected]>
  • Loading branch information
xin3he and wenhuach21 authored Apr 13, 2023
1 parent 7a7cfe5 commit 6a39f64
Show file tree
Hide file tree
Showing 9 changed files with 745 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def evaluate(self, model):
total += label.size(0)
hit += (pred == label).sum().item()
if index % args.log_frequency == 0:
print(hit / total)
print(hit / total, flush=True)
index += 1
acc = hit / total
print(acc)
print(acc, flush=True)
return acc


Expand Down Expand Up @@ -145,7 +145,7 @@ def eval_func(model):
recipes = {}
if args.sq:
recipes = {"smooth_quant": True, "smooth_quant_args": {'alpha': args.alpha}}
op_type_dict = None
op_type_dict = {}
if args.kl:
op_type_dict = {'linear': {'activation': {'algorithm': ['kl']}}}
if args.fallback_add:
Expand Down
3 changes: 2 additions & 1 deletion neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, framework_specific_info):

self.optype_statistics = None

def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5,
def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5, folding=False,
percentile=99.999, op_types=['MatMul', 'Linear', 'Conv'], scales_per_op=True):
"""Get augmented model with smooth quant.
Expand All @@ -162,6 +162,7 @@ def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5,
iterations: iterations
tune_cfg: quantization config
alpha: smooth alpha in SmoothQuant, 1.0 will fallback to SPIQ
folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant
percentile:Percentile of calibration to remove outliers
op_types: The op types whose input tensor will be dumped
scales_per_op: True, each op will have an individual scale, mainly for accuracy
Expand Down
272 changes: 236 additions & 36 deletions neural_compressor/adaptor/pytorch.py

Large diffs are not rendered by default.

106 changes: 106 additions & 0 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Torch.nn.Module Class Defination."""
# Note: Do not import this file unless you have already imported torch,
# since the model classes inherit torch.nn.Module.
import torch
from packaging.version import Version


def get_torch_version():
try:
torch_version = torch.__version__.split('+')[0]
except ValueError as e: # pragma: no cover
assert False, 'Got an unknown version of torch: {}'.format(e)
version = Version(torch_version)
return version

PT_VERSION = get_torch_version().release


class QDQLinear(torch.nn.Module):
def __init__(self, module, scale, zero_point, dtype):
super().__init__()
if PT_VERSION < Version("1.13.0").release:
import torch.nn.quantized as nnq
else:
import torch.ao.nn.quantized as nnq
self.add_module('quant', nnq.Quantize(scale, zero_point, dtype))
self.add_module('dequant', nnq.DeQuantize())
self.add_module('module', module)
self.qdq_weight()

def forward(self, X):
X = self.quant(X)
X = self.dequant(X)
X = self.module(X)
return X

def qdq_weight(self):
# update weight w/ QDQ
from .smooth_quant import quant_dequant_w
weith_qdq = quant_dequant_w(self.module)
self.module.weight = torch.nn.Parameter(weith_qdq)


class SQLinearWrapper(torch.nn.Module):
def __init__(self, module, input_scale, input_minmax, dtype=torch.quint8):
super().__init__()
self.input_scale = input_scale
self.dtype = dtype
# calculate and only save scale, zero_point to avoid memory usage
self.scale, self.zero_point = self._calculate_qparams(input_scale, input_minmax, dtype)
self.add_module('sq_linear', module)
self.ipex = False # a flag used for ipex inference

def forward(self, X):
if self.ipex:
X = self.sq_linear(X)
else:
X = torch.mul(X, self.input_scale)
X = self.sq_linear(X)
return X

def _calculate_qparams(self, input_scale, input_minmax, dtype=torch.quint8):
# calculate scale and zero_point
if dtype == torch.quint8:
quant_min, quant_max = 0, 255
min_val = torch.min(input_minmax[0] * input_scale)
max_val = torch.max(input_minmax[1] * input_scale)
# work when min_val bigger than zero.
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.max(scale, torch.tensor([torch.finfo(torch.float32).eps]))
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
return scale, zero_point

def _get_weight_scale(self):
# get weight scale and zero_point
from torch.ao.quantization.observer import default_per_channel_weight_observer
obs = default_per_channel_weight_observer()
obs(self.sq_linear.weight)
scale, _ = obs.calculate_qparams()
return scale

def _recover_sq_linear(self):
# remove mul and reset sq_linear for ipex inference
scale = self.input_scale.view(1, self.input_scale.shape[0])
with torch.no_grad():
self.sq_linear.weight *= scale
Loading

0 comments on commit 6a39f64

Please sign in to comment.