diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 07625c38d8474..a17da2a19bb99 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -79,6 +79,7 @@ class OpKernel { // the allocator tied to the session if the kernel owns the pre-packed buffer or an // allocator shared between sessions if the pre-packed buffer is to be shared across sessions // (i.e.) the kernel does not own the buffer. + // @param save_prepacked_initializers: Set it to true if intend to save prepacked initializers to external data file. // @param is_packed: Set it to true if the kernel packed the tensor or to false // The kernel is responsible for keeping the packed data and related metadata if is_packed is true, // and the original initialized constant tensor will be released and not accessible anymore in @@ -88,6 +89,7 @@ class OpKernel { virtual Status PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + bool, /*save_prepacked_initializers*/ /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; return Status::OK(); @@ -129,6 +131,26 @@ class OpKernel { return Status::OK(); } + // Override this function to get pre-packed tensors from this kernel. + // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from + // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory. + // @param input_idx : The index of input we prepacked before and intend to get packed tensor back. + // Please refer to matmul_nbits kernel for a complete example. + virtual std::optional GetPrePackTensor(int /*input_idx*/) { + return std::nullopt; + } + + // Override this function to set pre-packed tensors to this kernel and restore prepacked weight buffer. + // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from + // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory. + // Please refer to matmul_nbits kernel for a complete example. + // @param input_idx : The input index of the tensor in this kernel. + // @param pre_packed_tensor: The prepacked tensor read from onnx data file and use the prepacked tensor + // to restore prepacked weight buffer. + virtual Status SetPrePackTensor(int /*input_idx*/, const Tensor& /*pre_packed_tensor*/) { + return Status::OK(); + } + const OrtDevice GetDevice(OrtMemType mem_type) const; const OpKernelInfo& Info() const { return *op_kernel_info_; diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index eb9581e8018d1..69af3c93d7a07 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1148,6 +1148,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node); #endif + // Since one constant initializer could be used by different kernels + // and prepacked differently, use an unordered_map to store prepacked + // initializer in format of <[initializer_name], <[node_name], [prepacked_initializer]>> + typedef std::unordered_map> PrePackedTensorProtoToSave; + #if !defined(ORT_MINIMAL_BUILD) /** Gets the GraphProto representation of this Graph. */ const ONNX_NAMESPACE::GraphProto& ToGraphProto(); @@ -1182,18 +1187,26 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi @param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved in the external file. Initializer smaller than this threshold are included in the onnx file. @param align_info offset alignment info. + @param save_prepacked_constant_initializers whether to save prepacked initializer into external data file. + If set false to this boolean, prepacked initializer will not be saved into onnxruntime data file, + we keep constant initializer as it is. + @param pre_packed_initializers struct used to store all the prepacked initializers. @returns GraphProto serialization of the graph. */ ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, const std::filesystem::path& model_file_path, size_t initializer_size_threshold, - const OffsetAlignmentInfo& align_info) const; + const OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + PrePackedTensorProtoToSave& pre_packed_initializers) const; ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, const std::filesystem::path& model_file_path, size_t initializer_size_threshold) const { OffsetAlignmentInfo default_options; - return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options); + PrePackedTensorProtoToSave pre_packed_initializers; + return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options, + false, pre_packed_initializers); } /** Gets the ISchemaRegistry instances being used with this Graph. */ @@ -1508,6 +1521,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi private: void InitializeStateFromModelFileGraphProto(); + // Private method used to setup external initializer properly during model save, + // this external initializer could be oroginal initializer or prepacked initializer. + static void SetUpExternalInitializer(const Graph::OffsetAlignmentInfo& align_info, + size_t tensor_bytes_size, + int64_t& external_offset, + std::ofstream& external_stream, + gsl::span raw_data, + ONNX_NAMESPACE::TensorProto& output_proto, + const std::filesystem::path& external_file_path, + const ONNX_NAMESPACE::TensorProto& initializer, + bool is_prepacked); + // Add node with specified . Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto, const ArgNameToTypeMap& name_to_type); diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 6a01602e634f8..086919913cbea 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -246,6 +246,12 @@ static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disab static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName = "session.optimized_model_external_initializers_file_name"; +// Use this config when save prepacked constant initializers to onnx external data file. +// Default is not save prepacked initializers to onnx data file. +// Sample usage: sess_options.add_session_config_entry('session.save_prepacked_constant_initializers', "1") +static const char* const kOrtSessionOptionsSavePrePackedConstantInitializers = + "session.save_prepacked_constant_initializers"; + // Use this config to control the minimum size of the initializer when externalizing it during serialization static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = "session.optimized_model_external_initializers_min_size_in_bytes"; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index ad14fb8258656..b15e865aa423c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -30,6 +30,7 @@ class Attention : public OpKernel, public AttentionCPUBase { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -101,6 +102,7 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, template Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { /* The PrePack() massages the weights to speed up Compute(), there is an option to diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 2c897f183164f..71a66ea368943 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -24,6 +24,7 @@ class QAttention : public OpKernel, public AttentionCPUBase { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, bool& /*out*/ is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -58,6 +59,7 @@ QAttention::QAttention(const OpKernelInfo& info) : OpKernel(info), AttentionC template Status QAttention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { if (1 != input_idx) { diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc index aa47f365c0005..4148aae4b9a35 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc @@ -13,7 +13,7 @@ class DynamicQuantizeLSTM : public OpKernel, public LSTMBase { DynamicQuantizeLSTM(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {} Status PrePack(const Tensor& tensor, int input_idx, - AllocatorPtr alloc, /*out*/ bool& is_packed, + AllocatorPtr alloc, bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, @@ -91,6 +91,7 @@ static void UseSharedPrePackedBuffersImpl(std::vector& prepacke } Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 89e96543c4729..cee3dfc6b3f28 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -98,12 +98,19 @@ class MatMulNBits final : public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; + void ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx); + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) override; + std::optional GetPrePackTensor(int /*input_idx*/) override; + + Status SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) override; + private: const size_t K_; const size_t N_; @@ -119,6 +126,8 @@ class MatMulNBits final : public OpKernel { size_t packed_b_size_{0}; IAllocatorUniquePtr scales_fp32_{}; IAllocatorUniquePtr bias_fp32_{}; + std::optional packed_tensor_{std::nullopt}; + MLDataType prepack_tensor_data_type_; bool has_zp_input_{false}; @@ -148,8 +157,22 @@ class MatMulNBits final : public OpKernel { } }; +template +void MatMulNBits::ConvertPrepackWeightIntoTensor(const onnxruntime::Tensor& tensor, int input_idx) { + if (input_idx == InputIndex::B) { + prepack_tensor_data_type_ = tensor.DataType(); + } + + TensorShapeVector weights_dims = {static_cast((packed_b_size_ - 1) / prepack_tensor_data_type_->Size()) + 1}; + packed_tensor_ = Tensor(prepack_tensor_data_type_, + TensorShape(weights_dims), + packed_b_.get(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); +} + template Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); @@ -185,11 +208,16 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All #endif // MLAS_TARGET_AMD64_IX86 } + if (save_prepacked_initializers) { + ConvertPrepackWeightIntoTensor(tensor, input_idx); + } + return Status::OK(); } template <> Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); @@ -239,6 +267,34 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou #endif // MLAS_TARGET_AMD64_IX86 } + if (save_prepacked_initializers) { + ConvertPrepackWeightIntoTensor(tensor, input_idx); + } + + return Status::OK(); +} + +template +std::optional MatMulNBits::GetPrePackTensor(int input_idx) { + // For this kernel, prepack is performed on input_B, and possibly scales, zeros_points. + // During compute process, scales and zeros_points will keep as it is and only use prepacked + // buffer to replace input_B. + // Inorder to cope with this logic, we need to return latest prepacked buffer and only serialize + // the latest one. So, we need to always return packed_tensor_ here not only for input_B. + ORT_UNUSED_PARAMETER(input_idx); + return std::move(packed_tensor_); +} + +template +Status MatMulNBits::SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) { + if (input_idx == 1) { + // pre_packed_tensor is constant initialized tensor and its lifecycle is managed by session_state, + // session_state will release memory from pre_packed_tensor. packed_b_ will not release memory so + // pass empty/default buffer deleter here. + // const_cast here is temporary, will fix in follow up PR. + packed_b_ = BufferUniquePtr(const_cast(pre_packed_tensor.DataRaw()), BufferDeleter()); + } + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index 67b4950af73bf..c9ee9e2cb760d 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -278,6 +278,7 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { template Status SkipLayerNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index 08e2276c3d9d5..d904c14857437 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -16,7 +16,7 @@ class SkipLayerNorm final : public OpKernel { SkipLayerNorm(const OpKernelInfo& op_kernel_info); Status Compute(OpKernelContext* p_op_kernel_context) const override; - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers, bool& is_packed, PrePackedWeights* prepacked_weights) override; private: diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index dea5391c7629b..d190ed389f3e9 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -95,6 +95,7 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { } Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/, + bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index b408b3c1ee79b..4505c066baedb 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -17,6 +17,7 @@ class GroupNorm final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, bool& is_packed, PrePackedWeights* prepacked_weights) override; private: diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 3e93a527877c5..aa2c8755f6536 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -99,6 +99,7 @@ Status QOrderedAttention::PutIntoMergedBias(const Tensor& tensor, AllocatorPtr a } Status QOrderedAttention::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h index 9d4e563c1feab..529fd00307d66 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.h @@ -20,6 +20,7 @@ class QOrderedAttention final : public CudaKernel, public AttentionBase { public: Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc index a64f628f245e6..351e36b884540 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc @@ -51,6 +51,7 @@ QOrderedMatMul::QOrderedMatMul(const OpKernelInfo& info) : CudaKernel(info) { } Status QOrderedMatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) { is_packed = false; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h index dcb6cc6374be1..d1cef99779e09 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.h @@ -18,6 +18,7 @@ class QOrderedMatMul final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 8d4db36106f28..18405231750ba 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -83,6 +83,11 @@ struct SessionOptions { // enable profiling for this session. bool enable_profiling = false; + // save pre-packed constant external initializers instead of original initializers to onnxruntime data file. + // Only useful for models run on PC with CPU so ORT could load prepacked weights directly from + // ONNX data file with mmap and no need to do prepacking on fly to save a lot of heap memory. + bool save_prepacked_constant_initializers = false; + // Non empty filepath enables serialization of the transformed optimized model to the specified filepath. // // Set session config value for ORT_SESSION_OPTIONS_CONFIG_SAVE_MODEL_FORMAT to 'ORT' or 'ONNX' to explicitly @@ -191,6 +196,7 @@ inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_ << " execution_mode:" << session_options.execution_mode << " execution_order:" << session_options.execution_order << " enable_profiling:" << session_options.enable_profiling + << " save_prepacked_constant_initializers:" << session_options.save_prepacked_constant_initializers << " optimized_model_filepath:" << ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath) << " enable_mem_pattern:" << session_options.enable_mem_pattern << " enable_mem_reuse:" << session_options.enable_mem_reuse diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 0d0b22ff61e01..943db091b341f 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -14,6 +14,7 @@ #include "core/framework/op_kernel.h" #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/session_state_utils.h" +#include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/controlflow/utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -397,12 +398,18 @@ static std::string GenerateKeyForPrepackedWeightsMap(const std::string& op_type, } Status SessionState::PrepackConstantInitializedTensors(InlinedHashMap& constant_initializers_use_count, - const std::unordered_map& initializers_to_share_map) { - auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map]( + const std::unordered_map& initializers_to_share_map, + bool save_prepacked_constant_initializers, + PrePackInitializers& pre_packed_initializers) { + auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map, + save_prepacked_constant_initializers, &pre_packed_initializers]( bool should_cache_prepacked_weights_for_shared_initializers) -> Status { + std::unordered_map pre_packed_kernel_input_map; for (auto& node : GetGraphViewer().Nodes()) { auto kernel = GetMutableKernel(node.Index()); + auto kernel_name = kernel->Info().node().Name(); int input_idx = 0; + bool is_kernel_prepacked = false; for (auto& input_def : node.InputDefs()) { if (input_def->Exists()) { const std::string& input_name = input_def->Name(); @@ -414,16 +421,27 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapGetOrtValueNameIdxMap().GetIdx(input_name, ort_value_idx).IsOK()) { std::unordered_map& constant_initialized_tensors = st->constant_initialized_tensors_; - if (constant_initialized_tensors.count(ort_value_idx)) { + if (constant_initialized_tensors.count(ort_value_idx) && !is_kernel_prepacked) { bool is_packed = false; const Tensor& const_initialized_tensor = constant_initialized_tensors[ort_value_idx].Get(); auto iter = initializers_to_share_map.find(input_name); bool is_shared_initializer = (iter != initializers_to_share_map.end()); + // found pre-packed constant initializers from data file, no need to do pre-packing again + // apply pre-packed tensor to kernel so kernel can use it directly + if (pre_packed_initializers.pre_packed_initializer_names_read_from_file.count(input_name) != 0) { + is_packed = true; + + // kernel like Matmul_nbits will call prepack multiple times with input_B and possibly scales/zero_points. + // If prepacked weights already read from ONNX data file (this happens we ORT reads data file with prepacked + // weights serialized), only need to set prepacked weights once to kernel. + is_kernel_prepacked = true; + ORT_THROW_IF_ERROR(kernel->SetPrePackTensor(input_idx, const_initialized_tensor)); + } // Caching pre-packed weights is limited to shared initializers associated with the CPU EP for now - if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers && - node.GetExecutionProviderType() == kCpuExecutionProvider) { // caching of pre-packed weights' turned ON + else if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers && + node.GetExecutionProviderType() == kCpuExecutionProvider) { // caching of pre-packed weights' turned ON AllocatorPtr allocator_for_caching = prepacked_weights_container_->GetOrCreateAllocator(CPU); ORT_ENFORCE(allocator_for_caching.get() != nullptr); @@ -435,7 +453,7 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapPrePack(const_initialized_tensor, input_idx, allocator_for_caching, - is_packed, + save_prepacked_constant_initializers, is_packed, &weights_to_be_filled_in)); if (is_packed) { @@ -482,18 +500,50 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapInfo().GetDevice(OrtMemType::OrtMemTypeDefault)); ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, session_cpu_alloc, // use allocator tied to this session + save_prepacked_constant_initializers, is_packed, nullptr // no caching required )); } if (is_packed) { + // if intended to save prepacked initializers, get prepacked tensors from kernel and save in hashmap, + // will save to data file later + if (save_prepacked_constant_initializers) { + auto tensor = kernel->GetPrePackTensor(input_idx); + + if (tensor != std::nullopt) { + // save prepacked initializers per initializer and kernel since one initializer could + // be used by multiple kernels + pre_packed_initializers.pre_packed_initializers_to_save[input_name][kernel_name] = std::move(tensor.value()); + + pre_packed_kernel_input_map[kernel_name] = input_name; + } + } + ++number_of_prepacks_counter_; - if (constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) { + // if constant_initialized_tensor is already pre-packed, don't need to remove it + if (pre_packed_initializers.pre_packed_initializer_names_read_from_file.count(input_name) == 0 && + constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) { // release the constant initialized tensor st->initialized_tensors_.erase(ort_value_idx); constant_initialized_tensors.erase(ort_value_idx); } + } else { + // handle prepack for matmul_nbits, it will prepack several times but set is_packed + // to false for scales and zero_points, we keep scales and zero_points as it is only + // update packed_tensor to input_B. + // TODO: this logic works with matmul_nbits kernel but if other kernels also call prepack + // multiple times and use different initializers to store prepacked weights, this piece of logic + // might introduce bug and need a per kernel strategy to update prepacked weights. + if (save_prepacked_constant_initializers && pre_packed_kernel_input_map.count(kernel_name)) { + auto tensor = kernel->GetPrePackTensor(input_idx); + + if (tensor != std::nullopt) { + auto existing_input_name = pre_packed_kernel_input_map[kernel_name]; + pre_packed_initializers.pre_packed_initializers_to_save[existing_input_name][kernel_name] = std::move(tensor.value()); + } + } } } // stop searching in 2 cases: @@ -1176,6 +1226,7 @@ static Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging:: Status SessionState::FinalizeSessionState(const std::basic_string& graph_location, const KernelRegistryManager& kernel_registry_manager, + PrePackInitializers& pre_packed_initializers, bool remove_initializers, bool saving_ort_format) { // recursively create the subgraph session state instances and populate the kernel create info in them. @@ -1189,7 +1240,7 @@ Status SessionState::FinalizeSessionState(const std::basic_string constant_initializers_use_count; ComputeConstantInitializerUseCount(graph_, constant_initializers_use_count); return FinalizeSessionStateImpl(graph_location, kernel_registry_manager, nullptr, sess_options_, - remove_initializers, constant_initializers_use_count); + remove_initializers, constant_initializers_use_count, pre_packed_initializers); } static Status Index(const OrtValueNameIdxMap& ort_value_name_idx_map, @@ -1323,6 +1374,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string& constant_initializers_use_count, + PrePackInitializers& pre_packed_initializers, const InlinedHashMap& outer_scope_node_arg_to_location_map, bool graph_info_already_created) { if (!graph_info_already_created) { @@ -1422,6 +1474,8 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string> + typedef std::unordered_map> PrePackedTensorsToSave; + PrePackedTensorsToSave pre_packed_initializers_to_save; + + // This set is used during model load with prepacked initializer serialized in external data file. + // ORT reads prepacked initializers and store their name into this set so we could skip PrePack + // process later to save heap memory. Prepacked tensor itself is saved in session state's constant_initialized_tensors_. + typedef std::unordered_set PrePackedTensorNamesReadFromFile; + PrePackedTensorNamesReadFromFile pre_packed_initializer_names_read_from_file; + }; + Status FinalizeSessionState(const std::basic_string& graph_loc, const KernelRegistryManager& kernel_registry_manager, + PrePackInitializers& pre_packed_initializers, bool remove_initializers = true, bool saving_ort_format = false); @@ -321,6 +338,15 @@ class SessionState { return parent_; } + Status FinalizeSessionState(const std::basic_string& graph_loc, + const KernelRegistryManager& kernel_registry_manager, + bool remove_initializers = true, + bool saving_ort_format = false) { + PrePackInitializers pre_packed_initializers; + return FinalizeSessionState(graph_loc, kernel_registry_manager, pre_packed_initializers, + remove_initializers, saving_ort_format); + } + // Clear all removable attributes if they exists. // The function logs the list of removable attributes for every node. void PruneRemovableAttributes(); @@ -380,9 +406,13 @@ class SessionState { /** * Prepack the constant initialized tensors for better performance. * The original constant initialized tensors will be removed to save memory. + * For model with prepacked initializer serialized into ONNX data file, + * PrePack will be skipped to save memory. */ Status PrepackConstantInitializedTensors(InlinedHashMap& constant_initializers_use_count, - const std::unordered_map& initializers_to_share_map); + const std::unordered_map& initializers_to_share_map, + bool save_prepacked_constant_initializers, + PrePackInitializers& pre_packed_initializers); SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name); @@ -400,6 +430,7 @@ class SessionState { const SessionOptions& session_options, bool remove_initializers, InlinedHashMap& constant_initializers_use_count, + PrePackInitializers& pre_packed_initializers, const InlinedHashMap& outer_scope_node_arg_to_location_map = {}, bool graph_info_already_created = false); diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 2c74805c57dce..3424f40e79c01 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -21,7 +21,6 @@ #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/ort_value_name_idx_map.h" #include "core/framework/sequential_execution_plan.h" -#include "core/framework/session_state.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/framework/bfc_arena.h" @@ -72,6 +71,7 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env, const std::basic_string& proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor, OrtCallback& ext_data_deleter, + SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set, Tensor* buffered_tensor = nullptr) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); @@ -79,7 +79,7 @@ static inline common::Status ExtDataTensorProtoToTensor(const Env& env, SafeInt ext_data_len = 0; ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, ext_data_buf, ext_data_len, ext_data_deleter, - buffered_tensor)); + &pre_packed_initializers_name_set, buffered_tensor)); // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be // avoided if the Tensor class implements the do-nothing behavior when given a @@ -100,6 +100,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, const ExternalDataLoaderManager& external_data_loader_mgr, + SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set, bool use_device_allocator_for_initializers = false, Tensor* buffered_tensor = nullptr) { if (bool(alloc) == (m != nullptr)) { @@ -139,7 +140,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st // TensorProtoToTensor it would copy the data, causing unnecessary overhead OrtCallback ext_data_deleter; ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, - ext_data_deleter, buffered_tensor)); + ext_data_deleter, pre_packed_initializers_name_set, buffered_tensor)); ExtDataValueDeleter deleter{ext_data_deleter, p_tensor.get()}; MLDataType ml_tensor_type = DataTypeImpl::GetType(); @@ -163,7 +164,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st OrtCallback ext_data_deleter; std::optional scoped_ort_callback_invoker; ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor, - ext_data_deleter, buffered_tensor)); + ext_data_deleter, pre_packed_initializers_name_set, buffered_tensor)); scoped_ort_callback_invoker = ScopedOrtCallbackInvoker(ext_data_deleter); // TODO!! Need a temp buffer allocator for non-escape buffers that maybe too big for stack allocation. @@ -272,7 +273,8 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, - std::unordered_map>& buffered_tensors) { + std::unordered_map>& buffered_tensors, + SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); @@ -401,6 +403,7 @@ common::Status SaveInitializedTensors( Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr, + pre_packed_initializers_name_set, use_device_allocator_for_initializers, p_tensor); if (!st.IsOK()) { std::ostringstream oss; diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index af27f5caba0f4..4de501b6f7429 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -12,6 +12,7 @@ #include "core/framework/tensor.h" #include "core/framework/tensor_allocator.h" #include "core/framework/session_options.h" +#include "core/framework/session_state.h" #include "core/framework/sequential_execution_plan.h" #include "core/platform/path_lib.h" @@ -50,7 +51,8 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, - std::unordered_map>& buffered_tensors); + std::unordered_map>& buffered_tensors, + SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile& pre_packed_initializers_name_set); common::Status AllocateTensor( const onnxruntime::MemBuffer* m, diff --git a/onnxruntime/core/framework/tensor_external_data_info.cc b/onnxruntime/core/framework/tensor_external_data_info.cc index 93146e66d9f24..bcd04effe2bd4 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.cc +++ b/onnxruntime/core/framework/tensor_external_data_info.cc @@ -40,6 +40,8 @@ Status ExternalDataInfo::Create(const RepeatedPtrField& return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "parsing ", stringmap.value(), " failed"); } else if (stringmap.key() == "checksum" && !stringmap.value().empty()) { out->checksum_ = stringmap.value(); + } else if (stringmap.key() == "prepacked" && !stringmap.value().empty()) { + out->prepacked_ = stringmap.value() == "1"; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error!"); } diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h index afc8fda6c3037..c2490f5cc5bc2 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.h +++ b/onnxruntime/core/framework/tensor_external_data_info.h @@ -23,6 +23,8 @@ class ExternalDataInfo { const std::string& GetChecksum() const { return checksum_; } + bool GetPrePacked() const noexcept { return prepacked_; } + // If the value of 'offset' or 'length' field is larger the max value of ssize_t, this function will treat it as a // wrong value and return FAIL. static common::Status Create( @@ -36,5 +38,6 @@ class ExternalDataInfo { // 0 means the whole file size_t length_ = 0; std::string checksum_; + bool prepacked_ = false; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 2af9f95ad059e..0c69ee11f62bc 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -230,11 +230,12 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo namespace utils { -Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const std::filesystem::path& tensor_proto_dir, - std::basic_string& external_file_path, - onnxruntime::FileOffsetType& file_offset, - SafeInt& tensor_byte_size) { +static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size, + bool& pre_packed) { ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), "Tensor does not have external data to read from."); @@ -244,6 +245,8 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, std::unique_ptr external_data_info; ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); + pre_packed = external_data_info->GetPrePacked(); + const auto& location = external_data_info->GetRelPath(); external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) @@ -265,6 +268,11 @@ void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::str tensor_proto.set_raw_data(std::move(param)); } +Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::basic_string& external_file_path, onnxruntime::FileOffsetType& file_offset, SafeInt& tensor_byte_size) { + bool pre_packed = false; + return GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size, pre_packed); +} + void ConvertRawDataInTensorProto(TensorProto* tensor) { size_t element_size = 1; char* bytes = NULL; @@ -988,7 +996,7 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, OrtCallback& ext_data_deleter, - Tensor* buffered_tensor) { + SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile* pre_packed_initializers_name_set, Tensor* buffered_tensor) { ORT_ENFORCE(utils::HasExternalData(tensor_proto)); std::basic_string tensor_proto_dir; if (!model_path.empty()) { @@ -997,8 +1005,13 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo std::basic_string external_data_file_path; FileOffsetType file_offset; SafeInt raw_data_safe_len = 0; + bool pre_packed = false; ORT_RETURN_IF_ERROR( - GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len)); + GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len, pre_packed)); + + if (pre_packed && pre_packed_initializers_name_set != nullptr) { + (*pre_packed_initializers_name_set).insert(tensor_proto.name()); + } if (external_data_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag) { // the value in location is the memory address of the data @@ -1108,7 +1121,7 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa OrtCallback& d = deleter_for_file_data.d; if (utils::HasExternalData(tensor_proto)) { - ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, raw_data, raw_data_len, d)); + ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, raw_data, raw_data_len, d, nullptr)); } else if (utils::HasRawData(tensor_proto)) { raw_data = const_cast(tensor_proto.raw_data().data()); // TODO The line above has const-correctness issues. Below is a possible fix which copies the tensor_proto data diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 262f7adaca1cb..770132f8e95fc 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -17,26 +17,19 @@ #include "core/framework/external_data_loader.h" #include "core/framework/ort_value.h" #include "core/framework/mem_buffer.h" +#include "core/framework/session_state.h" #include "core/framework/tensor_external_data_info.h" #include "core/graph/onnx_protobuf.h" #include "core/platform/env.h" namespace onnxruntime { namespace utils { -/** - * This function is used to get the external data info from the given tensor proto. - * @param tensor_proto given initializer tensor - * @param tensor_proto_dir directory of the tensor proto file - * @param external_file_path output external file path - * @param file_offset output tensor offset - * @param tensor_byte_size output tensor byte size - * @returns Status::OK() if the function is executed successfully - */ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::basic_string& external_file_path, onnxruntime::FileOffsetType& file_offset, SafeInt& tensor_byte_size); + /** * This function is used to convert the endianess of Tensor data. * Mostly, will be used in big endian system to support the model file @@ -172,6 +165,7 @@ common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem:: const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, SafeInt& ext_data_len, OrtCallback& ext_data_deleter, + SessionState::PrePackInitializers::PrePackedTensorNamesReadFromFile* pre_packed_initializers_name_set, Tensor* buffered_tensor = nullptr); // Given a tensor proto with external data obtain a tensor using the specified custom external data loader. diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 9eed0249711f9..5402345447706 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -1064,5 +1064,11 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index return false; } +std::string GetPrepackedInitializerName(const std::string& initializer_name, const std::string& node_name) { + const std::string seperator = ":"; + + return initializer_name + seperator + node_name; +} + } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index afdb5a2cb27f5..db38ef1675595 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -234,6 +234,8 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); +std::string GetPrepackedInitializerName(const std::string& initializer_name, const std::string& node_name); + #ifdef ENABLE_TRAINING common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context); #endif diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index e8a5855b36496..3f50841f50913 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4084,10 +4084,75 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { return result; } +void Graph::SetUpExternalInitializer(const Graph::OffsetAlignmentInfo& align_info, + size_t tensor_bytes_size, + int64_t& external_offset, + std::ofstream& external_stream, + gsl::span raw_data, + ONNX_NAMESPACE::TensorProto& output_proto, + const std::filesystem::path& external_file_path, + const ONNX_NAMESPACE::TensorProto& initializer, + bool is_prepacked) { + // update external_offset for alignment + // need to do padding before write actual tensor data as we do offset alignment at the begin of + // large tensors (offset need to be page aligned and alloction granularity aligned) like below: + // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX + // |<---small tensor---->|<---padding--->|<------------------large tensor----------------------------->| + if (align_info.align_offset && static_cast(tensor_bytes_size) > align_info.align_threshold) { + // Align to the larger of the page size or the allocation granularity + int64_t alignment_factor = std::max(static_cast(4096), align_info.allocation_granularity); + // Align to the next page or alloc granularity boundary + int64_t new_external_offset = static_cast( + std::floor((external_offset + alignment_factor - 1) / alignment_factor)) * + alignment_factor; + + // padding tensor with zeros for alignment + InlinedVector paddings; + size_t padding_size = SafeInt(new_external_offset - external_offset); + paddings.reserve(padding_size); + for (size_t index = 0; index != padding_size; ++index) { + paddings.push_back(0x0); + } + external_stream.write(reinterpret_cast(paddings.data()), padding_size); + + external_offset = new_external_offset; + } + + external_stream.write(reinterpret_cast(raw_data.data()), tensor_bytes_size); + + output_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); + ONNX_NAMESPACE::StringStringEntryProto* location = output_proto.add_external_data(); + location->set_key("location"); + location->set_value(ToUTF8String(external_file_path.native())); + ONNX_NAMESPACE::StringStringEntryProto* offset = output_proto.add_external_data(); + offset->set_key("offset"); + offset->set_value(std::to_string(external_offset)); + ONNX_NAMESPACE::StringStringEntryProto* length = output_proto.add_external_data(); + length->set_key("length"); + length->set_value(std::to_string(tensor_bytes_size)); + + if (is_prepacked) { + ONNX_NAMESPACE::StringStringEntryProto* pre_packed = output_proto.add_external_data(); + pre_packed->set_key("prepacked"); + pre_packed->set_value("1"); + } + + output_proto.set_name(initializer.name()); + output_proto.set_data_type(initializer.data_type()); + for (int i = 0; i != initializer.dims_size(); ++i) { + output_proto.add_dims(initializer.dims(i)); + } + output_proto.set_doc_string(initializer.doc_string()); + + external_offset += tensor_bytes_size; +} + ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path, const std::filesystem::path& model_file_path, size_t initializer_size_threshold, - const OffsetAlignmentInfo& align_info) const { + const OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + PrePackedTensorProtoToSave& pre_packed_initializers) const { GraphProto result; ToGraphProtoInternal(result); ORT_ENFORCE(external_file_path.is_relative()); @@ -4106,6 +4171,34 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std #endif for (const auto& initializer : graph_proto_->initializer()) { + bool use_pre_packed_initializer = false; + InlinedVector pre_packed_initializers_tensor_proto; + // If this initializer has been prepacked, saved prepacked external initializer instead of original one. + // Since one initializer could be used by multiple kernels and been prepacked differently, + // Save each prepacked initializers seperately, chagne the initializer name to [initializer_name]:[node_name] + // to avoid conflict. Change the node input name accordingly. + // IT could potentially make the ONNX data file larger since we store multiple prepacked initializers into disk + // but this could be rare case. + if (save_prepacked_constant_initializers && pre_packed_initializers.count(initializer.name())) { + for (const auto& item : pre_packed_initializers[initializer.name()]) { + auto& node_name = item.first; + std::string prepacked_initializer_name = utils::GetPrepackedInitializerName(initializer.name(), node_name); + pre_packed_initializers_tensor_proto.push_back(item.second); + use_pre_packed_initializer = true; + + for (auto& node : *result.mutable_node()) { + if (node.name() == node_name) { + int input_index = 0; + for (const auto& input : node.input()) { + if (input == initializer.name()) { + node.set_input(input_index, prepacked_initializer_name); + } + input_index += 1; + } + } + } + } + } #if !defined(DISABLE_SPARSE_TENSORS) if (sparse_end != sparse_tensor_names_.find(initializer.name())) { // Sparse tensors are added to the ONNX file. @@ -4114,61 +4207,39 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); } else { #endif - // Dense tensors larger than the threshold are added to the external file. - TensorProto* output_proto = result.add_initializer(); - - std::vector raw_data; - ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, model_path, raw_data)); - size_t tensor_bytes_size = raw_data.size(); - if (tensor_bytes_size < initializer_size_threshold) { - *output_proto = initializer; - continue; - } + if (use_pre_packed_initializer) { + for (const auto& pre_packed_initializer : pre_packed_initializers_tensor_proto) { + // Dense tensors larger than the threshold are added to the external file. + TensorProto* output_proto = result.add_initializer(); + std::vector raw_data; + size_t tensor_bytes_size = 0; + + ORT_THROW_IF_ERROR(utils::UnpackInitializerData(pre_packed_initializer, model_path, raw_data)); + tensor_bytes_size = raw_data.size(); + if (tensor_bytes_size < initializer_size_threshold) { + *output_proto = pre_packed_initializer; + continue; + } - // update external_offset for alignment - // need to do padding before write actual tensor data as we do offset alignment at the begin of - // large tensors (offset need to be page aligned and alloction granularity aligned) like below: - // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX - // |<---small tensor---->|<---padding--->|<------------------large tensor----------------------------->| - if (align_info.align_offset && static_cast(tensor_bytes_size) > align_info.align_threshold) { - // Align to the larger of the page size or the allocation granularity - int64_t alignment_factor = std::max(static_cast(4096), align_info.allocation_granularity); - // Align to the next page or alloc granularity boundary - int64_t new_external_offset = static_cast( - std::floor((external_offset + alignment_factor - 1) / alignment_factor)) * - alignment_factor; - - // padding tensor with zeros for alignment - for (int64_t index = external_offset; index != new_external_offset; ++index) { - external_stream << '0'; + SetUpExternalInitializer(align_info, tensor_bytes_size, external_offset, external_stream, + raw_data, *output_proto, external_file_path, pre_packed_initializer, true); + } + } else { + // Dense tensors larger than the threshold are added to the external file. + TensorProto* output_proto = result.add_initializer(); + std::vector raw_data; + size_t tensor_bytes_size = 0; + + ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, model_path, raw_data)); + tensor_bytes_size = raw_data.size(); + if (tensor_bytes_size < initializer_size_threshold) { + *output_proto = initializer; + continue; } - external_offset = new_external_offset; - } - - for (size_t index = 0; index != tensor_bytes_size; ++index) { - external_stream << raw_data[index]; - } - - output_proto->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); - ONNX_NAMESPACE::StringStringEntryProto* location = output_proto->add_external_data(); - location->set_key("location"); - location->set_value(ToUTF8String(external_file_path.native())); - ONNX_NAMESPACE::StringStringEntryProto* offset = output_proto->add_external_data(); - offset->set_key("offset"); - offset->set_value(std::to_string(external_offset)); - ONNX_NAMESPACE::StringStringEntryProto* length = output_proto->add_external_data(); - length->set_key("length"); - length->set_value(std::to_string(tensor_bytes_size)); - - output_proto->set_name(initializer.name()); - output_proto->set_data_type(initializer.data_type()); - for (int i = 0; i != initializer.dims_size(); ++i) { - output_proto->add_dims(initializer.dims(i)); + SetUpExternalInitializer(align_info, tensor_bytes_size, external_offset, external_stream, + raw_data, *output_proto, external_file_path, initializer, false); } - output_proto->set_doc_string(initializer.doc_string()); - - external_offset += tensor_bytes_size; #if !defined(DISABLE_SPARSE_TENSORS) } #endif diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 1bae63b510563..ad1ec9c8dedb3 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -384,13 +384,17 @@ ModelProto Model::ToProto() const { ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info) const { + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers) const { ModelProto result(model_proto_); const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, - align_info); + align_info, + save_prepacked_constant_initializers, + pre_packed_initializers); return result; } @@ -608,7 +612,9 @@ static Status SaveModelWithExternalInitializers(Model& model, const T& file_path, const std::filesystem::path& external_file_name, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info) { + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers) { int fd = 0; Status status = Env::Default().FileOpenWr(file_path, fd); ORT_RETURN_IF_ERROR(status); @@ -616,7 +622,8 @@ static Status SaveModelWithExternalInitializers(Model& model, ORT_TRY { status = Model::SaveWithExternalInitializers(model, fd, file_path, external_file_name, initializer_size_threshold, - align_info); + align_info, save_prepacked_constant_initializers, + pre_packed_initializers); } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { @@ -647,9 +654,12 @@ Status Model::Load(const PathString& file_path, std::shared_ptr& p_model, Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, const std::filesystem::path& external_file_name, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info) { + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers) { return SaveModelWithExternalInitializers(model, file_path, external_file_name, initializer_size_threshold, - align_info); + align_info, save_prepacked_constant_initializers, + pre_packed_initializers); } Status Model::LoadFromBytes(int count, const void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { @@ -766,7 +776,9 @@ Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, const std::filesystem::path& external_file_name, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info) { + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers) { if (fd < 0) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, " is less than 0."); } @@ -775,7 +787,8 @@ Status Model::SaveWithExternalInitializers(Model& model, auto model_proto = model.ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, - align_info); + align_info, save_prepacked_constant_initializers, + pre_packed_initializers); google::protobuf::io::FileOutputStream output(fd); const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); if (result) { diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 9bcec6f78ca08..38d9044ff9d31 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -191,13 +191,17 @@ class Model { ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info) const; + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers) const; ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_name, const std::filesystem::path& file_path, size_t initializer_size_threshold) const { Graph::OffsetAlignmentInfo default_align_info; - return ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, default_align_info); + Graph::PrePackedTensorProtoToSave pre_packed_initializers; + return ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold, default_align_info, + false, pre_packed_initializers); } static common::Status Save(Model& model, const PathString& file_path); @@ -210,14 +214,18 @@ class Model { const std::filesystem::path& file_path, const std::filesystem::path& external_file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info); + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers); static common::Status SaveWithExternalInitializers(Model& model, const std::filesystem::path& file_path, const std::filesystem::path& external_file_path, size_t initializer_size_threshold) { Graph::OffsetAlignmentInfo default_align_info; - return SaveWithExternalInitializers(model, file_path, external_file_path, initializer_size_threshold, default_align_info); + Graph::PrePackedTensorProtoToSave pre_packed_initializers; + return SaveWithExternalInitializers(model, file_path, external_file_path, initializer_size_threshold, default_align_info, + false, pre_packed_initializers); } static common::Status SaveWithExternalInitializers(Model& model, @@ -225,7 +233,9 @@ class Model { const std::filesystem::path& file_path, const std::filesystem::path& external_file_path, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info); + const Graph::OffsetAlignmentInfo& align_info, + bool save_prepacked_constant_initializers, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers); static common::Status SaveWithExternalInitializers(Model& model, int fd, @@ -233,7 +243,9 @@ class Model { const std::filesystem::path& external_file_path, size_t initializer_size_threshold) { Graph::OffsetAlignmentInfo default_align_info; - return SaveWithExternalInitializers(model, fd, file_path, external_file_path, initializer_size_threshold, default_align_info); + Graph::PrePackedTensorProtoToSave pre_packed_initializers; + return SaveWithExternalInitializers(model, fd, file_path, external_file_path, initializer_size_threshold, default_align_info, + false, pre_packed_initializers); } static common::Status Load(std::istream& model_istream, ONNX_NAMESPACE::ModelProto* p_model_proto); diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index 37db095e92570..0a1a3a5995872 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -51,6 +51,7 @@ class FusedConvFp16 final : public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, @@ -101,6 +102,7 @@ class FusedConvFp16 final : public OpKernel { }; Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 5406dd1a40446..dbc7becdf2397 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -248,6 +248,7 @@ template void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE template Status Gemm::PrePack(const Tensor& /* tensor */, int /* input_idx */, AllocatorPtr /*alloc_for_caching*/, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weight_for_caching*/) { is_packed = false; @@ -256,7 +257,7 @@ Status Gemm::PrePack(const Tensor& /* tensor */, int /* input_idx */, Allocat template <> Status Gemm::PrePack(const Tensor& tensor, int input_idx, - AllocatorPtr alloc, /*out*/ bool& is_packed, + AllocatorPtr alloc, bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index 953949732560d..92f05a7921f8b 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -21,6 +21,7 @@ class Gemm : protected GemmBase, public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 2c6d23e4de908..8f2c2c53b188b 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -173,6 +173,7 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc, #endif Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index b9bbe36583879..0bb0e6c2ef596 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -37,6 +37,7 @@ class MatMul final : public OpKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc index f0c1b0b409831..2c7afddf38070 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc @@ -38,6 +38,7 @@ ONNX_CPU_OPERATOR_KERNEL( template Status ConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/ ) { @@ -47,6 +48,7 @@ Status ConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, Al template <> Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.h b/onnxruntime/core/providers/cpu/nn/conv_transpose.h index c82cd5ad49d7e..d03b5566e334f 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.h @@ -28,6 +28,7 @@ class ConvTranspose : public OpKernel { ConvTranspose(const OpKernelInfo& info) : OpKernel(info), conv_transpose_attrs_(info) {} Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index 24a5dcab225c4..fe2bf1035bb65 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -229,6 +229,7 @@ Status LayerNormImpl::Compute(OpKernelContext* p_ctx) const { } Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h index f8b528b398cba..abce87d03c14b 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h @@ -15,7 +15,7 @@ class LayerNormImpl : public OpKernel { LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified = false, bool contrib_op = false); Status Compute(OpKernelContext* p_op_kernel_context) const override; - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers, bool& is_packed, PrePackedWeights* prepacked_weights) override; // This method was created so that it can be called directly from `test/onnx/microbenchmark/layer_normalization.cc`. diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h index e26eae19b8fd4..8a8ce27990069 100644 --- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h @@ -14,6 +14,7 @@ class MatMulIntegerBase : public OpKernel { MatMulIntegerBase(const OpKernelInfo& info) : OpKernel(info) {} Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index 7797cbe678bd4..736cde24591ff 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -25,6 +25,7 @@ class QLinearConv : public OpKernel { Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -360,6 +361,7 @@ REGISTER_QLINEARCONV_INT8_KERNEL(kMSDomain, 1); template Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index b78c5236e6fab..7afd00eacef89 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -284,6 +284,7 @@ bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& } Status DeepCpuGruOp::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h index 5a6dd97c7c3f2..914077b2f2c15 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h @@ -62,6 +62,7 @@ class DeepCpuGruOp final : public OpKernel { private: Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; @@ -197,4 +198,4 @@ class UniDirectionalGru { }; } // namespace detail -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index 09bbf6c4c79e6..e4082e5d7634a 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -225,7 +225,9 @@ static void UseSharedPrePackedBuffersImpl(std::vector& prepacke } Status DeepCpuLstmOp::PrePack(const Tensor& tensor, int input_idx, - AllocatorPtr alloc, /*out*/ bool& is_packed, + AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h index 9c4c12954022a..ff8ab9abf0eed 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h @@ -19,6 +19,7 @@ class DeepCpuLstmOp final : public OpKernel, public LSTMBase { DeepCpuLstmOp(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {} Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 3129f519da2e5..45a1d3bbc0414 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -52,6 +52,7 @@ REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) // First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index e4047a6af272e..6294566af3cb9 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -219,6 +219,7 @@ class Conv : public CudaKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, bool& is_packed, PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 2972ae999adc4..9c9a83460daeb 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -45,7 +45,8 @@ REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) // First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template -Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, +Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) { is_packed = false; // only layout of weight input is adjusted via PrePack diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 1a6957164d22f..f23c2b94501f2 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -22,6 +22,7 @@ class ConvTranspose : public CudaKernel { ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info) {}; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index b04df44954295..276b600cf40d2 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -78,6 +78,7 @@ class ConvBase : public JsKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) override { is_packed = false; diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 5ff52e8fda4fa..baa93f825a203 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -126,8 +126,10 @@ class ConvTranspose : public JsKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) override { + ORT_UNUSED_PARAMETER(save_prepacked_initializers); is_packed = false; if (input_idx == 1) { diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.cc b/onnxruntime/core/providers/xnnpack/math/gemm.cc index 35a06cb7eb89f..68b55030c7363 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.cc +++ b/onnxruntime/core/providers/xnnpack/math/gemm.cc @@ -117,6 +117,7 @@ Gemm::Gemm(const OpKernelInfo& info) : GemmBase(info), XnnpackKernel(info, /*ena } Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights*) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.h b/onnxruntime/core/providers/xnnpack/math/gemm.h index 954aab0698b9c..d632eef015f9a 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.h +++ b/onnxruntime/core/providers/xnnpack/math/gemm.h @@ -23,6 +23,7 @@ class Gemm : protected GemmBase, public XnnpackKernel { static bool IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph); Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.cc b/onnxruntime/core/providers/xnnpack/math/matmul.cc index 44a6fb4ee835a..71a11cb05d9af 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.cc +++ b/onnxruntime/core/providers/xnnpack/math/matmul.cc @@ -78,6 +78,7 @@ MatMul::MatMul(const OpKernelInfo& info) : XnnpackKernel(info, /*enable_caches*/ } Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*Not used*/) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.h b/onnxruntime/core/providers/xnnpack/math/matmul.h index 188cc73189af5..31a8c36ad418b 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.h +++ b/onnxruntime/core/providers/xnnpack/math/matmul.h @@ -23,6 +23,7 @@ class MatMul : public XnnpackKernel { // Required for checking XNNpack restrictions on ORT side static bool IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph); Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index 4e6b308e28ae5..f2e697df475da 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -18,6 +18,7 @@ namespace xnnpack { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.h b/onnxruntime/core/providers/xnnpack/nn/conv.h index 3630aae208d49..762b68c8bd49a 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.h +++ b/onnxruntime/core/providers/xnnpack/nn/conv.h @@ -19,6 +19,7 @@ class Conv : public ConvBase { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; }; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc index b6930a5fc92d1..5729565b2feb9 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc @@ -15,6 +15,7 @@ namespace xnnpack { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool /*save_prepacked_initializers*/, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h index 866b9b6b98365..0313515d10fa1 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.h @@ -18,6 +18,7 @@ class ConvTranspose : public ConvBase { // use PrePack to handle the weight layout change as that's not a simple NCHW -> NHWC transpose Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f5f12c206ebad..e6aafaa1f2283 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2027,9 +2027,11 @@ common::Status InferenceSession::Initialize() { #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } + SessionState::PrePackInitializers pre_packed_initializers; ORT_RETURN_IF_ERROR_SESSIONID_( session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_, // need to keep the initializers if saving the optimized model + pre_packed_initializers, !saving_model, saving_ort_format)); @@ -2065,11 +2067,47 @@ common::Status InferenceSession::Initialize() { kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, "1024")); Graph::OffsetAlignmentInfo align_info; align_info.align_offset = true; + bool save_prepacked_constant_initializers = + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsSavePrePackedConstantInitializers, "0") == "1" ? true : false; + Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; + if (save_prepacked_constant_initializers) { + LOGS(*session_logger_, WARNING) << "Serialize prepacked initializers option has been turn on." + << "Use this option only when run model inference on PC with CPU." + << "Make sure to save and load model in same device as prepack is device specific." + << "Note: this feature in only work with ONNX model format." + << "Process of use this option is like below:" + << "1. Optimize model with external data file with save_prepacked_constant_initializers on:" + << " sample: sess_options.add_session_config_entry('session.save_prepacked_constant_initializers', ' 1 ')" + << " With save_prepacked_constant_initializers option, prepacked initializer will be serialized into data file." + << "2. Load optimized model and external data file in same device, no prepack is need." + << "3. Run inference with optimized model."; + + if (fbs::utils::IsOrtFormatModel(session_options_.optimized_model_filepath)) { + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unable to serialize prepacked external constant initializer for ORT format model." + "Please use ONNX format model with save_prepacked_constant_initializers.")); + } + + // convert pre_packed_initializers to tensorproto format and save to external data file + for (const auto& name_item_pair : pre_packed_initializers.pre_packed_initializers_to_save) { + auto initializer_name = name_item_pair.first; + + for (const auto& kernel_name_initializer_item_pair : name_item_pair.second) { + auto kernel_name = kernel_name_initializer_item_pair.first; + auto prepacked_initializer_name = utils::GetPrepackedInitializerName(initializer_name, kernel_name); + + pre_packed_initializers_tensor_proto[initializer_name][kernel_name] = utils::TensorToTensorProto(kernel_name_initializer_item_pair.second, prepacked_initializer_name); + } + } + } ORT_RETURN_IF_ERROR_SESSIONID_(Model::SaveWithExternalInitializers(*model_, session_options_.optimized_model_filepath, optimized_model_external_initializers_file_name, optimized_model_external_initializers_min_size_in_bytes, - align_info)); + align_info, + save_prepacked_constant_initializers, + pre_packed_initializers_tensor_proto)); } } } diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 61a8f7e23fe87..da5fa2c3a5a24 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -45,6 +45,7 @@ #include "core/session/environment.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" +#include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" #include "dummy_provider.h" @@ -64,6 +65,8 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; using namespace onnxruntime::concurrency; +extern std::unique_ptr ort_env; + namespace { struct KernelRegistryAndStatus { std::shared_ptr kernel_registry = std::make_shared(); @@ -496,6 +499,57 @@ TEST(InferenceSessionTests, TestModelSerialization) { ASSERT_TRUE(session_object_emptyValidation.Initialize().IsOK()); } +// Test feature serialize prepack weight is only used in PC with CPU on inference, +// disable this test for training, other device and eps +#if !ENABLE_TRAINING && !defined(USE_CUDA) && !defined(__wasm__) && !defined(USE_DNNL) && !defined(USE_QNN) && !defined(__ANDROID__) && !defined(USE_COREML) +// MLAS dispatcher used in matmul_nbits kernels here is 64 bit only +#if defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64) +TEST(InferenceSessionTests, TestPrePackSerialization) { + SessionOptions so; + std::string model_name = "model_with_matmul_nbits"; + + const std::string test_model = "testdata/prepack/" + model_name + ".onnx"; + const std::string optimized_model = "testdata/prepack/" + model_name + "_opt.onnx"; + + so.session_logid = "InferenceSessionTests.TestPrepackSerialization"; + so.enable_cpu_mem_arena = false; + so.graph_optimization_level = TransformerLevel::Default; + so.optimized_model_filepath = optimized_model; + std::string external_initializer_file_name = model_name + "_opt.onnx.data"; + + // enable serialize prepack initializer to data file + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsSavePrePackedConstantInitializers, + "1")); + // always save external initializer to data file for test + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, + "0")); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersFileName, + external_initializer_file_name.c_str())); + + // optimize model with serialize prepack constant initializers + InferenceSessionWrapper session_object{so, GetEnvironment()}; + ASSERT_TRUE(session_object.Load(test_model).IsOK()); + ASSERT_TRUE(session_object.Initialize().IsOK()); + + // Verify prepack initializers are serialized into optimized model and data file + // load optimized model and check initializer are prepacked + auto logger = DefaultLoggingManager().CreateLogger("TestPrepackSerialization"); + std::shared_ptr model; + auto load_status = Model::Load(ToWideString(optimized_model), model, nullptr, *logger); + ASSERT_EQ(Status::OK(), load_status); + Graph& graph = model->MainGraph(); + + bool found_prepack_initializer = false; + for (const auto& item : graph.GetAllInitializedTensors()) { + if (item.first.find(':') != std::string::npos) { + found_prepack_initializer = true; + } + } + ASSERT_TRUE(found_prepack_initializer); +} +#endif +#endif + #ifdef ORT_RUN_EXTERNAL_ONNX_TESTS static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) { if (f_arg.size() != s_arg.size()) { diff --git a/onnxruntime/test/framework/save_model_with_external_initializers.cc b/onnxruntime/test/framework/save_model_with_external_initializers.cc index d0bc088175755..0f76cb61ace74 100644 --- a/onnxruntime/test/framework/save_model_with_external_initializers.cc +++ b/onnxruntime/test/framework/save_model_with_external_initializers.cc @@ -7,6 +7,7 @@ #include "core/framework/data_types.h" #include "core/graph/model.h" #include "core/framework/tensorprotoutils.h" +#include "core/framework/session_state.h" #include "test/test_environment.h" #include "test_utils.h" #include "test/util/include/asserts.h" @@ -19,19 +20,34 @@ using namespace onnxruntime; namespace onnxruntime { namespace test { +std::vector split(const std::string& str, char delimiter) { + std::vector result; + std::stringstream ss(str); + std::string token; + + // Use getline with a delimiter to split the string + while (std::getline(ss, token, delimiter)) { + result.push_back(token); + } + + return result; +} + Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, const std::filesystem::path& input_external_init_file, const std::filesystem::path& output_onnx, const std::filesystem::path& output_external_init_file, size_t initializer_size_threshold, - const Graph::OffsetAlignmentInfo& align_info) { + const Graph::OffsetAlignmentInfo& align_info, + Graph::PrePackedTensorProtoToSave& pre_packed_initializers_tensor_proto, + bool save_prepacked_constant_initializers = false) { auto logger = DefaultLoggingManager().CreateLogger("LoadSaveAndCompareModel"); std::shared_ptr model; ORT_RETURN_IF_ERROR(Model::Load(input_onnx, model, nullptr, *logger)); std::filesystem::remove(output_onnx); std::filesystem::remove(output_external_init_file); ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(*model, output_onnx, output_external_init_file, initializer_size_threshold, - align_info)); + align_info, save_prepacked_constant_initializers, pre_packed_initializers_tensor_proto)); std::shared_ptr model_from_external; ORT_RETURN_IF_ERROR(Model::Load(output_onnx.native(), model_from_external, nullptr, *logger)); @@ -50,10 +66,11 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, // Compare the initializers of the two versions. std::filesystem::path model_path{}; std::filesystem::path external_data_path{}; - for (const auto& i : initializers) { + for (const auto& i : initializers_from_external) { const std::string kInitName = i.first; - const ONNX_NAMESPACE::TensorProto* tensor_proto = i.second; - const ONNX_NAMESPACE::TensorProto* from_external_tensor_proto = initializers_from_external[kInitName]; + const ONNX_NAMESPACE::TensorProto* from_external_tensor_proto = i.second; + // prepack initializer will have name as [original name]:[kernel name] in case initializer used by multiple kernels + const ONNX_NAMESPACE::TensorProto* tensor_proto = save_prepacked_constant_initializers ? initializers[split(kInitName, ':')[0]] : initializers[kInitName]; std::vector tensor_proto_data; model_path = input_onnx; @@ -75,8 +92,12 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, ORT_RETURN_IF_NOT(from_external_tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL, "location mismatch"); } - ORT_RETURN_IF_NOT(tensor_proto_size == from_external_tensor_proto_size, "size mismatch"); - ORT_RETURN_IF_NOT(memcmp(tensor_proto_data.data(), from_external_tensor_proto_data.data(), tensor_proto_size) == 0, "data mismatch"); + if (!save_prepacked_constant_initializers) { + ORT_RETURN_IF_NOT(tensor_proto_size == from_external_tensor_proto_size, "size mismatch"); + ORT_RETURN_IF_NOT(memcmp(tensor_proto_data.data(), from_external_tensor_proto_data.data(), tensor_proto_size) == 0, "data mismatch"); + } else { + ORT_RETURN_IF_NOT(from_external_tensor_proto_size >= tensor_proto_size, "prepack initializer's size is at least same as original tensor, might be larger"); + } if (align_info.align_offset) { for (const StringStringEntryProto& entry : from_external_tensor_proto->external_data()) { @@ -89,6 +110,7 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, } } } + // Cleanup. ORT_RETURN_IF_NOT(std::filesystem::remove(output_onnx), "delete file failed"); ORT_RETURN_IF_NOT(std::filesystem::remove(external_data_path), "delete file failed"); @@ -98,13 +120,15 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, // Original model does not have external initializers TEST(SaveWithExternalInitializers, Mnist) { Graph::OffsetAlignmentInfo align_info; - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/mnist.onnx"), ORT_TSTR(""), ORT_TSTR("testdata/mnist_with_external_initializers.onnx"), ORT_TSTR("mnist_external_initializers.bin"), 100, align_info)); + Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/mnist.onnx"), ORT_TSTR(""), ORT_TSTR("testdata/mnist_with_external_initializers.onnx"), ORT_TSTR("mnist_external_initializers.bin"), 100, align_info, pre_packed_initializers_tensor_proto)); } // Original model has external initializers TEST(SaveWithExternalInitializers, ModelWithOriginalExternalData) { Graph::OffsetAlignmentInfo align_info; - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info)); + Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info, pre_packed_initializers_tensor_proto)); } // Original model has external initializers, align offset @@ -112,7 +136,22 @@ TEST(SaveWithExternalInitializers, ModelWithOriginalExternalDataAlignOffset) { Graph::OffsetAlignmentInfo align_info; align_info.align_offset = true; align_info.align_threshold = 0; - ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info)); + Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/model_with_orig_ext_data.onnx"), ORT_TSTR("model_with_orig_ext_data.onnx.data"), ORT_TSTR("testdata/model_with_new_external_initializers.onnx"), ORT_TSTR("model_with_new_external_initializers.bin"), 0, align_info, pre_packed_initializers_tensor_proto)); +} + +// Original model has external initializers, align offset and serialize prepacked external initializer to model file +TEST(SaveWithExternalInitializers, ModelWithOriginalExternalDataAlignOffsetAndSavePrepackTensors) { + Graph::OffsetAlignmentInfo align_info; + align_info.align_offset = true; + align_info.align_threshold = 0; + std::shared_ptr alloc = std::make_shared(); + TensorShape shape = {178}; + // prepack both initializers for test purpose + Graph::PrePackedTensorProtoToSave pre_packed_initializers_tensor_proto; + pre_packed_initializers_tensor_proto["MatMul.Weight"]["MatMul_0"] = utils::TensorToTensorProto(Tensor(DataTypeImpl::GetType(), shape, alloc), "MatMul.Weight:MatMul_0"); + pre_packed_initializers_tensor_proto["scales"]["MatMul_0"] = utils::TensorToTensorProto(Tensor(DataTypeImpl::GetType(), shape, alloc), "scales:MatMul_0"); + ASSERT_STATUS_OK(LoadSaveAndCompareModel(ORT_TSTR("testdata/prepack/model_with_matmul_nbits.onnx"), ORT_TSTR("model_with_matmul_nbits.onnx.data"), ORT_TSTR("testdata/prepack/model_with_matmul_nbits_opt.onnx"), ORT_TSTR("model_with_matmul_nbits_opt.onnx.data"), 0, align_info, pre_packed_initializers_tensor_proto, true)); } } // namespace test diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index b94d24a1b180b..6265eccb7bd9b 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -372,10 +372,11 @@ class PrePackingTestOpKernel : public OpKernel { return Status::OK(); } - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool save_prepacked_initializers, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override { ORT_UNUSED_PARAMETER(tensor); ORT_UNUSED_PARAMETER(input_idx); + ORT_UNUSED_PARAMETER(save_prepacked_initializers); size_t weight_packed_len = 8; weight_packed_ = IAllocator::MakeUniquePtr(alloc, weight_packed_len, true); @@ -393,9 +394,20 @@ class PrePackingTestOpKernel : public OpKernel { return Status::OK(); } + std::optional GetPrePackTensor(int input_idx) override { + ORT_UNUSED_PARAMETER(input_idx); + ++get_prepack_tensors_count; + + TensorShape shape = {2}; + packed_tensor = Tensor(DataTypeImpl::GetType(), shape, std::make_shared()); + return std::move(packed_tensor); + } + int prepack_calls_count = 0; int store_pre_packed_weight_calls_count = 0; + int get_prepack_tensors_count = 0; IAllocatorUniquePtr weight_packed_; + Tensor packed_tensor; }; static void CreateSimpleGraph(Graph& graph) { @@ -530,6 +542,7 @@ static void PlaceAllNodesToCPUEP(Graph& graph) { struct PrepackingTestParam { bool test_subgraph; bool test_prepacking; + bool test_save_prepack_initializer; }; class SessionStatePrepackingTest : public testing::TestWithParam {}; @@ -572,6 +585,8 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { sess_options.enable_mem_reuse = true; sess_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = test_param.test_prepacking ? "0" : "1"; + sess_options.config_options.configurations[kOrtSessionOptionsSavePrePackedConstantInitializers] = + test_param.test_save_prepack_initializer ? "1" : "0"; SessionState session_state(model.MainGraph(), execution_providers, @@ -597,12 +612,47 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { kernel_registry_manager.RegisterKernelRegistry(kernel_registry); PlaceAllNodesToCPUEP(model.MainGraph()); + SessionState::PrePackInitializers pre_packed_initializers; ASSERT_STATUS_OK(session_state.FinalizeSessionState(std::basic_string(), - kernel_registry_manager)); + kernel_registry_manager, + pre_packed_initializers)); const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors(); // check prepacking ASSERT_EQ(const_initialized_tensors.size(), size_t(test_param.test_prepacking ? 0 : 1)); + + // check get prepack tensor method called when set save_prepacked_constant_initializers + if (!test_param.test_subgraph) { + const auto* kernel = reinterpret_cast(session_state.GetKernel(0)); + ASSERT_EQ(kernel->get_prepack_tensors_count, (test_param.test_prepacking && test_param.test_save_prepack_initializer) ? 1 : 0); + } else { + auto if_index = 1; + if (session_state.GetKernel(0)->Node().OpType() == "If") { + if_index = 0; + } + + const auto& subgraph_session_states = session_state.GetSubgraphSessionStateMap(); + const auto& if_node_session_states = subgraph_session_states.at(if_index); + const auto& session_state_1_then_branch_session_state = *if_node_session_states.at("then_branch"); + const auto& session_state_1_else_branch_session_state = *if_node_session_states.at("else_branch"); + + const auto* kernel_if_0 = reinterpret_cast(session_state_1_then_branch_session_state.GetKernel(0)); + const auto* kernel_if_1 = reinterpret_cast(session_state_1_else_branch_session_state.GetKernel(0)); + ASSERT_EQ(kernel_if_0->get_prepack_tensors_count, (test_param.test_prepacking && test_param.test_save_prepack_initializer) ? 1 : 0); + ASSERT_EQ(kernel_if_1->get_prepack_tensors_count, (test_param.test_prepacking && test_param.test_save_prepack_initializer) ? 1 : 0); + } + + // check pre_packed_initializers_to_save will be set properly when set save_prepacked_constant_initializers + if (!test_param.test_subgraph && test_param.test_prepacking && test_param.test_save_prepack_initializer) { + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.size(), size_t(1)); + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.count("node_0_input_1"), size_t(1)); + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save["node_0_input_1"].count("node_0"), size_t(1)); + } else if (test_param.test_subgraph && test_param.test_prepacking && test_param.test_save_prepack_initializer) { + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.size(), size_t(1)); + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save.count("if_shared"), size_t(1)); + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save["if_shared"].count("if_node_1"), size_t(1)); + ASSERT_EQ(pre_packed_initializers.pre_packed_initializers_to_save["if_shared"].count("if_node_0"), size_t(1)); + } } class SessionStateTestSharedInitalizersWithPrePacking : public ::testing::Test { @@ -1000,10 +1050,14 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) { INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStatePrepackingTest, - testing::Values(PrepackingTestParam{false, false}, - PrepackingTestParam{false, true}, - PrepackingTestParam{true, false}, - PrepackingTestParam{true, true})); + testing::Values(PrepackingTestParam{false, false, false}, + PrepackingTestParam{false, true, false}, + PrepackingTestParam{true, false, false}, + PrepackingTestParam{true, true, false}, + PrepackingTestParam{false, false, true}, + PrepackingTestParam{false, true, true}, + PrepackingTestParam{true, false, true}, + PrepackingTestParam{true, true, true})); #endif } // namespace test diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 0be1c0b1965ac..e19362e0ec32d 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -4600,3 +4600,86 @@ TEST(CApiTest, OrtCustomOp_GetInPlace) { ASSERT_EQ(len, static_cast(2)); mock_gqa.ReleaseAliasMap(input_index, output_index); } + +TEST(CApiTest, Serialize_PrePack_Initializers) { + std::string model_name = "model_with_matmul_nbits"; + + const std::string test_model = "testdata/prepack/" + model_name + ".onnx"; + const std::string optimized_model = "testdata/prepack/" + model_name + "_opt.onnx"; + std::string external_initializer_file_name = model_name + "_opt.onnx.data"; + + // Generate optimized with prepacked weights serialized + Ort::SessionOptions session_options_opt; + session_options_opt.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersFileName, external_initializer_file_name.c_str()); + session_options_opt.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, "0"); + session_options_opt.AddConfigEntry(kOrtSessionOptionsSavePrePackedConstantInitializers, "1"); + +#if defined(_WIN32) || defined(_WIN64) + std::wstring test_model_wide = onnxruntime::ToWideString(test_model); + session_options_opt.SetOptimizedModelFilePath(onnxruntime::ToWideString(optimized_model).c_str()); + Ort::Session session_opt_model(*ort_env, test_model_wide.c_str(), session_options_opt); +#else + session_options_opt.SetOptimizedModelFilePath(optimized_model.c_str()); + Ort::Session session_opt_model(*ort_env, test_model.c_str(), session_options_opt); +#endif + + // Do inference with original model and optimized model and check output is identical + // set inputs and session options + Ort::SessionOptions session_options; + const char* input_names[] = {"A"}; + const char* const output_names[] = {"Y"}; + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + std::vector ort_inputs; + std::vector input_0_data = {1.3f}; + std::vector input_0_dims = {1, 1}; + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_0_data.data()), + input_0_data.size(), input_0_dims.data(), input_0_dims.size())); + + // run inference with original model + // Convert std::string to std::wstring +#if defined(_WIN32) || defined(_WIN64) + Ort::Session session(*ort_env, test_model_wide.c_str(), session_options); +#else + Ort::Session session(*ort_env, test_model.c_str(), session_options); +#endif + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), + output_names, 1); + + // run inference with optimized model which load serialized prepack initializer +#if defined(_WIN32) || defined(_WIN64) + std::wstring optimized_model_wide = onnxruntime::ToWideString(optimized_model); + Ort::Session session_opt(*ort_env, optimized_model_wide.c_str(), session_options); +#else + Ort::Session session_opt(*ort_env, optimized_model.c_str(), session_options); +#endif + auto ort_outputs_opt = session_opt.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), + output_names, 1); + + // check output of original model and optimized model are equal + ASSERT_EQ(ort_outputs.size(), ort_outputs_opt.size()); + + for (size_t i = 0; i < ort_outputs.size(); ++i) { + const auto& sequences = ort_outputs[i]; + ASSERT_TRUE(sequences.IsTensor()); + + const auto& sequences_opt = ort_outputs_opt[i]; + ASSERT_TRUE(sequences_opt.IsTensor()); + + auto result_ts = sequences.GetTensorTypeAndShapeInfo(); + auto result_ts_opt = sequences_opt.GetTensorTypeAndShapeInfo(); + + ASSERT_EQ(result_ts.GetElementType(), result_ts_opt.GetElementType()); + + ASSERT_EQ(result_ts.GetShape(), result_ts_opt.GetShape()); + + const auto* result_vals = sequences.GetTensorData(); + auto result_span = gsl::make_span(result_vals, ort_outputs.size()); + + const auto* result_vals_opt = sequences_opt.GetTensorData(); + auto result_span_opt = gsl::make_span(result_vals_opt, ort_outputs_opt.size()); + + ASSERT_TRUE(std::equal(result_span_opt.begin(), result_span_opt.end(), result_span.begin(), result_span.end())); + } +} \ No newline at end of file diff --git a/onnxruntime/test/testdata/model_with_external_initializers.onnx b/onnxruntime/test/testdata/model_with_external_initializers.onnx index f815b4000f98f..3538f01b53c18 100644 --- a/onnxruntime/test/testdata/model_with_external_initializers.onnx +++ b/onnxruntime/test/testdata/model_with_external_initializers.onnx @@ -1,7 +1,8 @@ - onnx-example:� -& + + onnx-example:� +, X -PadsY"Pad* +PadsYpad0"Pad* mode"constant� test-model*"BPadsj locationPads.binpZ @@ -16,4 +17,4 @@ test-model*"BPadsj Y   -B \ No newline at end of file +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/model_with_external_initializers.py b/onnxruntime/test/testdata/model_with_external_initializers.py index 8d2589a9e6564..dc64d4a41424a 100644 --- a/onnxruntime/test/testdata/model_with_external_initializers.py +++ b/onnxruntime/test/testdata/model_with_external_initializers.py @@ -35,9 +35,10 @@ def GenerateModel(model_name, external_data_name): # noqa: N802 # Create a node (NodeProto) node_def = helper.make_node( - "Pad", # node name + "Pad", # op type ["X", external_data_name], # inputs ["Y"], # outputs + "pad0", # node name mode="constant", # Attributes ) diff --git a/onnxruntime/test/testdata/model_with_orig_ext_data.onnx b/onnxruntime/test/testdata/model_with_orig_ext_data.onnx index 6f9cce0bc5b4f..47d0c68235099 100644 --- a/onnxruntime/test/testdata/model_with_orig_ext_data.onnx +++ b/onnxruntime/test/testdata/model_with_orig_ext_data.onnx @@ -1,7 +1,8 @@ -  onnx-example:� -: + + onnx-example:� +@ X -model_with_orig_ext_dataY"Pad* +model_with_orig_ext_dataYpad0"Pad* mode"constant� test-model*JBmodel_with_orig_ext_dataj( locationmodel_with_orig_ext_data.binpZ @@ -16,4 +17,4 @@ test-model*JBmodel_with_orig_ext_dataj( Y   -B \ No newline at end of file +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/prepack/MatMul.Weight.bin b/onnxruntime/test/testdata/prepack/MatMul.Weight.bin new file mode 100644 index 0000000000000..0f8a571589c10 Binary files /dev/null and b/onnxruntime/test/testdata/prepack/MatMul.Weight.bin differ diff --git a/onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py b/onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py new file mode 100644 index 0000000000000..86af461edc2c4 --- /dev/null +++ b/onnxruntime/test/testdata/prepack/model_with_external_initializers_and_prepack_kernel.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import onnx +from onnx import TensorProto, helper +from onnx.external_data_helper import set_external_data +from onnx.numpy_helper import from_array + +M = 1 +K = 1 +N = 1 +q_cols = 1 +q_rows = 1 +q_scale_size = 1 + + +def create_external_data_tensor(value, tensor_name, data_type): + tensor = from_array(np.array(value)) + tensor.name = tensor_name + tensor_filename = f"{tensor_name}.bin" + set_external_data(tensor, location=tensor_filename) + + with open(os.path.join(tensor_filename), "wb") as data_file: + data_file.write(tensor.raw_data) + tensor.ClearField("raw_data") + tensor.data_location = onnx.TensorProto.EXTERNAL + tensor.data_type = data_type + return tensor + + +def create_internal_data_tensor(value, tensor_name, data_type): + tensor = helper.make_tensor(name=tensor_name, data_type=data_type, dims=value.shape, vals=value.flatten().tolist()) + print(tensor) + tensor.data_location = onnx.TensorProto.DEFAULT + return tensor + + +def GenerateMatmulNBitsModel(model_name, external_data_name): # noqa: N802 + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [M, K]) # noqa: N806 + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [M, N]) # noqa: N806 + + # Create a node (NodeProto) + node_def = helper.make_node( + op_type="MatMulNBits", # op type + inputs=["A", external_data_name, "scales"], # inputs + outputs=["Y"], # outputs + name="MatMul_0", # node name + domain="com.microsoft", # Custom domain for this operator + accuracy_level=4, # Attributes + bits=4, # Attributes + block_size=32, # Attributes + K=K, # Attributes + N=N, # Attributes + ) + + # Create the graph (GraphProto) + graph_def = helper.make_graph( + [node_def], + "test-model-matmul4bits", + [A], + [Y], + [ + create_external_data_tensor([[171]], external_data_name, TensorProto.UINT8), + create_internal_data_tensor(np.array([1.5], dtype=np.float32), "scales", TensorProto.FLOAT), + ], + ) + + # Create the model + model_def = helper.make_model( + graph_def, + producer_name="onnx-example", + opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)], + ) + + print(f"The ir_version in model: {model_def.ir_version}\n") + print(f"The producer_name in model: {model_def.producer_name}\n") + print(f"The graph in model:\n{model_def.graph}") + onnx.checker.check_model(model_def) + print("The model is checked!") + with open(model_name, "wb") as model_file: + model_file.write(model_def.SerializeToString()) + + +if __name__ == "__main__": + GenerateMatmulNBitsModel("model_with_matmul_nbits.onnx", "MatMul.Weight") diff --git a/onnxruntime/test/testdata/prepack/model_with_matmul_nbits.onnx b/onnxruntime/test/testdata/prepack/model_with_matmul_nbits.onnx new file mode 100644 index 0000000000000..0e06a75a5a7e8 Binary files /dev/null and b/onnxruntime/test/testdata/prepack/model_with_matmul_nbits.onnx differ diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index c4c7a98ba116a..ec7a458237c77 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -42,6 +42,7 @@ static SessionOptions session_options = { ExecutionMode::ORT_SEQUENTIAL, // execution_mode ExecutionOrder::PRIORITY_BASED, // execution_order false, // enable_profiling + false, // save prepacked initializer ORT_TSTR(""), // optimized_model_filepath true, // enable_mem_pattern true, // enable_mem_reuse diff --git a/orttraining/orttraining/models/pipeline_poc/main.cc b/orttraining/orttraining/models/pipeline_poc/main.cc index 1b7d6b9ea26f6..0e40d04ddac8c 100644 --- a/orttraining/orttraining/models/pipeline_poc/main.cc +++ b/orttraining/orttraining/models/pipeline_poc/main.cc @@ -89,6 +89,7 @@ int main(int argc, char* argv[]) { ExecutionMode::ORT_SEQUENTIAL, // execution_mode ExecutionOrder::DEFAULT, // execution_order false, // enable_profiling + false, // save prepacked initializer ORT_TSTR(""), // optimized_model_filepath true, // enable_mem_pattern true, // enable_mem_reuse diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index dae6f613f4329..5a2f1cd13683e 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -37,6 +37,7 @@ static SessionOptions SESSION_OPTION = { ExecutionMode::ORT_SEQUENTIAL, // execution_mode ExecutionOrder::PRIORITY_BASED, // execution_order false, // enable_profiling + false, // save prepacked initializer ORT_TSTR(""), // optimized_model_filepath true, // enable_mem_pattern true, // enable_mem_reuse