Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantization tool: support float 8 with MatMul, support float 16 weights #18043

Merged
merged 47 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
f09ad3c
Update quantization tools to support MatMul with float 8
xadupre Oct 20, 2023
bd47221
support float16
xadupre Oct 20, 2023
23c9e39
more consistent with types
xadupre Oct 23, 2023
7e7d0fe
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Oct 24, 2023
c073b88
fix types
xadupre Oct 24, 2023
35db149
fix many unit tests
xadupre Oct 24, 2023
799f7c5
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Oct 25, 2023
3a14a31
fix conversion, rounding
xadupre Oct 25, 2023
8199196
new fixes
xadupre Oct 25, 2023
87e2ba1
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Oct 26, 2023
6003524
fix softmax qdq
xadupre Oct 26, 2023
aa11d25
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Oct 26, 2023
0e96668
fix shape info
xadupre Oct 26, 2023
170e5c9
update test
xadupre Oct 26, 2023
47b41b6
fix remaining unit tests
xadupre Oct 26, 2023
b15ffcd
add value_info
xadupre Oct 26, 2023
0163699
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Oct 27, 2023
34b7a38
add subtest
xadupre Oct 27, 2023
8711792
refactoring onnxruntime/test/python/quantization/test_op_matmul.py
xadupre Oct 27, 2023
4806a04
disable f16 for old onnx package
xadupre Oct 27, 2023
260dd59
disable f16 unit tests
xadupre Oct 27, 2023
d376f66
support for Conv and float 16
xadupre Oct 27, 2023
76d9284
extend unit test for Conv
xadupre Oct 27, 2023
d2f9294
fix lint
xadupre Oct 27, 2023
6702b81
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Oct 30, 2023
a6433e8
change the disable condition
xadupre Oct 30, 2023
301c435
lint
xadupre Oct 30, 2023
1e70b3e
final fix
xadupre Oct 30, 2023
9049bcd
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Nov 6, 2023
e4f0415
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Nov 7, 2023
47efb3e
fir merge conflicts
xadupre Nov 20, 2023
fcf40b4
fix merge conflict
xadupre Nov 20, 2023
bf44eba
fix missing min_real_range
xadupre Nov 20, 2023
81528cc
merge conflicts
xadupre Dec 21, 2023
63b84d1
fix constant
xadupre Dec 21, 2023
a90f9f5
fix missing dtype
xadupre Dec 21, 2023
0164a38
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Dec 22, 2023
e6c39f4
use np arrays
xadupre Dec 22, 2023
1c8ae86
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Dec 25, 2023
1185de0
improve robustness
xadupre Dec 25, 2023
aee75ff
fix type issue
xadupre Dec 25, 2023
63a8ea9
fix wrong types
xadupre Dec 25, 2023
1948c4a
fix one bug
xadupre Dec 26, 2023
67bab54
fix dtype issue
xadupre Dec 26, 2023
fe1d0fe
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Jan 3, 2024
1d49bc5
better error message
xadupre Jan 10, 2024
fc406f9
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
xadupre Jan 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 58 additions & 30 deletions onnxruntime/python/tools/quantization/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@

class TensorData:
_allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"])
_floats = frozenset(["avg", "std", "lowest", "highest", "hist_edges"])

def __init__(self, **kwargs):
for k, v in kwargs.items():
if k not in TensorData._allowed:
raise ValueError(f"Unexpected value {k!r} not in {TensorData._allowed}.")
if k in TensorData._floats:
if not hasattr(v, "dtype"):
raise ValueError(f"Unexpected type {type(v)} for k={k!r}")
if v.dtype not in (np.float16, np.float32):
raise ValueError(f"Unexpected dtype {v.dtype} for k={k!r}")
setattr(self, k, v)

@property
Expand Down Expand Up @@ -171,7 +177,7 @@ def select_tensors_to_calibrate(self, model: ModelProto):
initializer = {init.name for init in model.graph.initializer}

tensors_to_calibrate = set()
tensor_type_to_calibrate = {TensorProto.FLOAT}
tensor_type_to_calibrate = {TensorProto.FLOAT, TensorProto.FLOAT16}

for node in model.graph.node:
if not self.op_types_to_calibrate or node.op_type in self.op_types_to_calibrate:
Expand Down Expand Up @@ -284,7 +290,14 @@ def add_reduce_min_max(tensor_name, reduce_op_name):
)

self.model.graph.node.extend([reduce_node, reshape_node])
self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, TensorProto.FLOAT, [1]))
value_infos = {vi.name: vi for vi in self.model.graph.value_info}
value_infos.update({o.name: o for o in self.model.graph.output})
value_infos.update({i.name: i for i in self.model.graph.input})
if tensor_name in value_infos:
onnx_type = value_infos[tensor_name].type.tensor_type.elem_type
else:
raise ValueError(f"Unable to guess tensor type for tensor {tensor_name!r}")
xadupre marked this conversation as resolved.
Show resolved Hide resolved
self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, onnx_type, [1]))

for tensor in tensors:
add_reduce_min_max(tensor, "ReduceMin")
Expand Down Expand Up @@ -364,24 +377,18 @@ def compute_data(self) -> TensorsData:

pairs = []
for i in range(0, len(added_output_names), 2):
min_value = 0
max_value = 0
if self.moving_average:
min_value_array = np.mean(merged_added_output_dict[added_output_names[i]], axis=0)
max_value_array = np.mean(merged_added_output_dict[added_output_names[i + 1]], axis=0)
else:
min_value_array = min(merged_added_output_dict[added_output_names[i]])
max_value_array = max(merged_added_output_dict[added_output_names[i + 1]])
if isinstance(min_value_array, int) or min_value_array.size > 0:
min_value = float(min_value_array)
if isinstance(max_value_array, int) or max_value_array.size > 0:
max_value = float(max_value_array)
min_value_array = np.min(merged_added_output_dict[added_output_names[i]], axis=0)
max_value_array = np.max(merged_added_output_dict[added_output_names[i + 1]], axis=0)

if self.symmetric:
max_absolute_value = max(abs(min_value), abs(max_value))
max_absolute_value = max(np.abs(min_value_array), np.abs(max_value_array))
pairs.append(tuple([-max_absolute_value, max_absolute_value]))
else:
pairs.append(tuple([min_value, max_value]))
pairs.append(tuple([min_value_array, max_value_array]))

new_calibrate_tensors_range = TensorsData(CalibrationMethod.MinMax, dict(zip(calibrate_tensor_names, pairs)))
if self.calibrate_tensors_range:
Expand Down Expand Up @@ -679,36 +686,53 @@ def collect_absolute_value(self, name_to_arr):
Collect histogram on absolute value
"""
for tensor, data_arr in name_to_arr.items():
data_arr = np.asarray(data_arr) # noqa: PLW2901
data_arr = data_arr.flatten() # noqa: PLW2901
if data_arr.size > 0:
min_value = np.min(data_arr)
max_value = np.max(data_arr)
if isinstance(data_arr, list):
for arr in data_arr:
if not isinstance(arr, np.ndarray):
raise ValueError(f"Unexpected type {type(arr)} for tensor={tensor!r}")
dtypes = set(a.dtype for a in arr)
if len(dtypes) != 1:
raise ValueError(
f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
)
data_arr_np = np.asarray(data_arr)
elif not isinstance(data_arr, np.ndarray):
raise ValueError(f"Unexpected type {type(data_arr)} for tensor={tensor!r}")
else:
data_arr_np = data_arr
data_arr_np = data_arr_np.flatten()
if data_arr_np.size > 0:
min_value = np.min(data_arr_np)
max_value = np.max(data_arr_np)
else:
min_value = 0
max_value = 0

data_arr = np.absolute(data_arr) # only consider absolute value # noqa: PLW2901
data_arr_np = np.absolute(data_arr_np) # only consider absolute value

if tensor not in self.histogram_dict:
# first time it uses num_bins to compute histogram.
hist, hist_edges = np.histogram(data_arr, bins=self.num_bins)
hist, hist_edges = np.histogram(data_arr_np, bins=self.num_bins)
hist_edges = hist_edges.astype(data_arr_np.dtype)
assert data_arr_np.dtype != np.float64
xadupre marked this conversation as resolved.
Show resolved Hide resolved
self.histogram_dict[tensor] = (hist, hist_edges, min_value, max_value)
else:
old_histogram = self.histogram_dict[tensor]
old_min = old_histogram[2]
old_max = old_histogram[3]
old_hist = old_histogram[0]
old_hist_edges = old_histogram[1]
temp_amax = np.max(data_arr)
temp_amax = np.max(data_arr_np)
if temp_amax > old_hist_edges[-1]:
# increase the number of bins
width = old_hist_edges[1] - old_hist_edges[0]
# NOTE: np.arange may create an extra bin after the one containing temp_amax
new_bin_edges = np.arange(old_hist_edges[-1] + width, temp_amax + width, width)
old_hist_edges = np.hstack((old_hist_edges, new_bin_edges))
hist, hist_edges = np.histogram(data_arr, bins=old_hist_edges)
hist, hist_edges = np.histogram(data_arr_np, bins=old_hist_edges)
hist_edges = hist_edges.astype(data_arr_np.dtype)
hist[: len(old_hist)] += old_hist
assert data_arr_np.dtype != np.float64
self.histogram_dict[tensor] = (hist, hist_edges, min(old_min, min_value), max(old_max, max_value))

def collect_value(self, name_to_arr):
Expand All @@ -723,8 +747,8 @@ def collect_value(self, name_to_arr):
min_value = np.min(data_arr)
max_value = np.max(data_arr)
else:
min_value = 0
max_value = 0
min_value = np.array(0, dtype=data_arr.dtype)
max_value = np.array(0, dtype=data_arr.dtype)

threshold = max(abs(min_value), abs(max_value))

Expand Down Expand Up @@ -811,16 +835,16 @@ def compute_percentile(self):
idx_right = np.searchsorted(cdf, percentile / 100.0)

thresholds_dict[tensor] = (
-float(hist_edges[idx_right]),
float(hist_edges[idx_right]),
-np.array(hist_edges[idx_right], dtype=hist_edges.dtype),
np.array(hist_edges[idx_right], dtype=hist_edges.dtype),
)
else:
percent_to_cut_one_side = (100.0 - percentile) / 200.0
idx_right = np.searchsorted(cdf, 1.0 - percent_to_cut_one_side)
idx_left = np.searchsorted(cdf, percent_to_cut_one_side)
thresholds_dict[tensor] = (
float(hist_edges[idx_left]),
float(hist_edges[idx_right]),
np.array(hist_edges[idx_left], dtype=hist_edges.dtype),
np.array(hist_edges[idx_right], dtype=hist_edges.dtype),
)
min_value = histogram[2]
max_value = histogram[3]
Expand Down Expand Up @@ -868,19 +892,19 @@ def _avg_std(hist, hist_edges, power=1):
if power == 1:
avg = (hist * values).sum() / hist.sum()
std = ((hist * values**2).sum() / hist.sum() - avg**2) ** 0.5
return avg, std
return np.array(avg, dtype=hist_edges.dtype), np.array(std, dtype=hist_edges.dtype)
if int(power) == power and int(power) % 2 == 1:
avg = (hist * values**power).sum() / hist.sum()
std = ((hist * (values**power - avg) ** 2).sum() / hist.sum()) ** 0.5
return avg, std
return np.array(avg, dtype=hist_edges.dtype), np.array(std, dtype=hist_edges.dtype)

fact = np.abs(values) / values
fact[np.isnan(fact)] = 1
fact[np.isinf(fact)] = 1
values = np.abs(values) ** power * fact
avg = (hist * values).sum() / hist.sum()
std = ((hist * values**2).sum() / hist.sum() - avg**2) ** 0.5
return avg, std
return np.array(avg, dtype=hist_edges.dtype), np.array(std, dtype=hist_edges.dtype)

def compute_distribution(self):
if self.num_bins < 512:
Expand All @@ -897,12 +921,16 @@ def compute_distribution(self):
hist = histogram[0]
hist_edges = histogram[1]

assert hist_edges.dtype != np.float64
if self.scenario == "same":
avg_coef, std_coef = self._avg_std(hist, hist_edges, power=1)
elif self.scenario == "p3":
avg_coef, std_coef = self._avg_std(hist, hist_edges, power=1.0 / 3.0)
else:
raise ValueError("Invalid scenario. Must be in {'same', 'p3'}.")
assert avg_coef.dtype != np.float64
assert std_coef.dtype != np.float64
assert hist_edges.dtype != np.float64
thresholds_dict[tensor] = TensorData(avg=avg_coef, std=std_coef, hist=hist, hist_edges=hist_edges)

# Plot histogram for debug only
Expand Down
Loading
Loading