Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add set_local support for static quant with pt2e #1870

Merged
merged 9 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def get_unquantized_node_set(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if meta := getattr(node, "meta"):
if quantization_annotation := meta.get(xiq.QUANT_ANNOTATION_KEY):
if quantization_annotation._annotated:
none_annotation = xiq._X86InductorQuantizationAnnotation(_annotated=True)
if quantization_annotation != none_annotation:
continue
unquantized_node_set.add(node)
return unquantized_node_set
Expand All @@ -163,11 +164,11 @@ def _parse_node_candidate_set_from_user_config(config, gm):
op_name_filters = []
for op_type_name, config in op_type_configs.items():
op_type = getattr(torch.nn, op_type_name)
if config.act_dtype == "fp16":
if config.act_dtype == "fp16": # pragma: no cover
filter = xpq._get_module_type_filter(op_type)
op_type_filters.append(filter)
for op_name, config in op_name_configs.items():
if config.act_dtype == "fp16":
if config.act_dtype == "fp16": # pragma: no cover
filter = xpq._get_module_name_filter(op_name)
op_name_filters.append(filter)
node_set_from_user_config = set()
Expand Down
24 changes: 23 additions & 1 deletion neural_compressor/torch/algorithms/pt2e_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from torch.ao.quantization.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer

from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2


def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec:
dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8}
Expand Down Expand Up @@ -53,6 +55,9 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals


def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig:
NOT_QUANT_DTYPES = ["fp32", "fp16", "bf16"]
if inc_config.act_dtype in NOT_QUANT_DTYPES and inc_config.w_dtype in NOT_QUANT_DTYPES:
return None
default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic)
input_act_quant_spec = create_quant_spec_from_config(
inc_config.act_dtype, inc_config.act_sym, inc_config.act_granularity, inc_config.act_algo, is_dynamic=is_dynamic
Expand All @@ -75,5 +80,22 @@ def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86Induct
# set global
global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic)
quantizer.set_global(global_config)
# Skip the local config for now (need torch 2.4)
# need torch >= 2.3.2
yiliu30 marked this conversation as resolved.
Show resolved Hide resolved
if GT_TORCH_VERSION_2_3_2: # pragma: no cover
op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config()
if op_type_config_dict:
for op_type, config in op_type_config_dict.items():
_nn_module_type = getattr(torch.nn, op_type, None)
if _nn_module_type:
quantizer.set_module_type_qconfig(
_nn_module_type, _map_inc_config_to_torch_quant_config(config, is_dynamic)
)
_nn_func_type = getattr(torch.nn.functional, op_type, None)
if _nn_func_type:
quantizer.set_function_type_qconfig(
_nn_module_type, _map_inc_config_to_torch_quant_config(config, is_dynamic)
)
if op_name_config_dict:
for op_name, config in op_name_config_dict.items():
quantizer.set_module_name_qconfig(op_name, _map_inc_config_to_torch_quant_config(config, is_dynamic))
return quantizer
3 changes: 3 additions & 0 deletions neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def get_torch_version():
return version


GT_TORCH_VERSION_2_3_2 = get_torch_version() > Version("2.3.2")
yiliu30 marked this conversation as resolved.
Show resolved Hide resolved


def get_accelerator(device_name="auto"):
global accelerator # update the global accelerator when calling this func
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
Expand Down
46 changes: 40 additions & 6 deletions test/3x/torch/quantization/test_pt2e_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
prepare,
quantize,
)
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version
from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2, TORCH_VERSION_2_2_2, get_torch_version

torch.manual_seed(0)

Expand Down Expand Up @@ -119,6 +119,42 @@ def calib_fn(model):
logger.warning("out shape is %s", out.shape)
assert out is not None

@pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.2")
def test_quantize_simple_model_with_set_local(self, force_not_import_ipex):
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
float_model_output = model(*example_inputs)
quant_config = None

def calib_fn(model):
for i in range(4):
model(*example_inputs)

quant_config = get_default_static_config()
quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
q_model = quantize(model=model, quant_config=quant_config, run_fn=calib_fn)

# check the half node
expected_node_occurrence = {
# Only quantize the `fc2`
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
}
expected_node_occurrence = {
torch_test_quant_common.NodeSpec.call_function(k): v for k, v in expected_node_occurrence.items()
}
node_in_graph = self.get_node_in_graph(q_model)
for node, cnt in expected_node_occurrence.items():
assert node_in_graph.get(node, 0) == cnt, f"Node {node} should occur {cnt} times, but {node_in_graph[node]}"

from torch._inductor import config

config.freezing = True
q_model_out = q_model(*example_inputs)
assert torch.allclose(float_model_output, q_model_out, atol=1e-2), "Quantization failed!"
opt_model = torch.compile(q_model)
out = opt_model(*example_inputs)
assert out is not None

@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
@pytest.mark.parametrize("is_dynamic", [False, True])
def test_prepare_and_convert_on_simple_model(self, is_dynamic, force_not_import_ipex):
Expand Down Expand Up @@ -193,9 +229,9 @@ def get_node_in_graph(graph_module):
nodes_in_graph[n] += 1
else:
nodes_in_graph[n] = 1
return
return nodes_in_graph

@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
@pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.0")
def test_mixed_fp16_and_int8(self, force_not_import_ipex):
model, example_inputs = self.build_model_include_conv_and_linear()
model = export(model, example_inputs=example_inputs)
Expand All @@ -221,9 +257,7 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex):
}
node_in_graph = self.get_node_in_graph(converted_model)
for node, cnt in expected_node_occurrence.items():
assert (
expected_node_occurrence.get(node, 0) == cnt
), f"Node {node} should occur {cnt} times, but {node_in_graph[node]}"
assert node_in_graph.get(node, 0) == cnt, f"Node {node} should occur {cnt} times, but {node_in_graph[node]}"

# inference
from torch._inductor import config
Expand Down
Loading