diff --git a/neural_compressor/torch/algorithms/mx_quant/mx.py b/neural_compressor/torch/algorithms/mx_quant/mx.py index c98a47f09f4..76af3511e20 100644 --- a/neural_compressor/torch/algorithms/mx_quant/mx.py +++ b/neural_compressor/torch/algorithms/mx_quant/mx.py @@ -62,9 +62,6 @@ def apply_mx_specs(self): axes=[-1], ) - def append_name(self, postfix): - self.name += postfix - def forward(self, input): if self.mx_none: return super().forward(input) diff --git a/test/3x/torch/algorithms/mx_quant/test_mx_utility.py b/test/3x/torch/algorithms/mx_quant/test_mx_utility.py new file mode 100644 index 00000000000..f5cab3093fd --- /dev/null +++ b/test/3x/torch/algorithms/mx_quant/test_mx_utility.py @@ -0,0 +1,22 @@ +import pytest +import torch + +from neural_compressor.torch.algorithms.mx_quant import utils + + +def test_mx_quant_utility(): + tensor = torch.rand((1, 30)) + assert torch.equal(tensor, utils.quantize_mx_op(tensor, None, "nearest", 32)) + assert torch.equal(tensor, utils._quantize_fp(tensor)) + assert torch.equal(tensor, utils._quantize_bfloat(tensor, 0)) + assert torch.equal(tensor, utils._quantize_mx(tensor, 8, None)) + + assert not torch.equal(utils._shared_exponents(tensor, "none"), utils._shared_exponents(tensor)) + with pytest.raises(Exception): + utils._shared_exponents(tensor, None) + with pytest.raises(Exception): + utils._reshape_to_blocks(tensor, None, 32) + with pytest.raises(Exception): + utils.quantize_elemwise_op(tensor, "test") + with pytest.raises(Exception): + utils._round_mantissa(tensor, 3, "test") diff --git a/test/3x/torch/quantization/test_mx_quant.py b/test/3x/torch/quantization/test_mx_quant.py index 9122e371235..88cb8923ebc 100644 --- a/test/3x/torch/quantization/test_mx_quant.py +++ b/test/3x/torch/quantization/test_mx_quant.py @@ -3,7 +3,7 @@ import pytest import torch -from neural_compressor.torch.quantization import MXQuantConfig, get_default_mx_config, quantize +from neural_compressor.torch.quantization import MXQuantConfig, convert, get_default_mx_config, prepare def build_simple_torch_model(): @@ -40,20 +40,35 @@ def teardown_class(self): def test_mx_quant_default(self): fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_mx_config() - q_model = quantize(fp32_model, quant_config=quant_config) + fp32_model = prepare(model=fp32_model, quant_config=quant_config) + q_model = convert(model=fp32_model) assert q_model is not None, "Quantization failed!" @pytest.mark.parametrize( - "w_dtype, weight_only", + "w_dtype, weight_only, round_method, out_dtype", [ - ("fp4", True), - ("fp8_e5m2", False), + ("fp4", True, "dither", "float32"), + ("fp8_e5m2", False, "floor", "bfloat16"), + ("int8", False, "even", "float16"), + ("int4", False, "nearest", "float32"), + ("int2", False, "dither", "bfloat16"), + ("fp8_e4m3", False, "floor", "float16"), + ("fp6_e3m2", False, "even", "float32"), + ("fp6_e2m3", False, "nearest", "bfloat16"), + ("float16", False, "dither", "float16"), + ("bfloat16", False, "floor", "float32"), ], ) - def test_mx_quant_params(self, w_dtype, weight_only): + def test_mx_quant_params(self, w_dtype, weight_only, round_method, out_dtype): fp32_model = copy.deepcopy(self.fp32_model) - quant_config = MXQuantConfig(w_dtype=w_dtype, weight_only=weight_only) - q_model = quantize(fp32_model, quant_config=quant_config) + quant_config = MXQuantConfig( + w_dtype=w_dtype, + weight_only=weight_only, + round_method=round_method, + out_dtype=out_dtype, + ) + fp32_model = prepare(model=fp32_model, quant_config=quant_config) + q_model = convert(model=fp32_model) assert q_model is not None, "Quantization failed!" def test_mx_quant_accuracy(self): @@ -72,8 +87,10 @@ def forward(self, x): fp32_model = copy.deepcopy(model) fp32_model.linear.weight = torch.nn.Parameter(torch.tensor([[0.0, 1.0], [1.0, 0.0]])) example_inputs = torch.zeros(3, 2) + quant_config = MXQuantConfig() - q_model = quantize(fp32_model, quant_config=quant_config) + fp32_model = prepare(model=fp32_model, quant_config=quant_config) + q_model = convert(model=fp32_model) output1 = fp32_model(example_inputs) output2 = q_model(example_inputs) # set a big atol to avoid random issue