From cf78d01546ca059a2ab487e01626e38029a3e8fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Jan 2024 16:36:50 +0100 Subject: [PATCH 1/9] remove use of ai.onnx.ml in test for custom ops and local functions (#19043) ### Description QNN_Nuget_Windows does not allow ai.onnx.ml operators but the test test_custom_op_local_function is using LabelEncoder. The operator can be removed as the test is only checking custom ops api. ### Motivation and Context Fix test test_custom_op_local_function in QNN_Nuget_Windows pipeline. --- .../custom_ops_type_inference_fails_0.onnx | Bin 2086 -> 1977 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/onnxruntime/test/testdata/custom_op_local_function/custom_ops_type_inference_fails_0.onnx b/onnxruntime/test/testdata/custom_op_local_function/custom_ops_type_inference_fails_0.onnx index 8116ec338064567cea06fafe45168567813071ed..3a43a7378a9123b1f51a30cc2270d2edaa1c0f76 100644 GIT binary patch delta 512 zcmZ1`u#=yigVSmo3zzdmLH&u|pBcR-+cQe&8*;IhCFYc-YN>KC3$QyuI9?J=3=n`$ zPcCPiE-cK&7$L;R#lyiU#KFZR!~n#TWtmc0PB9*1oji@nbnC(6Nyu#e+E6il{aR$x?~9L+4r#mU7~k}n~+fRSl(Bl8apbs?7Gl9V(h z4(I&5;*!L?5-m|KZm`qi({l0?OGNu*PdF`LWSpGATFRu#HCdWb$xwjNON>h>CqFqc zN3XCHs7Eg)KQ}QmPk=E9i|8rF*OQsqbn1?WfqiEv#FT4DSPRUb ipa21BF~Za$k*b9nHvgcSHt%5zV`TP%_)U*AzX1T^6p=>& delta 679 zcmdnVzf3@ogF}cxib07%gTaWwY8?xg-$ZtO9xYBTrjmRK!3B&=R+AN3g(rI5x3m`G zPA)Eq&&@5)NGwQI;&9H-D=taQE74NoVl7L|DNWT9<6stGw_;{uVsb)OI=PT>IzK-b zV}uYF7Y7HU5DOOr*JNp?R8c2BF0RB(z5Kkq3ccJM0j5)ouh_UIDu_*<%_PTS1a!dU zolFYeN-PXO;KW*$T9lu*fRV+@Mu;OXwH)0QYA{zwaWDd1!Sx>o7@gqCC(AP{Fse;< zV3y?N0eKMUy;Y1Ldy<)d)T;?kwTn6Zb@oh zaek2!dr&IaRaV77oiI@$uHurEG^h$5pTwlp9M`<${FKxpEeS5}?9|F)kojgNA;k`Y z4nhvX4k8N}*|pTTc)@lT$EW4wCzcf3_Q{@bvfU>T>13<=SH#Ws|63uq1&nM~Na1It z$Ax5-5DUoK$*-7YC%3Sa2%v>@5GYbvPcbfH;hH>=QD$;BD_^~~5DO@Xkpe&x7Nor3 yAZ2850;>SUfq@WHt|5|A1(;G1jL-)vgy}HCql51sBEpz*jW!Fgg)uUM-3b8JD$Ni8 From df116b82c743f9104bd090a310d5777a4475e539 Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Wed, 10 Jan 2024 14:13:25 -0800 Subject: [PATCH 2/9] Custom op API for thread pool (#18980) Allow custom op to invoke internal thread-pool for parallelism. --------- Co-authored-by: Randy Shuai --- .../core/session/onnxruntime_c_api.h | 13 +++++++ .../core/session/onnxruntime_cxx_api.h | 1 + .../core/session/onnxruntime_cxx_inline.h | 4 ++ onnxruntime/core/session/custom_ops.cc | 26 +++++++++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 1 + onnxruntime/core/session/ort_apis.h | 2 + .../testdata/custom_op_library/cpu/cpu_ops.cc | 37 +++++++++++++++++-- 7 files changed, 80 insertions(+), 4 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 06fef6bf72cc9..8cd0d0051d1eb 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4528,6 +4528,19 @@ struct OrtApi { * \since Version 1.17. */ ORT_API2_STATUS(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value); + + /** + * Run fn in parallel + * + * \param[in] context + * \param[in] fn Function accepting usr_data and an integer as iterator + * \param[in] total The number of times fn is to be invoked + * \param[in] num_batch Number of batches by which the "total" is to be divided in maximum. When zero, there is no limit + * \param[in] usr_data User data to be passed back to fn + * + * \since Version 1.17. + */ + ORT_API2_STATUS(KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 16d9451624533..3773a01cb65a8 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2057,6 +2057,7 @@ struct KernelContext { Logger GetLogger() const; OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const; OrtKernelContext* GetOrtKernelContext() const { return ctx_; } + void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const; private: OrtKernelContext* ctx_; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 63e55603736b6..db4619eeeae62 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1658,6 +1658,10 @@ inline Logger KernelContext::GetLogger() const { return Logger{out}; } +inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const { + ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data)); +} + inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) { Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_)); } diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index eea675eb0193a..984fdd6bce325 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -24,6 +24,7 @@ #include "core/session/custom_ops.h" #include "core/session/inference_session.h" #include "core/session/ort_apis.h" +#include "core/platform/threadpool.h" #if !defined(ORT_MINIMAL_BUILD) static constexpr uint32_t min_ort_version_with_optional_io_support = 8; @@ -380,6 +381,31 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetResource, _In_ const OrtKernelCont API_IMPL_END }; +ORT_API_STATUS_IMPL(OrtApis::KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data) { + API_IMPL_BEGIN + if (!context) { + return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, "Invalid context"); + } + if (fn && total) { + const auto* ctx = reinterpret_cast(context); + auto* tp = ctx->GetOperatorThreadPool(); + if (num_batch) { + onnxruntime::concurrency::ThreadPool::TryBatchParallelFor( + tp, + static_cast(total), + [fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast(ith)); }, + static_cast(num_batch)); + } else { + onnxruntime::concurrency::ThreadPool::TrySimpleParallelFor( + tp, + static_cast(total), + [fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast(ith)); }); + } + } + return nullptr; + API_IMPL_END +}; + #ifdef _WIN32 #pragma warning(pop) #endif diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 76a8a778025e1..08bfb618f55b4 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2722,6 +2722,7 @@ static constexpr OrtApi ort_api_1_to_17 = { &OrtApis::SetSymbolicDimensions, &OrtApis::ReadOpAttr, &OrtApis::SetDeterministicCompute, + &OrtApis::KernelContext_ParallelFor, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index c9e4074a1afe2..6df5e4145b416 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -502,4 +502,6 @@ ORT_API_STATUS_IMPL(SetSymbolicDimensions, _In_ OrtTensorTypeAndShapeInfo* info, ORT_API_STATUS_IMPL(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); ORT_API_STATUS_IMPL(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value); +ORT_API_STATUS_IMPL(KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* user_data); + } // namespace OrtApis diff --git a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc index 85edfa0e59f1d..ebef441350d4c 100644 --- a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc @@ -49,16 +49,45 @@ struct KernelOne { } }; +struct DataI { + const float* from = {}; + float* to = {}; +}; + +struct DataII { + const float* from = {}; + int32_t* to = {}; +}; + +// floats to floats +void CopyI(void* raw_data, size_t ith) { + auto data = reinterpret_cast(raw_data); + data->to[ith] = data->from[ith]; +} + +// floats to int32_t +void CopyII(void* raw_data, size_t ith) { + auto data = reinterpret_cast(raw_data); + data->to[ith] = static_cast(round(data->from[ith])); +} + // lite custom op as a function -void KernelTwo(const Ort::Custom::Tensor& X, +void KernelTwo(OrtKernelContext* context, + const Ort::Custom::Tensor& X, Ort::Custom::Tensor& Y) { const auto& shape = X.Shape(); auto X_raw = X.Data(); auto Y_raw = Y.Allocate(shape); + std::vector floats(static_cast(X.NumberOfElement()), 0.f); + + DataI data_i = {X_raw, floats.data()}; auto total = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); - for (int64_t i = 0; i < total; i++) { - Y_raw[i] = static_cast(round(X_raw[i])); - } + + Ort::KernelContext ctx(context); + ctx.ParallelFor(CopyI, static_cast(total), 0, &data_i); // test simple parallel for + + DataII data_ii = {floats.data(), Y_raw}; + ctx.ParallelFor(CopyII, static_cast(total), 2, &data_ii); // test batch parallel for } template From 731b50dfc4f8074185dc70f3a10236fa4fdfc0aa Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 10 Jan 2024 15:13:04 -0800 Subject: [PATCH 3/9] Support INT4 weight only quantize, including RTN and GPTQ 2 algorithms (#17390) ### Description Support INT4 weight only quantize (WOQ) via Intel Neural Compressor, including RTN and GPTQ 2 algorithms. **Note:** Please install `neural-compressor==2.3` for weight only quantize. ### Motivation and Context As large language models (LLMs) become more prevalent, there is a growing need for new and improved quantization methods that can meet the computational demands of these modern architectures while maintaining the accuracy. Compared to normal quantization like W8A8, weight only quantization is probably a better trade-off to balance the performance and the accuracy. RTN is the most straightforward way to quantize weight. GPTQ algorithm provides more accurate quantization but requires more computational resources. ### Evaluation results The following table shows the accuracy results of Llama-2 models evaluated on [lambada_openai](https://huggingface.co/datasets/lambada) task. `GPTQ W4G32Asym` in configuration column means GPTQ algorithm is used for 4-bit weight only quantization, setting group_size=32 and scheme=asym.
Model name Configuration Lambada_openai Accuracy Ratio
[WOQ/FP32]
Accuracy Perplexity
meta-llama/Llama-2-7b-chat-hf FP32 0.7058 3.2788 /
GPTQ
W4G32Asym
0.7025 3.4489 99.53%
meta-llama/Llama-2-7b-hf FP32 0.7392 3.3950 /
GPTQ
W4G32Asym
0.7326 3.5286 99.11%
meta-llama/Llama-2-13b-chat-hf FP32 0.7312 2.9163 /
GPTQ
W4G128Asym
0.7289 3.0061 99.56%
meta-llama/Llama-2-13b-hf FP32 0.7677 3.0438 /
GPTQ
W4G32Asym
0.7607 3.1562 99.09%
meta-llama/Llama-2-70b-chat-hf FP32 0.7543 2.6181 /
RTN
W4G32Sym
0.7489 2.6850 99.28%
meta-llama/Llama-2-70b-hf FP32 0.7964 2.6612 /
RTN
W4G32Sym
0.7896 2.7546 99.15%
--------- Signed-off-by: yuwenzho Co-authored-by: Wang, Mengni --- .../quantization/matmul_4bits_quantizer.py | 189 ++++++++++++++++-- .../python/tools/quantization/quantize.py | 11 +- .../quantization/test_op_matmul_4bits.py | 70 ++++++- 3 files changed, 246 insertions(+), 24 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 6293bcbbf95bd..3e9f9a6544a71 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -7,6 +7,8 @@ from __future__ import annotations import argparse +import copy +import importlib import logging import os @@ -14,9 +16,11 @@ import numpy.typing as npt import onnx from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto +from packaging import version from onnxruntime.capi._pybind_state import quantize_matmul_4bits +from .calibrate import CalibrationDataReader from .onnx_model import ONNXModel from .quant_utils import attribute_to_kwarg @@ -24,24 +28,98 @@ logger = logging.getLogger(__name__) +class WeightOnlyQuantConfig: + def __init__(self, algorithm): + """This is the Base class for Weight Only Quant Configuration. + + Args: + algorithm: + weight only quantize algorithm name. + """ + self.algorithm = algorithm + + +class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + ratios=None, + ): + """ + This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration. + RTN is the most straightforward way to quantize weight using scale maps. + + Args: + ratios: + percentile of clip. Defaults to {}. + """ + if ratios is None: + ratios = {} + super().__init__( + algorithm="RTN", + ) + self.ratios = ratios + + +class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + calibration_data_reader: CalibrationDataReader, + percdamp=0.01, + blocksize=128, + actorder=False, + mse=False, + perchannel=True, + ): + """ + This is a class for GPTQ algorithm Weight Only Quant Configuration. + GPTQ algorithm provides more accurate quantization but requires more computational resources. + + Args: + calibration_data_reader: + a calibration data reader. It enumerates calibration data and generates inputs for the original model. + percdamp: + percent of the average Hessian diagonal to use for dampening. + blocksize (int, optional): + channel number in one block to execute a GPTQ quantization iteration. + actorder (bool, optional): + whether rearrange Hessian matrix considering the diag's value. + mse (bool, optional): + whether get scale and zero point with mse error. + perchannel (bool, optional): + whether quantize weight per-channel. + """ + super().__init__( + algorithm="GPTQ", + ) + self.calibration_data_reader = calibration_data_reader + self.percdamp = percdamp + self.blocksize = blocksize + self.actorder = actorder + self.mse = mse + self.perchannel = perchannel + + class MatMul4BitsQuantizer: """Perform 4b quantization of constant MatMul weights""" def __init__( self, - model: ModelProto, + model: ModelProto | str, block_size: int, is_symmetric: bool, accuracy_level: int | None = None, - nodes_to_exclude: list[str] | None = None, + nodes_to_exclude=None, + algo_config: WeightOnlyQuantConfig = None, ): if nodes_to_exclude is None: nodes_to_exclude = [] - self.model = ONNXModel(model) + self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) + self.model_path = model if isinstance(model, str) else None self.block_size = block_size self.is_symmetric = is_symmetric self.accuracy_level = accuracy_level self.nodes_to_exclude = set(nodes_to_exclude) + self.algo_config = algo_config @staticmethod def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: @@ -176,20 +254,99 @@ def _process_subgraph(self, graph_stack: list[GraphProto]): graph_stack.pop() return graph + def _generate_q4_node_config(self): + """Generate weight only quant configuration for nodes.""" + q4_node_config = {} + template_config_q4 = { + "bits": 4, + "group_size": self.block_size, + "scheme": "sym" if self.is_symmetric else "asym", + } + for node in self.model.model.graph.node: + if node.op_type in ["MatMul"]: + if not all([self.model.get_initializer(i) is None for i in node.input]): + q4_node_config[node.name] = template_config_q4 + return q4_node_config + + def int4_quant_algo(self): + """4b quantize a model with RTN or GPTQ algorithm. Please refer to + https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md + for more details on weight only quantization using IntelĀ® Neural Compressor. + """ + + def inc_dataloader(): + data_reader = copy.deepcopy(self.algo_config.calibration_data_reader) + for data in data_reader: + yield data, None + + kwargs = {} + if self.accuracy_level is not None: + kwargs["accuracy_level"] = self.accuracy_level + weight_only_node_config = self._generate_q4_node_config() + + algorithm = self.algo_config.algorithm + logger.info(f"start to quantize model with {algorithm} algorithm...") + if algorithm == "RTN": + from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize + + kwargs["ratios"] = self.algo_config.ratios + + self.model = rtn_quantize( + model=self.model_path if self.model_path is not None else self.model.model, + weight_config=weight_only_node_config, + **kwargs, + ) + elif algorithm == "GPTQ": + from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize + + kwargs["percdamp"] = self.algo_config.percdamp + kwargs["blocksize"] = self.algo_config.blocksize + kwargs["actorder"] = self.algo_config.actorder + kwargs["mse"] = self.algo_config.mse + kwargs["perchannel"] = self.algo_config.perchannel + kwargs["n_samples"] = -1 + dataloader = inc_dataloader() + + self.model = gptq_quantize( + model=self.model_path if self.model_path is not None else self.model.model, + weight_config=weight_only_node_config, + dataloader=dataloader, + **kwargs, + ) + logger.info(f"complete quantization of model with {algorithm} algorithm.") + def process(self): - # use a stack to keep track of sub-graphs - graph_stack = [self.model.graph()] - opset_import = self.model.opset_import() - - has_ms_domain = False - for opset in opset_import: - if opset.domain == "com.microsoft": - has_ms_domain = True - if not has_ms_domain: - opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) - - self._process_subgraph(graph_stack) - self.model.clean_initializers() + if self.algo_config is None: + # use a stack to keep track of sub-graphs + graph_stack = [self.model.graph()] + opset_import = self.model.opset_import() + + has_ms_domain = False + for opset in opset_import: + if opset.domain == "com.microsoft": + has_ms_domain = True + if not has_ms_domain: + opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) + + self._process_subgraph(graph_stack) + self.model.clean_initializers() + else: + # use IntelĀ® Neural Compressor for RTN or GPTQ weight-only quantize algorithm + try: + importlib.import_module("neural_compressor") + except Exception as e: + logging.error(f"{e}.") + raise RuntimeError( + "neural-compressor is not correctly installed. Please check your environment." + ) from e + + import neural_compressor + + assert version.parse(neural_compressor.__version__) >= version.parse( + "2.3.2" + ), "Require neural-compressor >= 2.3.2 to support weight only quantization!" + + self.int4_quant_algo() def parse_args(): diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index aed46563c2764..1bd2ef42151d0 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -466,7 +466,6 @@ def quantize_static( import copy - import onnx from neural_compressor.adaptor.ox_utils.smooth_quant import ORTSmoothQuant def inc_dataloader(): @@ -478,13 +477,11 @@ def inc_dataloader(): dataloader = inc_dataloader() sq = ORTSmoothQuant(model_input, dataloader, reduce_range) del dataloader - model = sq.transform( - extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True) - ).model - nodes_to_exclude.extend([i.name for i in model.graph.node if i.name not in orig_nodes]) + model = sq.transform(extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True)) sq_path = tempfile.TemporaryDirectory(prefix="ort.quant.") - model_input = Path(sq_path.name).joinpath("sq_model.onnx").as_posix() - onnx.save_model(model, model_input, save_as_external_data=True) + model_input = Path(sq_path).joinpath("sq_model.onnx").as_posix() + model.save(model_input) + nodes_to_exclude.extend([i.name for i in model.model.graph.node if i.name not in orig_nodes]) model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir: diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 02f51cc4fa809..73dae08af8ece 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -71,13 +71,16 @@ def construct_model_matmul(self, output_model_path: str, symmetric: bool) -> Non output_name = "output" initializers = [] - def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str): + def make_matmul( + input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str, node_name: str + ): weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) return onnx.helper.make_node( "MatMul", [input_name, weight_name], [output_name], + node_name, ) in_features = 52 @@ -88,6 +91,7 @@ def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_na [in_features, out_features], "linear1.weight", output_name, + "MatMul_0", ) # make graph @@ -139,6 +143,48 @@ def quant_test( else: raise exception + def quant_test_with_algo( + self, + algorithm: str, + model_fp32_path: str, + data_reader: TestDataFeeds, + block_size: int, + is_symmetric: bool, + ): + model_int4_path = str( + Path(self._tmp_model_dir.name).joinpath(f"MatMulNBits_{block_size}_{is_symmetric}.onnx").absolute() + ) + + # Quantize fp32 model to int4 model + from onnxruntime.quantization import matmul_4bits_quantizer + + algo_config = None + if algorithm == "RTN": + # test RTN algorithm + algo_config = matmul_4bits_quantizer.RTNWeightOnlyQuantConfig() + elif algorithm == "GPTQ": + # test GPTQ algorithm + algo_config = matmul_4bits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader) + + model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) + quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config) + quant.process() + quant.model.save_model_to_file(model_int4_path, False) + + quant_nodes = {"MatMulNBits": 1} + check_op_type_count(self, model_int4_path, **quant_nodes) + + data_reader.rewind() + + try: + check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next()) + except Exception as exception: + if "4b quantization not yet supported on this hardware platform!" in exception.args[0]: + # Currently we don't have int4 quantization support on all platforms, has to tolerate this exception + pass + else: + raise exception + @unittest.skipIf( find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" ) @@ -159,6 +205,28 @@ def test_quantize_matmul_int4_offsets(self): data_reader = self.input_feeds(1, {"input": [100, 52]}) self.quant_test(model_fp32_path, data_reader, 32, False) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_using_rtn_algo(self): + if not find_spec("neural_compressor"): + self.skipTest("skip test_smooth_quant since neural_compressor is not installed") + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test_with_algo("RTN", model_fp32_path, data_reader, 32, False) + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_using_gptq_algo(self): + if not find_spec("neural_compressor"): + self.skipTest("skip test_smooth_quant since neural_compressor is not installed") + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False) + if __name__ == "__main__": unittest.main() From e58319ebfc344419b94ab5f8f27f7ce5eabe56f5 Mon Sep 17 00:00:00 2001 From: Yifan Li <109183385+yf711@users.noreply.github.com> Date: Wed, 10 Jan 2024 15:29:34 -0800 Subject: [PATCH 4/9] [TensorRT EP] Fix memleak (#19053) ### Description To fix memleak: ```bash 192 bytes in 1 blocks are definitely lost in loss record 1,254 of 1,999 at 0x483BE63: operator new(unsigned long) (in /usr/lib/x86_64-linux-gnu/valgrind/vgpreload_memcheck-amd64-linux.so) by 0x4A93FD5: OrtApis::CreateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2**) (in /code/onnxruntime/build/Linux/Release/libonnxruntime.so.1.17.0) by 0x1502E1: onnxruntime::perftest::OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env&, std::random_device&, onnxruntime::perftest::PerformanceTestConfig const&, TestModelInfo const&) (in /code/onnxruntime/build/Linux/Release/onnxruntime_perf_test) by 0x15A404: onnxruntime::perftest::PerformanceRunner::PerformanceRunner(Ort::Env&, onnxruntime::perftest::PerformanceTestConfig const&, std::random_device&) (in /code/onnxruntime/build/Linux/Release/onnxruntime_perf_test) by 0x14C6D9: real_main(int, char**) (in /code/onnxruntime/build/Linux/Release/onnxruntime_perf_test) by 0x145A2A: main (in /code/onnxruntime/build/Linux/Release/onnxruntime_perf_test) ``` add ptr to help release trtep provider options ### Motivation and Context --- onnxruntime/test/perftest/ort_test_session.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index ac25c98b15758..13082fe69cf48 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -170,6 +170,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device const auto& api = Ort::GetApi(); OrtTensorRTProviderOptionsV2* tensorrt_options; Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); + std::unique_ptr rel_trt_options( + tensorrt_options, api.ReleaseTensorRTProviderOptions); std::vector option_keys, option_values; // used to keep all option keys and value strings alive std::list buffer; From fd6bab4250c41a7f6498e6fa02ba446bc74e0a8d Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 11 Jan 2024 08:12:43 +0800 Subject: [PATCH 5/9] [js/webgpu] Provide a vectorized algorithm for GroupedConv (#18884) ### Description This PR provides a vectorized algorithm for NHWC GroupedConv to improve performance. The aggregate time of GroupedConv in mobilenetv2-12 becomes ~1ms from ~4ms on Intel Alder Lake machine. About 20% improvement for the whole model. --- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 99 +++++++++++- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 26 ++- js/web/test/data/ops/conv.jsonc | 152 +++++++++++++++++- 3 files changed, 271 insertions(+), 6 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 14482272bad38..21b4953d3f90c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -3,9 +3,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ProgramInfo} from '../types'; +import {ProgramInfo, ProgramUniform} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; import {getActivationSnippet} from './fuse-utils'; @@ -95,3 +95,98 @@ export const createGroupedConvProgramInfo = getShaderSource, }; }; + +export const createGroupedConvVectorizeProgramInfo = + (inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[]): ProgramInfo => { + const hasBias = inputs.length > 2; + const components = getMaxComponents(outputShape[3]); + const outputNumber = getMaxComponents(outputShape[2]); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; + const xShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[0].dims[2], inputs[0].dims[3] / components]; + const wShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[1].dims[3] / components]; + const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; + + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'int32', data: attributes.strides}, + {type: 'int32', data: attributes.pads}, ...createTensorShapeVariables(xShape), + ...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShapeInShader) + ]; + const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); + const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); + const x = inputVariable('x', inputs[0].dataType, xShape.length, components); + const w = inputVariable('w', inputs[1].dataType, wShape.length, components); + const inputVars = [x, w]; + if (hasBias) { + inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components)); + } + const processBias = hasBias ? 'value += b[output_channel];' : ''; + + return ` + ${ + shaderHelper.registerUniform('output_size', 'u32') + .registerUniform('strides', 'i32', 2) + .registerUniform('pads', 'i32', 2) + .declareVariables(...inputVars, output)} + ${activationFunction} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let width0 = uniforms.output_shape[3]; + let output_channel = global_idx % width0; + var index1 = global_idx / width0; + let width1 = uniforms.output_shape[2] / ${outputNumber}u; + let col = (index1 % width1) * ${outputNumber}u; + index1 = index1 / width1; + let row = index1 % uniforms.output_shape[1]; + let batch = index1 / uniforms.output_shape[1]; + + let x_corner = vec2(i32(row), i32(col)) * uniforms.strides - uniforms.pads; + + var x_vals: array<${x.type.value}, ${xNumber}>; + var values: array<${output.type.value}, ${outputNumber}>; + let input_channel = output_channel; + // Use constant instead of uniform can give better performance for w's height/width. + for (var w_height: u32 = 0u; w_height < ${wShape[0]}; w_height++) { + let x_height = x_corner.x + i32(w_height); + if (x_height >= 0 || u32(x_height) < uniforms.x_shape[1]) { + for (var i = 0; i < ${xNumber}; i++) { + let x_width = x_corner.y + i; + if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) { + x_vals[i] = ${x.get('batch', 'u32(x_height)', 'u32(x_width)', 'input_channel')}; + } else { + x_vals[i] = ${x.type.value}(0); + } + } + for (var w_width: u32 = 0u; w_width < ${wShape[1]}; w_width++) { + let w_val = ${w.get('w_height', 'w_width', '0', 'output_channel')}; + for (var i = 0u; i < ${outputNumber}u; i++) { + values[i] = fma(x_vals[i * ${attributes.strides[1]}u + w_width], w_val, values[i]); + } + } + } + } + + for (var i = 0u; i < ${outputNumber}u; i++) { + var value = values[i]; + ${processBias} + ${applyActivation} + ${output.set('batch', 'row', 'col + i', 'output_channel', 'value')}; + } + }`; + }; + + return { + name: 'GroupedConv-Vectorize', + shaderCache: { + hint: `${attributes.activationCacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, + inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'] + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms + }), + getShaderSource, + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 33a5db7ff6b25..cb40a9f08d2d7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -8,7 +8,7 @@ import {ComputeContext} from '../types'; import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createGroupedConvProgramInfo} from './conv-grouped'; +import {createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo} from './conv-grouped'; import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; import {createNaiveMatmulProgramInfo} from './matmul'; import {createTransposeProgramInfo} from './transpose'; @@ -136,12 +136,32 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // check attributes // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */ + const isChannelsLast = attributes.format === 'NHWC'; if (attributes.group !== 1) { - context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes)); + if (isChannelsLast && inputs[1].dims[0] === attributes.group && inputs[1].dims[1] === 1 && + attributes.dilations[0] === 1 && attributes.dilations[1] === 1) { + const outputShape = calculateOutputShape( + inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides, + isChannelsLast); + const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute( + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + const convInputs = [inputs[0], transposedWeight]; + if (inputs.length === 3) { + convInputs.push(inputs[2]); + } + context.compute( + createGroupedConvVectorizeProgramInfo(convInputs, adjustedAttributes, outputShape), {inputs: convInputs}); + } else { + context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes)); + } return; } - const isChannelsLast = attributes.format === 'NHWC'; const hasBias = inputs.length === 3; const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2]; const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3]; diff --git a/js/web/test/data/ops/conv.jsonc b/js/web/test/data/ops/conv.jsonc index 2e8eaaba191d0..cc10df5864233 100644 --- a/js/web/test/data/ops/conv.jsonc +++ b/js/web/test/data/ops/conv.jsonc @@ -298,7 +298,157 @@ } ] }, - + { + "name": "conv - vectorize group - A", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [1, 1], "type": "ints" }, + { "name": "group", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0], + "dims": [1, 2, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0], + "dims": [2, 1, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0], + "dims": [1, 2, 3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "conv - vectorize group - B", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + }, + { + "data": [0.1, 0.2, 0.3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [27.1, 37.1, 57.1, 67.1, 293.2, 319.2, 371.2, 397.2, 847.3, 889.3, 409.3, 428.3], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "conv - vectorize group - C", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0 + ], + "dims": [1, 3, 3, 4], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [34, 44, 54, 74, 84, 94, 386, 412, 438, 490, 516, 542, 1122, 1164, 1206, 1290, 1332, 1374], + "dims": [1, 3, 2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "conv - vectorize group - D", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "strides", "data": [2, 2], "type": "ints" } + ], + "cases": [ + { + "name": "T[0] strides = [2, 2]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0 + ], + "dims": [1, 3, 3, 4], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [34, 54, 386, 438, 1122, 1206], + "dims": [1, 3, 1, 2], + "type": "float32" + } + ] + } + ] + }, { "name": "conv - pointwise", "operator": "Conv", From 5678317bafd219e2b71c72431905b776460e11a4 Mon Sep 17 00:00:00 2001 From: Yvonne Chen Date: Thu, 11 Jan 2024 10:36:33 +0800 Subject: [PATCH 6/9] Fix the duplicated QDQ attributes setup issue (#18039) ### Description The copied QDQ node should have exactly the same attributes as the original QDQ node. Otherwise, it might cause errors when the original node has attributes that use non default values (such as axis != 1 case). An example user case is like: A DequantizeLinear node has more than 1 consumer in the graph, and its attributes axis is 0. ### Motivation and Context I see the errors like https://github.com/microsoft/onnxruntime/issues/16188 and this fix could solve the issue. --- .../ensure_unique_dq_for_node_unit.cc | 2 +- .../ensure_unique_dq_for_node_unit_test.cc | 40 ++++++++++++++++++ .../qdq_with_multi_consumer_q_dq_axis.onnx | Bin 0 -> 9361 bytes 3 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/testdata/qdq_with_multi_consumer_q_dq_axis.onnx diff --git a/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc b/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc index cc0f7854791d4..9d53e28921784 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.cc @@ -53,7 +53,7 @@ Status DuplicateDQForOutputEdge(const graph_utils::GraphEdge& original_dq_output MakeString("Added by ", kTransformerName), dq_inputs, {&new_dq_output_nodearg}, - nullptr, // attributes + &original_dq_node.GetAttributes(), original_dq_node.Domain()); // set up edges diff --git a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc index 7a67747f7cf4c..89ffb8ec87dcb 100644 --- a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc +++ b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc @@ -234,4 +234,44 @@ TEST(EnsureUniqueDQForNodeUnitTests, QDQWithMultiConsumerDQNodes) { EXPECT_EQ(OpCount(op_count_before, "DequantizeLinear") + 4, OpCount(op_count_after, "DequantizeLinear")); } +TEST(EnsureUniqueDQForNodeUnitTests, QDQWithMultiConsumerDQNodesPreservingAttributes) { + constexpr auto model_uri = ORT_TSTR("testdata/qdq_with_multi_consumer_q_dq_axis.onnx"); + + SessionOptions session_options{}; + // test interaction with level 1 transformers + session_options.graph_optimization_level = TransformerLevel::Level1; + + InferenceSessionWrapper session{session_options, GetEnvironment()}; + + ASSERT_STATUS_OK(session.Load(model_uri)); + + const auto op_count_before = CountOpsInGraph(session.GetGraph()); + + ASSERT_STATUS_OK(session.Initialize()); + + const auto op_count_after = CountOpsInGraph(session.GetGraph()); + + EXPECT_EQ(OpCount(op_count_before, "DequantizeLinear") + 8, OpCount(op_count_after, "DequantizeLinear")); + + int64_t given_axis = 0; // all the following 4 DQ nodes and their duplicated one should have axis = 0 + std::string axis_dq_name0 = "Convolution28_Output_0/fusedmuladd_B/DequantizeLinear"; + std::string axis_dq_name1 = "Parameter5/DequantizeLinear"; + std::string axis_dq_name2 = "Convolution110_Output_0/fusedmuladd_B/DequantizeLinear"; + std::string axis_dq_name3 = "Parameter87/DequantizeLinear"; + for (const auto& node : session.GetGraph().Nodes()) { + if (node.OpType() == "DequantizeLinear") { + if (node.Name().find(axis_dq_name0) == 0 || + node.Name().find(axis_dq_name1) == 0 || + node.Name().find(axis_dq_name2) == 0 || + node.Name().find(axis_dq_name3) == 0) { + const auto& attrs = node.GetAttributes(); + ASSERT_TRUE(attrs.find("axis") != attrs.end()); + const auto& axis_attr = attrs.at("axis"); + int64_t axis = axis_attr.i(); + EXPECT_EQ(axis, given_axis); + } + } + } +} + } // namespace onnxruntime::test diff --git a/onnxruntime/test/testdata/qdq_with_multi_consumer_q_dq_axis.onnx b/onnxruntime/test/testdata/qdq_with_multi_consumer_q_dq_axis.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4f575ebb2841a02802754e2a449a05ad58220afc GIT binary patch literal 9361 zcmcgy2~<;O+Rh5$W)CDJA?&LG*&)W0gi3^9sdLUR=N`E4eYfv@p6|Q& z`(9D*B+sm^+p_|;?Ap3DVB68W1MeUBDEAm0vTA%2a0?2Mw`U*Aem^%a_gD%m=7SU2 zhx5S4`v9V|`1m{7hjIamvj}u?Wbe@<2M*_v7|!x_xj*}O<-p|^w2J`Ex))Mugr zMD~XVj#DA|P*^-6b0(pkzdxG0DgX?X!d?Xd3ZQ`3lZZ)9DHEKevl-8iEh)w+VIL>O zh4JandiXr)7Q%p-wMP#B?#Q7Nc?XUhj*E-kyX{2Y(Gz)lV`JVuaXfe5`zH=%@7uRG z-DxC|%QKiAWbE>koyMQMJc~u8c}}}usv7ZWkQ3r(Q2q%9L_4J!pYqF9V=ha7wrI|f zEbZH=W=YF(zNlo+XZBnb%Ld#w9|n1q0AOZ+$IK3uICtkX7}T67lm6tdAlzdVz_~v= zZ(UBk_kbt*zZ}G4F37q}`uv}ZN9HonC={5t0iHD_f#<6*3BM!vl~)tupRkIKeG1?p z#LO*|2B5M}Yt5GRdtcqYef^GnD4a5VD6rV_c~T-RObMie3)wPx@lPn=!k9$be&ooZ1Bc&B zPM*ip$2Qo1ZC&&|9q7u6>w^Z7F*yKo1>>RP-^WNkBF zB8#(JmkB)c0|I!?cM^#(!^^n1#3z8!aZU>W_?G^EUID&>0*g`tC<~LEW021xxf6dL zPY~yk#mvpPr)A6H&dIs?_O$8q<>k33X^=Cngv(?MxJ0?KCV9@$C@y*aON$Q73xD-; z08fMe@^Ub%1TF_Zf@#eMdqOXO=0Z3Jr=P1wPKf?IF`P#Pg6dklEL<GMo}T*n%0RIG$YnFBp!MR< zuo(+~(8BeLelBnzX!?v-{)~@c(8tp3nkz&o=;s%BHskAc6an>Fpcus33p**z_9vIlk?VgKcC3GA=p|C{}n;JxfU0jt;|>Tb4Q9)``kafp3b z7Rk0%Y-2N{LD%r~B{q>6geff9HjH$KTBZQ4n?j3Xo~NM*C=?2bLV{-VJ2~fFY6h1! zDxd81@oFqE;@vGnE#3Gj&sy>Hx`dHVZoyBUWcvd(bTlf>*gmd~ixnbT8FT={ZzuV- zi^^Ju`m2je`f)dVPqpjet!RDFV1mc5}o<>aJ;r*+_Uh)kBgHHtuI@ZlBgt z#+wP!;-7}FhJ+Z$3y}wQ?DXj#RX5&JVdF?d8N%M4#4Q@3q2a9xgEc}x@XJzs@LG$a zW821XRBgm>12)`ifJm4=5F}O_OX|beQv8RsoWtu_n_k7a;S&uBFvA!m$|rGU`lDV* z@b>nKv*XaW;}N60cY?w*HpMXPQSQ?)?KI6_*?dK|I>Dm)VI?bgN1_)+8aw4a2)MgN z_-G}yal7L@S~P(>R~&}IiMvwZb_%Ri(>_>J+3D!1sh95ZAyXsJUE(f$2*qs=EOg&T znWvjgDBqa2R&T<7tYoCoT11lyM;pBfBa4N|qLG>c1hAU(y60DSqi?WPW=%hEhJIadW z;z#0P#zRZt!KbtSnWsbA=_3Fin>?oG$RR#gH4QnkItIjM{R?b zR-}5O$|6|y4c+{{<6hb;1BB6b_5IfK;&Bl+>J|Sg^|-O;-lu23jiKK;%@v5?VWHhO zZt+yD8P${FFlN&ll>Wjar`nzyq=NeCp8g7&} zNzCmYLW`$b)G8VH?$dusu~8xE);J`)eW0#hQ;H3#|Mcv*rT&(nN@{LwDlHtf=}9$a z*u;3{w^OK-HC1(e(N@L7f=J@P!|%!(OWNyy653HEZf^(olxFRVc#ESCQfUD>?E<&f zQr}Q{!Q3SnIUn+7fU-XVQCIc@=C8R zVfuuqh?Jaz^6tW#)dBJewM`8d=xuie-vZSTQG4Ir$;2>U-s$KKI9LUQl|U7F;B9so z(LJ>7hfy88{a@89==~1{4g>}wJwtVLB4kp;Tnh)HXg!{Zto|MyQ#GYE)pSN65(`ZuR^iDM5-{){DVa>hK zyL{uQLJwC2cZ9E0U;&J>E2PCu2VW}`8wVn#!^S|4s;ygLZn-N*Cs3MXRvLt_NPGRX zdqQkiqY|&J;j8uB@<3D9=q7x>ikga1kGW0bjYB@}D|KW8k%mFK#f-t+NRh21#`H@T zMVB8j(px%-3-oWl-t}_m;3LaEL_bs#DJC9t8yElH582gO(b|N+)7K5{xq{j9k*GVC z?0Ey`4@{wh?UQ0spuFd*`%!FlL=sz)<%_fLA;yVFZNj&c-OAdx$6ZuLa9$|$)#qL7 zos30{>QTOiqB9h@^r4bGaW08Iar?m8P{dh)_I4=(-K-6;H+Yu3fnqSn1q-=d{@mW^ER`d(3frc!hMPk$3O zYRjy_s6r1TUtJL!0^#?ve1}aBs11p@Q+`-f`{3&IlyXLr?aufH=;bD3#OIALR++jI zj&C*#|uJFrbvXvZ?DRvJEQS65b$4&J5uJO_tR0)3d6P|?wo}>2~y;|TTf>APL zlb!Mo2(PY*^lU3^R7wkqgcVkVskLduVdBaA3~F<|6&bOM7QcFh(m*ol){Mu6wRrS4 zC`Yju2_ks=uxivOF3O8aZ|DkRCjUg}G9+%;ZliKL)QQS6XuNuWLbq+C!8 zk$%Ay8~tM3YHe!2ydGL}Xcmlb4df6y+NV?OG*hp--#^h{s_WNRivA()m4)*k2wJF+ zG#{Nt6&zXVC->lspghE-n^3Xlkw$gHy|z!P20U8@ba7vSK3XmonD$Es&9D%%HNU*m z7cp{mR4y4TZLznvw)*O8dvWn)MPJpem6|o`OfAz+c%8OAjQL(Xf!Yu89-$ynqimB9 z;*YXgtBJ*((Y3noN*kyB!)(Zb*q2(>y2)NoFGsoO3kj3n`=ir1ME!`^64ZerSYsZl zwf>vojL`8Udi**M4ZpD(qHG&86Rgk)VqkCRI`4<6dp@e?0+#;1WX$wbu$w=0g;(UL zr&(vbE54;V3hHIp$AQUi2%QAWN27B27?f{>;U9%?Ot(f z>=HexsJBGkD3~-=Q&(=#A5* z=0VhCQ^ORTBD}1JMHyuVbH%mdPLZYX%luo7Llsp*{B%&?0138#%0fNF8MBzx4SFoB z)6#;Pm||-?WxW`Kt!z{l>JVxi6;NEg#pFgFARF6lQv%JTwvJ2hX2y=8r;HQRQcY`} z3{uN$Xe45!VnJ*y@?4qGMl3t8ZSNM|gidJdZ&(Iy@+{OB`S-=SzcDa03VBY{n{COuo;QK zkMKzN@+%#X5_EGzH*bwQPO6lvsvQI0D0m*jl=|YL|Jg75K~^S)k!YK5|NCy+7_x0J zj*y6$4yU|oGwfO07K}?<$LrEuyKuGE92V2xAnwx-6zBi#8V{o7^9gF60a8?Oxz*3o z_7~okuf-a zs(5(+YQ&Wr`r@9+o{4h#8Zx1)Uc1*zYQb?3H}O8q3|2^W`I*}%FY#m&Wn=RNQ8nKF zRbhW2gw*YIK4}ZIv2XQD5MxT3i8m0zz)u=F-VL+V8906#P2|BTp(ZE|c6Taq;!&Gn z3`1(cQV02#=wiWrS!5$$E^oR#VnX-ldo!rgvFM1>A^ZNo5&>*C>gHH3cCZwF>d!Vu z$4_D(pvFw1YP3BRavw}XQeyb!7tRl1Mjt{t(O3$ckSIhdqx269N_9_vd5J~}Ttbkb z;dSJs(B|Y9|*c9g84KRHCE)`+pH*oz7l3}%*d+l&7Re~8{rl3&Zj%iWvDI-8s zEM(ug`M+#@eFvKsx`+M2rhN9k!Q1TUl8@Pk4t>IAhJu#PI~(U7RR&6iAma^g1`Bfp z3hsiSbNe8dov(~~XkYO^J57K7SJ!TjuZB3R({-#DZ*vFXkG{D*#rp6)yrvq?$&o+k z+x-Ujs^oEvxmbv&b?H1PhjrvbD-~CD=;^41@^SGlc+?- z&PYxsX8i{rQn#L{>o(N1B7NIxxT=lXFS$L0sA-)s^f)J2-Jl=@g#XArReoT77x)dw z+3cq$z8Lpr42Dz6eYExc&;EM%jn&*sy)O{zPd8@-A;~I(7#FkIe@FOQjwGB4r#eJE z4cu|wmxVm$@y8p`U(*1~NNb4qNvT+aNO?<1OSxBsO<6N(Wvxu4rc{_hQr_j7Sy)DV zN)?}$^4-gjl(-466mEV6>tRT6${M|fbvY5A!u${n>X!bi8)+G7kpJ)>^S>TPcJ4+k zr$LyJmT7i4{-k#}rnL@7rrO~+170oQbq|zL1|5zpmBYaqa5y%Y9FE^XA&{3G4#(@A z;5{h+ddlHo?s!bwU(qtl2!Mam%G$Z=Ggj=$P8P?yhxMW|pY`rHcUYvPJkbs5Qb=u*uV<8Zu)#0cFi(>){KtnmI5TBanK|A|9Ombc7N2te9PsHg<{Yxi)7=@d zi(^KofKQ)i=1{@^wlLdPcPFgrQK^+X5r^w@U!41t^sx#?6CjP--C+awj>8aUqaK$1@El`~) z!K}YX7gSpid|o-q6$O`c4QekYp$lysm<;YMp47LVbo>h=cP3eQKs;x=3f%Aa<##wN!|71OOqB{tq3W Bly?9C literal 0 HcmV?d00001 From 053ddfe3fd52135742567940750b7bf6ccffe166 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 10 Jan 2024 18:45:49 -0800 Subject: [PATCH 7/9] Disable per-session thread pool for web (#18480) ### Description ORT web prefers to use a global thread pool for all inference sessions. See how OrtCreateSession is implemented in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/wasm/api.cc#L183 . Application code can only the global thread poo. However, internal testing code still often use per-session threadpool. This pr is to fix the inconsistency. ### Motivation and Context Replace PR #18476 --- onnxruntime/core/framework/session_options.h | 8 +++++++- .../test/framework/inference_session_test.cc | 18 ++++++++++++++++++ .../cpu/activation/activation_op_test.cc | 3 +++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 40c59cfcf699d..796a018ac0f68 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -65,6 +65,11 @@ struct FreeDimensionOverride { * Configuration information for a session. */ struct SessionOptions { +#if defined(__wasm__) && defined(__EMSCRIPTEN_PTHREADS__) + static constexpr bool DEFAULT_USE_PER_SESSION_THREADS = false; +#else + static constexpr bool DEFAULT_USE_PER_SESSION_THREADS = true; +#endif ExecutionMode execution_mode = ExecutionMode::ORT_SEQUENTIAL; // set the execution order of the graph @@ -129,7 +134,8 @@ struct SessionOptions { // By default the session uses its own set of threadpools, unless this is set to false. // Use this in conjunction with the CreateEnvWithGlobalThreadPools API. - bool use_per_session_threads = true; + bool use_per_session_threads = DEFAULT_USE_PER_SESSION_THREADS; + bool thread_pool_allow_spinning = true; // Deterministic compute is likely not as performant. This option is default to false. diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 486ec37d1eebd..2522ee3b496f6 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -578,6 +578,9 @@ TEST(InferenceSessionTests, ModelMetadata) { } #endif TEST(InferenceSessionTests, CheckRunLogger) { + if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) { + GTEST_SKIP() << "Skipping the test"; + } SessionOptions so; so.session_logid = "CheckRunLogger"; @@ -837,6 +840,9 @@ TEST(InferenceSessionTests, PreAllocateOutputVector) { } TEST(InferenceSessionTests, ConfigureVerbosityLevel) { + if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) { + GTEST_SKIP() << "Skipping the test"; + } SessionOptions so; so.session_logid = "ConfigureVerbosityLevel"; @@ -2661,6 +2667,9 @@ class InferenceSessionTestSharingAllocator : public InferenceSessionWrapper { // Ensure sessions use the same allocator. It uses ORT created allocator. TEST(InferenceSessionTests, AllocatorSharing_EnsureSessionsUseSameOrtCreatedAllocator) { + if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) { + GTEST_SKIP() << "Skipping the test"; + } auto logging_manager = std::make_unique( std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, LoggingManager::InstanceType::Temporal); @@ -2706,6 +2715,9 @@ TEST(InferenceSessionTests, AllocatorSharing_EnsureSessionsUseSameOrtCreatedAllo // Ensure sessions don't use the same allocator. It uses ORT created allocator. TEST(InferenceSessionTests, AllocatorSharing_EnsureSessionsDontUseSameOrtCreatedAllocator) { + if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) { + GTEST_SKIP() << "Skipping the test"; + } auto logging_manager = std::make_unique( std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, LoggingManager::InstanceType::Temporal); @@ -2758,6 +2770,9 @@ class InferenceSessionTestSharingInitializer : public InferenceSessionWrapper { }; TEST(InferenceSessionTests, InitializerSharing_EnsureSessionsUseUserAddedInitializer) { + if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) { + GTEST_SKIP() << "Skipping the test"; + } auto logging_manager = std::make_unique( std::unique_ptr(new CLogSink()), logging::Severity::kVERBOSE, false, LoggingManager::InstanceType::Temporal); @@ -2942,6 +2957,9 @@ TEST(InferenceSessionTests, GlobalThreadPoolWithDenormalAsZero) { // test inter thread pool with setting denormal as zero TEST(InferenceSessionTests, InterThreadPoolWithDenormalAsZero) { + if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) { + GTEST_SKIP() << "Skipping the test"; + } // test if denormal-as-zero mode is supported if (!SetDenormalAsZero(false)) { return; diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index 7ec9e0f345187..ddb0a6620619c 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -588,6 +588,9 @@ TEST_F(ActivationOpTest, Softplus) { } TEST_F(ActivationOpNoInfTest, Softsign) { + if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) { + GTEST_SKIP() << "Skipping the test"; + } // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1, which exceeds threshold"; From 0a0ef958eb94d94a17fd8c09a6d217e0589827a2 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 10 Jan 2024 19:26:01 -0800 Subject: [PATCH 8/9] update .vscode/settings.json (#19084) ### Description `"explicit"` now replaced `true` to config entry "source.organizeImports". Latest VSCode will automatically modify this config. --- .vscode/settings.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 2f2adc78f6de9..3e2b1f31dd6cf 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -11,7 +11,7 @@ // Auto sort imports "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.organizeImports": "explicit" }, "editor.defaultFormatter": "ms-python.black-formatter" }, From d03e477b9026a97d22dba64cd00b4614603671e5 Mon Sep 17 00:00:00 2001 From: pengwa Date: Thu, 11 Jan 2024 12:50:55 +0800 Subject: [PATCH 9/9] Fix missing subgraph candidates for recompute (#19077) ### Fix missing subgraph candidates for recompute For subgraphs for example `MatMul+Transpose+Reshape`, since the ending node is a Reshape, in ORT, it is reusing input buffers. Currently, the subgraph detection logic has defect, as a result, those subgraphs will be missing as recompute candidates. Also append a few more node types for recompute support. TODO: add unit test later. This PR is needed for a customer model now. --- .../memory_optimizer/memory_insight.cc | 34 +++++++++----- .../memory_optimizer/optimization_planner.cc | 29 ------------ .../memory_optimizer/optimization_planner.h | 2 +- .../memory_optimizer/recompute_analysis.cc | 3 ++ .../memory_optimizer/recompute_analysis.h | 45 +++++++++++++++++++ .../ortmodule/_graph_execution_manager.py | 2 +- .../python/training/ortmodule/options.py | 9 ++++ 7 files changed, 83 insertions(+), 41 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 9b77832abb6f1..3fbdd5da7b768 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -485,12 +485,15 @@ void ListAllCombinations(const InlinedVector> new_combination = current_combination; - new_combination.push_back(plan); - ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations); - } + const InlinedVector>>& + plan_combination_list_at_cur_index = all_possible_node_optimization_plans[index]; + // For the index-th reused buffer, iterate all possible complete plans. + for (size_t i = 0; i < plan_combination_list_at_cur_index.size(); ++i) { + const auto& plan_combination = plan_combination_list_at_cur_index[i]; + InlinedVector> new_combination = current_combination; + // Append the chosen complete plan and continue exploring the next reused buffer by index + 1. + new_combination.insert(new_combination.end(), plan_combination.begin(), plan_combination.end()); + ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations); } MO_LOG_DEBUG_INFO(logger, "Exit ListAllCombinations"); @@ -520,17 +523,28 @@ void IterateNodeOptimizationPlan(const std::shared_ptr } InlinedVector>>> - all_possible_node_optimization_plans; - all_possible_node_optimization_plans.resize(plan->reuse_buffers.size()); + all_possible_node_optimization_plans(plan->reuse_buffers.size()); size_t i = 0; for (const auto& p : plan->reuse_buffers) { MO_LOG_DEBUG_INFO(logger, ">>>reuse buffer: " + std::to_string(p.first)); - IterateNode(p.second.first, node_to_optimization_plans_map, {}, logger, all_possible_node_optimization_plans[i]); + // If the resued node is part of current node optimization plan, then we just add current combination to the result. + if (plan->GetOptimizationType() == OptimizationType::RecomputeWithCompromise || plan->GetOptimizationType() == OptimizationType::Recompute) { + const auto& recompute_subgraph = + dynamic_cast(plan.get())->GetNodesInTopoOrder(); + if (std::find(recompute_subgraph.begin(), recompute_subgraph.end(), p.second.first) != recompute_subgraph.end()) { + all_possible_node_optimization_plans[i].push_back(current_combination); + } + } + + if (all_possible_node_optimization_plans[i].size() == 0) { + IterateNode(p.second.first, node_to_optimization_plans_map, current_combination, logger, all_possible_node_optimization_plans[i]); + } + ++i; } - ListAllCombinations(all_possible_node_optimization_plans, 0, current_combination, logger, all_combinations); + ListAllCombinations(all_possible_node_optimization_plans, 0, {}, logger, all_combinations); MO_LOG_DEBUG_INFO(logger, "Exit IterateNodeOptimizationPlan: " + plan->GetClusterId()); } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc index 64e99a4a0bca5..4ce896c5350b0 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc @@ -15,35 +15,6 @@ namespace onnxruntime::optimizer::memory_optimizer { -std::string NodeOptimizationPlanBase::GetMemorySavingSymbolicString() const { - std::string saving_str; - for (auto output_index : activation_output_indices_) { - // If the output is reusing other node's buffer, then no memory saving. - if (reuse_buffers.find(output_index) != reuse_buffers.end()) { - continue; - } - - const auto& output_def = node->OutputDefs()[output_index]; - MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto()); - ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ", - DataTypeImpl::ToString(ml_data_type)); - const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); - ORT_ENFORCE(nullptr != tensor_type_base); - MLDataType elt_type = tensor_type_base->GetElementType(); - const auto byte_count_per_element = elt_type->Size(); - if (!saving_str.empty()) { - saving_str += " + "; - } - saving_str = "(" + GetActivationOutputDimParamString(output_index) + " * " + - std::to_string(byte_count_per_element) + " * " + - std::to_string(GetSaveRatio()) + ")"; - } - if (saving_str.empty()) { - return saving_str; - } - return "(" + saving_str + ")"; -} - Status MemoryOptimizationPlanner::UpdateNodePlansFromExecutionPlan(const GraphViewer& graph_viewer, const OrtValueNameIdxMap& ortvalue_name_to_idx_map, const SequentialExecutionPlan& p_seq_exec_plan) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h index c585b2810b39d..789f530b29f1d 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h @@ -83,7 +83,7 @@ class NodeOptimizationPlanBase { /** * Get a symbolic string to represent the memory saving for this optimization plan. */ - std::string GetMemorySavingSymbolicString() const; + virtual std::string GetMemorySavingSymbolicString() const = 0; std::string GetActivationOutputDimParamString(size_t index) const { ORT_ENFORCE(activation_output_dim_params_.find(index) != activation_output_dim_params_.end(), diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 52dea571a1eaf..12c83591c0036 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -72,12 +72,14 @@ const InlinedHashMap& GetAllowedRecompu {"Add", AllowedRecomputeNodeConfig{{0, 1}}}, {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, {"Div", AllowedRecomputeNodeConfig{{0, 1}}}, + {"Equal", AllowedRecomputeNodeConfig{{0, 1}}}, {"Mul", AllowedRecomputeNodeConfig{{0, 1}}}, {"Sub", AllowedRecomputeNodeConfig{{0, 1}}}, // Data layout /// The shape input is trivial whether it exists or not in backward. {"Reshape", AllowedRecomputeNodeConfig{{0}}}, + {"Shape", AllowedRecomputeNodeConfig{{0}}}, {"Squeeze", AllowedRecomputeNodeConfig{{0}}}, {"Transpose", AllowedRecomputeNodeConfig{{0}}}, {"Unsqueeze", AllowedRecomputeNodeConfig{{0}}}, @@ -92,6 +94,7 @@ const InlinedHashMap& GetAllowedRecompu {"Expand", AllowedRecomputeNodeConfig{{0}}}, {"FastGelu", AllowedRecomputeNodeConfig{{0}}}, {"Gelu", AllowedRecomputeNodeConfig{{0}}}, + {"QuickGelu", AllowedRecomputeNodeConfig{{0}}}, // Ternary elementwise {"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}}, diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h index d9693835313b8..ab114d970191e 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -86,6 +86,51 @@ class NodeRecomputePlan : public NodeOptimizationPlanBase { std::string GetNodesInTopoOrderStr() const; + std::string GetMemorySavingSymbolicString() const override { + std::string saving_str; + for (auto output_index : GetActivationOutputIndices()) { + // If the output is reusing other node's buffer, then no memory saving. + std::string cur_output_saving_str; + + bool is_reused = reuse_buffers.find(output_index) != reuse_buffers.end(); + bool is_src_node_in_cur_node_subgraph = false; + if (is_reused) { + // Here we assume the src_node is the real owner of the buffer, so we don't need trace further. + const auto* src_node = reuse_buffers.at(output_index).first; + is_src_node_in_cur_node_subgraph = std::find(nodes_in_topological_order_.begin(), + nodes_in_topological_order_.end(), + src_node) != nodes_in_topological_order_.end(); + } + + if (!is_reused || is_src_node_in_cur_node_subgraph) { + // For is_src_node_in_cur_node_subgraph is True, still use the output to calculate the saving, because + // reusing buffer is the same size. + const auto& output_def = node->OutputDefs()[output_index]; + MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto()); + ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ", + DataTypeImpl::ToString(ml_data_type)); + const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); + ORT_ENFORCE(nullptr != tensor_type_base); + MLDataType elt_type = tensor_type_base->GetElementType(); + const auto byte_count_per_element = elt_type->Size(); + cur_output_saving_str = GetActivationOutputDimParamString(output_index) + " * " + + std::to_string(byte_count_per_element) + " * " + + std::to_string(GetSaveRatio()); + } else { + cur_output_saving_str = "0"; + } + + if (!saving_str.empty()) { + saving_str += " + "; + } + + saving_str = "(" + cur_output_saving_str + ")"; + } + + ORT_ENFORCE(!saving_str.empty(), "saving_str should not be empty for node: ", node->OpType(), " ", node->Name()); + return "(" + saving_str + ")"; + } + private: bool compromise_recompute_; InlinedVector nodes_in_topological_order_; diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 76943b954837b..853eab61b4bd6 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -243,7 +243,7 @@ def _get_session_config(self): # requires PRIORITY_BASED order to work properly. So we use PRIORITY_BASED order when recompute is enabled. session_options.execution_order = ( onnxruntime.ExecutionOrder.PRIORITY_BASED - if self._runtime_options.memory_optimizer_config != "" + if self._runtime_options.memory_optimizer_is_enabled() else onnxruntime.ExecutionOrder.DEFAULT ) # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index a93f6413b7ab4..bfa38efb349ae 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -399,3 +399,12 @@ def _override_from_env_vars(self): if "ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT" in os.environ: self.deepcopy_before_model_export = int(os.getenv("ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT")) == 1 + + def memory_optimizer_is_enabled(self) -> bool: + """Check whether memory optimizer is enabled.""" + if self.memory_optimization_level == _MemoryOptimizationLevel.USER_SPECIFIED: + return len(self.memory_optimizer_config) > 0 + elif self.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: + return True + + return False