From c5276ac44874eb453485e73a7dca9ed3328d88c2 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 11 Nov 2024 09:59:05 -0800 Subject: [PATCH 1/4] Revert "enable serialize prepacked weights into data file (#22256)" (#22788) This reverts commit c5b6be045ff58b706390bd7504e7d865e451a689. ### Description Revert ### Motivation and Context This needs simpler and more robust approach --- .../onnxruntime/core/framework/op_kernel.h | 22 --- include/onnxruntime/core/graph/graph.h | 29 +-- .../onnxruntime_session_options_config_keys.h | 6 - onnxruntime/contrib_ops/cpu/bert/attention.cc | 2 - .../cpu/quantization/attention_quant.cc | 2 - .../cpu/quantization/dynamic_quantize_lstm.cc | 3 +- .../cpu/quantization/matmul_nbits.cc | 56 ------ .../contrib_ops/cpu/skip_layer_norm.cc | 1 - onnxruntime/contrib_ops/cpu/skip_layer_norm.h | 2 +- .../contrib_ops/cuda/diffusion/group_norm.cc | 1 - .../contrib_ops/cuda/diffusion/group_norm.h | 1 - .../qordered_ops/qordered_attention.cc | 1 - .../qordered_ops/qordered_attention.h | 1 - .../qordered_ops/qordered_matmul.cc | 1 - .../qordered_ops/qordered_matmul.h | 1 - onnxruntime/core/framework/session_options.h | 6 - onnxruntime/core/framework/session_state.cc | 85 ++------- onnxruntime/core/framework/session_state.h | 33 +--- .../core/framework/session_state_utils.cc | 13 +- .../core/framework/session_state_utils.h | 4 +- .../framework/tensor_external_data_info.cc | 2 - .../framework/tensor_external_data_info.h | 3 - .../core/framework/tensorprotoutils.cc | 29 +-- onnxruntime/core/framework/tensorprotoutils.h | 12 +- onnxruntime/core/framework/utils.cc | 6 - onnxruntime/core/framework/utils.h | 2 - onnxruntime/core/graph/graph.cc | 175 ++++++------------ onnxruntime/core/graph/model.cc | 29 +-- onnxruntime/core/graph/model.h | 24 +-- .../core/providers/cpu/fp16/fp16_conv.cc | 2 - onnxruntime/core/providers/cpu/math/gemm.cc | 3 +- onnxruntime/core/providers/cpu/math/gemm.h | 1 - onnxruntime/core/providers/cpu/math/matmul.cc | 1 - onnxruntime/core/providers/cpu/math/matmul.h | 1 - .../core/providers/cpu/nn/conv_transpose.cc | 2 - .../core/providers/cpu/nn/conv_transpose.h | 1 - .../core/providers/cpu/nn/layer_norm_impl.cc | 1 - .../core/providers/cpu/nn/layer_norm_impl.h | 2 +- .../cpu/quantization/matmul_integer_base.h | 1 - .../providers/cpu/quantization/qlinearconv.cc | 2 - .../core/providers/cpu/rnn/deep_cpu_gru.cc | 1 - .../core/providers/cpu/rnn/deep_cpu_gru.h | 3 +- .../core/providers/cpu/rnn/deep_cpu_lstm.cc | 4 +- .../core/providers/cpu/rnn/deep_cpu_lstm.h | 1 - onnxruntime/core/providers/cuda/nn/conv.cc | 1 - onnxruntime/core/providers/cuda/nn/conv.h | 1 - .../core/providers/cuda/nn/conv_transpose.cc | 3 +- .../core/providers/cuda/nn/conv_transpose.h | 1 - .../core/providers/js/operators/conv.h | 1 - .../providers/js/operators/conv_transpose.h | 2 - .../core/providers/xnnpack/math/gemm.cc | 1 - .../core/providers/xnnpack/math/gemm.h | 1 - .../core/providers/xnnpack/math/matmul.cc | 1 - .../core/providers/xnnpack/math/matmul.h | 1 - onnxruntime/core/providers/xnnpack/nn/conv.cc | 1 - onnxruntime/core/providers/xnnpack/nn/conv.h | 1 - .../providers/xnnpack/nn/conv_transpose.cc | 1 - .../providers/xnnpack/nn/conv_transpose.h | 1 - onnxruntime/core/session/inference_session.cc | 40 +--- .../test/framework/inference_session_test.cc | 54 ------ .../save_model_with_external_initializers.cc | 59 +----- .../test/framework/session_state_test.cc | 66 +------ onnxruntime/test/shared_lib/test_inference.cc | 83 --------- .../model_with_external_initializers.onnx | 9 +- .../model_with_external_initializers.py | 3 +- .../testdata/model_with_orig_ext_data.onnx | 9 +- .../test/testdata/prepack/MatMul.Weight.bin | Bin 8 -> 0 bytes ...xternal_initializers_and_prepack_kernel.py | 88 --------- .../prepack/model_with_matmul_nbits.onnx | Bin 333 -> 0 bytes orttraining/orttraining/models/bert/main.cc | 1 - .../orttraining/models/pipeline_poc/main.cc | 1 - .../models/runner/training_runner.cc | 1 - 72 files changed, 137 insertions(+), 872 deletions(-) delete mode 100644 onnxruntime/test/testdata/prepack/MatMul.Weight.bin delete mode 100644 onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py delete mode 100644 onnxruntime/test/testdata/prepack/model_with_matmul_nbits.onnx diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index a17da2a19bb99..07625c38d8474 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -79,7 +79,6 @@ class OpKernel { // the allocator tied to the session if the kernel owns the pre-packed buffer or an // allocator shared between sessions if the pre-packed buffer is to be shared across sessions // (i.e.) the kernel does not own the buffer. - // @param save_prepacked_initializers: Set it to true if intend to save prepacked initializers to external data file. // @param is_packed: Set it to true if the kernel packed the tensor or to false // The kernel is responsible for keeping the packed data and related metadata if is_packed is true, // and the original initialized constant tensor will be released and not accessible anymore in @@ -89,7 +88,6 @@ class OpKernel { virtual Status PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, - bool, /*save_prepacked_initializers*/ /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; return Status::OK(); @@ -131,26 +129,6 @@ class OpKernel { return Status::OK(); } - // Override this function to get pre-packed tensors from this kernel. - // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from - // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory. - // @param input_idx : The index of input we prepacked before and intend to get packed tensor back. - // Please refer to matmul_nbits kernel for a complete example. - virtual std::optional GetPrePackTensor(int /*input_idx*/) { - return std::nullopt; - } - - // Override this function to set pre-packed tensors to this kernel and restore prepacked weight buffer. - // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from - // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory. - // Please refer to matmul_nbits kernel for a complete example. - // @param input_idx : The input index of the tensor in this kernel. - // @param pre_packed_tensor: The prepacked tensor read from onnx data file and use the prepacked tensor - // to restore prepacked weight buffer. - virtual Status SetPrePackTensor(int /*input_idx*/, const Tensor& /*pre_packed_tensor*/) { - return Status::OK(); - } - const OrtDevice GetDevice(OrtMemType mem_type) const; const OpKernelInfo& Info() const { return *op_kernel_info_; diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 69af3c93d7a07..eb9581e8018d1 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1148,11 +1148,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node); #endif - // Since one constant initializer could be used by different kernels - // and prepacked differently, use an unordered_map to store prepacked - // initializer in format of <[initializer_name], <[node_name], [prepacked_initializer]>> - typedef std::unordered_map> PrePackedTensorProtoToSave; - #if !defined(ORT_MINIMAL_BUILD) /** Gets the GraphProto representation of this Graph. */ const ONNX_NAMESPACE::GraphProto& ToGraphProto(); @@ -1187,26 +1182,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi @param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved in the external file. Initializer smaller than this threshold are included in the onnx file. @param align_info offset alignment info. - @param save_prepacked_constant_initializers whether to save prepacked initializer into external data file. - If set false to this boolean, prepacked initializer will not be saved into onnxruntime data file, - we keep constant initializer as it is. - @param pre_packed_initializers struct used to store all the prepacked initializers. @returns GraphProto serialization of the graph. */ ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, const std::filesystem::path& model_file_path, size_t initializer_size_threshold, - const OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - PrePackedTensorProtoToSave& pre_packed_initializers) const; + const OffsetAlignmentInfo& align_info) const; ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, const std::filesystem::path& model_file_path, size_t initializer_size_threshold) const { OffsetAlignmentInfo default_options; - PrePackedTensorProtoToSave pre_packed_initializers; - return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options, - false, pre_packed_initializers); + return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options); } /** Gets the ISchemaRegistry instances being used with this Graph. */ @@ -1521,18 +1508,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi private: void InitializeStateFromModelFileGraphProto(); - // Private method used to setup external initializer properly during model save, - // this external initializer could be oroginal initializer or prepacked initializer. - static void SetUpExternalInitializer(const Graph::OffsetAlignmentInfo& align_info, - size_t tensor_bytes_size, - int64_t& external_offset, - std::ofstream& external_stream, - gsl::span raw_data, - ONNX_NAMESPACE::TensorProto& output_proto, - const std::filesystem::path& external_file_path, - const ONNX_NAMESPACE::TensorProto& initializer, - bool is_prepacked); - // Add node with specified . Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto, const ArgNameToTypeMap& name_to_type); diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 086919913cbea..6a01602e634f8 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -246,12 +246,6 @@ static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disab static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName = "session.optimized_model_external_initializers_file_name"; -// Use this config when save prepacked constant initializers to onnx external data file. -// Default is not save prepacked initializers to onnx data file. -// Sample usage: sess_options.add_session_config_entry('session.save_prepacked_constant_initializers', "1") -static const char* const kOrtSessionOptionsSavePrePackedConstantInitializers = - "session.save_prepacked_constant_initializers"; - // Use this config to control the minimum size of the initializer when externalizing it during serialization static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = "session.optimized_model_external_initializers_min_size_in_bytes"; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index b15e865aa423c..ad14fb8258656 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -30,7 +30,6 @@ class Attention : public OpKernel, public AttentionCPUBase { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -102,7 +101,6 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, template Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { /* The PrePack() massages the weights to speed up Compute(), there is an option to diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 71a66ea368943..2c897f183164f 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -24,7 +24,6 @@ class QAttention : public OpKernel, public AttentionCPUBase { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, bool& /*out*/ is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -59,7 +58,6 @@ QAttention::QAttention(const OpKernelInfo& info) : OpKernel(info), AttentionC template Status QAttention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { if (1 != input_idx) { diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc index 4148aae4b9a35..aa47f365c0005 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc @@ -13,7 +13,7 @@ class DynamicQuantizeLSTM : public OpKernel, public LSTMBase { DynamicQuantizeLSTM(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {} Status PrePack(const Tensor& tensor, int input_idx, - AllocatorPtr alloc, bool save_prepacked_initializers, /*out*/ bool& is_packed, + AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, @@ -91,7 +91,6 @@ static void UseSharedPrePackedBuffersImpl(std::vector& prepacke } Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index cee3dfc6b3f28..89e96543c4729 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -98,19 +98,12 @@ class MatMulNBits final : public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; - void ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx); - Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) override; - std::optional GetPrePackTensor(int /*input_idx*/) override; - - Status SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) override; - private: const size_t K_; const size_t N_; @@ -126,8 +119,6 @@ class MatMulNBits final : public OpKernel { size_t packed_b_size_{0}; IAllocatorUniquePtr scales_fp32_{}; IAllocatorUniquePtr bias_fp32_{}; - std::optional packed_tensor_{std::nullopt}; - MLDataType prepack_tensor_data_type_; bool has_zp_input_{false}; @@ -157,22 +148,8 @@ class MatMulNBits final : public OpKernel { } }; -template -void MatMulNBits::ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx) { - if (input_idx == InputIndex::B) { - prepack_tensor_data_type_ = tensor.DataType(); - } - - TensorShapeVector weights_dims = {static_cast((packed_b_size_ - 1) / prepack_tensor_data_type_->Size()) + 1}; - packed_tensor_ = Tensor(prepack_tensor_data_type_, - TensorShape(weights_dims), - packed_b_.get(), - OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); -} - template Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); @@ -208,16 +185,11 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All #endif // MLAS_TARGET_AMD64_IX86 } - if (save_prepacked_initializers) { - ConvertPrepackWeightIntoTensor(tensor, input_idx); - } - return Status::OK(); } template <> Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); @@ -267,34 +239,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou #endif // MLAS_TARGET_AMD64_IX86 } - if (save_prepacked_initializers) { - ConvertPrepackWeightIntoTensor(tensor, input_idx); - } - - return Status::OK(); -} - -template -std::optional MatMulNBits::GetPrePackTensor(int input_idx) { - // For this kernel, prepack is performed on input_B, and possibly scales, zeros_points. - // During compute process, scales and zeros_points will keep as it is and only use prepacked - // buffer to replace input_B. - // Inorder to cope with this logic, we need to return latest prepacked buffer and only serialize - // the latest one. So, we need to always return packed_tensor_ here not only for input_B. - ORT_UNUSED_PARAMETER(input_idx); - return std::move(packed_tensor_); -} - -template -Status MatMulNBits::SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) { - if (input_idx == 1) { - // pre_packed_tensor is constant initialized tensor and its lifecycle is managed by session_state, - // session_state will release memory from pre_packed_tensor. packed_b_ will not release memory so - // pass empty/default buffer deleter here. - // const_cast here is temporary, will fix in follow up PR. - packed_b_ = BufferUniquePtr(const_cast(pre_packed_tensor.DataRaw()), BufferDeleter()); - } - return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index c9ee9e2cb760d..67b4950af73bf 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -278,7 +278,6 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { template Status SkipLayerNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index d904c14857437..08e2276c3d9d5 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -16,7 +16,7 @@ class SkipLayerNorm final : public OpKernel { SkipLayerNorm(const OpKernelInfo& op_kernel_info); Status Compute(OpKernelContext* p_op_kernel_context) const override; - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers, + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights) override; private: diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index d190ed389f3e9..dea5391c7629b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -95,7 +95,6 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { } Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/, - bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 4505c066baedb..b408b3c1ee79b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -17,7 +17,6 @@ class GroupNorm final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, bool& is_packed, PrePackedWeights* prepacked_weights) override; private: diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index aa2c8755f6536..3e93a527877c5 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -99,7 +99,6 @@ Status QOrderedAttention::PutIntoMergedBias(const Tensor& tensor, AllocatorPtr a } Status QOrderedAttention::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h index 529fd00307d66..9d4e563c1feab 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h @@ -20,7 +20,6 @@ class QOrderedAttention final : public CudaKernel, public AttentionBase { public: Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc index 351e36b884540..a64f628f245e6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc @@ -51,7 +51,6 @@ QOrderedMatMul::QOrderedMatMul(const OpKernelInfo& info) : CudaKernel(info) { } Status QOrderedMatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h index d1cef99779e09..dcb6cc6374be1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h @@ -18,7 +18,6 @@ class QOrderedMatMul final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 18405231750ba..8d4db36106f28 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -83,11 +83,6 @@ struct SessionOptions { // enable profiling for this session. bool enable_profiling = false; - // save pre-packed constant external initializers instead of original initializers to onnxruntime data file. - // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from - // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory. - bool save_prepacked_constant_initializers = false; - // Non empty filepath enables serialization of the transformed optimized model to the specified filepath. // // Set session config value for ORT_SESSION_OPTIONS_CONFIG_SAVE_MODEL_FORMAT to 'ORT' or 'ONNX' to explicitly @@ -196,7 +191,6 @@ inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_ << " execution_mode:" << session_options.execution_mode << " execution_order:" << session_options.execution_order << " enable_profiling:" << session_options.enable_profiling - << " save_prepacked_constant_initializers:" << session_options.save_prepacked_constant_initializers << " optimized_model_filepath:" << ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath) << " enable_mem_pattern:" << session_options.enable_mem_pattern << " enable_mem_reuse:" << session_options.enable_mem_reuse diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 943db091b341f..0d0b22ff61e01 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -14,7 +14,6 @@ #include "core/framework/op_kernel.h" #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/session_state_utils.h" -#include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/controlflow/utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -398,18 +397,12 @@ static std::string GenerateKeyForPrepackedWeightsMap(const std::string& op_type, } Status SessionState::PrepackConstantInitializedTensors(InlinedHashMap& constant_initializers_use_count, - const std::unordered_map& initializers_to_share_map, - bool save_prepacked_constant_initializers, - PrePackInitializers& pre_packed_initializers) { - auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map, - save_prepacked_constant_initializers, &pre_packed_initializers]( + const std::unordered_map& initializers_to_share_map) { + auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map]( bool should_cache_prepacked_weights_for_shared_initializers) -> Status { - std::unordered_map pre_packed_kernel_input_map; for (auto& node : GetGraphViewer().Nodes()) { auto kernel = GetMutableKernel(node.Index()); - auto kernel_name = kernel->Info().node().Name(); int input_idx = 0; - bool is_kernel_prepacked = false; for (auto& input_def : node.InputDefs()) { if (input_def->Exists()) { const std::string& input_name = input_def->Name(); @@ -421,27 +414,16 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapGetOrtValueNameIdxMap().GetIdx(input_name, ort_value_idx).IsOK()) { std::unordered_map& constant_initialized_tensors = st->constant_initialized_tensors_; - if (constant_initialized_tensors.count(ort_value_idx) && !is_kernel_prepacked) { + if (constant_initialized_tensors.count(ort_value_idx)) { bool is_packed = false; const Tensor& const_initialized_tensor = constant_initialized_tensors[ort_value_idx].Get(); auto iter = initializers_to_share_map.find(input_name); bool is_shared_initializer = (iter != initializers_to_share_map.end()); - // found pre-packed constant initializers from data file, no need to do pre-packing again - // apply pre-packed tensor to kernel so kernel can use it directly - if (pre_packed_initializers.pre_packed_initializer_names_read_from_file.count(input_name) != 0) { - is_packed = true; - - // kernel like Matmul_nbits will call prepack multiple times with input_B and possibly scales/zero_points. - // If prepacked weights already read from ONNX data file (this happens we ORT reads data file with prepacked - // weights serialized), only need to set prepacked weights once to kernel. - is_kernel_prepacked = true; - ORT_THROW_IF_ERROR(kernel->SetPrePackTensor(input_idx, const_initialized_tensor)); - } // Caching pre-packed weights is limited to shared initializers associated with the CPU EP for now - else if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers && - node.GetExecutionProviderType() == kCpuExecutionProvider) { // caching of pre-packed weights' turned ON + if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers && + node.GetExecutionProviderType() == kCpuExecutionProvider) { // caching of pre-packed weights' turned ON AllocatorPtr allocator_for_caching = prepacked_weights_container_->GetOrCreateAllocator(CPU); ORT_ENFORCE(allocator_for_caching.get() != nullptr); @@ -453,7 +435,7 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapPrePack(const_initialized_tensor, input_idx, allocator_for_caching, - save_prepacked_constant_initializers, is_packed, + is_packed, &weights_to_be_filled_in)); if (is_packed) { @@ -500,50 +482,18 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapInfo().GetDevice(OrtMemType::OrtMemTypeDefault)); ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, session_cpu_alloc, // use allocator tied to this session - save_prepacked_constant_initializers, is_packed, nullptr // no caching required )); } if (is_packed) { - // if intended to save prepacked initializers, get prepacked tensors from kernel and save in hashmap, - // will save to data file later - if (save_prepacked_constant_initializers) { - auto tensor = kernel->GetPrePackTensor(input_idx); - - if (tensor != std::nullopt) { - // save prepacked initializers per initializer and kernel since one initializer could - // be used by multiple kernels - pre_packed_initializers.pre_packed_initializers_to_save[input_name][kernel_name] = std::move(tensor.value()); - - pre_packed_kernel_input_map[kernel_name] = input_name; - } - } - ++number_of_prepacks_counter_; - // if constant_initialized_tensor is already pre-packed, don't need to remove it - if (pre_packed_initializers.pre_packed_initializer_names_read_from_file.count(input_name) == 0 && - constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) { + if (constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) { // release the constant initialized tensor st->initialized_tensors_.erase(ort_value_idx); constant_initialized_tensors.erase(ort_value_idx); } - } else { - // handle prepack for matmul_nbits, it will prepack several times but set is_packed - // to false for scales and zero_points, we keep scales and zero_points as it is only - // update packed_tensor to input_B. - // TODO: this logic works with matmul_nbits kernel but if other kernels also call prepack - // multiple times and use different initializers to store prepacked weights, this piece of logic - // might introduce bug and need a per kernel strategy to update prepacked weights. - if (save_prepacked_constant_initializers && pre_packed_kernel_input_map.count(kernel_name)) { - auto tensor = kernel->GetPrePackTensor(input_idx); - - if (tensor != std::nullopt) { - auto existing_input_name = pre_packed_kernel_input_map[kernel_name]; - pre_packed_initializers.pre_packed_initializers_to_save[existing_input_name][kernel_name] = std::move(tensor.value()); - } - } } } // stop searching in 2 cases: @@ -1226,7 +1176,6 @@ static Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging:: Status SessionState::FinalizeSessionState(const std::basic_string& graph_location, const KernelRegistryManager& kernel_registry_manager, - PrePackInitializers& pre_packed_initializers, bool remove_initializers, bool saving_ort_format) { // recursively create the subgraph session state instances and populate the kernel create info in them. @@ -1240,7 +1189,7 @@ Status SessionState::FinalizeSessionState(const std::basic_string constant_initializers_use_count; ComputeConstantInitializerUseCount(graph_, constant_initializers_use_count); return FinalizeSessionStateImpl(graph_location, kernel_registry_manager, nullptr, sess_options_, - remove_initializers, constant_initializers_use_count, pre_packed_initializers); + remove_initializers, constant_initializers_use_count); } static Status Index(const OrtValueNameIdxMap& ort_value_name_idx_map, @@ -1374,7 +1323,6 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string& constant_initializers_use_count, - PrePackInitializers& pre_packed_initializers, const InlinedHashMap& outer_scope_node_arg_to_location_map, bool graph_info_already_created) { if (!graph_info_already_created) { @@ -1474,8 +1422,6 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string> - typedef std::unordered_map> PrePackedTensorsToSave; - PrePackedTensorsToSave pre_packed_initializers_to_save; - - // This set is used during model load with prepacked initializer serialized in external data file. - // ORT reads prepacked initializers and store their name into this set so we could skip PrePack - // process later to save heap memory. Prepacked tensor itself is saved in session state's constant_initialized_tensors_. - typedef std::unordered_set PrePackedTensorNamesReadFromFile; - PrePackedTensorNamesReadFromFile pre_packed_initializer_names_read_from_file; - }; - Status FinalizeSessionState(const std::basic_string& graph_loc, const KernelRegistryManager& kernel_registry_manager, - PrePackInitializers& pre_packed_initializers, bool remove_initializers = true, bool saving_ort_format = false); @@ -338,15 +321,6 @@ class SessionState { return parent_; } - Status FinalizeSessionState(const std::basic_string& graph_loc, - const KernelRegistryManager& kernel_registry_manager, - bool remove_initializers = true, - bool saving_ort_format = false) { - PrePackInitializers pre_packed_initializers; - return FinalizeSessionState(graph_loc, kernel_registry_manager, pre_packed_initializers, - remove_initializers, saving_ort_format); - } - // Clear all removable attributes if they exists. // The function logs the list of removable attributes for every node. void PruneRemovableAttributes(); @@ -406,13 +380,9 @@ class SessionState { /** * Prepack the constant initialized tensors for better performance. * The original constant initialized tensors will be removed to save memory. - * For model with prepacked initializer serialized into ONNX data file, - * PrePack will be skipped to save memory. */ Status PrepackConstantInitializedTensors(InlinedHashMap& constant_initializers_use_count, - const std::unordered_map& initializers_to_share_map, - bool save_prepacked_constant_initializers, - PrePackInitializers& pre_packed_initializers); + const std::unordered_map& initializers_to_share_map); SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name); @@ -430,7 +400,6 @@ class SessionState { const SessionOptions& session_options, bool remove_initializers, InlinedHashMap& constant_initializers_use_count, - PrePackInitializers& pre_packed_initializers, const InlinedHashMap& outer_scope_node_arg_to_location_map = {}, bool graph_info_already_created = false); diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 3424f40e79c01..2c74805c57dce 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -21,6 +21,7 @@ #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/ort_value_name_idx_map.h" #include "core/framework/sequential_execution_plan.h" +#include "core/framework/session_state.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/framework/bfc_arena.h" @@ -71,7 +72,6 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor, OrtCallback& ext_data_deleter, - SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set, Tensor* buffered_tensor = nullptr) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); @@ -79,7 +79,7 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env, SafeInt ext_data_len = 0; ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, ext_data_buf, ext_data_len, ext_data_deleter, - &pre_packed_initializers_name_set, buffered_tensor)); + buffered_tensor)); // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be // avoided if the Tensor class implements the do-nothing behavior when given a @@ -100,7 +100,6 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, const ExternalDataLoaderManager& external_data_loader_mgr, - SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set, bool use_device_allocator_for_initializers = false, Tensor* buffered_tensor = nullptr) { if (bool(alloc) == (m != nullptr)) { @@ -140,7 +139,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st // TensorProtoToTensor it would copy the data, causing unnecessary overhead OrtCallback ext_data_deleter; ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, - ext_data_deleter, pre_packed_initializers_name_set, buffered_tensor)); + ext_data_deleter, buffered_tensor)); ExtDataValueDeleter deleter{ext_data_deleter, p_tensor.get()}; MLDataType ml_tensor_type = DataTypeImpl::GetType(); @@ -164,7 +163,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st OrtCallback ext_data_deleter; std::optional scoped_ort_callback_invoker; ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor, - ext_data_deleter, pre_packed_initializers_name_set, buffered_tensor)); + ext_data_deleter, buffered_tensor)); scoped_ort_callback_invoker = ScopedOrtCallbackInvoker(ext_data_deleter); // TODO!! Need a temp buffer allocator for non-escape buffers that maybe too big for stack allocation. @@ -273,8 +272,7 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, - std::unordered_map>& buffered_tensors, - SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set) { + std::unordered_map>& buffered_tensors) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); @@ -403,7 +401,6 @@ common::Status SaveInitializedTensors( Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr, - pre_packed_initializers_name_set, use_device_allocator_for_initializers, p_tensor); if (!st.IsOK()) { std::ostringstream oss; diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index 4de501b6f7429..af27f5caba0f4 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -12,7 +12,6 @@ #include "core/framework/tensor.h" #include "core/framework/tensor_allocator.h" #include "core/framework/session_options.h" -#include "core/framework/session_state.h" #include "core/framework/sequential_execution_plan.h" #include "core/platform/path_lib.h" @@ -51,8 +50,7 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, - std::unordered_map>& buffered_tensors, - SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set); + std::unordered_map>& buffered_tensors); common::Status AllocateTensor( const onnxruntime::MemBuffer* m, diff --git a/onnxruntime/core/framework/tensor_external_data_info.cc b/onnxruntime/core/framework/tensor_external_data_info.cc index bcd04effe2bd4..93146e66d9f24 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.cc +++ b/onnxruntime/core/framework/tensor_external_data_info.cc @@ -40,8 +40,6 @@ Status ExternalDataInfo::Create(const RepeatedPtrField& return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "parsing ", stringmap.value(), " failed"); } else if (stringmap.key() == "checksum" && !stringmap.value().empty()) { out->checksum_ = stringmap.value(); - } else if (stringmap.key() == "prepacked" && !stringmap.value().empty()) { - out->prepacked_ = stringmap.value() == "1"; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error!"); } diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h index c2490f5cc5bc2..afc8fda6c3037 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.h +++ b/onnxruntime/core/framework/tensor_external_data_info.h @@ -23,8 +23,6 @@ class ExternalDataInfo { const std::string& GetChecksum() const { return checksum_; } - bool GetPrePacked() const noexcept { return prepacked_; } - // If the value of 'offset' or 'length' field is larger the max value of ssize_t, this function will treat it as a // wrong value and return FAIL. static common::Status Create( @@ -38,6 +36,5 @@ class ExternalDataInfo { // 0 means the whole file size_t length_ = 0; std::string checksum_; - bool prepacked_ = false; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 0c69ee11f62bc..2af9f95ad059e 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -230,12 +230,11 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo namespace utils { -static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const std::filesystem::path& tensor_proto_dir, - std::basic_string& external_file_path, - onnxruntime::FileOffsetType& file_offset, - SafeInt& tensor_byte_size, - bool& pre_packed) { +Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size) { ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), "Tensor does not have external data to read from."); @@ -245,8 +244,6 @@ static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_prot std::unique_ptr external_data_info; ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); - pre_packed = external_data_info->GetPrePacked(); - const auto& location = external_data_info->GetRelPath(); external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) @@ -268,11 +265,6 @@ void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::str tensor_proto.set_raw_data(std::move(param)); } -Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::basic_string& external_file_path, onnxruntime::FileOffsetType& file_offset, SafeInt& tensor_byte_size) { - bool pre_packed = false; - return GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size, pre_packed); -} - void ConvertRawDataInTensorProto(TensorProto* tensor) { size_t element_size = 1; char* bytes = NULL; @@ -996,7 +988,7 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, OrtCallback& ext_data_deleter, - SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile* pre_packed_initializers_name_set, Tensor* buffered_tensor) { + Tensor* buffered_tensor) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); std::basic_string tensor_proto_dir; if (!model_path.empty()) { @@ -1005,13 +997,8 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo std::basic_string external_data_file_path; FileOffsetType file_offset; SafeInt raw_data_safe_len = 0; - bool pre_packed = false; ORT_RETURN_IF_ERROR( - GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len, pre_packed)); - - if (pre_packed && pre_packed_initializers_name_set != nullptr) { - (*pre_packed_initializers_name_set).insert(tensor_proto.name()); - } + GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len)); if (external_data_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag) { // the value in location is the memory address of the data @@ -1121,7 +1108,7 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa OrtCallback& d = deleter_for_file_data.d; if (utils::HasExternalData(tensor_proto)) { - ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, raw_data, raw_data_len, d, nullptr)); + ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, raw_data, raw_data_len, d)); } else if (utils::HasRawData(tensor_proto)) { raw_data = const_cast(tensor_proto.raw_data().data()); // TODO The line above has const-correctness issues. Below is a possible fix which copies the tensor_proto data diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 770132f8e95fc..262f7adaca1cb 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -17,19 +17,26 @@ #include "core/framework/external_data_loader.h" #include "core/framework/ort_value.h" #include "core/framework/mem_buffer.h" -#include "core/framework/session_state.h" #include "core/framework/tensor_external_data_info.h" #include "core/graph/onnx_protobuf.h" #include "core/platform/env.h" namespace onnxruntime { namespace utils { +/** + * This function is used to get the external data info from the given tensor proto. + * @param tensor_proto given initializer tensor + * @param tensor_proto_dir directory of the tensor proto file + * @param external_file_path output external file path + * @param file_offset output tensor offset + * @param tensor_byte_size output tensor byte size + * @returns Status::OK() if the function is executed successfully + */ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::basic_string& external_file_path, onnxruntime::FileOffsetType& file_offset, SafeInt& tensor_byte_size); - /** * This function is used to convert the endianess of Tensor data. * Mostly, will be used in big endian system to support the model file @@ -165,7 +172,6 @@ common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem:: const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, OrtCallback& ext_data_deleter, - SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile* pre_packed_initializers_name_set, Tensor* buffered_tensor = nullptr); // Given a tensor proto with external data obtain a tensor using the specified custom external data loader. diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 5402345447706..9eed0249711f9 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -1064,11 +1064,5 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index return false; } -std::string GetPrepackedInitializerName(const std::string& initializer_name, const std::string& node_name) { - const std::string seperator = ":"; - - return initializer_name + seperator + node_name; -} - } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index db38ef1675595..afdb5a2cb27f5 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -234,8 +234,6 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); -std::string GetPrepackedInitializerName(const std::string& initializer_name, const std::string& node_name); - #ifdef ENABLE_TRAINING common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context); #endif diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 3f50841f50913..e8a5855b36496 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4084,75 +4084,10 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { return result; } -void Graph::SetUpExternalInitializer(const Graph::OffsetAlignmentInfo& align_info, - size_t tensor_bytes_size, - int64_t& external_offset, - std::ofstream& external_stream, - gsl::span raw_data, - ONNX_NAMESPACE::TensorProto& output_proto, - const std::filesystem::path& external_file_path, - const ONNX_NAMESPACE::TensorProto& initializer, - bool is_prepacked) { - // update external_offset for alignment - // need to do padding before write actual tensor data as we do offset alignment at the begin of - // large tensors (offset need to be page aligned and alloction granularity aligned) like below: - // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX - // |<---small tensor---->|<---padding--->|<------------------large tensor----------------------------->| - if (align_info.align_offset && static_cast(tensor_bytes_size) > align_info.align_threshold) { - // Align to the larger of the page size or the allocation granularity - int64_t alignment_factor = std::max(static_cast(4096), align_info.allocation_granularity); - // Align to the next page or alloc granularity boundary - int64_t new_external_offset = static_cast( - std::floor((external_offset + alignment_factor - 1) / alignment_factor)) * - alignment_factor; - - // padding tensor with zeros for alignment - InlinedVector paddings; - size_t padding_size = SafeInt(new_external_offset - external_offset); - paddings.reserve(padding_size); - for (size_t index = 0; index != padding_size; ++index) { - paddings.push_back(0x0); - } - external_stream.write(reinterpret_cast(paddings.data()), padding_size); - - external_offset = new_external_offset; - } - - external_stream.write(reinterpret_cast(raw_data.data()), tensor_bytes_size); - - output_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); - ONNX_NAMESPACE::StringStringEntryProto* location = output_proto.add_external_data(); - location->set_key("location"); - location->set_value(ToUTF8String(external_file_path.native())); - ONNX_NAMESPACE::StringStringEntryProto* offset = output_proto.add_external_data(); - offset->set_key("offset"); - offset->set_value(std::to_string(external_offset)); - ONNX_NAMESPACE::StringStringEntryProto* length = output_proto.add_external_data(); - length->set_key("length"); - length->set_value(std::to_string(tensor_bytes_size)); - - if (is_prepacked) { - ONNX_NAMESPACE::StringStringEntryProto* pre_packed = output_proto.add_external_data(); - pre_packed->set_key("prepacked"); - pre_packed->set_value("1"); - } - - output_proto.set_name(initializer.name()); - output_proto.set_data_type(initializer.data_type()); - for (int i = 0; i != initializer.dims_size(); ++i) { - output_proto.add_dims(initializer.dims(i)); - } - output_proto.set_doc_string(initializer.doc_string()); - - external_offset += tensor_bytes_size; -} - ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, const std::filesystem::path& model_file_path, size_t initializer_size_threshold, - const OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - PrePackedTensorProtoToSave& pre_packed_initializers) const { + const OffsetAlignmentInfo& align_info) const { GraphProto result; ToGraphProtoInternal(result); ORT_ENFORCE(external_file_path.is_relative()); @@ -4171,34 +4106,6 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std #endif for (const auto& initializer : graph_proto_->initializer()) { - bool use_pre_packed_initializer = false; - InlinedVector pre_packed_initializers_tensor_proto; - // If this initializer has been prepacked, saved prepacked external initializer instead of original one. - // Since one initializer could be used by multiple kernels and been prepacked differently, - // Save each prepacked initializers seperately, chagne the initializer name to [initializer_name]:[node_name] - // to avoid conflict. Change the node input name accordingly. - // IT could potentially make the ONNX data file larger since we store multiple prepacked initializers into disk - // but this could be rare case. - if (save_prepacked_constant_initializers && pre_packed_initializers.count(initializer.name())) { - for (const auto& item : pre_packed_initializers[initializer.name()]) { - auto& node_name = item.first; - std::string prepacked_initializer_name = utils::GetPrepackedInitializerName(initializer.name(), node_name); - pre_packed_initializers_tensor_proto.push_back(item.second); - use_pre_packed_initializer = true; - - for (auto& node : *result.mutable_node()) { - if (node.name() == node_name) { - int input_index = 0; - for (const auto& input : node.input()) { - if (input == initializer.name()) { - node.set_input(input_index, prepacked_initializer_name); - } - input_index += 1; - } - } - } - } - } #if !defined(DISABLE_SPARSE_TENSORS) if (sparse_end != sparse_tensor_names_.find(initializer.name())) { // Sparse tensors are added to the ONNX file. @@ -4207,39 +4114,61 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); } else { #endif - if (use_pre_packed_initializer) { - for (const auto& pre_packed_initializer : pre_packed_initializers_tensor_proto) { - // Dense tensors larger than the threshold are added to the external file. - TensorProto* output_proto = result.add_initializer(); - std::vector raw_data; - size_t tensor_bytes_size = 0; - - ORT_THROW_IF_ERROR(utils::UnpackInitializerData(pre_packed_initializer, model_path, raw_data)); - tensor_bytes_size = raw_data.size(); - if (tensor_bytes_size < initializer_size_threshold) { - *output_proto = pre_packed_initializer; - continue; - } + // Dense tensors larger than the threshold are added to the external file. + TensorProto* output_proto = result.add_initializer(); + + std::vector raw_data; + ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, model_path, raw_data)); + size_t tensor_bytes_size = raw_data.size(); + if (tensor_bytes_size < initializer_size_threshold) { + *output_proto = initializer; + continue; + } - SetUpExternalInitializer(align_info, tensor_bytes_size, external_offset, external_stream, - raw_data, *output_proto, external_file_path, pre_packed_initializer, true); - } - } else { - // Dense tensors larger than the threshold are added to the external file. - TensorProto* output_proto = result.add_initializer(); - std::vector raw_data; - size_t tensor_bytes_size = 0; - - ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, model_path, raw_data)); - tensor_bytes_size = raw_data.size(); - if (tensor_bytes_size < initializer_size_threshold) { - *output_proto = initializer; - continue; + // update external_offset for alignment + // need to do padding before write actual tensor data as we do offset alignment at the begin of + // large tensors (offset need to be page aligned and alloction granularity aligned) like below: + // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX + // |<---small tensor---->|<---padding--->|<------------------large tensor----------------------------->| + if (align_info.align_offset && static_cast(tensor_bytes_size) > align_info.align_threshold) { + // Align to the larger of the page size or the allocation granularity + int64_t alignment_factor = std::max(static_cast(4096), align_info.allocation_granularity); + // Align to the next page or alloc granularity boundary + int64_t new_external_offset = static_cast( + std::floor((external_offset + alignment_factor - 1) / alignment_factor)) * + alignment_factor; + + // padding tensor with zeros for alignment + for (int64_t index = external_offset; index != new_external_offset; ++index) { + external_stream << '0'; } - SetUpExternalInitializer(align_info, tensor_bytes_size, external_offset, external_stream, - raw_data, *output_proto, external_file_path, initializer, false); + external_offset = new_external_offset; } + + for (size_t index = 0; index != tensor_bytes_size; ++index) { + external_stream << raw_data[index]; + } + + output_proto->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); + ONNX_NAMESPACE::StringStringEntryProto* location = output_proto->add_external_data(); + location->set_key("location"); + location->set_value(ToUTF8String(external_file_path.native())); + ONNX_NAMESPACE::StringStringEntryProto* offset = output_proto->add_external_data(); + offset->set_key("offset"); + offset->set_value(std::to_string(external_offset)); + ONNX_NAMESPACE::StringStringEntryProto* length = output_proto->add_external_data(); + length->set_key("length"); + length->set_value(std::to_string(tensor_bytes_size)); + + output_proto->set_name(initializer.name()); + output_proto->set_data_type(initializer.data_type()); + for (int i = 0; i != initializer.dims_size(); ++i) { + output_proto->add_dims(initializer.dims(i)); + } + output_proto->set_doc_string(initializer.doc_string()); + + external_offset += tensor_bytes_size; #if !defined(DISABLE_SPARSE_TENSORS) } #endif diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index ad1ec9c8dedb3..1bae63b510563 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -384,17 +384,13 @@ ModelProto Model::ToProto() const { ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers) const { + const Graph::OffsetAlignmentInfo& align_info) const { ModelProto result(model_proto_); const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, - align_info, - save_prepacked_constant_initializers, - pre_packed_initializers); + align_info); return result; } @@ -612,9 +608,7 @@ static Status SaveModelWithExternalInitializers(Model& model, const T& file_path, const std::filesystem::path& external_file_name, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers) { + const Graph::OffsetAlignmentInfo& align_info) { int fd = 0; Status status = Env::Default().FileOpenWr(file_path, fd); ORT_RETURN_IF_ERROR(status); @@ -622,8 +616,7 @@ static Status SaveModelWithExternalInitializers(Model& model, ORT_TRY { status = Model::SaveWithExternalInitializers(model, fd, file_path, external_file_name, initializer_size_threshold, - align_info, save_prepacked_constant_initializers, - pre_packed_initializers); + align_info); } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { @@ -654,12 +647,9 @@ Status Model::Load(const PathString& file_path, std::shared_ptr& p_model, Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, const std::filesystem::path& external_file_name, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers) { + const Graph::OffsetAlignmentInfo& align_info) { return SaveModelWithExternalInitializers(model, file_path, external_file_name, initializer_size_threshold, - align_info, save_prepacked_constant_initializers, - pre_packed_initializers); + align_info); } Status Model::LoadFromBytes(int count, const void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { @@ -776,9 +766,7 @@ Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, const std::filesystem::path& external_file_name, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers) { + const Graph::OffsetAlignmentInfo& align_info) { if (fd < 0) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, " is less than 0."); } @@ -787,8 +775,7 @@ Status Model::SaveWithExternalInitializers(Model& model, auto model_proto = model.ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, - align_info, save_prepacked_constant_initializers, - pre_packed_initializers); + align_info); google::protobuf::io::FileOutputStream output(fd); const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); if (result) { diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 38d9044ff9d31..9bcec6f78ca08 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -191,17 +191,13 @@ class Model { ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers) const; + const Graph::OffsetAlignmentInfo& align_info) const; ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold) const { Graph::OffsetAlignmentInfo default_align_info; - Graph::PrePackedTensorProtoToSave pre_packed_initializers; - return ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, default_align_info, - false, pre_packed_initializers); + return ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, default_align_info); } static common::Status Save(Model& model, const PathString& file_path); @@ -214,18 +210,14 @@ class Model { const std::filesystem::path& file_path, const std::filesystem::path& external_file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers); + const Graph::OffsetAlignmentInfo& align_info); static common::Status SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, const std::filesystem::path& external_file_path, size_t initializer_size_threshold) { Graph::OffsetAlignmentInfo default_align_info; - Graph::PrePackedTensorProtoToSave pre_packed_initializers; - return SaveWithExternalInitializers(model, file_path, external_file_path, initializer_size_threshold, default_align_info, - false, pre_packed_initializers); + return SaveWithExternalInitializers(model, file_path, external_file_path, initializer_size_threshold, default_align_info); } static common::Status SaveWithExternalInitializers(Model& model, @@ -233,9 +225,7 @@ class Model { const std::filesystem::path& file_path, const std::filesystem::path& external_file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - bool save_prepacked_constant_initializers, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers); + const Graph::OffsetAlignmentInfo& align_info); static common::Status SaveWithExternalInitializers(Model& model, int fd, @@ -243,9 +233,7 @@ class Model { const std::filesystem::path& external_file_path, size_t initializer_size_threshold) { Graph::OffsetAlignmentInfo default_align_info; - Graph::PrePackedTensorProtoToSave pre_packed_initializers; - return SaveWithExternalInitializers(model, fd, file_path, external_file_path, initializer_size_threshold, default_align_info, - false, pre_packed_initializers); + return SaveWithExternalInitializers(model, fd, file_path, external_file_path, initializer_size_threshold, default_align_info); } static common::Status Load(std::istream& model_istream, ONNX_NAMESPACE::ModelProto* p_model_proto); diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index 0a1a3a5995872..37db095e92570 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -51,7 +51,6 @@ class FusedConvFp16 final : public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, @@ -102,7 +101,6 @@ class FusedConvFp16 final : public OpKernel { }; Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index dbc7becdf2397..5406dd1a40446 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -248,7 +248,6 @@ template void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE template Status Gemm::PrePack(const Tensor& /* tensor */, int /* input_idx */, AllocatorPtr /*alloc_for_caching*/, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weight_for_caching*/) { is_packed = false; @@ -257,7 +256,7 @@ Status Gemm::PrePack(const Tensor& /* tensor */, int /* input_idx */, Allocat template <> Status Gemm::PrePack(const Tensor& tensor, int input_idx, - AllocatorPtr alloc, bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, + AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index 92f05a7921f8b..953949732560d 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -21,7 +21,6 @@ class Gemm : protected GemmBase, public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 8f2c2c53b188b..2c6d23e4de908 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -173,7 +173,6 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc, #endif Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index 0bb0e6c2ef596..b9bbe36583879 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -37,7 +37,6 @@ class MatMul final : public OpKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc index 2c7afddf38070..f0c1b0b409831 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc @@ -38,7 +38,6 @@ ONNX_CPU_OPERATOR_KERNEL( template Status ConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/ ) { @@ -48,7 +47,6 @@ Status ConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, Al template <> Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.h b/onnxruntime/core/providers/cpu/nn/conv_transpose.h index d03b5566e334f..c82cd5ad49d7e 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.h @@ -28,7 +28,6 @@ class ConvTranspose : public OpKernel { ConvTranspose(const OpKernelInfo& info) : OpKernel(info), conv_transpose_attrs_(info) {} Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index fe2bf1035bb65..24a5dcab225c4 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -229,7 +229,6 @@ Status LayerNormImpl::Compute(OpKernelContext* p_ctx) const { } Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h index abce87d03c14b..f8b528b398cba 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h @@ -15,7 +15,7 @@ class LayerNormImpl : public OpKernel { LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified = false, bool contrib_op = false); Status Compute(OpKernelContext* p_op_kernel_context) const override; - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers, + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights) override; // This method was created so that it can be called directly from `test/onnx/microbenchmark/layer_normalization.cc`. diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h index 8a8ce27990069..e26eae19b8fd4 100644 --- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h @@ -14,7 +14,6 @@ class MatMulIntegerBase : public OpKernel { MatMulIntegerBase(const OpKernelInfo& info) : OpKernel(info) {} Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index 736cde24591ff..7797cbe678bd4 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -25,7 +25,6 @@ class QLinearConv : public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -361,7 +360,6 @@ REGISTER_QLINEARCONV_INT8_KERNEL(kMSDomain, 1); template Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index 7afd00eacef89..b78c5236e6fab 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -284,7 +284,6 @@ bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& } Status DeepCpuGruOp::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h index 914077b2f2c15..5a6dd97c7c3f2 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h @@ -62,7 +62,6 @@ class DeepCpuGruOp final : public OpKernel { private: Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -198,4 +197,4 @@ class UniDirectionalGru { }; } // namespace detail -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index e4082e5d7634a..09bbf6c4c79e6 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -225,9 +225,7 @@ static void UseSharedPrePackedBuffersImpl(std::vector& prepacke } Status DeepCpuLstmOp::PrePack(const Tensor& tensor, int input_idx, - AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, - /*out*/ bool& is_packed, + AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h index ff8ab9abf0eed..9c4c12954022a 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h @@ -19,7 +19,6 @@ class DeepCpuLstmOp final : public OpKernel, public LSTMBase { DeepCpuLstmOp(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {} Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 45a1d3bbc0414..3129f519da2e5 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -52,7 +52,6 @@ REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) // First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 6294566af3cb9..e4047a6af272e 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -219,7 +219,6 @@ class Conv : public CudaKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, bool& is_packed, PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 9c9a83460daeb..2972ae999adc4 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -45,8 +45,7 @@ REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) // First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template -Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, bool& is_packed, +Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) { is_packed = false; // only layout of weight input is adjusted via PrePack diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index f23c2b94501f2..1a6957164d22f 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -22,7 +22,6 @@ class ConvTranspose : public CudaKernel { ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info) {}; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 276b600cf40d2..b04df44954295 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -78,7 +78,6 @@ class ConvBase : public JsKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) override { is_packed = false; diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index baa93f825a203..5ff52e8fda4fa 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -126,10 +126,8 @@ class ConvTranspose : public JsKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) override { - ORT_UNUSED_PARAMETER(save_prepacked_initializers); is_packed = false; if (input_idx == 1) { diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.cc b/onnxruntime/core/providers/xnnpack/math/gemm.cc index 68b55030c7363..35a06cb7eb89f 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.cc +++ b/onnxruntime/core/providers/xnnpack/math/gemm.cc @@ -117,7 +117,6 @@ Gemm::Gemm(const OpKernelInfo& info) : GemmBase(info), XnnpackKernel(info, /*ena } Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights*) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.h b/onnxruntime/core/providers/xnnpack/math/gemm.h index d632eef015f9a..954aab0698b9c 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.h +++ b/onnxruntime/core/providers/xnnpack/math/gemm.h @@ -23,7 +23,6 @@ class Gemm : protected GemmBase, public XnnpackKernel { static bool IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph); Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.cc b/onnxruntime/core/providers/xnnpack/math/matmul.cc index 71a11cb05d9af..44a6fb4ee835a 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.cc +++ b/onnxruntime/core/providers/xnnpack/math/matmul.cc @@ -78,7 +78,6 @@ MatMul::MatMul(const OpKernelInfo& info) : XnnpackKernel(info, /*enable_caches*/ } Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*Not used*/) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.h b/onnxruntime/core/providers/xnnpack/math/matmul.h index 31a8c36ad418b..188cc73189af5 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.h +++ b/onnxruntime/core/providers/xnnpack/math/matmul.h @@ -23,7 +23,6 @@ class MatMul : public XnnpackKernel { // Required for checking XNNpack restrictions on ORT side static bool IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph); Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index f2e697df475da..4e6b308e28ae5 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -18,7 +18,6 @@ namespace xnnpack { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.h b/onnxruntime/core/providers/xnnpack/nn/conv.h index 762b68c8bd49a..3630aae208d49 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.h +++ b/onnxruntime/core/providers/xnnpack/nn/conv.h @@ -19,7 +19,6 @@ class Conv : public ConvBase { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; }; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc index 5729565b2feb9..b6930a5fc92d1 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc @@ -15,7 +15,6 @@ namespace xnnpack { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h index 0313515d10fa1..866b9b6b98365 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h @@ -18,7 +18,6 @@ class ConvTranspose : public ConvBase { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index bc5db98e7c595..2ff9fa525fa3b 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2062,11 +2062,9 @@ common::Status InferenceSession::Initialize() { #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } - SessionState::PrePackInitializers pre_packed_initializers; ORT_RETURN_IF_ERROR_SESSIONID_( session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_, // need to keep the initializers if saving the optimized model - pre_packed_initializers, !saving_model, saving_ort_format)); @@ -2102,47 +2100,11 @@ common::Status InferenceSession::Initialize() { kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, "1024")); Graph::OffsetAlignmentInfo align_info; align_info.align_offset = true; - bool save_prepacked_constant_initializers = - session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsSavePrePackedConstantInitializers, "0") == "1" ? true : false; - Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; - if (save_prepacked_constant_initializers) { - LOGS(*session_logger_, WARNING) << "Serialize prepacked initializers option has been turn on." - << "Use this option only when run model inference on PC with CPU." - << "Make sure to save and load model in same device as prepack is device specific." - << "Note: this feature in only work with ONNX model format." - << "Process of use this option is like below:" - << "1. Optimize model with external data file with save_prepacked_constant_initializers on:" - << " sample: sess_options.add_session_config_entry('session.save_prepacked_constant_initializers', ' 1 ')" - << " With save_prepacked_constant_initializers option, prepacked initializer will be serialized into data file." - << "2. Load optimized model and external data file in same device, no prepack is need." - << "3. Run inference with optimized model."; - - if (fbs::utils::IsOrtFormatModel(session_options_.optimized_model_filepath)) { - ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Unable to serialize prepacked external constant initializer for ORT format model." - "Please use ONNX format model with save_prepacked_constant_initializers.")); - } - - // convert pre_packed_initializers to tensorproto format and save to external data file - for (const auto& name_item_pair : pre_packed_initializers.pre_packed_initializers_to_save) { - auto initializer_name = name_item_pair.first; - - for (const auto& kernel_name_initializer_item_pair : name_item_pair.second) { - auto kernel_name = kernel_name_initializer_item_pair.first; - auto prepacked_initializer_name = utils::GetPrepackedInitializerName(initializer_name, kernel_name); - - pre_packed_initializers_tensor_proto[initializer_name][kernel_name] = utils::TensorToTensorProto(kernel_name_initializer_item_pair.second, prepacked_initializer_name); - } - } - } ORT_RETURN_IF_ERROR_SESSIONID_(Model::SaveWithExternalInitializers(*model_, session_options_.optimized_model_filepath, optimized_model_external_initializers_file_name, optimized_model_external_initializers_min_size_in_bytes, - align_info, - save_prepacked_constant_initializers, - pre_packed_initializers_tensor_proto)); + align_info)); } } } diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 9c7e6e9761728..c6a81e8a1c1ad 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -46,7 +46,6 @@ #include "core/session/environment.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" -#include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" #include "dummy_provider.h" @@ -66,8 +65,6 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; using namespace onnxruntime::concurrency; -extern std::unique_ptr ort_env; - namespace { struct KernelRegistryAndStatus { std::shared_ptr kernel_registry = std::make_shared(); @@ -500,57 +497,6 @@ TEST(InferenceSessionTests, TestModelSerialization) { ASSERT_TRUE(session_object_emptyValidation.Initialize().IsOK()); } -// Test feature serialize prepack weight is only used in PC with CPU on inference, -// disable this test for training, other device and eps -#if !ENABLE_TRAINING && !defined(USE_CUDA) && !defined(__wasm__) && !defined(USE_DNNL) && !defined(USE_QNN) && !defined(__ANDROID__) && !defined(USE_COREML) -// MLAS dispatcher used in matmul_nbits kernels here is 64 bit only -#if defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64) -TEST(InferenceSessionTests, TestPrePackSerialization) { - SessionOptions so; - std::string model_name = "model_with_matmul_nbits"; - - const std::string test_model = "testdata/prepack/" + model_name + ".onnx"; - const std::string optimized_model = "testdata/prepack/" + model_name + "_opt.onnx"; - - so.session_logid = "InferenceSessionTests.TestPrepackSerialization"; - so.enable_cpu_mem_arena = false; - so.graph_optimization_level = TransformerLevel::Default; - so.optimized_model_filepath = optimized_model; - std::string external_initializer_file_name = model_name + "_opt.onnx.data"; - - // enable serialize prepack initializer to data file - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsSavePrePackedConstantInitializers, - "1")); - // always save external initializer to data file for test - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, - "0")); - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersFileName, - external_initializer_file_name.c_str())); - - // optimize model with serialize prepack constant initializers - InferenceSessionWrapper session_object{so, GetEnvironment()}; - ASSERT_TRUE(session_object.Load(test_model).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); - - // Verify prepack initializers are serialized into optimized model and data file - // load optimized model and check initializer are prepacked - auto logger = DefaultLoggingManager().CreateLogger("TestPrepackSerialization"); - std::shared_ptr model; - auto load_status = Model::Load(ToWideString(optimized_model), model, nullptr, *logger); - ASSERT_EQ(Status::OK(), load_status); - Graph& graph = model->MainGraph(); - - bool found_prepack_initializer = false; - for (const auto& item : graph.GetAllInitializedTensors()) { - if (item.first.find(':') != std::string::npos) { - found_prepack_initializer = true; - } - } - ASSERT_TRUE(found_prepack_initializer); -} -#endif -#endif - #ifdef ORT_RUN_EXTERNAL_ONNX_TESTS static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) { if (f_arg.size() != s_arg.size()) { diff --git a/onnxruntime/test/framework/save_model_with_external_initializers.cc b/onnxruntime/test/framework/save_model_with_external_initializers.cc index 0f76cb61ace74..d0bc088175755 100644 --- a/onnxruntime/test/framework/save_model_with_external_initializers.cc +++ b/onnxruntime/test/framework/save_model_with_external_initializers.cc @@ -7,7 +7,6 @@ #include "core/framework/data_types.h" #include "core/graph/model.h" #include "core/framework/tensorprotoutils.h" -#include "core/framework/session_state.h" #include "test/test_environment.h" #include "test_utils.h" #include "test/util/include/asserts.h" @@ -20,34 +19,19 @@ using namespace onnxruntime; namespace onnxruntime { namespace test { -std::vector split(const std::string& str, char delimiter) { - std::vector result; - std::stringstream ss(str); - std::string token; - - // Use getline with a delimiter to split the string - while (std::getline(ss, token, delimiter)) { - result.push_back(token); - } - - return result; -} - Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, const std::filesystem::path& input_external_init_file, const std::filesystem::path& output_onnx, const std::filesystem::path& output_external_init_file, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info, - Graph::PrePackedTensorProtoToSave& pre_packed_initializers_tensor_proto, - bool save_prepacked_constant_initializers = false) { + const Graph::OffsetAlignmentInfo& align_info) { auto logger = DefaultLoggingManager().CreateLogger("LoadSaveAndCompareModel"); std::shared_ptr model; ORT_RETURN_IF_ERROR(Model::Load(input_onnx, model, nullptr, *logger)); std::filesystem::remove(output_onnx); std::filesystem::remove(output_external_init_file); ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(*model, output_onnx, output_external_init_file, initializer_size_threshold, - align_info, save_prepacked_constant_initializers, pre_packed_initializers_tensor_proto)); + align_info)); std::shared_ptr model_from_external; ORT_RETURN_IF_ERROR(Model::Load(output_onnx.native(), model_from_external, nullptr, *logger)); @@ -66,11 +50,10 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, // Compare the initializers of the two versions. std::filesystem::path model_path{}; std::filesystem::path external_data_path{}; - for (const auto& i : initializers_from_external) { + for (const auto& i : initializers) { const std::string kInitName = i.first; - const ONNX_NAMESPACE::TensorProto* from_external_tensor_proto = i.second; - // prepack initializer will have name as [original name]:[kernel name] in case initializer used by multiple kernels - const ONNX_NAMESPACE::TensorProto* tensor_proto = save_prepacked_constant_initializers ? initializers[split(kInitName, ':')[0]] : initializers[kInitName]; + const ONNX_NAMESPACE::TensorProto* tensor_proto = i.second; + const ONNX_NAMESPACE::TensorProto* from_external_tensor_proto = initializers_from_external[kInitName]; std::vector tensor_proto_data; model_path = input_onnx; @@ -92,12 +75,8 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, ORT_RETURN_IF_NOT(from_external_tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL, "location mismatch"); } - if (!save_prepacked_constant_initializers) { - ORT_RETURN_IF_NOT(tensor_proto_size == from_external_tensor_proto_size, "size mismatch"); - ORT_RETURN_IF_NOT(memcmp(tensor_proto_data.data(), from_external_tensor_proto_data.data(), tensor_proto_size) == 0, "data mismatch"); - } else { - ORT_RETURN_IF_NOT(from_external_tensor_proto_size >= tensor_proto_size, "prepack initializer's size is at least same as original tensor, might be larger"); - } + ORT_RETURN_IF_NOT(tensor_proto_size == from_external_tensor_proto_size, "size mismatch"); + ORT_RETURN_IF_NOT(memcmp(tensor_proto_data.data(), from_external_tensor_proto_data.data(), tensor_proto_size) == 0, "data mismatch"); if (align_info.align_offset) { for (const StringStringEntryProto& entry : from_external_tensor_proto->external_data()) { @@ -110,7 +89,6 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, } } } - // Cleanup. ORT_RETURN_IF_NOT(std::filesystem::remove(output_onnx), "delete file failed"); ORT_RETURN_IF_NOT(std::filesystem::remove(external_data_path), "delete file failed"); @@ -120,15 +98,13 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, // Original model does not have external initializers TEST(SaveWithExternalInitializers, Mnist) { Graph::OffsetAlignmentInfo align_info; - Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/mnist.onnx"), ORT_TSTR(""), ORT_TSTR("testdata/mnist_with_external_initializers.onnx"), ORT_TSTR("mnist_external_initializers.bin"), 100, align_info, pre_packed_initializers_tensor_proto)); + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/mnist.onnx"), ORT_TSTR(""), ORT_TSTR("testdata/mnist_with_external_initializers.onnx"), ORT_TSTR("mnist_external_initializers.bin"), 100, align_info)); } // Original model has external initializers TEST(SaveWithExternalInitializers, ModelWithOriginalExternalData) { Graph::OffsetAlignmentInfo align_info; - Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info, pre_packed_initializers_tensor_proto)); + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info)); } // Original model has external initializers, align offset @@ -136,22 +112,7 @@ TEST(SaveWithExternalInitializers, ModelWithOriginalExternalDataAlignOffset) { Graph::OffsetAlignmentInfo align_info; align_info.align_offset = true; align_info.align_threshold = 0; - Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info, pre_packed_initializers_tensor_proto)); -} - -// Original model has external initializers, align offset and serialize prepacked external initializer to model file -TEST(SaveWithExternalInitializers, ModelWithOriginalExternalDataAlignOffsetAndSavePrepackTensors) { - Graph::OffsetAlignmentInfo align_info; - align_info.align_offset = true; - align_info.align_threshold = 0; - std::shared_ptr alloc = std::make_shared(); - TensorShape shape = {178}; - // prepack both initializers for test purpose - Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; - pre_packed_initializers_tensor_proto["MatMul.Weight"]["MatMul_0"] = utils::TensorToTensorProto(Tensor(DataTypeImpl::GetType(), shape, alloc), "MatMul.Weight:MatMul_0"); - pre_packed_initializers_tensor_proto["scales"]["MatMul_0"] = utils::TensorToTensorProto(Tensor(DataTypeImpl::GetType(), shape, alloc), "scales:MatMul_0"); - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/prepack/model_with_matmul_nbits.onnx"), ORT_TSTR("model_with_matmul_nbits.onnx.data"), ORT_TSTR("testdata/prepack/model_with_matmul_nbits_opt.onnx"), ORT_TSTR("model_with_matmul_nbits_opt.onnx.data"), 0, align_info, pre_packed_initializers_tensor_proto, true)); + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info)); } } // namespace test diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 6265eccb7bd9b..b94d24a1b180b 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -372,11 +372,10 @@ class PrePackingTestOpKernel : public OpKernel { return Status::OK(); } - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers, + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override { ORT_UNUSED_PARAMETER(tensor); ORT_UNUSED_PARAMETER(input_idx); - ORT_UNUSED_PARAMETER(save_prepacked_initializers); size_t weight_packed_len = 8; weight_packed_ = IAllocator::MakeUniquePtr(alloc, weight_packed_len, true); @@ -394,20 +393,9 @@ class PrePackingTestOpKernel : public OpKernel { return Status::OK(); } - std::optional GetPrePackTensor(int input_idx) override { - ORT_UNUSED_PARAMETER(input_idx); - ++get_prepack_tensors_count; - - TensorShape shape = {2}; - packed_tensor = Tensor(DataTypeImpl::GetType(), shape, std::make_shared()); - return std::move(packed_tensor); - } - int prepack_calls_count = 0; int store_pre_packed_weight_calls_count = 0; - int get_prepack_tensors_count = 0; IAllocatorUniquePtr weight_packed_; - Tensor packed_tensor; }; static void CreateSimpleGraph(Graph& graph) { @@ -542,7 +530,6 @@ static void PlaceAllNodesToCPUEP(Graph& graph) { struct PrepackingTestParam { bool test_subgraph; bool test_prepacking; - bool test_save_prepack_initializer; }; class SessionStatePrepackingTest : public testing::TestWithParam {}; @@ -585,8 +572,6 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { sess_options.enable_mem_reuse = true; sess_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = test_param.test_prepacking ? "0" : "1"; - sess_options.config_options.configurations[kOrtSessionOptionsSavePrePackedConstantInitializers] = - test_param.test_save_prepack_initializer ? "1" : "0"; SessionState session_state(model.MainGraph(), execution_providers, @@ -612,47 +597,12 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { kernel_registry_manager.RegisterKernelRegistry(kernel_registry); PlaceAllNodesToCPUEP(model.MainGraph()); - SessionState::PrePackInitializers pre_packed_initializers; ASSERT_STATUS_OK(session_state.FinalizeSessionState(std::basic_string(), - kernel_registry_manager, - pre_packed_initializers)); + kernel_registry_manager)); const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors(); // check prepacking ASSERT_EQ(const_initialized_tensors.size(), size_t(test_param.test_prepacking ? 0 : 1)); - - // check get prepack tensor method called when set save_prepacked_constant_initializers - if (!test_param.test_subgraph) { - const auto* kernel = reinterpret_cast(session_state.GetKernel(0)); - ASSERT_EQ(kernel->get_prepack_tensors_count, (test_param.test_prepacking && test_param.test_save_prepack_initializer) ? 1 : 0); - } else { - auto if_index = 1; - if (session_state.GetKernel(0)->Node().OpType() == "If") { - if_index = 0; - } - - const auto& subgraph_session_states = session_state.GetSubgraphSessionStateMap(); - const auto& if_node_session_states = subgraph_session_states.at(if_index); - const auto& session_state_1_then_branch_session_state = *if_node_session_states.at("then_branch"); - const auto& session_state_1_else_branch_session_state = *if_node_session_states.at("else_branch"); - - const auto* kernel_if_0 = reinterpret_cast(session_state_1_then_branch_session_state.GetKernel(0)); - const auto* kernel_if_1 = reinterpret_cast(session_state_1_else_branch_session_state.GetKernel(0)); - ASSERT_EQ(kernel_if_0->get_prepack_tensors_count, (test_param.test_prepacking && test_param.test_save_prepack_initializer) ? 1 : 0); - ASSERT_EQ(kernel_if_1->get_prepack_tensors_count, (test_param.test_prepacking && test_param.test_save_prepack_initializer) ? 1 : 0); - } - - // check pre_packed_initializers_to_save will be set properly when set save_prepacked_constant_initializers - if (!test_param.test_subgraph && test_param.test_prepacking && test_param.test_save_prepack_initializer) { - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.size(), size_t(1)); - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.count("node_0_input_1"), size_t(1)); - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save["node_0_input_1"].count("node_0"), size_t(1)); - } else if (test_param.test_subgraph && test_param.test_prepacking && test_param.test_save_prepack_initializer) { - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.size(), size_t(1)); - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.count("if_shared"), size_t(1)); - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save["if_shared"].count("if_node_1"), size_t(1)); - ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save["if_shared"].count("if_node_0"), size_t(1)); - } } class SessionStateTestSharedInitalizersWithPrePacking : public ::testing::Test { @@ -1050,14 +1000,10 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) { INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStatePrepackingTest, - testing::Values(PrepackingTestParam{false, false, false}, - PrepackingTestParam{false, true, false}, - PrepackingTestParam{true, false, false}, - PrepackingTestParam{true, true, false}, - PrepackingTestParam{false, false, true}, - PrepackingTestParam{false, true, true}, - PrepackingTestParam{true, false, true}, - PrepackingTestParam{true, true, true})); + testing::Values(PrepackingTestParam{false, false}, + PrepackingTestParam{false, true}, + PrepackingTestParam{true, false}, + PrepackingTestParam{true, true})); #endif } // namespace test diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index e19362e0ec32d..0be1c0b1965ac 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -4600,86 +4600,3 @@ TEST(CApiTest, OrtCustomOp_GetInPlace) { ASSERT_EQ(len, static_cast(2)); mock_gqa.ReleaseAliasMap(input_index, output_index); } - -TEST(CApiTest, Serialize_PrePack_Initializers) { - std::string model_name = "model_with_matmul_nbits"; - - const std::string test_model = "testdata/prepack/" + model_name + ".onnx"; - const std::string optimized_model = "testdata/prepack/" + model_name + "_opt.onnx"; - std::string external_initializer_file_name = model_name + "_opt.onnx.data"; - - // Generate optimized with prepacked weights serialized - Ort::SessionOptions session_options_opt; - session_options_opt.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersFileName, external_initializer_file_name.c_str()); - session_options_opt.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, "0"); - session_options_opt.AddConfigEntry(kOrtSessionOptionsSavePrePackedConstantInitializers, "1"); - -#if defined(_WIN32) || defined(_WIN64) - std::wstring test_model_wide = onnxruntime::ToWideString(test_model); - session_options_opt.SetOptimizedModelFilePath(onnxruntime::ToWideString(optimized_model).c_str()); - Ort::Session session_opt_model(*ort_env, test_model_wide.c_str(), session_options_opt); -#else - session_options_opt.SetOptimizedModelFilePath(optimized_model.c_str()); - Ort::Session session_opt_model(*ort_env, test_model.c_str(), session_options_opt); -#endif - - // Do inference with original model and optimized model and check output is identical - // set inputs and session options - Ort::SessionOptions session_options; - const char* input_names[] = {"A"}; - const char* const output_names[] = {"Y"}; - Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); - - std::vector ort_inputs; - std::vector input_0_data = {1.3f}; - std::vector input_0_dims = {1, 1}; - ort_inputs.emplace_back( - Ort::Value::CreateTensor(info, const_cast(input_0_data.data()), - input_0_data.size(), input_0_dims.data(), input_0_dims.size())); - - // run inference with original model - // Convert std::string to std::wstring -#if defined(_WIN32) || defined(_WIN64) - Ort::Session session(*ort_env, test_model_wide.c_str(), session_options); -#else - Ort::Session session(*ort_env, test_model.c_str(), session_options); -#endif - auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), - output_names, 1); - - // run inference with optimized model which load serialized prepack initializer -#if defined(_WIN32) || defined(_WIN64) - std::wstring optimized_model_wide = onnxruntime::ToWideString(optimized_model); - Ort::Session session_opt(*ort_env, optimized_model_wide.c_str(), session_options); -#else - Ort::Session session_opt(*ort_env, optimized_model.c_str(), session_options); -#endif - auto ort_outputs_opt = session_opt.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), - output_names, 1); - - // check output of original model and optimized model are equal - ASSERT_EQ(ort_outputs.size(), ort_outputs_opt.size()); - - for (size_t i = 0; i < ort_outputs.size(); ++i) { - const auto& sequences = ort_outputs[i]; - ASSERT_TRUE(sequences.IsTensor()); - - const auto& sequences_opt = ort_outputs_opt[i]; - ASSERT_TRUE(sequences_opt.IsTensor()); - - auto result_ts = sequences.GetTensorTypeAndShapeInfo(); - auto result_ts_opt = sequences_opt.GetTensorTypeAndShapeInfo(); - - ASSERT_EQ(result_ts.GetElementType(), result_ts_opt.GetElementType()); - - ASSERT_EQ(result_ts.GetShape(), result_ts_opt.GetShape()); - - const auto* result_vals = sequences.GetTensorData(); - auto result_span = gsl::make_span(result_vals, ort_outputs.size()); - - const auto* result_vals_opt = sequences_opt.GetTensorData(); - auto result_span_opt = gsl::make_span(result_vals_opt, ort_outputs_opt.size()); - - ASSERT_TRUE(std::equal(result_span_opt.begin(), result_span_opt.end(), result_span.begin(), result_span.end())); - } -} \ No newline at end of file diff --git a/onnxruntime/test/testdata/model_with_external_initializers.onnx b/onnxruntime/test/testdata/model_with_external_initializers.onnx index 3538f01b53c18..f815b4000f98f 100644 --- a/onnxruntime/test/testdata/model_with_external_initializers.onnx +++ b/onnxruntime/test/testdata/model_with_external_initializers.onnx @@ -1,8 +1,7 @@ - - onnx-example:� -, + onnx-example:� +& X -PadsYpad0"Pad* +PadsY"Pad* mode"constant� test-model*"BPadsj locationPads.binpZ @@ -17,4 +16,4 @@ test-model*"BPadsj Y   -B \ No newline at end of file +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/model_with_external_initializers.py b/onnxruntime/test/testdata/model_with_external_initializers.py index dc64d4a41424a..8d2589a9e6564 100644 --- a/onnxruntime/test/testdata/model_with_external_initializers.py +++ b/onnxruntime/test/testdata/model_with_external_initializers.py @@ -35,10 +35,9 @@ def GenerateModel(model_name, external_data_name): # noqa: N802 # Create a node (NodeProto) node_def = helper.make_node( - "Pad", # op type + "Pad", # node name ["X", external_data_name], # inputs ["Y"], # outputs - "pad0", # node name mode="constant", # Attributes ) diff --git a/onnxruntime/test/testdata/model_with_orig_ext_data.onnx b/onnxruntime/test/testdata/model_with_orig_ext_data.onnx index 47d0c68235099..6f9cce0bc5b4f 100644 --- a/onnxruntime/test/testdata/model_with_orig_ext_data.onnx +++ b/onnxruntime/test/testdata/model_with_orig_ext_data.onnx @@ -1,8 +1,7 @@ - - onnx-example:� -@ +  onnx-example:� +: X -model_with_orig_ext_dataYpad0"Pad* +model_with_orig_ext_dataY"Pad* mode"constant� test-model*JBmodel_with_orig_ext_dataj( locationmodel_with_orig_ext_data.binpZ @@ -17,4 +16,4 @@ test-model*JBmodel_with_orig_ext_dataj( Y   -B \ No newline at end of file +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/prepack/MatMul.Weight.bin b/onnxruntime/test/testdata/prepack/MatMul.Weight.bin deleted file mode 100644 index 0f8a571589c1050d3b3e512801441efcb22cdf3c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8 KcmZ3@00966U;wND diff --git a/onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py b/onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py deleted file mode 100644 index 86af461edc2c4..0000000000000 --- a/onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os - -import numpy as np -import onnx -from onnx import TensorProto, helper -from onnx.external_data_helper import set_external_data -from onnx.numpy_helper import from_array - -M = 1 -K = 1 -N = 1 -q_cols = 1 -q_rows = 1 -q_scale_size = 1 - - -def create_external_data_tensor(value, tensor_name, data_type): - tensor = from_array(np.array(value)) - tensor.name = tensor_name - tensor_filename = f"{tensor_name}.bin" - set_external_data(tensor, location=tensor_filename) - - with open(os.path.join(tensor_filename), "wb") as data_file: - data_file.write(tensor.raw_data) - tensor.ClearField("raw_data") - tensor.data_location = onnx.TensorProto.EXTERNAL - tensor.data_type = data_type - return tensor - - -def create_internal_data_tensor(value, tensor_name, data_type): - tensor = helper.make_tensor(name=tensor_name, data_type=data_type, dims=value.shape, vals=value.flatten().tolist()) - print(tensor) - tensor.data_location = onnx.TensorProto.DEFAULT - return tensor - - -def GenerateMatmulNBitsModel(model_name, external_data_name): # noqa: N802 - A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [M, K]) # noqa: N806 - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [M, N]) # noqa: N806 - - # Create a node (NodeProto) - node_def = helper.make_node( - op_type="MatMulNBits", # op type - inputs=["A", external_data_name, "scales"], # inputs - outputs=["Y"], # outputs - name="MatMul_0", # node name - domain="com.microsoft", # Custom domain for this operator - accuracy_level=4, # Attributes - bits=4, # Attributes - block_size=32, # Attributes - K=K, # Attributes - N=N, # Attributes - ) - - # Create the graph (GraphProto) - graph_def = helper.make_graph( - [node_def], - "test-model-matmul4bits", - [A], - [Y], - [ - create_external_data_tensor([[171]], external_data_name, TensorProto.UINT8), - create_internal_data_tensor(np.array([1.5], dtype=np.float32), "scales", TensorProto.FLOAT), - ], - ) - - # Create the model - model_def = helper.make_model( - graph_def, - producer_name="onnx-example", - opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)], - ) - - print(f"The ir_version in model: {model_def.ir_version}\n") - print(f"The producer_name in model: {model_def.producer_name}\n") - print(f"The graph in model:\n{model_def.graph}") - onnx.checker.check_model(model_def) - print("The model is checked!") - with open(model_name, "wb") as model_file: - model_file.write(model_def.SerializeToString()) - - -if __name__ == "__main__": - GenerateMatmulNBitsModel("model_with_matmul_nbits.onnx", "MatMul.Weight") diff --git a/onnxruntime/test/testdata/prepack/model_with_matmul_nbits.onnx b/onnxruntime/test/testdata/prepack/model_with_matmul_nbits.onnx deleted file mode 100644 index 0e06a75a5a7e84e1fe2f090a4c7c6a513ed6344f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 333 zcmZ8cO-sZu6znTS!rR@_#t(`h$Q}zV9>l|5#9n3hD(j`BF={kv$|jZ7AK`D%f8tLw zrC#dc!MvF_j~Rk=ZrXNVh&|Jt607eJKLOze7i;F$y(;g7e0p|xU^!F5QrMo7QK>JM zvk`47>1<9AZZr6Ta6p?89b?Qm?{|#9*Gjwzl|{qB45P+d#wA5;l;N+nl^-HI_xftV zjV`t1J7dkGqbE*SS7`GfRH2#Ey}BIi`4s^INmxyzzMLWP|Cp1erRk(a*~qqo{K> o83n=5b@kV)3+@knYZ~L603{d_7^d;$_CHxg7$k9(;xuLgzX5qu+5i9m diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index ec7a458237c77..c4c7a98ba116a 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -42,7 +42,6 @@ static SessionOptions session_options = { ExecutionMode::ORT_SEQUENTIAL, // execution_mode ExecutionOrder::PRIORITY_BASED, // execution_order false, // enable_profiling - false, // save prepacked initializer ORT_TSTR(""), // optimized_model_filepath true, // enable_mem_pattern true, // enable_mem_reuse diff --git a/orttraining/orttraining/models/pipeline_poc/main.cc b/orttraining/orttraining/models/pipeline_poc/main.cc index 0e40d04ddac8c..1b7d6b9ea26f6 100644 --- a/orttraining/orttraining/models/pipeline_poc/main.cc +++ b/orttraining/orttraining/models/pipeline_poc/main.cc @@ -89,7 +89,6 @@ int main(int argc, char* argv[]) { ExecutionMode::ORT_SEQUENTIAL, // execution_mode ExecutionOrder::DEFAULT, // execution_order false, // enable_profiling - false, // save prepacked initializer ORT_TSTR(""), // optimized_model_filepath true, // enable_mem_pattern true, // enable_mem_reuse diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index 5a2f1cd13683e..dae6f613f4329 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -37,7 +37,6 @@ static SessionOptions SESSION_OPTION = { ExecutionMode::ORT_SEQUENTIAL, // execution_mode ExecutionOrder::PRIORITY_BASED, // execution_order false, // enable_profiling - false, // save prepacked initializer ORT_TSTR(""), // optimized_model_filepath true, // enable_mem_pattern true, // enable_mem_reuse From 1f3b675453e8412e5c084bfb95997967d0c2eec2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 11 Nov 2024 19:48:25 +0100 Subject: [PATCH 2/4] Fix MatMulBnFusion to exclude cases when tensors are not 2D tensors (#22762) ### Description Fixes #22512, MatMul, Add can be fused into a single Gemm even if tensors dimensions are > 2. The PR excludes that cases. ### Motivation and Context ORT crashes on valid models due to that unexpected fusion. --- .../core/optimizer/matmul_bn_fusion.cc | 17 ++++++++++ .../test/optimizer/graph_transform_test.cc | 29 ++++++++++++++++++ .../fuse-matmul-bn-directly-dont-fuse.onnx | Bin 0 -> 517 bytes 3 files changed, 46 insertions(+) create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index e944522c9c338..6b76dc626fba0 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -107,6 +107,22 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons return false; } + // Checks the first input of MatMul has 2 dimensions. + // The test for the second input is done in method Apply as it accesses the constant. + if (node.InputDefs()[0] == nullptr) { + // This should never happen but just in case. + return false; + } + auto shape_a = node.InputDefs()[0]->Shape(); + if (shape_a == nullptr) { + // We cannot shape the rank. It is better to avoid fusing. + return false; + } + if (shape_a->dim_size() != 2) { + // Gemm only supports 2D tensors. + return false; + } + // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. const auto& output_defs = batch_norm_node->OutputDefs(); if (output_defs.size() > 1) { @@ -165,6 +181,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& bias_tensor->dims_size() != 1 || mean_tensor->dims_size() != 1 || var_tensor->dims_size() != 1 || + matmul_b_tensor->dims_size() != 2 || scale_tensor->dims(0) != matmul_b_tensor->dims(1) || bias_tensor->dims(0) != matmul_b_tensor->dims(1) || mean_tensor->dims(0) != matmul_b_tensor->dims(1) || diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index ee3a1baade005..67d60ea3a4ff6 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1764,6 +1764,35 @@ TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) { } } +TEST_F(GraphTransformationTests, DoNotApplyFuseMatmulBNDirectly) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly-dont-fuse.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly-dont-fuse.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8ca8282572db8d474327e5f88d7e8a1f60c47e15 GIT binary patch literal 517 zcmd zSDarY#0wS3FD(JeE3x?|miU(DaQSngN^o%`<;52#C+4Jbu>)C2nTf? zq%5&Whz)9pkW*qwa)w`iQEp;RW>sPd&@n<{FpKkG&I7wutCoY6gH?c0DW&9@y}fs& zzrATnsojapNc+btLhLz8((JXXI_#F_bK0x1?6p(Ye{Q$n%?jIp7hc)TKdWx{F3r&H z+?8o_>{ Date: Mon, 11 Nov 2024 16:05:34 -0500 Subject: [PATCH 3/4] Fix warning - LegacyKeyValueFormat: "ENV key=value" should be used instead of legacy "ENV key value" format (#22800) ### Description This PR Fix warning - `LegacyKeyValueFormat: "ENV key=value" should be used instead of legacy "ENV key value" format` from all Dockerfile ### Motivation and Context --- dockerfiles/Dockerfile.migraphx | 2 +- dockerfiles/Dockerfile.openvino | 6 +++--- dockerfiles/Dockerfile.rocm | 2 +- dockerfiles/Dockerfile.tensorrt | 2 +- dockerfiles/Dockerfile.vitisai | 4 ++-- orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch | 2 +- .../github/linux/docker/Dockerfile.manylinux2_28_cuda | 4 ++-- .../github/linux/docker/Dockerfile.manylinux2_28_rocm | 2 +- .../docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 | 2 +- .../docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 | 2 +- .../Dockerfile.package_ubi8_cuda_tensorrt10_0_torch | 2 +- .../linux/docker/Dockerfile.package_ubuntu_2004_gpu | 2 +- .../docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg | 2 +- .../linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 | 4 ++-- .../linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 | 4 ++-- .../linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 | 4 ++-- .../linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 | 4 ++-- .../github/linux/docker/Dockerfile.ubuntu_openvino | 8 ++++---- .../github/linux/docker/Dockerfile.ubuntu_tensorrt_bin | 4 ++-- .../docker/inference/x86_64/default/cuda11/Dockerfile | 4 ++-- .../docker/inference/x86_64/default/cuda12/Dockerfile | 4 ++-- .../linux/docker/inference/x86_64/python/cuda/Dockerfile | 4 ++-- 22 files changed, 37 insertions(+), 37 deletions(-) diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx index c5d998d503899..876a07e4ffaf6 100644 --- a/dockerfiles/Dockerfile.migraphx +++ b/dockerfiles/Dockerfile.migraphx @@ -10,7 +10,7 @@ FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main -ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} +ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} RUN apt-get update &&\ apt-get install -y migraphx diff --git a/dockerfiles/Dockerfile.openvino b/dockerfiles/Dockerfile.openvino index 39e75a68a369f..d1ebdae3cbdd6 100644 --- a/dockerfiles/Dockerfile.openvino +++ b/dockerfiles/Dockerfile.openvino @@ -11,7 +11,7 @@ FROM openvino/ubuntu22_runtime:${OPENVINO_VERSION} AS builder ENV WORKDIR_PATH=/home/openvino WORKDIR $WORKDIR_PATH -ENV DEBIAN_FRONTEND noninteractive +ENV DEBIAN_FRONTEND=noninteractive ARG DEVICE=CPU ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime.git @@ -41,7 +41,7 @@ RUN tar cvf GPL_sources.tar.gz /sources # Deploy stage FROM openvino/ubuntu22_runtime:${OPENVINO_VERSION} -ENV DEBIAN_FRONTEND noninteractive +ENV DEBIAN_FRONTEND=noninteractive USER root COPY --from=builder /home/openvino/onnxruntime/build/Linux/Release/dist/*.whl ./ COPY --from=builder /GPL_sources.tar.gz ./ @@ -50,7 +50,7 @@ ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER RUN usermod -a -G video,users ${BUILD_USER} -ENV WORKDIR_PATH /home/${BUILD_USER} +ENV WORKDIR_PATH=/home/${BUILD_USER} WORKDIR ${WORKDIR_PATH} USER ${BUILD_USER} diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm index bef8d7a5f47d2..aca8c3feaff71 100644 --- a/dockerfiles/Dockerfile.rocm +++ b/dockerfiles/Dockerfile.rocm @@ -12,7 +12,7 @@ ARG ONNXRUNTIME_BRANCH=main WORKDIR /code -ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} +ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} # Prepare onnxruntime repository & build onnxruntime RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ diff --git a/dockerfiles/Dockerfile.tensorrt b/dockerfiles/Dockerfile.tensorrt index ef51d41c5ff1b..24947df6308a6 100644 --- a/dockerfiles/Dockerfile.tensorrt +++ b/dockerfiles/Dockerfile.tensorrt @@ -17,7 +17,7 @@ RUN apt-get update &&\ RUN unattended-upgrade WORKDIR /code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/code/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/code/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} # Prepare onnxruntime repository & build onnxruntime with TensorRT RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ diff --git a/dockerfiles/Dockerfile.vitisai b/dockerfiles/Dockerfile.vitisai index e11ab70a61332..c6226155e01e3 100644 --- a/dockerfiles/Dockerfile.vitisai +++ b/dockerfiles/Dockerfile.vitisai @@ -22,8 +22,8 @@ RUN apt-get update && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* -ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:$PATH -ENV LD_LIBRARY_PATH /opt/xilinx/xrt/lib:$LD_LIBRARY_PATH +ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:$PATH +ENV LD_LIBRARY_PATH=/opt/xilinx/xrt/lib:$LD_LIBRARY_PATH WORKDIR /code RUN . $VAI_ROOT/conda/etc/profile.d/conda.sh &&\ diff --git a/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch b/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch index 3a408e2265fe7..29b8812c979e4 100644 --- a/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch +++ b/orttraining/tools/amdgpu/Dockerfile.rocm4.3.1.pytorch @@ -46,7 +46,7 @@ RUN cd MLNX_OFED_LINUX-${MOFED_VERSION}-${MOFED_OS}-x86_64 && \ rm -r MLNX_OFED_LINUX-${MOFED_VERSION}-${MOFED_OS}-x86_64 ENV PATH=${OLD_PATH} -ENV unset OLD_PATH +ENV unset=OLD_PATH # python env RUN pip3 install --upgrade setuptools diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index 07885ba65af8a..77dd63298ff3c 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -42,5 +42,5 @@ ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH -ENV CUDA_MODULE_LOADING "LAZY" \ No newline at end of file +ENV PATH=/usr/local/dotnet:$PATH +ENV CUDA_MODULE_LOADING="LAZY" \ No newline at end of file diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index e4c3af05053ba..9a265b4249f0b 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -201,5 +201,5 @@ ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH +ENV PATH=/usr/local/dotnet:$PATH ENV ORTMODULE_ONNX_OPSET_VERSION=$OPSET_VERSION diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 index a9da75ea87f07..9de88d1664b82 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 @@ -9,7 +9,7 @@ ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 ARG TRT_VERSION=8.6.1.6-1.cuda11.8 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} RUN dnf install -y bash wget &&\ dnf clean dbcache diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 index 5f21c8cbb5dfa..c2bae5fd7ee59 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 @@ -9,7 +9,7 @@ ARG BASEIMAGE=nvidia/cuda:12.5.1-cudnn-devel-ubi8 ARG TRT_VERSION=10.6.0.26-1.cuda12.6 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /opt/python/cp310-cp310/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/opt/python/cp310-cp310/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} RUN dnf install -y bash wget &&\ dnf clean dbcache diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch index eea205797af79..e1203f55106ce 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch @@ -9,7 +9,7 @@ ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 ARG TRT_VERSION=10.6.0.26-1.cuda11.8 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /opt/python/cp38-cp38/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/opt/python/cp38-cp38/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} RUN dnf install -y bash wget &&\ dnf clean dbcache diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu index 34c700c22a7c9..81aeada6a4a46 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu @@ -10,7 +10,7 @@ ARG TRT_VERSION=10.6.0.26-1+cuda11.8 ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG}:${LD_LIBRARY_PATH} diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg index 1665a46d10f43..6ce5a59802641 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu_ffmpeg @@ -10,7 +10,7 @@ ARG TRT_VERSION=10.6.0.26-1+cuda11.8 ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 FROM $BASEIMAGE AS base ARG TRT_VERSION -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG}:${LD_LIBRARY_PATH} diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 index dfc057b129f91..3b4d36a9a8fd8 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 @@ -10,7 +10,7 @@ FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -82,7 +82,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 index 45c854f62cd37..22d5e3b0248a8 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_tensorrt10 @@ -10,7 +10,7 @@ FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -98,7 +98,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 index f63112039fe8e..6d35df72894d8 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 @@ -10,7 +10,7 @@ FROM nvidia/cuda:12.3.1-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -85,7 +85,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 index 53b1072ded8f4..819d9bab7be75 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 @@ -10,7 +10,7 @@ FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -98,7 +98,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino index 5f525c1310412..4c80e7a907630 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino @@ -12,10 +12,10 @@ RUN /tmp/scripts/install_python_deps.sh -p $PYTHON_VERSION -d EdgeDevice RUN apt update && apt install -y libnuma1 ocl-icd-libopencl1 && \ rm -rf /var/lib/apt/lists/* /tmp/scripts -ENV INTEL_OPENVINO_DIR /opt/intel/openvino_${OPENVINO_VERSION} -ENV LD_LIBRARY_PATH $INTEL_OPENVINO_DIR/runtime/lib/intel64:$INTEL_OPENVINO_DIR/runtime/3rdparty/tbb/lib:/usr/local/openblas/lib:$LD_LIBRARY_PATH -ENV OpenVINO_DIR $INTEL_OPENVINO_DIR/runtime/cmake -ENV IE_PLUGINS_PATH $INTEL_OPENVINO_DIR/runtime/lib/intel64 +ENV INTEL_OPENVINO_DIR=/opt/intel/openvino_${OPENVINO_VERSION} +ENV LD_LIBRARY_PATH=$INTEL_OPENVINO_DIR/runtime/lib/intel64:$INTEL_OPENVINO_DIR/runtime/3rdparty/tbb/lib:/usr/local/openblas/lib:$LD_LIBRARY_PATH +ENV OpenVINO_DIR=$INTEL_OPENVINO_DIR/runtime/cmake +ENV IE_PLUGINS_PATH=$INTEL_OPENVINO_DIR/runtime/lib/intel64 ENV DEBIAN_FRONTEND=noninteractive RUN cd /opt && mkdir -p intel && cd intel && \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin index 797495abef57b..4f58dc89333ba 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin @@ -10,7 +10,7 @@ FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu20.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.30.1-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ @@ -92,7 +92,7 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" +ENV CUDA_MODULE_LOADING="LAZY" ARG PARSER_CONFIG="" RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile index 6702474d75801..9be2ff7560bae 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile @@ -31,11 +31,11 @@ else \ echo "TRT_VERSION is none skipping Tensor RT Installation" ; \ fi -ENV PATH /usr/lib/jvm/msopenjdk-11/bin:$PATH +ENV PATH=/usr/lib/jvm/msopenjdk-11/bin:$PATH ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 -ENV CUDAHOSTCXX /opt/rh/gcc-toolset-11/root/usr/bin/g++ +ENV CUDAHOSTCXX=/opt/rh/gcc-toolset-11/root/usr/bin/g++ ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile index 4059de23b2480..c039c641bef27 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile @@ -35,11 +35,11 @@ fi ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 -ENV CUDAHOSTCXX /opt/rh/gcc-toolset-12/root/usr/bin/g++ +ENV CUDAHOSTCXX=/opt/rh/gcc-toolset-12/root/usr/bin/g++ ADD scripts /tmp/scripts RUN sed -i 's/enabled\s*=\s*1/enabled = 1\nexclude=dotnet* aspnet* netstandard*/g' /etc/yum.repos.d/ubi.repo && \ rpm -Uvh https://packages.microsoft.com/config/centos/8/packages-microsoft-prod.rpm && dnf install -y msopenjdk-11 && cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts -ENV PATH /usr/lib/jvm/msopenjdk-11/bin:$PATH +ENV PATH=/usr/lib/jvm/msopenjdk-11/bin:$PATH ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile index f6f3ad7384592..a69b98f86ba1b 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile @@ -32,8 +32,8 @@ else \ echo "TRT_VERSION is x${TRT_VERSION} skipping Tensor RT Installation" ; \ fi -ENV PATH /usr/local/cuda/bin:$PATH -ENV CUDA_MODULE_LOADING "LAZY" +ENV PATH=/usr/local/cuda/bin:$PATH +ENV CUDA_MODULE_LOADING="LAZY" ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && rm -rf /tmp/scripts From b1e0930eab19436590505790e331744804b12caa Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 11 Nov 2024 15:20:07 -0800 Subject: [PATCH 4/4] Fix build for linux python wheel (#22801) ### Description Fixes command for building Linux python packages by preventing an empty `-p` command-line option from being passed to a subsequent build script: https://github.com/microsoft/onnxruntime/blob/1f3b675453e8412e5c084bfb95997967d0c2eec2/tools/ci_build/github/linux/run_python_dockerbuild.sh#L37 ### Motivation and Context A recent [PR ](https://github.com/microsoft/onnxruntime/pull/22773)introduced a new optional command-line option (`-p`) to pass custom python exe paths. We need to check if the option is empty before forwarding the option to a separate build script. --- tools/ci_build/github/linux/run_python_dockerbuild.sh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tools/ci_build/github/linux/run_python_dockerbuild.sh b/tools/ci_build/github/linux/run_python_dockerbuild.sh index 8285776b58e4a..2fec98e569919 100755 --- a/tools/ci_build/github/linux/run_python_dockerbuild.sh +++ b/tools/ci_build/github/linux/run_python_dockerbuild.sh @@ -18,6 +18,10 @@ done mkdir -p "${HOME}/.onnx" DOCKER_SCRIPT_OPTIONS="-d ${DEVICE} -c ${BUILD_CONFIG}" +if [ "${PYTHON_EXES}" != "" ] ; then + DOCKER_SCRIPT_OPTIONS+=" -p ${PYTHON_EXES}" +fi + if [ "${BUILD_EXTR_PAR}" != "" ] ; then DOCKER_SCRIPT_OPTIONS+=" -x ${BUILD_EXTR_PAR}" fi @@ -34,7 +38,7 @@ docker run --rm \ -e ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION \ -e DEFAULT_TRAINING_PACKAGE_DEVICE \ $ADDITIONAL_DOCKER_PARAMETER \ - $DOCKER_IMAGE tools/ci_build/github/linux/build_linux_python_package.sh -c $BUILD_CONFIG -p $PYTHON_EXES $DOCKER_SCRIPT_OPTIONS + $DOCKER_IMAGE tools/ci_build/github/linux/build_linux_python_package.sh $DOCKER_SCRIPT_OPTIONS sudo rm -rf "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/onnxruntime" "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/pybind11" \ "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/models" "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/_deps" \