diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index d0db57c392961..77b3dce9fb004 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -5,6 +5,7 @@ # license information. # -------------------------------------------------------------------------- import abc +import copy import itertools import os import uuid @@ -21,6 +22,48 @@ from .quant_utils import apply_plot, load_model_with_shape_infer, smooth_distribution +def rel_entr(pk: np.ndarray, qk: np.ndarray) -> np.ndarray: + """ + See https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.rel_entr.html#scipy.special.rel_entr. + Python implementation. + """ + res = np.empty(pk.shape, dtype=pk.dtype) + res[:] = pk[:] * np.log(pk[:] / qk[:]) + c2 = (pk == 0) & (qk >= 0) + res[c2] = 0 + c1 = (pk > 0) & (qk > 0) + res[~c1] = np.inf + return res + + +def entropy( + pk: np.ndarray, + qk: np.ndarray, + base: Optional[float] = None, + axis: int = 0, +) -> np.ndarray: + """ + Simplifeied version of entropy. + Source: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html. + This avoids taking a dependency on scipy just for this function. + """ + assert base is None or base > 0, "base={base} must be a positive number or `None`." + assert qk is not None, "qk is None" + + pk = np.asarray(pk).astype(np.float32) + pk = 1.0 * pk / np.sum(pk, axis=axis, keepdims=True) + + qk = np.asarray(qk).astype(np.float32) + pk, qk = np.broadcast_arrays(pk, qk) + qk = 1.0 * qk / np.sum(qk, axis=axis, keepdims=True) + vec = rel_entr(pk, qk) + + s = np.sum(vec, axis=axis) + if base is not None: + s /= np.log(base) + return s.astype(pk.dtype) + + class TensorData: _allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"]) _floats = frozenset(["avg", "std", "lowest", "highest", "hist_edges"]) @@ -708,8 +751,8 @@ def collect_absolute_value(self, name_to_arr): min_value = np.min(data_arr_np) max_value = np.max(data_arr_np) else: - min_value = 0 - max_value = 0 + min_value = np.array(0, dtype=data_arr_np.dtype) + max_value = np.array(0, dtype=data_arr_np.dtype) data_arr_np = np.absolute(data_arr_np) # only consider absolute value @@ -725,6 +768,8 @@ def collect_absolute_value(self, name_to_arr): old_histogram = self.histogram_dict[tensor] old_min = old_histogram[2] old_max = old_histogram[3] + assert hasattr(old_min, "dtype"), f"old_min should be a numpy array but is {type(old_min)}" + assert hasattr(old_max, "dtype"), f"old_min should be a numpy array but is {type(old_max)}" old_hist = old_histogram[0] old_hist_edges = old_histogram[1] temp_amax = np.max(data_arr_np) @@ -757,7 +802,7 @@ def collect_value(self, name_to_arr): min_value = np.array(0, dtype=data_arr.dtype) max_value = np.array(0, dtype=data_arr.dtype) - threshold = max(abs(min_value), abs(max_value)) + threshold = np.array(max(abs(min_value), abs(max_value)), dtype=data_arr.dtype) if tensor in self.histogram_dict: old_histogram = self.histogram_dict[tensor] @@ -809,7 +854,7 @@ def merge_histogram(self, old_histogram, data_arr, new_min, new_max, new_thresho def compute_collection_result(self): if not self.histogram_dict or len(self.histogram_dict) == 0: raise ValueError("Histogram has not been collected. Please run collect() first.") - print(f"Finding optimal threshold for each tensor using {self.method} algorithm ...") + print(f"Finding optimal threshold for each tensor using {self.method!r} algorithm ...") if self.method == "entropy": return self.compute_entropy() @@ -938,7 +983,14 @@ def compute_distribution(self): assert avg_coef.dtype != np.float64 assert std_coef.dtype != np.float64 assert hist_edges.dtype != np.float64 - thresholds_dict[tensor] = TensorData(avg=avg_coef, std=std_coef, hist=hist, hist_edges=hist_edges) + thresholds_dict[tensor] = TensorData( + avg=avg_coef, + std=std_coef, + hist=hist, + hist_edges=hist_edges, + lowest=hist_edges.min(), + highest=hist_edges.max(), + ) # Plot histogram for debug only if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"): @@ -952,18 +1004,15 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): `q` is a truncated version of the original distribution. Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf """ - import copy - - from scipy.stats import entropy - hist = histogram[0] hist_edges = histogram[1] num_bins = hist.size zero_bin_index = num_bins // 2 num_half_quantized_bin = num_quantized_bins // 2 + dtype = histogram[1].dtype kl_divergence = np.zeros(zero_bin_index - num_half_quantized_bin + 1) - thresholds = [(0, 0) for i in range(kl_divergence.size)] + thresholds = [(np.array(0, dtype=dtype), np.array(0, dtype=dtype)) for i in range(kl_divergence.size)] # <------------ num bins ----------------> # <--- quantized bins ----> @@ -983,10 +1032,7 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): start_index = zero_bin_index - i end_index = zero_bin_index + i + 1 if (zero_bin_index + i + 1) <= num_bins else num_bins - thresholds[i - num_half_quantized_bin] = ( - float(hist_edges[start_index]), - float(hist_edges[end_index]), - ) + thresholds[i - num_half_quantized_bin] = (hist_edges[start_index], hist_edges[end_index]) sliced_distribution = copy.deepcopy(hist[start_index:end_index]) @@ -1020,15 +1066,15 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): norm = sum(nonzeros[start:end]) if norm != 0: - q[start:end] = float(quantized_bins[index]) / float(norm) + q[start:end] = quantized_bins[index] / norm p = smooth_distribution(p) q = smooth_distribution(q) - - if isinstance(q, np.ndarray): - kl_divergence[i - num_half_quantized_bin] = entropy(p, q) + if p is None or q is None: + div = np.array(np.inf, dtype=dtype) else: - kl_divergence[i - num_half_quantized_bin] = float("inf") + div = np.array(entropy(p, q), dtype=dtype) + kl_divergence[i - num_half_quantized_bin] = div min_kl_divergence_idx = np.argmin(kl_divergence) optimal_threshold = thresholds[min_kl_divergence_idx] @@ -1038,6 +1084,8 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): optimal_threshold = (min_value, optimal_threshold[1]) if optimal_threshold[1] > max_value: optimal_threshold = (optimal_threshold[0], max_value) + assert hasattr(optimal_threshold[0], "dtype") + assert hasattr(optimal_threshold[1], "dtype") return optimal_threshold diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 68c2b3bf79c8b..036f49b420734 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -653,7 +653,7 @@ def smooth_distribution(p, eps=0.0001): if not n_nonzeros: # raise ValueError('The discrete probability distribution is malformed. All entries are 0.') - return -1 + return None eps1 = eps * float(n_zeros) / float(n_nonzeros) assert eps1 < 1.0, "n_zeros=%d, n_nonzeros=%d, eps1=%f" % ( n_zeros, diff --git a/onnxruntime/test/python/quantization/test_op_matmul.py b/onnxruntime/test/python/quantization/test_op_matmul.py index 344583aa7c624..91368bd643158 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul.py +++ b/onnxruntime/test/python/quantization/test_op_matmul.py @@ -10,13 +10,39 @@ import numpy as np import onnx import packaging.version as pv +from numpy.testing import assert_almost_equal from onnx import TensorProto, helper from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type +from onnxruntime.capi.onnxruntime_pybind11_state import Fail from onnxruntime.quantization import CalibrationMethod, QuantFormat, QuantType, quantize_dynamic, quantize_static +from onnxruntime.quantization.calibrate import entropy + + +def skip_if_new_opset_exception_raised(func): + def wrapper(*args, **kwargs): + try: + func(*args, **kwargs) + except Fail as e: + if "is under development and support for this is limited" in str(e): + raise unittest.SkipTest(f"Skipped {func} due to opset under development.") # noqa: B904 + raise + + return wrapper class TestOpMatMul(unittest.TestCase): + def test_entropy(self): + try: + from scipy.stats import entropy as scipy_entropy + except ImportError: + raise unittest.SkipTest("scipy not installed.") # noqa: B904 + pk = (np.arange(10) - 5).astype(np.float32) / 10 + qk = -(np.arange(10) - 5).astype(np.float32) / 10 + ent = scipy_entropy(pk, qk) + get = entropy(pk, qk) + assert_almost_equal(ent, get) + def input_feeds(self, n, name2shape, dtype): input_data_list = [] for _i in range(n): @@ -324,10 +350,11 @@ def test_quantize_matmul_u8u8(self): @unittest.skipIf( pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709" ) + @skip_if_new_opset_exception_raised def test_quantize_matmul_u8u8_f16(self): - self.quantize_matmul_u8u8(onnx.TensorProto.FLOAT16, 19, 9) + self.quantize_matmul_u8u8(onnx.TensorProto.FLOAT16, 21, 9) - def quantize_matmul_s8s8(self, tt, opset, ir_version): + def quantize_matmul_s8s8(self, tt, opset, ir_version, calibrate_method=CalibrationMethod.MinMax): np.random.seed(1) model_fp_path = "matmul_fp.onnx" self.construct_model_matmul(model_fp_path, tensor_type=tt, opset=opset, ir_version=ir_version) @@ -341,6 +368,7 @@ def quantize_matmul_s8s8(self, tt, opset, ir_version): activation_type=QuantType.QInt8, weight_type=QuantType.QInt8, extra_options={"ActivationSymmetric": True}, + calibrate_method=calibrate_method, ) self.static_quant_test_qdq( model_fp_path, @@ -348,6 +376,7 @@ def quantize_matmul_s8s8(self, tt, opset, ir_version): activation_type=QuantType.QInt8, weight_type=QuantType.QInt8, extra_options={"ActivationSymmetric": True}, + calibrate_method=calibrate_method, ) # dynamic quantization doesn't support activation:int8 @@ -357,11 +386,42 @@ def quantize_matmul_s8s8(self, tt, opset, ir_version): def test_quantize_matmul_s8s8(self): self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8) + def test_quantize_matmul_s8s8_entropy(self): + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8, calibrate_method=CalibrationMethod.Entropy) + + def test_quantize_matmul_s8s8_percentile(self): + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8, calibrate_method=CalibrationMethod.Percentile) + + def test_quantize_matmul_s8s8_distribution(self): + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8, calibrate_method=CalibrationMethod.Distribution) + @unittest.skipIf( pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709" ) + @skip_if_new_opset_exception_raised def test_quantize_matmul_s8s8_f16(self): - self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 19, 9) + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9) + + @unittest.skipIf( + pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709" + ) + @skip_if_new_opset_exception_raised + def test_quantize_matmul_s8s8_f16_entropy(self): + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Entropy) + + @unittest.skipIf( + pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709" + ) + @skip_if_new_opset_exception_raised + def test_quantize_matmul_s8s8_f16_percentile(self): + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Percentile) + + @unittest.skipIf( + pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709" + ) + @skip_if_new_opset_exception_raised + def test_quantize_matmul_s8s8_f16_distribution(self): + self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Distribution) def quantize_matmul_e4m3fn_same(self, tt, opset, ir_version): np.random.seed(1)