Skip to content

Commit

Permalink
Add UT and remove unused code for torch MX quant (intel#1854)
Browse files Browse the repository at this point in the history
* Add UT and remove unused code for torch MX quant
---------

Change-Id: I2727aa716fa99467fa2d63b966de4d88470e4bb3
Signed-off-by: Mengni Wang <[email protected]>
  • Loading branch information
mengniwang95 authored and Eran Geva committed Jun 18, 2024
1 parent 647905a commit 31d8bb9
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 12 deletions.
3 changes: 0 additions & 3 deletions neural_compressor/torch/algorithms/mx_quant/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ def apply_mx_specs(self):
axes=[-1],
)

def append_name(self, postfix):
self.name += postfix

def forward(self, input):
if self.mx_none:
return super().forward(input)
Expand Down
22 changes: 22 additions & 0 deletions test/3x/torch/algorithms/mx_quant/test_mx_utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
import torch

from neural_compressor.torch.algorithms.mx_quant import utils


def test_mx_quant_utility():
tensor = torch.rand((1, 30))
assert torch.equal(tensor, utils.quantize_mx_op(tensor, None, "nearest", 32))
assert torch.equal(tensor, utils._quantize_fp(tensor))
assert torch.equal(tensor, utils._quantize_bfloat(tensor, 0))
assert torch.equal(tensor, utils._quantize_mx(tensor, 8, None))

assert not torch.equal(utils._shared_exponents(tensor, "none"), utils._shared_exponents(tensor))
with pytest.raises(Exception):
utils._shared_exponents(tensor, None)
with pytest.raises(Exception):
utils._reshape_to_blocks(tensor, None, 32)
with pytest.raises(Exception):
utils.quantize_elemwise_op(tensor, "test")
with pytest.raises(Exception):
utils._round_mantissa(tensor, 3, "test")
35 changes: 26 additions & 9 deletions test/3x/torch/quantization/test_mx_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from neural_compressor.torch.quantization import MXQuantConfig, get_default_mx_config, quantize
from neural_compressor.torch.quantization import MXQuantConfig, convert, get_default_mx_config, prepare


def build_simple_torch_model():
Expand Down Expand Up @@ -40,20 +40,35 @@ def teardown_class(self):
def test_mx_quant_default(self):
fp32_model = copy.deepcopy(self.fp32_model)
quant_config = get_default_mx_config()
q_model = quantize(fp32_model, quant_config=quant_config)
fp32_model = prepare(model=fp32_model, quant_config=quant_config)
q_model = convert(model=fp32_model)
assert q_model is not None, "Quantization failed!"

@pytest.mark.parametrize(
"w_dtype, weight_only",
"w_dtype, weight_only, round_method, out_dtype",
[
("fp4", True),
("fp8_e5m2", False),
("fp4", True, "dither", "float32"),
("fp8_e5m2", False, "floor", "bfloat16"),
("int8", False, "even", "float16"),
("int4", False, "nearest", "float32"),
("int2", False, "dither", "bfloat16"),
("fp8_e4m3", False, "floor", "float16"),
("fp6_e3m2", False, "even", "float32"),
("fp6_e2m3", False, "nearest", "bfloat16"),
("float16", False, "dither", "float16"),
("bfloat16", False, "floor", "float32"),
],
)
def test_mx_quant_params(self, w_dtype, weight_only):
def test_mx_quant_params(self, w_dtype, weight_only, round_method, out_dtype):
fp32_model = copy.deepcopy(self.fp32_model)
quant_config = MXQuantConfig(w_dtype=w_dtype, weight_only=weight_only)
q_model = quantize(fp32_model, quant_config=quant_config)
quant_config = MXQuantConfig(
w_dtype=w_dtype,
weight_only=weight_only,
round_method=round_method,
out_dtype=out_dtype,
)
fp32_model = prepare(model=fp32_model, quant_config=quant_config)
q_model = convert(model=fp32_model)
assert q_model is not None, "Quantization failed!"

def test_mx_quant_accuracy(self):
Expand All @@ -72,8 +87,10 @@ def forward(self, x):
fp32_model = copy.deepcopy(model)
fp32_model.linear.weight = torch.nn.Parameter(torch.tensor([[0.0, 1.0], [1.0, 0.0]]))
example_inputs = torch.zeros(3, 2)

quant_config = MXQuantConfig()
q_model = quantize(fp32_model, quant_config=quant_config)
fp32_model = prepare(model=fp32_model, quant_config=quant_config)
q_model = convert(model=fp32_model)
output1 = fp32_model(example_inputs)
output2 = q_model(example_inputs)
# set a big atol to avoid random issue
Expand Down

0 comments on commit 31d8bb9

Please sign in to comment.