Skip to content

Commit

Permalink
Make some parameters configurable for calibration (#10204)
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms authored Jan 10, 2022
1 parent 32ee379 commit be9cc40
Showing 1 changed file with 37 additions and 13 deletions.
50 changes: 37 additions & 13 deletions onnxruntime/python/tools/quantization/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,9 @@ def __init__(self,
op_types_to_calibrate=[],
augmented_model_path='augmented_model.onnx',
method='percentile',
num_quantized_bins=128,
percentile=99.99):
num_bins=128,
num_quantized_bins=2048,
percentile=99.999):
'''
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
Expand All @@ -292,6 +293,7 @@ def __init__(self,
self.model_original_outputs = set(output.name for output in self.model.graph.output)
self.collector = None
self.method = method
self.num_bins = num_bins
self.num_quantized_bins = num_quantized_bins
self.percentile = percentile

Expand Down Expand Up @@ -347,6 +349,7 @@ def collect_data(self, data_reader: CalibrationDataReader):

if not self.collector:
self.collector = HistogramCollector(method=self.method,
num_bins=self.num_bins,
num_quantized_bins=self.num_quantized_bins,
percentile=self.percentile)
self.collector.collect(clean_merged_dict)
Expand All @@ -369,24 +372,26 @@ def __init__(self,
op_types_to_calibrate=[],
augmented_model_path='augmented_model.onnx',
method='entropy',
num_bins=128,
num_quantized_bins=128):
'''
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
:param augmented_model_path: save augmented model to this path.
:param method: A string. One of ['entropy', 'percentile'].
:param num_bins: number of bins to create a new histogram for collecting tensor values.
:param num_quantized_bins: number of quantized bins. Default 128.
'''
super(EntropyCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path,
method=method, num_quantized_bins=num_quantized_bins)
method=method, num_bins=num_bins, num_quantized_bins=num_quantized_bins)

class PercentileCalibrater(HistogramCalibrater):
def __init__(self,
model,
op_types_to_calibrate=[],
augmented_model_path='augmented_model.onnx',
method='percentile',
num_quantized_bins=2048,
num_bins=2048,
percentile=99.999):
'''
:param model: ONNX model to calibrate. It can be a ModelProto or a model path
Expand All @@ -397,7 +402,7 @@ def __init__(self,
:param percentile: A float number between [0, 100]. Default 99.99.
'''
super(PercentileCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path,
method=method, num_quantized_bins=num_quantized_bins,
method=method, num_bins=num_bins,
percentile=percentile)

class CalibrationDataCollector(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -429,17 +434,20 @@ class HistogramCollector(CalibrationDataCollector):
ref: https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/_modules/
pytorch_quantization/calib/histogram.html
"""
def __init__(self, method, num_quantized_bins, percentile):
def __init__(self, method, num_bins, num_quantized_bins, percentile):
self.histogram_dict = {}
self.method = method
self.num_bins = num_bins
self.num_quantized_bins= num_quantized_bins
self.percentile = percentile

def get_histogram_dict(self):
return self.histogram_dict

def collect(self, name_to_arr):
# TODO: Currently we have different collect() for percentile and percentile method respectively.
print("Collecting tensor data and making histogram ...")

# TODO: Currently we have different collect() for entropy and percentile method respectively.
# Need unified collect in the future.
if self.method == 'entropy':
return self.collect_for_entropy(name_to_arr)
Expand All @@ -455,8 +463,8 @@ def collect_for_percentile(self, name_to_arr):
data_arr = np.absolute(data_arr) # only consider absolute value

if tensor not in self.histogram_dict:
# first time it uses num_quantized_bins to compute histogram.
hist, hist_edges = np.histogram(data_arr, bins=self.num_quantized_bins)
# first time it uses num_bins to compute histogram.
hist, hist_edges = np.histogram(data_arr, bins=self.num_bins)
self.histogram_dict[tensor] = (hist, hist_edges)
else:
old_histogram = self.histogram_dict[tensor]
Expand Down Expand Up @@ -491,7 +499,7 @@ def collect_for_entropy(self, name_to_arr):
old_histogram = self.histogram_dict[tensor]
self.histogram_dict[tensor] = self.merge_histogram(old_histogram, data_arr, min_value, max_value, threshold)
else:
hist, hist_edges = np.histogram(data_arr, self.num_quantized_bins, range=(-threshold, threshold))
hist, hist_edges = np.histogram(data_arr, self.num_bins, range=(-threshold, threshold))
self.histogram_dict[tensor] = (hist, hist_edges, min_value, max_value, threshold)

def merge_histogram(self, old_histogram, data_arr, new_min, new_max, new_threshold):
Expand All @@ -518,6 +526,7 @@ def merge_histogram(self, old_histogram, data_arr, new_min, new_max, new_thresho
def compute_collection_result(self):
if not self.histogram_dict or len(self.histogram_dict) == 0:
raise ValueError("Histogram has not been collected. Please run collect() first.")
print("Finding optimal threshold for each tensor using {} algorithm ...".format(self.method))

if self.method == 'entropy':
return self.compute_entropy()
Expand All @@ -535,6 +544,10 @@ def compute_percentile(self):

thresholds_dict = {} # per tensor thresholds

print("Number of tensors : {}".format(len(histogram_dict)))
print("Number of histogram bins : {}".format(self.num_bins))
print("Percentile : {}".format(percentile))

for tensor, histogram in histogram_dict.items():
hist = histogram[0]
hist_edges = histogram[1]
Expand All @@ -551,6 +564,10 @@ def compute_entropy(self):

thresholds_dict = {} # per tensor thresholds

print("Number of tensors : {}".format(len(histogram_dict)))
print("Number of histogram bins : {} (The number may increase depends on the data it collects)".format(self.num_bins))
print("Number of quantized bins : {}".format(self.num_quantized_bins))

for tensor, histogram in histogram_dict.items():
optimal_threshold = self.get_entropy_threshold(histogram, num_quantized_bins)
thresholds_dict[tensor] = optimal_threshold
Expand Down Expand Up @@ -631,12 +648,19 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
def create_calibrator(model,
op_types_to_calibrate=[],
augmented_model_path='augmented_model.onnx',
calibrate_method=CalibrationMethod.MinMax):
calibrate_method=CalibrationMethod.MinMax,
extra_options={}):
if calibrate_method == CalibrationMethod.MinMax:
return MinMaxCalibrater(model, op_types_to_calibrate, augmented_model_path)
elif calibrate_method == CalibrationMethod.Entropy:
return EntropyCalibrater(model, op_types_to_calibrate, augmented_model_path)
# default settings for entropy algorithm
num_bins = 128 if 'num_bins' not in extra_options else extra_options['num_bins']
num_quantized_bins = 128 if 'num_quantized_bins' not in extra_options else extra_options['num_quantized_bins']
return EntropyCalibrater(model, op_types_to_calibrate, augmented_model_path, num_bins=num_bins, num_quantized_bins=num_quantized_bins)
elif calibrate_method == CalibrationMethod.Percentile:
return PercentileCalibrater(model, op_types_to_calibrate, augmented_model_path)
# default settings for percentile algorithm
num_bins = 2048 if 'num_bins' not in extra_options else extra_options['num_bins']
percentile = 99.999 if 'percentile' not in extra_options else extra_options['percentile']
return PercentileCalibrater(model, op_types_to_calibrate, augmented_model_path, num_bins=num_bins, percentile=percentile)

raise ValueError('Unsupported calibration method {}'.format(calibrate_method))

0 comments on commit be9cc40

Please sign in to comment.