Skip to content

Commit

Permalink
fix example
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
  • Loading branch information
mengniwang95 committed Jun 27, 2024
1 parent ecbb392 commit 9d3b082
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def eval_func(model, dataloader, metric):
parser.add_argument("--tune", action="store_true", default=False, help="whether quantize the model")
parser.add_argument("--output_model", type=str, help="output model path")
parser.add_argument("--mode", type=str, help="benchmark mode of performance or accuracy")
parser.add_argument(
"--intra_op_num_threads", type=int, default=4, help="intra_op_num_threads for performance benchmark")
parser.add_argument(
"--quant_format", type=str, default="QOperator", choices=["QDQ", "QOperator"], help="quantization format"
)
Expand All @@ -213,7 +215,6 @@ def eval_func(model, dataloader, metric):
)
args = parser.parse_args()

model = onnx.load(args.model_path)
top1 = TopK()
dataloader = DataReader(args.model_path, args.dataset_location, args.label_path, args.batch_size)

Expand All @@ -230,7 +231,7 @@ def eval(onnx_model):
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = args.intra_op_num_threads
session = ort.InferenceSession(
model.SerializeToString(), sess_options, providers=ort.get_available_providers()
args.model_path, sess_options, providers=ort.get_available_providers()
)
ort_inputs = {}
len_inputs = len(session.get_inputs())
Expand All @@ -250,7 +251,7 @@ def eval(onnx_model):
throughput = (num_iter - num_warmup) / total_time
print("Throughput: {} samples/s".format(throughput))
elif args.mode == "accuracy":
acc_result = eval_func(model, dataloader, top1)
acc_result = eval_func(args.model_path, dataloader, top1)
print("Batch size = %d" % dataloader.batch_size)
print("Accuracy: %.5f" % acc_result)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ function init_params {
--mode=*)
mode=$(echo "$var" |cut -f2 -d=)
;;
--intra_op_num_threads=*)
intra_op_num_threads=$(echo "$var" |cut -f2 -d=)
;;
esac
done

Expand All @@ -36,8 +39,9 @@ function run_benchmark {
--model_path "${input_model}" \
--dataset_location "${dataset_location}" \
--label_path "${label_path-${dataset_location}/../val.txt}" \
--mode="${mode}" \
--mode "${mode}" \
--batch_size 1 \
--intra_op_num_threads "${intra_op_num_threads-4}" \
--benchmark

}
Expand Down
11 changes: 5 additions & 6 deletions examples/nlp/bert/quantization/ptq_dynamic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def eval_func(model):
for idx, batch in enumerate(dataloader):
label = batch[-1]
batch = tuple(t.detach().cpu().numpy() if not isinstance(t, np.ndarray) else t for t in batch[0])
batch_seq_length = args.max_seq_length if not args.dynamic_length else torch.max(batch[-2], 0)[0].item()
batch_seq_length = args.max_seq_length if not args.dynamic_length else batch[0].shape[-1]
data = [
batch[0][:, :batch_seq_length],
batch[1][:, :batch_seq_length],
Expand All @@ -369,7 +369,6 @@ def eval_func(model):
return metric.result()

if args.benchmark:
model = onnx.load(args.model_path)
if args.mode == "performance":
total_time = 0.0
num_iter = 100
Expand All @@ -378,7 +377,7 @@ def eval_func(model):
sess_options = onnxruntime.SessionOptions()
sess_options.intra_op_num_threads = args.intra_op_num_threads
session = onnxruntime.InferenceSession(
model.SerializeToString(), sess_options, providers=onnxruntime.get_available_providers()
args.model_path, sess_options, providers=onnxruntime.get_available_providers()
)
ort_inputs = {}
len_inputs = len(session.get_inputs())
Expand All @@ -388,8 +387,8 @@ def eval_func(model):
if idx + 1 > num_iter:
break

batch = tuple(t.detach().cpu().numpy() if not isinstance(t, np.ndarray) else t for t in batch)
batch_seq_length = args.max_seq_length if not args.dynamic_length else torch.max(batch[-2], 0)[0].item()
batch = tuple(t.detach().cpu().numpy() if not isinstance(t, np.ndarray) else t for t in batch[0])
batch_seq_length = args.max_seq_length if not args.dynamic_length else batch[0].shape[-1]
data = [
batch[0][:, :batch_seq_length],
batch[1][:, :batch_seq_length],
Expand All @@ -408,7 +407,7 @@ def eval_func(model):
throughput = (num_iter - num_warmup) / total_time
print("Throughput: {} samples/s".format(throughput))
elif args.mode == "accuracy":
acc_result = eval_func(model)
acc_result = eval_func(args.model_path)
print("Batch size = %d" % args.batch_size)
print("Accuracy: %.5f" % acc_result)

Expand Down
4 changes: 4 additions & 0 deletions examples/nlp/bert/quantization/ptq_dynamic/run_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ function init_params {
--batch_size=*)
batch_size=$(echo "$var" |cut -f2 -d=)
;;
--intra_op_num_threads=*)
intra_op_num_threads=$(echo "$var" |cut -f2 -d=)
;;
esac
done

Expand Down Expand Up @@ -52,6 +55,7 @@ function run_benchmark {
--batch_size "${batch_size}" \
--mode "${mode}" \
--dynamic_length "${dynamic_length}" \
--intra_op_num_threads "${intra_op_num_threads-4}" \
--benchmark

}
Expand Down
17 changes: 10 additions & 7 deletions examples/nlp/bert/quantization/ptq_static/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
parser.add_argument(
"--quant_format", type=str, default="QOperator", choices=["QDQ", "QOperator"], help="quantization format"
)
parser.add_argument(
"--intra_op_num_threads", type=int, default=4, help="intra_op_num_threads for performance benchmark"
)
parser.add_argument("--dynamic_length", type=bool, default=False, help="dynamic length")
parser.add_argument("--max_seq_length", type=int, default=128, help="max sequence length")
parser.add_argument(
Expand Down Expand Up @@ -407,7 +410,7 @@ def eval_func(model):
for idx, batch in enumerate(dataloader):
label = batch[-1]
batch = tuple(t.detach().cpu().numpy() if not isinstance(t, np.ndarray) else t for t in batch[0])
batch_seq_length = args.max_seq_length if not args.dynamic_length else torch.max(batch[-2], 0)[0].item()
batch_seq_length = args.max_seq_length if not args.dynamic_length else batch[0].shape[-1]
inputs = [
batch[0][:, :batch_seq_length],
batch[1][:, :batch_seq_length],
Expand All @@ -420,7 +423,6 @@ def eval_func(model):
return metric.result()

if args.benchmark:
model = onnx.load(args.model_path)
if args.mode == "performance":
total_time = 0.0
num_iter = 100
Expand All @@ -429,7 +431,7 @@ def eval_func(model):
sess_options = onnxruntime.SessionOptions()
sess_options.intra_op_num_threads = args.intra_op_num_threads
session = onnxruntime.InferenceSession(
model.SerializeToString(), sess_options, providers=onnxruntime.get_available_providers()
args.model_path, sess_options, providers=onnxruntime.get_available_providers()
)
ort_inputs = {}
len_inputs = len(session.get_inputs())
Expand All @@ -438,8 +440,8 @@ def eval_func(model):
for idx, batch in enumerate(dataloader):
if idx + 1 > num_iter:
break
batch = tuple(t.detach().cpu().numpy() if not isinstance(t, np.ndarray) else t for t in batch)
batch_seq_length = args.max_seq_length if not args.dynamic_length else torch.max(batch[-2], 0)[0].item()
batch = tuple(t.detach().cpu().numpy() if not isinstance(t, np.ndarray) else t for t in batch[0])
batch_seq_length = args.max_seq_length if not args.dynamic_length else batch[0].shape[-1]
inputs = [
batch[0][:, :batch_seq_length],
batch[1][:, :batch_seq_length],
Expand All @@ -458,7 +460,7 @@ def eval_func(model):
throughput = (num_iter - num_warmup) / total_time
print("Throughput: {} samples/s".format(throughput))
elif args.mode == "accuracy":
acc_result = eval_func(model)
acc_result = eval_func(args.model_path)
print("Batch size = %d" % args.batch_size)
print("Accuracy: %.5f" % acc_result)

Expand Down Expand Up @@ -492,7 +494,8 @@ def eval_func(model):
else quantization.QuantFormat.QDQ
),
calibration_sampling_size=8,
extra_options={"optypes_to_exclude_output_quant": ["MatMul", "Gemm", "Attention", "FusedGemm"]},
op_types_to_quantize=["MatMul"],
extra_options={"OpTypesToExcludeOutputQuantization": ["MatMul"]},
execution_provider=provider,
)
)
Expand Down
4 changes: 4 additions & 0 deletions examples/nlp/bert/quantization/ptq_static/run_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ function init_params {
--batch_size=*)
batch_size=$(echo "$var" |cut -f2 -d=)
;;
--intra_op_num_threads=*)
intra_op_num_threads=$(echo "$var" |cut -f2 -d=)
;;
esac
done

Expand All @@ -51,6 +54,7 @@ function run_benchmark {
--task "${task_name}" \
--batch_size "${batch_size}" \
--mode "${mode}" \
--intra_op_num_threads "${intra_op_num_threads-4}" \
--dynamic_length "${dynamic_length}" \
--benchmark

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, onnx_quantizer, onnx_node):
if node_name in self.quantizer.config:
self.dtype = self.quantizer.config[node_name]
self.disable_qdq_for_node_output = (
True if onnx_node.op_type in onnx_quantizer.op_types_to_exclude_output_quantization else False
True if onnx_node.op_type in onnx_quantizer.optypes_to_exclude_output_quant else False
)
self.per_channel = False
self.calibrate_method = 0 # minmax
Expand Down
23 changes: 0 additions & 23 deletions onnx_neural_compressor/algorithms/post_training_quant/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,29 +260,6 @@ def merge_dedicated_qdq_pair(self):
self.model.replace_node_input(node, old_input_name, new_input_name)
self.model.update()

if self.quant_format == "qdq":
# node node
# / | \ |
# A q B -> q
# | |
# dq dq
# / \
# A B
for node in self.model.nodes():
if node.op_type in ["QuantizeLinear"] and len(self.model.get_parents(node)) > 0:
if "QuantizeLinear" in [sibling.op_type for sibling in self.model.get_siblings(node)]:
continue
for sibling in self.model.get_siblings(node):
if not self.should_quantize(sibling) and sibling.op_type in base_op.OPERATORS[self.mode]:
for inp_idx in range(len(sibling.input)):
if sibling.input[inp_idx] == node.input[0]:
self.replace_input.append(
[sibling, sibling.input[inp_idx], self.model.get_children(node)[0].output[0]]
)
for node, old_input_name, new_input_name in self.replace_input:
self.model.replace_node_input(node, old_input_name, new_input_name)
self.model.update()

def remove_duplicate_qdq_paris(self):
"""Remove duplicated qdq pairs."""
self.remove_nodes = []
Expand Down
4 changes: 3 additions & 1 deletion onnx_neural_compressor/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def static_quantize_entry(
quantization_params=quantize_params,
op_types_to_quantize=quant_config.op_types_to_quantize,
execution_provider=quant_config.execution_provider,
optypes_to_exclude_output_quant=quant_config.extra_options.get("optypes_to_exclude_output_quant", []),
optypes_to_exclude_output_quant=quant_config.optypes_to_exclude_output_quant,
dedicated_qdq_pair=quant_config.dedicated_qdq_pair,
add_qdq_pair_to_weight=quant_config.add_qdq_pair_to_weight,
)
_quantizer.quantize_model()
if model_output is not None:
Expand Down
9 changes: 6 additions & 3 deletions onnx_neural_compressor/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,9 +1599,9 @@ def __init__(
os.environ["ORT_TENSORRT_INT8_ENABLE"] = "0"
self.extra_options.update(
{
"add_qdq_pair_to_weight": True,
"dedicated_qdq_pair": True,
"optypes_to_exclude_output_quant": ["Conv", "Gemm", "Add", "MatMul"],
"AddQDQPairToWeight": True,
"DedicatedQDQPair": True,
"OpTypesToExcludeOutputQuantization": ["Conv", "Gemm", "Add", "MatMul"],
}
)
else:
Expand All @@ -1614,6 +1614,9 @@ def __init__(
_extra_options = ExtraOptions(**self.extra_options)
self.weight_sym = _extra_options.WeightSymmetric
self.activation_sym = _extra_options.ActivationSymmetric
self.optypes_to_exclude_output_quant = _extra_options.OpTypesToExcludeOutputQuantization
self.dedicated_qdq_pair = _extra_options.DedicatedQDQPair
self.add_qdq_pair_to_weight = _extra_options.AddQDQPairToWeight
self.white_list = white_list
self._post_init()

Expand Down

0 comments on commit 9d3b082

Please sign in to comment.