From 562b94adf7695b0015a8faa4ad5a1d504e1f6043 Mon Sep 17 00:00:00 2001 From: Mengni Wang Date: Thu, 6 Jun 2024 17:26:23 +0800 Subject: [PATCH 1/3] Add UT and remove unused code for torch MX quant Signed-off-by: Mengni Wang --- .../algorithms/mx_quant/test_mx_utility.py | 21 +++++++++++ test/3x/torch/quantization/test_mx_quant.py | 36 +++++++++++++------ 2 files changed, 47 insertions(+), 10 deletions(-) create mode 100644 test/3x/torch/algorithms/mx_quant/test_mx_utility.py diff --git a/test/3x/torch/algorithms/mx_quant/test_mx_utility.py b/test/3x/torch/algorithms/mx_quant/test_mx_utility.py new file mode 100644 index 00000000000..0e633a8a5d6 --- /dev/null +++ b/test/3x/torch/algorithms/mx_quant/test_mx_utility.py @@ -0,0 +1,21 @@ +import pytest +import torch + +from neural_compressor.torch.algorithms.mx_quant import utils + +def test_mx_quant_utility(): + tensor = torch.rand((1, 30)) + assert torch.equal(tensor, utils.quantize_mx_op(tensor, None, "nearest", 32)) + assert torch.equal(tensor, utils._quantize_fp(tensor)) + assert torch.equal(tensor, utils._quantize_bfloat(tensor, 0)) + assert torch.equal(tensor, utils._quantize_mx(tensor, 8, None)) + + assert not torch.equal(utils._shared_exponents(tensor, "none"), utils._shared_exponents(tensor)) + with pytest.raises(Exception): + utils._shared_exponents(tensor, None) + with pytest.raises(Exception): + utils._reshape_to_blocks(tensor, None, 32) + with pytest.raises(Exception): + utils.quantize_elemwise_op(tensor, "test") + with pytest.raises(Exception): + utils._round_mantissa(tensor, 3, "test") diff --git a/test/3x/torch/quantization/test_mx_quant.py b/test/3x/torch/quantization/test_mx_quant.py index 9122e371235..e348797f507 100644 --- a/test/3x/torch/quantization/test_mx_quant.py +++ b/test/3x/torch/quantization/test_mx_quant.py @@ -3,8 +3,7 @@ import pytest import torch -from neural_compressor.torch.quantization import MXQuantConfig, get_default_mx_config, quantize - +from neural_compressor.torch.quantization import MXQuantConfig, prepare, convert, get_default_mx_config def build_simple_torch_model(): class Model(torch.nn.Module): @@ -40,20 +39,35 @@ def teardown_class(self): def test_mx_quant_default(self): fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_mx_config() - q_model = quantize(fp32_model, quant_config=quant_config) + fp32_model = prepare(model=fp32_model, quant_config=quant_config) + q_model = convert(model=fp32_model) assert q_model is not None, "Quantization failed!" @pytest.mark.parametrize( - "w_dtype, weight_only", + "w_dtype, weight_only, round_method, out_dtype", [ - ("fp4", True), - ("fp8_e5m2", False), + ("fp4", True, "dither", "float32"), + ("fp8_e5m2", False, "floor", "bfloat16"), + ("int8", False, "even", "float16"), + ("int4", False, "nearest", "float32"), + ("int2", False, "dither", "bfloat16"), + ("fp8_e4m3", False, "floor", "float16"), + ("fp6_e3m2", False, "even", "float32"), + ("fp6_e2m3", False, "nearest", "bfloat16"), + ("float16", False, "dither", "float16"), + ("bfloat16", False, "floor", "float32"), ], ) - def test_mx_quant_params(self, w_dtype, weight_only): + def test_mx_quant_params(self, w_dtype, weight_only, round_method, out_dtype): fp32_model = copy.deepcopy(self.fp32_model) - quant_config = MXQuantConfig(w_dtype=w_dtype, weight_only=weight_only) - q_model = quantize(fp32_model, quant_config=quant_config) + quant_config = MXQuantConfig( + w_dtype=w_dtype, + weight_only=weight_only, + round_method=round_method, + out_dtype=out_dtype, + ) + fp32_model = prepare(model=fp32_model, quant_config=quant_config) + q_model = convert(model=fp32_model) assert q_model is not None, "Quantization failed!" def test_mx_quant_accuracy(self): @@ -72,8 +86,10 @@ def forward(self, x): fp32_model = copy.deepcopy(model) fp32_model.linear.weight = torch.nn.Parameter(torch.tensor([[0.0, 1.0], [1.0, 0.0]])) example_inputs = torch.zeros(3, 2) + quant_config = MXQuantConfig() - q_model = quantize(fp32_model, quant_config=quant_config) + fp32_model = prepare(model=fp32_model, quant_config=quant_config) + q_model = convert(model=fp32_model) output1 = fp32_model(example_inputs) output2 = q_model(example_inputs) # set a big atol to avoid random issue From 8cd538a64c4e55a035d5c6c8fe9b5562a6307117 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Jun 2024 09:30:32 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/3x/torch/algorithms/mx_quant/test_mx_utility.py | 1 + test/3x/torch/quantization/test_mx_quant.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test/3x/torch/algorithms/mx_quant/test_mx_utility.py b/test/3x/torch/algorithms/mx_quant/test_mx_utility.py index 0e633a8a5d6..f5cab3093fd 100644 --- a/test/3x/torch/algorithms/mx_quant/test_mx_utility.py +++ b/test/3x/torch/algorithms/mx_quant/test_mx_utility.py @@ -3,6 +3,7 @@ from neural_compressor.torch.algorithms.mx_quant import utils + def test_mx_quant_utility(): tensor = torch.rand((1, 30)) assert torch.equal(tensor, utils.quantize_mx_op(tensor, None, "nearest", 32)) diff --git a/test/3x/torch/quantization/test_mx_quant.py b/test/3x/torch/quantization/test_mx_quant.py index e348797f507..88cb8923ebc 100644 --- a/test/3x/torch/quantization/test_mx_quant.py +++ b/test/3x/torch/quantization/test_mx_quant.py @@ -3,7 +3,8 @@ import pytest import torch -from neural_compressor.torch.quantization import MXQuantConfig, prepare, convert, get_default_mx_config +from neural_compressor.torch.quantization import MXQuantConfig, convert, get_default_mx_config, prepare + def build_simple_torch_model(): class Model(torch.nn.Module): From 802c7a445f6f7a323079cdcb6b4459885995678b Mon Sep 17 00:00:00 2001 From: Mengni Wang Date: Thu, 6 Jun 2024 17:31:02 +0800 Subject: [PATCH 3/3] update Signed-off-by: Mengni Wang --- neural_compressor/torch/algorithms/mx_quant/mx.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/neural_compressor/torch/algorithms/mx_quant/mx.py b/neural_compressor/torch/algorithms/mx_quant/mx.py index c98a47f09f4..76af3511e20 100644 --- a/neural_compressor/torch/algorithms/mx_quant/mx.py +++ b/neural_compressor/torch/algorithms/mx_quant/mx.py @@ -62,9 +62,6 @@ def apply_mx_specs(self): axes=[-1], ) - def append_name(self, postfix): - self.name += postfix - def forward(self, input): if self.mx_none: return super().forward(input)