Skip to content

Commit

Permalink
3.x SQ autotune supports calib_func w/ capture input (#1821)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Zixuan <[email protected]>
  • Loading branch information
violetch24 authored May 30, 2024
1 parent 7120dd4 commit 5dafe5f
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 22 deletions.
68 changes: 57 additions & 11 deletions neural_compressor/torch/algorithms/smooth_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import intel_extension_for_pytorch as ipex
import numpy
import torch
import tqdm
from packaging.version import Version
from tqdm import tqdm

from neural_compressor.torch.algorithms.static_quant import (
CpuInfo,
Expand Down Expand Up @@ -78,6 +78,9 @@ def get_quantizable_ops_recursively(model, example_inputs, alpha, act_algo, inpl

from torch.ao.quantization import MinMaxObserver

if alpha == "auto": # for quantize API
alpha = 0.5

if ipex_ver.release >= Version("2.1.1").release:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
alpha=alpha, act_observer=MinMaxObserver
Expand Down Expand Up @@ -390,6 +393,9 @@ def forward_wrapper(model, input, device=torch.device("cpu")): # pragma: no cov
output = model(*input)
except:
output = model(input)
elif isinstance(input, zip):
for args, kwargs in input:
output = model(*args, **kwargs)
else:
output = model(input)
return output
Expand All @@ -412,6 +418,43 @@ def model_forward(model, dataloader, iters, device): # pragma: no cover
break


def build_captured_dataloader(model, run_fn, calib_num=None):
class CapturedDataloader:
def __init__(self, args_list, kwargs_list) -> None:
self.args_list = args_list
self.kwargs_list = kwargs_list

def __iter__(self):
for args, kwargs in zip(self.args_list, self.kwargs_list):
if not args:
yield kwargs
elif not kwargs:
yield args
else:
yield args, kwargs

class InputCaptureModule(torch.nn.Module):
def __init__(self, model) -> None:
super().__init__()
self.args_list = []
self.kwargs_list = []
self.orig_model = model
self.iters = 0
self.calib_num = calib_num

def forward(self, *args, **kwargs):
if self.iters < self.calib_num:
self.args_list.append(args)
self.kwargs_list.append(kwargs)
self.iters += 1

captured_model = InputCaptureModule(model)
run_fn(captured_model)
dataloader = CapturedDataloader(captured_model.args_list, captured_model.kwargs_list)
model = captured_model.orig_model
return model, dataloader


def cal_scale(input_max_abs, weights, alpha, weight_max_lb=1e-5): # pragma: no cover
weights = torch.cat(weights, dim=0)
weight_max = torch.max(torch.abs(weights), dim=0)[0]
Expand Down Expand Up @@ -1349,14 +1392,15 @@ def _auto_tune_alpha(self):
best_alphas = self.init_alpha

if not self.dataloader:
logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.")
self._qdq_model_unwrapper_for_auto()
return best_alphas
logger.info("No dataloader, performing auto-tuning with calibration function instead.")
self.model, self.dataloader = build_captured_dataloader(self.model, self.q_func, self.calib_sample_num)

bar = tqdm(self.dataloader, total=self.calib_sample_num, desc="auto tune alpha") # pylint: disable=E1102
for input in bar:
if isinstance(input, tuple) or isinstance(input, list):
if len(input) == 2:
input, _ = input # Extract input when both input and label are yielded by dataloader.

loss_alphas = {}
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
Expand All @@ -1374,8 +1418,9 @@ def _auto_tune_alpha(self):
cur_loss = loss_alphas[key]
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
total_cnt += self.dataloader.batch_size
tmp_cnt += self.dataloader.batch_size

total_cnt += 1
tmp_cnt += 1
if tmp_cnt // multiply_factor >= 1:
alpha_update_iter += 1
tmp_cnt = 0
Expand Down Expand Up @@ -1418,13 +1463,14 @@ def _auto_tune_alpha_blockwise(self):
best_alphas = self.init_alpha

if not self.dataloader:
logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.")
self._qdq_model_unwrapper_for_auto()
return best_alphas
logger.info("No dataloader, performing auto-tuning with calibration function instead.")
self.model, self.dataloader = build_captured_dataloader(self.model, self.q_func, self.calib_sample_num)

bar = tqdm(self.dataloader, total=self.calib_sample_num, desc="auto tune alpha") # pylint: disable=E1102
for input in bar:
if isinstance(input, tuple): # Extract input when both input and label are yielded by dataloader.
input = input[0]

loss_alphas = {}
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
Expand All @@ -1446,8 +1492,8 @@ def _auto_tune_alpha_blockwise(self):
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[block_name][alpha_key]

total_cnt += self.dataloader.batch_size
tmp_cnt += self.dataloader.batch_size
total_cnt += 1
tmp_cnt += 1
if tmp_cnt // multiply_factor >= 1:
alpha_update_iter += 1
tmp_cnt = 0
Expand Down
37 changes: 37 additions & 0 deletions test/3x/torch/algorithms/smooth_quant/test_sq_utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import copy

import pytest
import torch


class Model(torch.nn.Module):
device = torch.device("cpu")

def __init__(self):
super(Model, self).__init__()
self.fc1 = torch.nn.Linear(3, 4)
self.fc2 = torch.nn.Linear(4, 3)

def forward(self, x):
out = self.fc1(x)
out = self.fc2(out)
return out


model = Model()


def test_captured_dataloader():
from neural_compressor.torch.algorithms.smooth_quant import build_captured_dataloader

fp32_model = copy.deepcopy(model)

def run_fn(model):
for i in range(10):
example_inputs = torch.randn([1, 3])
model(example_inputs)

tmp_model, dataloader = build_captured_dataloader(fp32_model, run_fn, calib_num=32)
assert tmp_model == fp32_model, "Model should be same after building dataloader. Please check."
assert isinstance(dataloader.args_list[0][0], torch.Tensor), "Args list should contain tensors. Please check."
assert not dataloader.kwargs_list[0], "Kwargs list should be empty. Please check."
54 changes: 43 additions & 11 deletions test/3x/torch/quantization/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def forward(self, x):


model = Model()
example_inputs = torch.rand([1, 3])


def run_fn(model):
model(torch.randn([1, 3]))
for i in range(10):
model(example_inputs)


class TestSmoothQuant:
Expand All @@ -40,7 +42,6 @@ def teardown_class(self):
def test_smooth_quant_default(self):
fp32_model = copy.deepcopy(model)
quant_config = get_default_sq_config()
example_inputs = torch.randn([1, 3])
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(prepared_model)
q_model = convert(prepared_model)
Expand All @@ -57,7 +58,6 @@ def test_smooth_quant_default(self):
def test_smooth_quant_fallback(self):
fp32_model = copy.deepcopy(model)
quant_config = get_default_sq_config()
example_inputs = torch.randn([1, 3])
# fallback by op_type
quant_config.set_local(torch.nn.Linear, SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
Expand Down Expand Up @@ -87,10 +87,6 @@ def test_sq_linear_params(self, act_sym, act_algo, alpha, folding, scale_sharing
quant_config = SmoothQuantConfig(
act_sym=act_sym, act_algo=act_algo, alpha=alpha, folding=folding, scale_sharing=scale_sharing
)
example_inputs = torch.zeros([1, 3])

def run_fn(model):
model(example_inputs)

prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(prepared_model)
Expand All @@ -102,7 +98,6 @@ def run_fn(model):

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
def test_sq_ipex_accuracy(self):
example_inputs = torch.zeros([1, 3])
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
user_model = copy.deepcopy(model)
user_model = ipex.quantization.prepare(user_model.eval(), qconfig, example_inputs=example_inputs, inplace=True)
Expand Down Expand Up @@ -144,7 +139,6 @@ def run_fn(model):
def test_sq_save_load(self):
fp32_model = copy.deepcopy(model)
quant_config = get_default_sq_config()
example_inputs = torch.zeros([1, 3])
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(prepared_model)
q_model = convert(prepared_model)
Expand All @@ -171,7 +165,6 @@ def test_sq_save_load(self):
def test_smooth_quant_with_quantize_API(self):
fp32_model = copy.deepcopy(model)
quant_config = get_default_sq_config()
example_inputs = torch.randn([1, 3])
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"

Expand All @@ -184,7 +177,6 @@ def test_smooth_quant_with_quantize_API(self):
def test_smooth_quant_mixed_precision(self):
fp32_model = copy.deepcopy(model)
quant_config = get_default_sq_config() # do mixed_precison by default.
example_inputs = torch.randn([1, 3])

# prepare/convert API
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
Expand All @@ -203,3 +195,43 @@ def test_smooth_quant_mixed_precision(self):
quant_config.folding = True
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
def test_smooth_quant_auto(self):
fp32_model = copy.deepcopy(model)
example_inputs = torch.rand([1, 3])

def run_fn(model):
for i in range(100):
model(example_inputs)

# block-wise
quant_config = SmoothQuantConfig(
alpha="auto",
alpha_min=0.45,
alpha_max=0.55,
alpha_step=0.01,
shared_criterion="mean",
do_blockwise=True,
folding=False,
)
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"
output1 = fp32_model(example_inputs)
output2 = q_model(example_inputs)
assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."

# layer-wise
quant_config = SmoothQuantConfig(
alpha="auto",
alpha_min=0.45,
alpha_max=0.55,
alpha_step=0.01,
shared_criterion="max",
do_blockwise=False,
folding=False,
)
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"
output2 = q_model(example_inputs)
assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."

0 comments on commit 5dafe5f

Please sign in to comment.