diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index b374371446a90..86b44a6784817 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -743,7 +743,7 @@ internal static OrtValue CreateFromTensorObject(TensorBase value, out TensorElem /// /// Creates an OrtValue that contains a string tensor of specified shape, and /// containing empty strings. String tensors are always on CPU. - /// Use FillStringTensorElement to assign individual elements values. + /// Use StringTensorSetElementAt to assign individual elements values. /// /// /// disposable OrtValue diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index c52ca4d1a4631..ac790242409e3 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -15,6 +15,7 @@ public struct OrtTrainingApi public IntPtr LoadCheckpoint; public IntPtr SaveCheckpoint; public IntPtr CreateTrainingSession; + public IntPtr CreateTrainingSessionFromBuffer; public IntPtr TrainingSessionGetTrainingModelOutputCount; public IntPtr TrainingSessionGetEvalModelOutputCount; public IntPtr TrainingSessionGetTrainingModelOutputName; diff --git a/include/onnxruntime/core/framework/ort_value.h b/include/onnxruntime/core/framework/ort_value.h index 48c4e4320dfd7..a071f3182faad 100644 --- a/include/onnxruntime/core/framework/ort_value.h +++ b/include/onnxruntime/core/framework/ort_value.h @@ -68,11 +68,7 @@ struct OrtValue { } bool IsSparseTensor() const { -#if !defined(DISABLE_SPARSE_TENSORS) return (type_ != nullptr && type_->IsSparseTensorType()); -#else - ORT_THROW("Sparse tensor is not supported in this build."); -#endif } onnxruntime::MLDataType Type() const { diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 4a1109b9ec5dc..e33854819c5db 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -38,7 +38,7 @@ Do not modify directly.* | Floor | ai.onnx(6-12,13+) | | | Gather | ai.onnx(1-10,11-12,13+) | | | Gelu | com.microsoft(1+) | | -| Gemm | ai.onnx(7-8,9-10,11+) | | +| Gemm | ai.onnx(7-8,9-10,11-12,13+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 54f493422816f..f5b8a7e3b0ef9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -23,10 +23,12 @@ const validateInputs = (inputs: readonly TensorView[]): void => { const createSplitAttributesFromInputs = (inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => { const splitSizes: number[] = []; + let numOutputs: number = attributes.numOutputs; if (inputs[1].dims[0] > 0) { inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v))); + numOutputs = splitSizes.length; } - return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes}); + return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes}); }; const calculateOutputIndexImpl = (numberOfTensors: number): string => ` @@ -114,7 +116,7 @@ const createSplitProgramInfoLoader = const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes); const metadata: ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey}; - return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)}; + return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)}; }; export const split = (context: ComputeContext, attributes: SplitAttributes): void => { diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 300db24a986f4..0bf27fdf5e5dc 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1715,31 +1715,39 @@ class PlannerImpl { void PartitionIntoStreams(const logging::Logger& /*logger*/, const ExecutionProviders& /*execution_providers*/, const PathString& /*partition_config_file*/) { - stream_nodes_.push_back({}); - node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); - for (auto node_index : graph_viewer_.GetNodesInTopologicalOrder()) { - stream_nodes_[0].push_back(node_index); - node_stream_map_[node_index] = 0; + if (graph_viewer_.NumberOfNodes() > 0) { + stream_nodes_.push_back({}); + node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); + for (auto node_index : graph_viewer_.GetNodesInTopologicalOrder()) { + stream_nodes_[0].push_back(node_index); + node_stream_map_[node_index] = 0; + } + num_logic_streams_ = 1; } - num_logic_streams_ = 1; } Status BuildExecutionPlan(const ExecutionProviders& execution_providers) { // 1. create logic stream instance auto& execution_plan = plan_.execution_plan; - ORT_ENFORCE(num_logic_streams_ == 1 && !stream_nodes_[0].empty()); - execution_plan.reserve(1); - auto first_node_index = stream_nodes_[0][0]; - auto* node = graph_viewer_.GetNode(first_node_index); - onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType(); - const IExecutionProvider* ep = execution_providers.Get(exec_provider_name); - ORT_ENFORCE(ep); - auto node_device_mem_location = ep->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault); - execution_plan.emplace_back(std::make_unique(node_device_mem_location)); - // 2. add steps to the execution plan - for (auto node_index : stream_nodes_[0]) { - execution_plan[0]->steps_.emplace_back(std::make_unique(node_index)); + + if (graph_viewer_.NumberOfNodes() > 0) { + ORT_ENFORCE(num_logic_streams_ == 1 && !stream_nodes_[0].empty()); + execution_plan.reserve(1); + auto first_node_index = stream_nodes_[0][0]; + auto* node = graph_viewer_.GetNode(first_node_index); + onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType(); + const IExecutionProvider* ep = execution_providers.Get(exec_provider_name); + ORT_ENFORCE(ep); + auto node_device_mem_location = ep->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault); + execution_plan.emplace_back(std::make_unique(node_device_mem_location)); + // 2. add steps to the execution plan + for (auto node_index : stream_nodes_[0]) { + execution_plan[0]->steps_.emplace_back(std::make_unique(node_index)); + } + } else { + // graph with no nodes. e.g. subgraph of If might return the input as-is or a constant value from an initializer } + return Status::OK(); } diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index fc2f14263f7a7..df3a7afebc176 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -254,10 +254,11 @@ common::Status SaveInitializedTensors( auto initialized_tensors_to_allocate = id_to_initialized_tensor; for (int ort_value_index : initializer_allocation_order) { const auto entry = initialized_tensors_to_allocate.find(ort_value_index); + ORT_ENFORCE(entry != initialized_tensors_to_allocate.end(), + "OrtValue index: ", ort_value_index, " from initializer_allocation_order not found among initialized tensors"); if (!(utils::HasExternalData(*entry->second) && exec_plan.GetLocation(ort_value_index).Type() == OrtDevice::CPU)) { // can not trace string tensor - ORT_ENFORCE(entry != initialized_tensors_to_allocate.end() && - entry->second->data_type() != ONNX_NAMESPACE::TensorProto_DataType_STRING); + ORT_ENFORCE(entry->second->data_type() != ONNX_NAMESPACE::TensorProto_DataType_STRING, "Can not trace string tensor"); ORT_RETURN_IF_ERROR(planner.Trace(entry->first, entry->second)); } initialized_tensors_to_allocate.erase(entry); diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 56f41154b719c..ea6a629f87cb8 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -223,7 +223,7 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); -#ifdef ENABLE_TRAINING_CORE +#ifdef ENABLE_TRAINING common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context); #endif diff --git a/onnxruntime/core/graph/function_utils.cc b/onnxruntime/core/graph/function_utils.cc index 4b7900194488d..aa0727e3750b0 100644 --- a/onnxruntime/core/graph/function_utils.cc +++ b/onnxruntime/core/graph/function_utils.cc @@ -344,10 +344,15 @@ std::unique_ptr CreateSchema(const std::string& functi std::unordered_map map_copy(model_local_functions.begin(), model_local_functions.end()); std::unordered_map empty_map; - ONNX_NAMESPACE::shape_inference::SymbolTableImpl symbolTable; + + // https://github.com/microsoft/onnxruntime/issues/17061 + // We are passing a nullptr for the symbol table, because symbol table must be global + // for all the shape inferencing to work correctly. Otherwise, unrelated shapes get + // the same symbolic shapes and are marked for memory re-use. This is a Temp fix. + constexpr ONNX_NAMESPACE::shape_inference::SymbolTableImpl* symbolTable = nullptr; ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(*onnx_func_proto, func_domain_to_version, schema_registry, ctx, options, map_copy, - &symbolTable, &empty_map); + symbolTable, &empty_map); }); op_schema->Finalize(); diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc index a3ac4312053aa..dd38ee9b07ee6 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc @@ -462,6 +462,27 @@ bool LayerNormalizationGatherActor::PreCheck(const Graph& /* graph */, return true; } +bool LayerNormalizationGatherActor::PostProcess(Graph& /*graph*/, Node& current_node, + const SliceInfo& info_without_node, + const logging::Logger& /*logger*/, + const std::unordered_map& /*propagate_input_indices*/, + const std::unordered_map>& + /*all_input_cmp_rets*/, + const std::unordered_map& /*new_gather_infos*/) { + // Update LayerNormalization's axis attribute if it is scalar slice. + if (info_without_node.is_scalar_slice) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + auto original_ln_input_rank = info_without_node.input_rank; + axis = axis < 0 ? axis + original_ln_input_rank : axis; + auto new_axis = axis - 1; + + auto& attributes = current_node.GetMutableAttributes(); + attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis)); + } + + return true; +} + bool SoftmaxGatherActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, const logging::Logger& logger, std::unordered_map& propagate_input_indices, @@ -479,6 +500,28 @@ bool SoftmaxGatherActor::PreCheck(const Graph& graph, const Node& current_node, propagate_input_indices, all_input_cmp_rets, shape_update_func); } +bool SoftmaxGatherActor::PostProcess(Graph& graph, Node& current_node, const SliceInfo& info_without_node, + const logging::Logger& logger, + const std::unordered_map& propagate_input_indices, + const std::unordered_map>& all_input_cmp_rets, + const std::unordered_map& new_gather_infos) { + SimplePointwiseGatherActor::PostProcess(graph, current_node, info_without_node, logger, + propagate_input_indices, all_input_cmp_rets, new_gather_infos); + + // Update Softmax's axis attribute if it is scalar slice. + if (info_without_node.is_scalar_slice) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + auto original_ln_input_rank = info_without_node.input_rank; + axis = axis < 0 ? axis + original_ln_input_rank : axis; + auto new_axis = axis - 1; + + auto& attributes = current_node.GetMutableAttributes(); + attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis)); + } + + return true; +} + bool ReshapeGatherActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, const logging::Logger& logger, std::unordered_map& propagate_input_indices, @@ -566,6 +609,11 @@ bool ReshapeGatherActor::PreCheck(const Graph& graph, const Node& current_node, return true; } + LOG_DEBUG_INFO(logger, "Skip handle the Reshape, new_shape_const_values[info.non_negative_axis]:" + + std::to_string(new_shape_const_values[info.non_negative_axis]) + + ", info.output_dim_on_axis.has_dim_value(): " + + std::to_string(info.output_dim_on_axis.has_dim_value()) + "."); + return false; } @@ -604,11 +652,12 @@ bool ReshapeGatherActor::PostProcess( return true; } - // If it selected shape is a dim value, we can update the shape tensor directory. + // If the selected shape is a dim value, we can update the shape tensor directory. if (info_without_node.output_dim_on_axis.has_dim_value()) { new_shape_const_values[slice_axis] = info_without_node.output_dim_on_axis.dim_value(); auto new_shape_arg = - CreateInitializerFromVector(graph, {static_cast(new_shape_const_values.size())}, new_shape_const_values, + CreateInitializerFromVector(graph, {static_cast(new_shape_const_values.size())}, + new_shape_const_values, graph.GenerateNodeArgName(current_node.MutableInputDefs()[1]->Name())); graph_utils::ReplaceNodeInput(current_node, 1, *new_shape_arg); return true; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h index f6715e4bb1f32..0c21be1397636 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h @@ -189,7 +189,7 @@ class LayerNormalizationGatherActor : public UpStreamGatherOperatorActorBase { const logging::Logger& /* logger */, const std::unordered_map& /* propagate_input_indices */, const std::unordered_map>& /* all_input_cmp_rets */, - const std::unordered_map& /* new_gather_infos */) override { return true; } + const std::unordered_map& /* new_gather_infos */) override; }; class SoftmaxGatherActor : public SimplePointwiseGatherActor { @@ -202,6 +202,12 @@ class SoftmaxGatherActor : public SimplePointwiseGatherActor { std::unordered_map& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& /* current_node */, const SliceInfo& /* info_without_node */, + const logging::Logger& /* logger */, + const std::unordered_map& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_gather_infos */) override; }; class ReshapeGatherActor : public UpStreamGatherOperatorActorBase { diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index d9c6126d4bf36..3fe1980141ca5 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -231,7 +231,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul); @@ -464,7 +465,8 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/gemm.cc b/onnxruntime/core/providers/js/operators/gemm.cc index f579d62bdfb5f..04700d0f54705 100644 --- a/onnxruntime/core/providers/js/operators/gemm.cc +++ b/onnxruntime/core/providers/js/operators/gemm.cc @@ -12,7 +12,15 @@ namespace js { ONNX_OPERATOR_TYPED_KERNEL_EX( \ Gemm, \ kOnnxDomain, \ - 11, \ + 13, \ + T, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Gemm); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Gemm, \ + kOnnxDomain, \ + 11, 12, \ T, \ kJsExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 691af48711a56..cfacc1aa6a363 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -25,8 +25,9 @@ class Split : public JsKernel, public SplitBase { if (num_outputs_ < 0) { num_outputs_ = split_sizes.size(); } - } else if (split_sizes_.size() == 0) { - // Compute split_sizes from input shape and num_outputs + } else if (split_sizes_.size() == 0 && info.GetInputCount() < 2) { + // Compute split_sizes from input shape and num_outputs. + // TODO: Shape might not be known at this point, better to handle this in javascript auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast(axis_)).dim_value(); int64_t split_size_sum = 0; if (num_outputs_ < 0) { @@ -44,6 +45,7 @@ class Split : public JsKernel, public SplitBase { ORT_ENFORCE(split_size_sum == total_split_size, "Sum of split sizes (", split_size_sum, ") does not match input size (", total_split_size, ")"); } + // else: let javascript handle all other cases, ie. split_sizes come as input[1] JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1, "numOutputs" : $2, diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index 4df11c2224e27..b5f45b15a5992 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -138,6 +138,7 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri model_proto = model.ToProto(); } else { model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, + ToPathString(filename), initializer_size_threshold); } auto& metadata = model.MetaData(); diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index f8da4e895913a..e2cb82e47f32b 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -4,6 +4,7 @@ #include "core/flatbuffers/schema/ort.fbs.h" #include "core/framework/data_types.h" #include "core/framework/tensorprotoutils.h" +#include "core/framework/TensorSeq.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" #include "core/session/onnxruntime_cxx_api.h" @@ -556,6 +557,41 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBufferNoCopyInitializersUseBuffer) RunOrtModel(test_info); } +// regression test for 2 issues covered by PR #17000 (internally reported issue). +// 1) allocation planner broke in minimal build when subgraph had no nodes. +// 2) usage of a sequence data type caused an exception due to IsSparseTensor() throwing +// instead of allowing the calling code to have #ifdef'd code to handle when IsSparseTensor +// returned true and sparse tensors were disabled. +TEST(OrtModelOnlyTests, GithubIssue17000) { + // need to run the model to + auto model_uri = ORT_TSTR("testdata/ort_github_issue_17000.ort"); + + auto allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + + OrtValue item0, item1; + CreateMLValue(allocator, {1}, {1.f}, &item0); + CreateMLValue(allocator, {2}, {2.f, 3.f}, &item1); + + auto elem_type = DataTypeImpl::GetType(); + auto tensor_seq = std::make_unique(elem_type); + tensor_seq->SetElements({item0, item1}); + + auto mltype = DataTypeImpl::GetType(); + OrtValue value(tensor_seq.release(), mltype, mltype->GetDeleteFunc()); + + OrtModelTestInfo test_info; + test_info.model_filename = model_uri; + test_info.inputs.insert(std::make_pair("seq_in", value)); + test_info.output_names = {"still_has_elements"}; + test_info.output_verifier = [](const std::vector& fetches) { + const auto& output = fetches[0].Get(); + ASSERT_EQ(output.Shape().Size(), 1); + ASSERT_EQ(output.Data()[0], true); // removed one item from seq so should still have elements + }; + + RunOrtModel(test_info); +} + #if !defined(DISABLE_ML_OPS) // test that we can deserialize and run a previously saved ORT format model // for a model with sequence and map outputs diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc index 01016774288e4..a03d0da2538d4 100644 --- a/onnxruntime/test/optimizer/compute_optimizer_test.cc +++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc @@ -638,7 +638,8 @@ TEST(ComputeOptimizerTests, GatherMatMul_ScalarSlicingOnSecondLastDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); @@ -737,7 +738,8 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnSecondLastDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); @@ -826,6 +828,345 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnSecondLastDim) { } } +/* +Test graph includes multiple equivalent subgraphs as below. + graph input [2, 32, 256] (float) + | + LayerNormalization[axis=-1 (as example)] + | + [2, 32, 256] + | + | 0 (scalar) + | / + Gather[axis=1] + | + Identity + | + graph output [2, 256] (float) + +Add an Identity node because currently, we don't allow Gather generates graph output. +*/ +TEST(ComputeOptimizerTests, GatherLayerNormalization) { + std::vector> test_config_pairs{ + // { + // is_scalar_slice, + // ln_axis_before_propagation, + // expected_ln_axis_after_propagation, + // expected to propagate + // } + {true, 0, 0, false}, + {true, 1, 1, false}, + {true, 2, 1, true}, + {true, -3, -3, false}, + {true, -2, -2, false}, + {true, -1, 1, true}, + {false, 0, 0, false}, + {false, 1, 1, false}, + {false, 2, 2, true}, + {false, -3, -3, false}, + {false, -2, -2, false}, + {false, -1, -1, true}, + }; + + constexpr static int64_t gather_axis = 1; + constexpr static int64_t slice_data_value = 0; + + for (auto p : test_config_pairs) { + bool is_scalar_slice = std::get<0>(p); + int64_t ln_axis_before = std::get<1>(p); + int64_t ln_axis_after = std::get<2>(p); + bool expected_to_propagate = std::get<3>(p); + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + + InlinedVector indices; + auto pre_graph_checker = [&indices](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Gather") { + TEST_RETURN_IF_NOT(indices.empty()); + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(node.InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, + indices, require_constant)); + } + } + return Status::OK(); + }; + + auto post_graph_checker = [is_scalar_slice, ln_axis_after, + &indices, expected_to_propagate](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + const auto& input_defs = node.InputDefs(); + + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + if (expected_to_propagate) { + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Gather"); + + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, + values, require_constant)); + for (size_t i = 0; i < values.size(); i++) { + TEST_RETURN_IF_NOT(values[i] == indices[i]); + } + + const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(slice_out_shape != nullptr); + + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == ln_axis_after); + + if (is_scalar_slice) { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 256); + } else { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 3); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 1); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && + slice_out_shape->dim(2).dim_value() == 256); + } + + } else { + TEST_RETURN_IF_NOT(producer_node == nullptr); + } + } + } + + return Status::OK(); + }; + + auto build_test_case = [is_scalar_slice, ln_axis_before](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 32, 256}}); + auto* input2_arg = builder.MakeInput({{256}}); + auto* input3_arg = builder.MakeInput({{256}}); + auto* ln_out = builder.MakeIntermediate(); + builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) + .AddAttribute("axis", ln_axis_before); + + std::vector slice_inputs; + NodeArg* indices_initializer = nullptr; + + if (is_scalar_slice) { + indices_initializer = builder.MakeScalarInitializer(slice_data_value); + } else { + indices_initializer = builder.MakeInitializer({1}, {slice_data_value}); + } + + slice_inputs = {ln_out, indices_initializer}; + + auto* gather_out = builder.MakeIntermediate(); + builder.AddNode("Gather", slice_inputs, + {gather_out}) + .AddAttribute("axis", gather_axis); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {gather_out}, {identity_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); + } +} + +/* +Test graph includes multiple equivalent subgraphs as below. + graph input [2, 4, 32, 256] (float) + | + Softmax[axis=3 (as example)] + | + [2, 4, 32, 256] + | + | 0 (scalar) + | / + Gather[axis=1] + | + Identity + | + graph output [2, 32, 256] (float) + +Add an Identity node because currently, we don't allow Gather generates graph output. +*/ +TEST(ComputeOptimizerTests, GatherSoftmax) { + std::vector> test_config_pairs{ + // {is_scalar_slice, softmax_axis_before_propagation, + // expected_softmax_axis_after_propagation, expected to propagate} + {true, 0, 0, false}, + {true, 1, 1, false}, + {true, 2, 1, true}, + {true, 3, 2, true}, + {true, -4, -4, false}, + {true, -3, -3, false}, + {true, -2, 1, true}, + {true, -1, 2, true}, + {false, 0, 0, false}, + {false, 1, 1, false}, + {false, 2, 2, true}, + {false, 3, 3, true}, + {false, -4, -4, false}, + {false, -3, -3, false}, + {false, -2, -2, true}, + {false, -1, -1, true}, + }; + + constexpr static int64_t gather_axis = 1; + constexpr static int64_t slice_data_value = 0; + + for (auto p : test_config_pairs) { + bool is_scalar_slice = std::get<0>(p); + int64_t softmax_axis_before = std::get<1>(p); + int64_t softmax_axis_after = std::get<2>(p); + bool expected_to_propagate = std::get<3>(p); + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + + InlinedVector indices; + auto pre_graph_checker = [&indices](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Gather") { + TEST_RETURN_IF_NOT(indices.empty()); + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(node.InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, + indices, require_constant)); + } + } + return Status::OK(); + }; + + auto post_graph_checker = [is_scalar_slice, softmax_axis_after, + &indices, expected_to_propagate](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Softmax") { + const auto& input_defs = node.InputDefs(); + + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + if (expected_to_propagate) { + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Gather"); + + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values, + require_constant)); + for (size_t i = 0; i < values.size(); i++) { + TEST_RETURN_IF_NOT(values[i] == indices[i]); + } + + const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(slice_out_shape != nullptr); + + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == softmax_axis_after); + + if (is_scalar_slice) { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 3); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 32); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && + slice_out_shape->dim(2).dim_value() == 256); + } else { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 4); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 1); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && + slice_out_shape->dim(2).dim_value() == 32); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(3)) && + slice_out_shape->dim(3).dim_value() == 256); + } + + } else { + TEST_RETURN_IF_NOT(producer_node == nullptr); + } + } + } + + return Status::OK(); + }; + + auto build_test_case = [is_scalar_slice, softmax_axis_before](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 4, 32, 256}}); + auto* softmax_out = builder.MakeIntermediate(); + builder.AddNode("Softmax", {input1_arg}, {softmax_out}) + .AddAttribute("axis", softmax_axis_before); + + std::vector slice_inputs; + + NodeArg* indices_initializer = nullptr; + + if (is_scalar_slice) { + indices_initializer = builder.MakeScalarInitializer(slice_data_value); + } else { + indices_initializer = builder.MakeInitializer({1}, {slice_data_value}); + } + + slice_inputs = {softmax_out, indices_initializer}; + + auto* gather_out = builder.MakeIntermediate(); + builder.AddNode("Gather", slice_inputs, + {gather_out}) + .AddAttribute("axis", gather_axis); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {gather_out}, {identity_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); + } +} + TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnBatchDim) { const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_reshape_scalar_batch_dim.onnx"; @@ -835,7 +1176,8 @@ TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnBatchDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); @@ -928,7 +1270,8 @@ TEST(ComputeOptimizerTests, GatherReshape_SlicingOnBatchDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); diff --git a/onnxruntime/test/testdata/ort_github_issue_17000.onnx b/onnxruntime/test/testdata/ort_github_issue_17000.onnx new file mode 100644 index 0000000000000..8320c19cb6de4 Binary files /dev/null and b/onnxruntime/test/testdata/ort_github_issue_17000.onnx differ diff --git a/onnxruntime/test/testdata/ort_github_issue_17000.ort b/onnxruntime/test/testdata/ort_github_issue_17000.ort new file mode 100644 index 0000000000000..08d9826dd5346 Binary files /dev/null and b/onnxruntime/test/testdata/ort_github_issue_17000.ort differ diff --git a/onnxruntime/test/testdata/ort_github_issue_17000.py b/onnxruntime/test/testdata/ort_github_issue_17000.py new file mode 100644 index 0000000000000..43c10f5590212 --- /dev/null +++ b/onnxruntime/test/testdata/ort_github_issue_17000.py @@ -0,0 +1,77 @@ +import numpy as np +import onnx +from onnx import TensorProto, helper, numpy_helper + + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) + if doc_string == "": + node.doc_string = "" + order_repeated_field(node.attribute, "name", kwargs.keys()) + return node + + +def make_graph(*args, doc_string=None, **kwargs): + graph = helper.make_graph(*args, doc_string=doc_string, **kwargs) + if doc_string == "": + graph.doc_string = "" + return graph + + +test_graph = make_graph( + name="test_graph", + # model input of a sequence type to test IsSparseTensor issue + inputs=[ + helper.make_tensor_sequence_value_info("seq_in", TensorProto.FLOAT, shape=None), + ], + outputs=[ + helper.make_tensor_value_info("still_has_elements", TensorProto.BOOL, shape=[]), + ], + initializer=[ + numpy_helper.from_array(np.array(0, dtype="int64"), name="i0"), + ], + nodes=[ + make_node("SequenceLength", inputs=["seq_in"], outputs=["seq_len"], name="get_seq_len"), + make_node("Greater", inputs=["seq_len", "i0"], outputs=["has_elements"], name="get_has_elements"), + # If node with one branch that has no nodes to test the allocation planner issue + # if sequence has elements: + # remove one + # output bool of whether it still has elements + # else: + # output false (gives us branch with no nodes) + make_node( + "If", + name="test_if", + inputs=["has_elements"], + outputs=["still_has_elements"], + then_branch=make_graph( + name="then", + inputs=[], + outputs=[helper.make_tensor_value_info("then_bool_out", TensorProto.BOOL, shape=[])], + nodes=[ + make_node("SequenceErase", inputs=["seq_in", "i0"], outputs=["seq_less_one"]), + make_node("SequenceLength", inputs=["seq_less_one"], outputs=["new_seq_len"]), + make_node("Greater", inputs=["new_seq_len", "i0"], outputs=["then_bool_out"]), + ], + ), + else_branch=make_graph( + name="else", + initializer=[numpy_helper.from_array(np.array(False, dtype="bool"), name="else_bool_out")], + inputs=[], + outputs=[helper.make_tensor_value_info("else_bool_out", TensorProto.BOOL, shape=[])], + nodes=[], + ), + ), + ], +) + +# Graph with Sequence operations and an If node that has a subgraph with no nodes +model = helper.make_model(opset_imports=[helper.make_operatorsetid("ai.onnx", 14)], ir_version=7, graph=test_graph) + +onnx.shape_inference.infer_shapes(model, strict_mode=True) +onnx.save(model, "ort_github_issue_17000.onnx") diff --git a/onnxruntime/test/testdata/required_ops.config b/onnxruntime/test/testdata/required_ops.config index ac9d46666e1b6..e70362bab4017 100644 --- a/onnxruntime/test/testdata/required_ops.config +++ b/onnxruntime/test/testdata/required_ops.config @@ -3,9 +3,9 @@ ai.onnx;7;Abs,Add,And,BatchNormalization,Concat,Conv,Dropout,Flatten,Foo,Gather, ai.onnx;8;Add,Conv,Flatten,Gemm,MatMul,MaxPool,Mul,Relu,Reshape ai.onnx;9;Abs,Add,BatchNormalization,Cast,Clip,Concat,Constant,ConstantOfShape,Conv,Div,Equal,Gather,Gemm,Identity,If,LayerNormalization,LeakyRelu,Loop,MatMul,Mul,Pow,ReduceMean,Relu,Reshape,Scan,Shape,Sigmoid,Slice,Softmax,Softsign,Sqrt,Squeeze,Sub,Tanh,Transpose,Unsqueeze ai.onnx;10;Add,Cast,Concat,ConstantOfShape,Div,Dropout,Erf,Expand,Gather,Greater,Identity,If,LayerNormalization,Loop,MatMul,Mul,Neg,NonZero,Pow,ReduceMean,ReduceSum,Shape,Sqrt,Squeeze,Sub,Tanh,Transpose,Unsqueeze -ai.onnx;11;Abs,Add,ArgMax,BatchNormalization,Cast,Clip,Concat,Constant,ConstantOfShape,Conv,Div,Equal,Exp,Expand,Flatten,Gather,Gemm,Identity,If,LayerNormalization,Log,Loop,MatMul,MatMulInteger,Max,Min,Mul,Neg,Pow,RandomUniform,Range,ReduceMean,ReduceSum,ReduceSumSquare,Relu,Reshape,Scan,SequenceConstruct,SequenceInsert,SequenceLength,Shape,Sigmoid,Slice,Softmax,Split,Sqrt,Squeeze,Sub,Sum,Tanh,Transpose,Unsqueeze,Where +ai.onnx;11;Abs,Add,ArgMax,BatchNormalization,Cast,Clip,Concat,Constant,ConstantOfShape,Conv,Div,Equal,Exp,Expand,Flatten,Gather,Gemm,Identity,If,LayerNormalization,Log,Loop,MatMul,MatMulInteger,Max,Min,Mul,Neg,Pow,RandomUniform,Range,ReduceMean,ReduceSum,ReduceSumSquare,Relu,Reshape,Scan,SequenceConstruct,SequenceErase,SequenceInsert,SequenceLength,Shape,Sigmoid,Slice,Softmax,Split,Sqrt,Squeeze,Sub,Sum,Tanh,Transpose,Unsqueeze,Where ai.onnx;12;Add,And,Cast,Concat,Constant,ConstantOfShape,Conv,CumSum,Div,Dropout,DynamicQuantizeLinear,Equal,Erf,Expand,Flatten,Gather,GatherND,Gemm,GlobalAveragePool,Greater,Identity,If,IsInf,LayerNormalization,Less,Loop,MatMul,MatMulInteger,Min,Mul,Not,Pad,Pow,RandomNormalLike,RandomUniform,ReduceMean,ReduceSum,Relu,Reshape,Shape,Slice,Softmax,SoftmaxCrossEntropyLoss,SparseSoftmaxCrossEntropy,Split,Sqrt,Squeeze,Sub,Tanh,Transpose,Unsqueeze,Where -ai.onnx;13;Abs,Add,Cast,Concat,ConstantOfShape,Conv,DequantizeLinear,DynamicQuantizeLinear,Equal,Expand,FooBar,FooBar_Attr,Gather,Identity,LayerNormalization,MatMul,MatMulInteger,Mul,Pad,Pow,QuantizeLinear,Range,ReduceSum,Reshape,Shape,Tanh,Transpose,Unsqueeze,Where +ai.onnx;13;Abs,Add,Cast,Concat,ConstantOfShape,Conv,DequantizeLinear,DynamicQuantizeLinear,Equal,Expand,FooBar,FooBar_Attr,Gather,Greater,Identity,If,LayerNormalization,MatMul,MatMulInteger,Mul,Pad,Pow,QuantizeLinear,Range,ReduceSum,Reshape,Shape,Tanh,Transpose,Unsqueeze,Where ai.onnx;14;Add,ArgMax,Cast,Conv,Identity,Relu,Sigmoid,Sub ai.onnx;314159;Add ai.onnx.contrib;1;StringLower diff --git a/onnxruntime/test/testdata/required_ops_and_types.config b/onnxruntime/test/testdata/required_ops_and_types.config index 17687906d7250..41f374214747b 100644 --- a/onnxruntime/test/testdata/required_ops_and_types.config +++ b/onnxruntime/test/testdata/required_ops_and_types.config @@ -1,9 +1,12 @@ # required ops and types for ORT format models in testdata -ai.onnx;1;Conv{"inputs": {"0": ["float"]}},Foo,Identity +ai.onnx;1;Conv{"inputs": {"0": ["float"]}} ai.onnx;5;Reshape ai.onnx;6;Relu{"inputs": {"0": ["float"]}} ai.onnx;7;Add{"inputs": {"0": ["float"]}},Gemm{"inputs": {"0": ["float"]}},Mul{"inputs": {"0": ["float"]}} ai.onnx;8;MaxPool{"inputs": {"0": ["float"]}},Sum{"inputs": {"0": ["float"]}} ai.onnx;9;Cast{"inputs": {"0": ["float"]}, "outputs": {"0": ["bool"]}} -ai.onnx;11;ArgMax{"inputs": {"0": ["float"]}},If,Loop +ai.onnx;10;QLinearConv{"inputs": {"0": ["uint8_t"]}} +ai.onnx;11;ArgMax{"inputs": {"0": ["float"]}},Clip{"inputs": {"0": ["float"]}},Conv{"inputs": {"0": ["float"]}},If,Loop,SequenceErase,SequenceLength +ai.onnx;13;DequantizeLinear{"inputs": {"0": ["int32_t", "uint8_t"]}},Greater{"inputs": {"0": ["int64_t"]}},If,QuantizeLinear{"outputs": {"0": ["uint8_t"]}} ai.onnx.ml;1;ArrayFeatureExtractor,LinearClassifier,Normalizer,ZipMap +test;1;Foo diff --git a/onnxruntime/test/testdata/training_api/ort_format/checkpoint b/onnxruntime/test/testdata/training_api/ort_format/checkpoint new file mode 100644 index 0000000000000..ab35c9ad5acde Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/checkpoint differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort b/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort new file mode 100644 index 0000000000000..69b2c7e029de0 Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort b/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort new file mode 100644 index 0000000000000..88f192462362d Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py b/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py new file mode 100644 index 0000000000000..70e8c4ac011a9 --- /dev/null +++ b/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""This file is used to generate test data for ort format model tests in + orttraining/orttraining/test/training_api/core/training_capi_tests.cc.""" + +import onnx +import torch +import torch.nn as nn + +from onnxruntime.training import artifacts + + +class SimpleNet(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + +def model_export(pt_model, model_path, input_size): + # Generate random input data + input_data = torch.randn(32, input_size) + torch.onnx.export( + pt_model, + input_data, + model_path, + input_names=["input"], + output_names=["output"], + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + ) + + +def main(): + # Set the dimensions for input, hidden, and output layers + input_size = 10 + hidden_size = 20 + output_size = 5 + + # Create an instance of the neural network + pt_model = SimpleNet(input_size, hidden_size, output_size) + + train_model_path = "simplenet_training.onnx" + model_export(pt_model, train_model_path, input_size) + + onnx_model = onnx.load(train_model_path) + + requires_grad = ["fc2.weight", "fc2.bias"] + frozen_params = [param.name for param in onnx_model.graph.initializer if param.name not in requires_grad] + + # Generate the training artifacts. + artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + frozen_params=frozen_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + ort_format=True, + ) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/testdata/training_api/ort_format/training_model.ort b/onnxruntime/test/testdata/training_api/ort_format/training_model.ort new file mode 100644 index 0000000000000..94bda328a9f9f Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/training_model.ort differ diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index eac17f3d4d2e8..3f3aa396e6ca0 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -174,10 +174,11 @@ struct PyOptimizer { PyOptimizer(const std::string optimizer_model_uri, onnxruntime::training::api::CheckpointState* state, std::vector> providers, PySessionOptions* session_options) : optimizer_() { + auto model_identifiers = onnxruntime::training::api::ModelIdentifiers("", std::nullopt, optimizer_model_uri); auto env = GetTrainingEnv().GetORTEnv(); // XXX: We hope that env will be around when optimizer needs it. optimizer_ = std::make_shared( - optimizer_model_uri, state, session_options->value, *env, providers, session_options->custom_op_domains_); + model_identifiers, state, session_options->value, *env, providers, session_options->custom_op_domains_); } std::shared_ptr optimizer_; @@ -941,9 +942,10 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn OrtDevice device, PySessionOptions* session_options) { std::vector> provider = GetExecutionProvidersForTrainingApis(device); auto env = GetTrainingEnv().GetORTEnv(); - return std::make_unique( - model_uri, state, session_options->value, *env, provider, eval_model_uri, - session_options->custom_op_domains_); + auto model_identifiers = onnxruntime::training::api::ModelIdentifiers(model_uri, eval_model_uri, std::nullopt); + return std::make_unique(model_identifiers, + state, session_options->value, *env, provider, + session_options->custom_op_domains_); })) .def("train_step", [](onnxruntime::training::api::Module* model, diff --git a/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py b/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py index 7b24bb400b162..1213342004d48 100644 --- a/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import copy +import logging import os from typing import List, Optional, Set, Tuple, Union @@ -70,13 +71,16 @@ def _move_initializers_to_inputs(model: onnx.ModelProto, initializer_names: Opti def _gradient_model_for( model: onnx.ModelProto, requires_grad: Set[str], - output_names: List[str], loss_name: str, options: Optional[SessionOptions] = None, ) -> onnx.ModelProto: """Builds the gradient graph on top of the given input forward only graph.""" - builder = GradientGraphBuilder(model.SerializeToString(), set(output_names), requires_grad, loss_name, options) + logging.debug( + "The loss output is %s. The gradient graph will be built starting from %s_grad.", loss_name, loss_name + ) + + builder = GradientGraphBuilder(model.SerializeToString(), {loss_name}, requires_grad, loss_name, options) builder.build() return onnx.load_from_string(builder.get_model()) @@ -123,7 +127,7 @@ def build_gradient_graph( optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options)) # Assumption is that the first graph output is the loss output - gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names, output_names[0], options) + gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names[0], options) _reorder_outputs(gradient_model, output_names, requires_grad) diff --git a/orttraining/orttraining/python/training/onnxblock/onnxblock.py b/orttraining/orttraining/python/training/onnxblock/onnxblock.py index 9f90a5a0c30cd..a2922353ac70e 100644 --- a/orttraining/orttraining/python/training/onnxblock/onnxblock.py +++ b/orttraining/orttraining/python/training/onnxblock/onnxblock.py @@ -205,6 +205,8 @@ def __call__(self, *args, **kwargs): model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY ) + logging.debug("Adding gradient accumulation nodes for training block %s", self.__class__.__name__) + _training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad) accessor._GLOBAL_ACCESSOR.model.CopyFrom(self._training_model) diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc index 4fa3844717ef9..1369c9c69865a 100644 --- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc @@ -331,9 +331,12 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad) { #if defined(USE_CUDA) providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider()); #endif - auto model = std::make_unique(model_uri, &state, session_option, + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); + auto model = std::make_unique(model_identifier, &state, session_option, *env, providers); - auto optimizer = std::make_unique(optim_uri, &state, session_option, + auto optimizer = std::make_unique(model_identifier, &state, session_option, *env, providers); // Remove the temporary directory if it already exists. diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index ec0c7a1968ba4..2170f7957e6a6 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -76,9 +76,12 @@ void TestModuleExport(const std::vector>& pr std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(training_model_uri), + std::optional(onnxruntime::ToUTF8String(eval_model_uri)), + std::nullopt); auto model = std::make_unique( - ToUTF8String(training_model_uri), &state, onnxruntime::SessionOptions(), - *env, providers, ToUTF8String(eval_model_uri)); + model_identifier, &state, onnxruntime::SessionOptions(), + *env, providers); auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir"); if (Env::Default().FolderExists(test_dir)) { @@ -141,7 +144,9 @@ TEST(TrainingApiTest, ModuleParametersSize) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model = std::make_unique(ToUTF8String(model_uri), + auto model_identifiers = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, std::nullopt); + auto model = std::make_unique(model_identifiers, &state, session_option, *env, std::vector>()); size_t params_size = 0; @@ -164,7 +169,10 @@ TEST(TrainingApiTest, ModuleCopyBufferToParameters) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model = std::make_unique(ToUTF8String(model_uri), + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::nullopt); + auto model = std::make_unique(model_identifier, &state, session_option, *env, std::vector>()); int64_t params_size = static_cast(model->GetParametersSize()); @@ -202,7 +210,10 @@ TEST(TrainingApiTest, ModuleTrainStep) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model = std::make_unique(ToUTF8String(model_uri), + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::nullopt); + auto model = std::make_unique(model_identifier, &state, session_option, *env, std::vector>()); ASSERT_EQ(model->GetTrainingModelOutputCount(), 1); @@ -274,8 +285,12 @@ TEST(TrainingApiTest, OptimizerCreatedWithOptimizerCheckpointState) { ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); + std::shared_ptr model = std::make_shared( - ToUTF8String(model_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); // Load state dict from faked optimizer checkpoint state. @@ -285,7 +300,7 @@ TEST(TrainingApiTest, OptimizerCreatedWithOptimizerCheckpointState) { {"momentum0", "momentum1"}, external_optimizer_checkpoint_state)); std::shared_ptr optim = std::make_shared( - ToUTF8String(optim_uri), &new_state, session_option, *env, providers); + model_identifier, &new_state, session_option, *env, providers); ASSERT_TRUE(optim.get() != nullptr); } @@ -320,8 +335,12 @@ void TestLRSchduler(const std::basic_string& test_file_name, ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); + std::shared_ptr model = std::make_shared( - ToUTF8String(model_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); OrtValue input, target; @@ -351,7 +370,7 @@ void TestLRSchduler(const std::basic_string& test_file_name, } std::shared_ptr optim = std::make_shared( - ToUTF8String(optim_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); // KNOWN ISSUE: LinearLRScheduler by default use optim's states to calculate the first step's learning rate. @@ -445,11 +464,15 @@ TEST(TrainingApiTest, OptimStep) { providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider()); #endif ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); auto model = std::make_unique( - ToUTF8String(model_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); auto optim = std::make_unique( - ToUTF8String(optim_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); OrtValue input, target; diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index e864f3b8632de..d734be8e3474b 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "gmock/gmock.h" #include "onnxruntime_c_api.h" #include "onnxruntime_training_c_api.h" @@ -16,6 +17,7 @@ namespace onnxruntime::training::test { #define MODEL_FOLDER ORT_TSTR("testdata/training_api/") +#define ORT_FORMAT_MODEL_FOLDER ORT_TSTR("testdata/training_api/ort_format/") TEST(TrainingCApiTest, SaveCheckpoint) { auto model_uri = MODEL_FOLDER "training_model.onnx"; @@ -220,4 +222,100 @@ TEST(TrainingCApiTest, RegisterCustomOps) { ASSERT_TRUE(loss.front().IsTensor()); } +TEST(TrainingCApiTest, LoadModelsAndCreateSession) { + auto model_path = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + model_path); +} + +TEST(TrainingCApiTest, LoadModelsAndCreateSession_ORTFormat) { + auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort"; + auto eval_train_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort"; + auto optimizer_model_path = ORT_FORMAT_MODEL_FOLDER "optimizer_model.ort"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(ORT_FORMAT_MODEL_FOLDER "checkpoint"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_path, + eval_train_model_path, + optimizer_model_path); +} + +TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) { + auto model_path = MODEL_FOLDER "training_model.onnx"; + size_t model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len)); + std::vector train_model_data(model_data_len); + std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len); + ASSERT_TRUE(train_model_data.size() == model_data_len); + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_data); +} + +TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) { + auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort"; + auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort"; + auto optimizer_model_path = ORT_FORMAT_MODEL_FOLDER "optimizer_model.ort"; + size_t model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(train_model_path, model_data_len)); + std::vector train_model_data(model_data_len); + { + std::ifstream bytes_stream(train_model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len); + ASSERT_TRUE(train_model_data.size() == model_data_len); + } + + model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, model_data_len)); + std::vector eval_model_data(model_data_len); + { + std::ifstream bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(eval_model_data.data()), model_data_len); + ASSERT_TRUE(eval_model_data.size() == model_data_len); + } + + model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(optimizer_model_path, model_data_len)); + std::vector optimizer_model_data(model_data_len); + { + std::ifstream bytes_stream(optimizer_model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(optimizer_model_data.data()), model_data_len); + ASSERT_TRUE(optimizer_model_data.size() == model_data_len); + } + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(ORT_FORMAT_MODEL_FOLDER "checkpoint"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), + checkpoint_state, train_model_data, + eval_model_data, optimizer_model_data); +} + +TEST(TrainingCApiTest, LoadModelsFromBufferThrows) { + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + + try { + std::vector train_model_data; + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_data); + } catch (const std::exception& ex) { + ASSERT_THAT(ex.what(), + testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL.")); + } +} } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index b3042c449a50b..0af737074964d 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -190,7 +190,29 @@ struct OrtTrainingApi { ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, - _Outptr_ OrtTrainingSession** out); + _Outptr_result_maybenull_ OrtTrainingSession** out); + + /** \brief Create a training session that can be used to begin or resume training. + * This api provides a way to load all the training artifacts from buffers instead of files. + * + * \param[in] env Environment to be used for the training session. + * \param[in] options Session options that the user can customize for this training session. + * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. + * \param[in] train_model_data Buffer containing the model data to be used to perform training + * \param[in] train_data_length Length of the buffer containing train_model_data + * \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation + * \param[in] eval_data_length Length of the buffer containing eval_model_data + * \param[in] optim_model_data Buffer containing the model data to be used to perform weight update + * \param[in] optim_data_length Length of the buffer containing optim_model_data + * \param[out] out Created training session. + * + */ + ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out); /// @} diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 5bfdfcc74e817..0edef20ba6da8 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -176,6 +176,20 @@ class TrainingSession : public detail::Base { const std::optional>& eval_model_path = std::nullopt, const std::optional>& optimizer_model_path = std::nullopt); + /** \brief Create a training session that can be used to begin or resume training. + * This constructor allows the users to load the models from buffers instead of files. + * + * \param[in] env Env to be used for the training session. + * \param[in] session_options SessionOptions that the user can customize for this training session. + * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. + * \param[in] train_model_data Buffer containing training model data. + * \param[in] eval_model_data Buffer containing evaluation model data. + * \param[in] optim_model_data Buffer containing optimizer model (used for performing weight/parameter update). + * + */ + TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state, + const std::vector& train_model_data, const std::vector& eval_model_data = {}, + const std::vector& optim_model_data = {}); /// @} /// \name Implementing The Training Loop diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index 393e5b01f7f85..066147708863f 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -24,6 +24,23 @@ inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& se ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); } +inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options, + CheckpointState& checkpoint_state, + const std::vector& train_model_data, + const std::vector& eval_model_data, + const std::vector& optim_model_data) { + ThrowOnError(GetTrainingApi().CreateTrainingSessionFromBuffer( + env, session_options, checkpoint_state, + train_model_data.data(), train_model_data.size(), + eval_model_data.data(), eval_model_data.size(), + optim_model_data.data(), optim_model_data.size(), + &p_)); + + ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_)); + + ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); +} + inline std::vector TrainingSession::TrainStep(const std::vector& input_values) { std::vector output_values; output_values.reserve(training_model_output_count_); diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 29300bbb7e8ec..d1775e358163c 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -12,7 +12,6 @@ #include "core/graph/graph_utils.h" #include "orttraining/training_api/checkpoint.h" -#include "orttraining/training_api/utils.h" using namespace onnxruntime; @@ -150,12 +149,11 @@ Status Parameter::ResetGrad() { return Status::OK(); } -Module::Module(const std::string& train_model_path_or_bytes, +Module::Module(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, const std::vector>& providers, - const std::optional& eval_model_path_or_bytes, [[maybe_unused]] gsl::span op_domains) : state_{state} { // Enforce weight prepacking is disabled @@ -176,7 +174,12 @@ Module::Module(const std::string& train_model_path_or_bytes, } #endif - ORT_THROW_IF_ERROR(train_sess_->Load(train_model_path_or_bytes)); + // Load the training model + ORT_THROW_IF_ERROR(std::holds_alternative(model_identifiers.train_model) + ? train_sess_->Load(std::get(model_identifiers.train_model)) + : train_sess_->Load(std::get>(model_identifiers.train_model).data(), + static_cast(std::get>(model_identifiers.train_model).size()))); + for (const auto& provider : providers) { ORT_THROW_IF_ERROR(train_sess_->RegisterExecutionProvider(provider)); } @@ -239,7 +242,6 @@ Module::Module(const std::string& train_model_path_or_bytes, // Copy ortvalue buffer from CPU to target_device for this "param_name" (based on graph partitioning) // Only copies data if the target device is not the same as the current device the buffer is placed on - OrtValue& param_data = params_iter->second->Data(); ORT_ENFORCE(param_data.IsTensor()); const Tensor& param_data_tensor = param_data.Get(); @@ -278,47 +280,57 @@ Module::Module(const std::string& train_model_path_or_bytes, } } - if (eval_model_path_or_bytes.has_value()) { + if (model_identifiers.IsEvalModelAvailable()) { eval_sess_ = std::make_unique(session_options, env); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) if (!op_domains.empty()) { ORT_THROW_IF_ERROR(eval_sess_->AddCustomOpDomains(op_domains)); } #endif - - ORT_THROW_IF_ERROR(eval_sess_->Load(eval_model_path_or_bytes.value())); - for (const auto& provider : providers) { - ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider)); - } - ORT_THROW_IF_ERROR(eval_sess_->Initialize()); - utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_); - - // Eval model validation - // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval - // graphs, and all the weights present in both graphs match. - // TODO: Add the checks instead of making assumptions?? - InlinedVector eval_user_input_names, eval_param_input_names; - for (const auto& input_name : eval_input_names_) { - if (state_->module_checkpoint_state.named_parameters.find(input_name) != - state_->module_checkpoint_state.named_parameters.end()) { - // it is a parameter - eval_param_input_names.emplace_back(input_name); - continue; - } else { - // It is user input. We handle user inputs separately in the eval - // because the eval graph might have different user inputs. - // Eg if loss is not a part of the eval graph, it won't have - // certain inputs like targets - eval_user_input_names.emplace_back(input_name); - } + if (std::holds_alternative>(model_identifiers.eval_model)) { + ORT_THROW_IF_ERROR(eval_sess_->Load(std::get>(model_identifiers.eval_model).value())); + } else { + auto model_data = std::get>(model_identifiers.eval_model); + ORT_THROW_IF_ERROR(eval_sess_->Load(model_data.data(), static_cast(model_data.size()))); } - eval_input_names_ = eval_user_input_names; - eval_user_input_count_ = eval_user_input_names.size(); - eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); + } else { + return; + } - // Keep a copy of the eval model path to be able to later export the model for inferencing. - // The inference model will be reconstructed from the eval model. - eval_model_path_ = eval_model_path_or_bytes.value(); + for (const auto& provider : providers) { + ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider)); + } + ORT_THROW_IF_ERROR(eval_sess_->Initialize()); + utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_); + + // Eval model validation + // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval + // graphs, and all the weights present in both graphs match. + // TODO(askhade): Add the checks instead of making assumptions?? + InlinedVector eval_user_input_names, eval_param_input_names; + for (const auto& input_name : eval_input_names_) { + if (state_->module_checkpoint_state.named_parameters.find(input_name) != + state_->module_checkpoint_state.named_parameters.end()) { + // it is a parameter + eval_param_input_names.emplace_back(input_name); + continue; + } else { + // It is user input. We handle user inputs separately in the eval + // because the eval graph might have different user inputs. + // Eg if loss is not a part of the eval graph, it won't have + // certain inputs like targets + eval_user_input_names.emplace_back(input_name); + } + } + eval_input_names_ = eval_user_input_names; + eval_user_input_count_ = eval_user_input_names.size(); + eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); + + // Keep a copy of the eval model path to be able to later export the model for inferencing. + // The inference model will be reconstructed from the eval model. + // TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer. + if (std::holds_alternative>(model_identifiers.eval_model)) { + eval_model_path_ = std::get>(model_identifiers.eval_model); } } @@ -486,14 +498,14 @@ Status Module::EvalStep(const std::vector& inputs, std::vector graph_output_names) const { - ORT_RETURN_IF(!eval_sess_ || eval_model_path_.empty(), + ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(), "Eval model was not provided. Cannot export a model for inferencing."); ONNX_NAMESPACE::ModelProto eval_model; - ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_), eval_model)); + ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model)); // Clone the eval mode into an inference onnxruntime::Model. std::shared_ptr inference_model; diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 9013ab22c124f..adb633343263e 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -3,7 +3,9 @@ #pragma once +#include #include "core/session/inference_session.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime { namespace training { @@ -73,12 +75,12 @@ struct Module { public: // Initialize a module from an ORT inference session with loaded // training ONNX model and load parameters - Module(const std::string& train_model_path_or_bytes, + // The model and checkpoint state can be provided as a file path or a byte array + Module(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, const std::vector>& providers, - const std::optional& eval_model_path_or_bytes = std::nullopt, gsl::span op_domains = gsl::span()); // Return the trainable/nontrainable parameters @@ -159,7 +161,7 @@ struct Module { CheckpointState* state_; // Non owning pointer to the state. bool accumulate_gradient_ = false; - std::string eval_model_path_; + std::optional eval_model_path_; size_t train_user_input_count_{0U}; size_t eval_user_input_count_{0U}; }; diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index b84009e7f3591..6693bba348648 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -13,6 +13,8 @@ #include "orttraining/training_api/ort_training_apis.h" #include "orttraining/training_api/training_session.h" +using namespace onnxruntime::training::api; + namespace { std::vector> CreateProviders( @@ -26,44 +28,85 @@ std::vector> CreateProviders( return execution_providers; } +static OrtStatus* CreateSessionAndLoadModel(_In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, + _Inout_ OrtCheckpointState* checkpoint_state, + const ModelIdentifiers& model_identifiers, + std::unique_ptr& train_sess) { + auto chkpt_state = reinterpret_cast(checkpoint_state); + + using ProvidersType = std::vector>; + train_sess = std::make_unique(env->GetEnvironment(), + options == nullptr ? onnxruntime::SessionOptions() : options->value, + options == nullptr + ? ProvidersType() + : CreateProviders(options->provider_factories), + chkpt_state, + model_identifiers, + options == nullptr + ? gsl::span() + : options->custom_op_domains_); + + return nullptr; +} + } // namespace ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, - _In_ const ORTCHAR_T* optimizer_model_path, _Outptr_ OrtTrainingSession** out) { + _In_ const ORTCHAR_T* optimizer_model_path, _Outptr_result_maybenull_ OrtTrainingSession** out) { API_IMPL_BEGIN if (options != nullptr && options->value.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseEnvAllocators, "0") == "1") { return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Use Env Allocators is not supported for on device training."); } std::unique_ptr train_sess; - auto chkpt_state = reinterpret_cast(checkpoint_state); OrtStatus* status = nullptr; *out = nullptr; - ORT_TRY { - using ProvidersType = std::vector>; - train_sess = std::make_unique( - env->GetEnvironment(), - options == nullptr ? onnxruntime::SessionOptions() : options->value, - options == nullptr ? ProvidersType() : CreateProviders(options->provider_factories), - chkpt_state, - onnxruntime::training::api::ModelIdentifiers( - onnxruntime::ToUTF8String(train_model_path), - eval_model_path ? std::optional(onnxruntime::ToUTF8String(eval_model_path)) - : std::nullopt, - optimizer_model_path ? std::optional(onnxruntime::ToUTF8String(optimizer_model_path)) - : std::nullopt), - options == nullptr ? gsl::span() : options->custom_op_domains_); - - *out = reinterpret_cast(train_sess.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } + ORT_ENFORCE(train_model_path != nullptr, + "Train model path is required to create TrainingSession, it cannot be empty."); + + auto model_identifiers = onnxruntime::training::api::ModelIdentifiers( + onnxruntime::ToUTF8String(train_model_path), + eval_model_path ? std::optional(onnxruntime::ToUTF8String(eval_model_path)) + : std::nullopt, + optimizer_model_path ? std::optional(onnxruntime::ToUTF8String(optimizer_model_path)) + : std::nullopt); + + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(env, options, checkpoint_state, model_identifiers, train_sess)); + *out = reinterpret_cast(train_sess.release()); + + return status; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out) { + API_IMPL_BEGIN + std::unique_ptr train_sess; + OrtStatus* status = nullptr; + *out = nullptr; + ORT_ENFORCE(train_model_data != nullptr && train_data_length != 0, + "Training Session Creation failed. Train model data cannot be NULL."); + + auto model_identifiers = ModelIdentifiers(gsl::make_span(reinterpret_cast(train_model_data), + train_data_length), + eval_data_length == 0 || eval_model_data == nullptr + ? gsl::span() + : gsl::make_span(reinterpret_cast(eval_model_data), + eval_data_length), + optim_data_length == 0 || optim_model_data == nullptr + ? gsl::span() + : gsl::make_span(reinterpret_cast(optim_model_data), + optim_data_length)); + + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(env, options, checkpoint_state, model_identifiers, train_sess)); + *out = reinterpret_cast(train_sess.release()); return status; API_IMPL_END } @@ -523,6 +566,7 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::LoadCheckpoint, &OrtTrainingApis::SaveCheckpoint, &OrtTrainingApis::CreateTrainingSession, + &OrtTrainingApis::CreateTrainingSessionFromBuffer, &OrtTrainingApis::TrainingSessionGetTrainingModelOutputCount, &OrtTrainingApis::TrainingSessionGetEvalModelOutputCount, &OrtTrainingApis::TrainingSessionGetTrainingModelOutputName, diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index a6b82f1d50fc0..7f583ce8f6e76 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -61,19 +61,10 @@ Status GraphInputsAreExpected(gsl::span actual_graph_inputs, } // namespace std::unique_ptr OptimizerAlorithmFactory::CreateInstance( - const std::string& optim_path, int32_t& group_count) { + std::shared_ptr model, int32_t& group_count) { std::map, int32_t> opt_type_to_freq_map; #if !defined(ORT_MINIMAL_BUILD) - if (const auto optim_path_str = ToPathString(optim_path); - fbs::utils::IsOrtFormatModel(optim_path_str)) { - // TODO (baijumeswani): Figure out the best way to extract the optimizer type - // from an ort format model. - opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; - } else { - std::shared_ptr model; - ORT_ENFORCE(Model::Load(optim_path_str, model, nullptr, - logging::LoggingManager::DefaultLogger()) - .IsOK()); + if (model != nullptr) { Graph& graph = model->MainGraph(); for (auto& node : graph.Nodes()) { if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) { @@ -85,33 +76,71 @@ std::unique_ptr OptimizerAlorithmFactory::CreateInstance opt_type_to_freq_map[domain_type_pair] += 1; } } - } + } else { #else - // TODO (baijumeswani): Figure out the best way to extract the optimizer type - // from the model (either onnx model or ort format model) or from the checkpoint. - // For now, assume that the optimizer type is AdamWOptimizer in a minimal build. - ORT_UNUSED_PARAMETER(optim_path); - - opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; + ORT_UNUSED_PARAMETER(model); +#endif + // TODO(baijumeswani): Figure out the best way to extract the optimizer type + // from the model (either onnx model or ort format model) or from the checkpoint. + // For now, assume that the optimizer type is AdamWOptimizer when using ort format models. + opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; +#if !defined(ORT_MINIMAL_BUILD) + } #endif ORT_ENFORCE(opt_type_to_freq_map.size() == 1U, "Only support one type of optimizer algorithm, but got: " + std::to_string(opt_type_to_freq_map.size())); auto opt_it = opt_type_to_freq_map.begin(); + auto& op_type = opt_it->first.second; group_count = opt_it->second; - auto& domain = opt_it->first.first; - auto& type = opt_it->first.second; + ORT_ENFORCE(group_count == 1, "Group count can only be 1, but got: " + std::to_string(group_count)); // TODO: to support multiple groups, need to create a mapping between each group to its parameter list. - if (domain == kMSDomain && type == "AdamWOptimizer") { + if (op_type == "AdamWOptimizer") { return std::make_unique(); - } else if (domain == kMSDomain && type == "SGDOptimizerV2") { + } else if (op_type == "SGDOptimizerV2") { return std::make_unique(); } else { ORT_NOT_IMPLEMENTED("Not implemented for optimizer algo: " + opt_it->first.second); } } +std::unique_ptr OptimizerAlorithmFactory::CreateInstance( + const PathString& optim_path, int32_t& group_count) { + std::shared_ptr model = nullptr; +#if !defined(ORT_MINIMAL_BUILD) + if (!fbs::utils::IsOrtFormatModel(optim_path)) { + ORT_ENFORCE(Model::Load(optim_path, model, nullptr, + logging::LoggingManager::DefaultLogger()) + .IsOK()); + } +#else + ORT_UNUSED_PARAMETER(optim_path); +#endif + return CreateInstance(model, group_count); +} + +std::unique_ptr OptimizerAlorithmFactory::CreateInstance( + const uint8_t* optim_model_data, size_t optim_model_data_len, int32_t& group_count) { + std::shared_ptr model = nullptr; +#if !defined(ORT_MINIMAL_BUILD) + if (!fbs::utils::IsOrtFormatModelBytes(optim_model_data, static_cast(optim_model_data_len))) { + ONNX_NAMESPACE::ModelProto model_proto; + ORT_ENFORCE(model_proto.ParseFromArray(optim_model_data, static_cast(optim_model_data_len)) == true, + "Failed to load model because protobuf parsing failed."); + + ORT_ENFORCE(Model::Load(std::move(model_proto), model, nullptr, + logging::LoggingManager::DefaultLogger(), ModelOptions(true, true)) + .IsOK()); + } +#else + ORT_UNUSED_PARAMETER(optim_model_data); + ORT_UNUSED_PARAMETER(optim_model_data_len); +#endif + + return CreateInstance(model, group_count); +} + Status Optimizer::GenerateMomentumNamedStates(OptimizerCheckpointState& optimizer_checkpoint_states) { auto group_optimizer_state_it = optimizer_checkpoint_states.group_named_optimizer_states.find(GROUP_ZERO_NAME); @@ -200,14 +229,14 @@ Status Optimizer::ConstructInputs() { return Status::OK(); } // namespace api -Optimizer::Optimizer(const std::string& optim_path_or_bytes, +Optimizer::Optimizer(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, const std::vector>& providers, gsl::span op_domains) : optim_sess_(std::make_unique(session_options, env)), state_(state) { - Initialize(optim_path_or_bytes, providers, op_domains); + Initialize(model_identifiers, providers, op_domains); ORT_ENFORCE(state != nullptr, "Checkpoint state cannot be null."); auto g_it = state_->optimizer_checkpoint_state.group_named_optimizer_states.find(GROUP_ZERO_NAME); @@ -223,7 +252,7 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes, } } -void Optimizer::Initialize(const std::string& optim_path_or_bytes, +void Optimizer::Initialize(const ModelIdentifiers& model_identifiers, const std::vector>& providers, [[maybe_unused]] gsl::span op_domains) { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) @@ -236,7 +265,22 @@ void Optimizer::Initialize(const std::string& optim_path_or_bytes, ORT_THROW_IF_ERROR(optim_sess_->RegisterExecutionProvider(execution_provider)); } - ORT_THROW_IF_ERROR(optim_sess_->Load(optim_path_or_bytes)); + ORT_ENFORCE(model_identifiers.IsOptimizerModelAvailable(), "Optimizer model is not available."); + + if (std::holds_alternative>(model_identifiers.optim_model)) { + auto optimizer_model = std::get>(model_identifiers.optim_model); + // The above call to IsOptimizerModelAvailable() ensures that optimizer_model is not nullopt + ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.value())); + optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(ToWideString(optimizer_model.value()), group_count_); + } else { + auto optimizer_model = std::get>(model_identifiers.optim_model); + ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.data(), + static_cast(optimizer_model.size()))); + optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optimizer_model.data(), + optimizer_model.size(), + group_count_); + } + ORT_THROW_IF_ERROR(optim_sess_->Initialize()); // Make sure that the checkpoint state can copy tensors @@ -244,10 +288,6 @@ void Optimizer::Initialize(const std::string& optim_path_or_bytes, utils::GetGraphInputOutputNames(optim_sess_, input_names_, output_names_); - optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optim_path_or_bytes, group_count_); - ORT_ENFORCE(group_count_ == 1, "Group count can only be 1, but got: " + std::to_string(group_count_)); - ORT_ENFORCE(optimizer_algo_ptr_, "optimizer_algo_ptr_ should not be nullptr."); - InlinedVector all_input_names; all_input_names.reserve(CommonOptimizerInputs.size() + optimizer_algo_ptr_->optimizer_states_inputs.size()); all_input_names.insert(all_input_names.end(), CommonOptimizerInputs.begin(), diff --git a/orttraining/orttraining/training_api/optimizer.h b/orttraining/orttraining/training_api/optimizer.h index 36ce3297fe3c4..d9bc4870bb7ed 100644 --- a/orttraining/orttraining/training_api/optimizer.h +++ b/orttraining/orttraining/training_api/optimizer.h @@ -64,8 +64,11 @@ struct SGDOptimizerV2Algorithm : public OptimizerAlgorithmBase { }; struct OptimizerAlorithmFactory { - static std::unique_ptr CreateInstance(const std::string& optim_path_or_bytes, + static std::unique_ptr CreateInstance(const PathString& optim_path, int32_t& group_count); + static std::unique_ptr CreateInstance(const uint8_t* optim_model_data, + size_t optim_model_data_len, int32_t& group_count); + static std::unique_ptr CreateInstance(std::shared_ptr model, int32_t& group_count); }; struct CheckpointState; @@ -96,7 +99,7 @@ struct Optimizer { // Initialize an optimizer module from an ORT inference session with loaded // training ONNX model For each parameter, initialize the OptimizerState based // on the graph input's ValueInfoProto if the parameter doesn't have it already. - Optimizer(const std::string& optim_path_or_bytes, + Optimizer(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, @@ -121,7 +124,7 @@ struct Optimizer { } private: - void Initialize(const std::string& optim_path_or_bytes, + void Initialize(const ModelIdentifiers& model_identifiers, const std::vector>& providers, gsl::span op_domains); diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h index 2b383f3b9782a..c87108957c975 100644 --- a/orttraining/orttraining/training_api/ort_training_apis.h +++ b/orttraining/orttraining/training_api/ort_training_apis.h @@ -8,7 +8,14 @@ ORT_API(const OrtTrainingApi*, GetTrainingApi, uint32_t version); ORT_API_STATUS_IMPL(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, - _Outptr_ OrtTrainingSession** out); + _Outptr_result_maybenull_ OrtTrainingSession** out); + +ORT_API_STATUS_IMPL(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out); ORT_API_STATUS_IMPL(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc index 6915193a8ff7c..45f0f0ddcf7f4 100644 --- a/orttraining/orttraining/training_api/training_session.cc +++ b/orttraining/orttraining/training_api/training_session.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "orttraining/training_api/training_session.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime::training::api { @@ -12,13 +13,12 @@ TrainingSession::TrainingSession(const Environment& session_env, const ModelIdentifiers& model_identifiers, gsl::span custom_op_domains) : state_{state}, - module_{std::make_unique(model_identifiers.train_model, state_, - session_options, session_env, providers, - model_identifiers.eval_model, custom_op_domains)}, - optimizer_{model_identifiers.optim_model.has_value() + module_{std::make_unique(model_identifiers, state_, + session_options, session_env, providers, custom_op_domains)}, + optimizer_{model_identifiers.IsOptimizerModelAvailable() ? std::make_unique( - model_identifiers.optim_model.value(), state_, - session_options, session_env, providers, custom_op_domains) + model_identifiers, state_, + session_options, session_env, providers) : std::unique_ptr()} {} Status TrainingSession::RegisterScheduler( diff --git a/orttraining/orttraining/training_api/training_session.h b/orttraining/orttraining/training_api/training_session.h index 1a16acd5115f0..13b0ae79093de 100644 --- a/orttraining/orttraining/training_api/training_session.h +++ b/orttraining/orttraining/training_api/training_session.h @@ -3,25 +3,17 @@ #pragma once #include "core/common/common.h" -#include "module.h" -#include "optimizer.h" -#include "lr_scheduler.h" -#include "checkpoint.h" +#include "orttraining/training_api/module.h" +#include "orttraining/training_api/optimizer.h" +#include "orttraining/training_api/lr_scheduler.h" +#include "orttraining/training_api/checkpoint.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime { namespace training { namespace api { using namespace common; -struct ModelIdentifiers { - const std::string train_model; - const std::optional eval_model, optim_model; - ModelIdentifiers(const std::string& train_model_uri, - const std::optional& eval_model_uri, - const std::optional& optim_model_uri) - : train_model(train_model_uri), eval_model(eval_model_uri), optim_model(optim_model_uri) {} -}; - // Wrapper on top of module and optimizer classes and is the only class exposed via capis class TrainingSession { public: diff --git a/orttraining/orttraining/training_api/utils.h b/orttraining/orttraining/training_api/utils.h index e856554c971ec..f16f0f947fbd5 100644 --- a/orttraining/orttraining/training_api/utils.h +++ b/orttraining/orttraining/training_api/utils.h @@ -10,6 +10,40 @@ namespace onnxruntime { namespace training { namespace api { + +struct ModelIdentifiers { + // ModelIdentifiers struct enables an easy way to store and identify the models used for training, evaluation + // and model updates(optimizer model). + // The model can be specified by a path to the model file or by a span of bytes containing the model data. + // Training model is required, evaluation and optimizer models are optional. + std::variant> train_model; + std::variant, gsl::span> eval_model; + std::variant, gsl::span> optim_model; + + ModelIdentifiers(std::variant> training_model, + std::variant, gsl::span> evaluation_model, + std::variant, gsl::span> optimzer_model) + : train_model(training_model), eval_model(evaluation_model), optim_model(optimzer_model) {} + + bool IsModelAvailable(const std::variant, gsl::span>& model) const { + if ((std::holds_alternative>(model) && + std::get>(model).has_value()) || + (std::holds_alternative>(model) && + std::get>(model).size() > 0)) { + return true; + } + return false; + } + + bool IsEvalModelAvailable() const { + return IsModelAvailable(eval_model); + } + + bool IsOptimizerModelAvailable() const { + return IsModelAvailable(optim_model); + } +}; + namespace utils { // Get names of graph inputs and outputs diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml index 8806707d21317..ac551a53cddaa 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml @@ -1,14 +1,5 @@ trigger: none -variables: - - name: isMain - value: ${{ eq(variables['Build.SourceBranch'], 'refs/heads/main') }} - - name: finalStorage - ${{ if eq(variables['isMain'], 'true') }}: - value: '--final_storage' - ${{ else }}: - value: '' - resources: repositories: - repository: manylinux @@ -39,14 +30,6 @@ stages: PythonVersion: '3.11' steps: - - task: CmdLine@2 - displayName: 'check variables' - inputs: - script: | - echo "Branch is "${{ variables['Build.SourceBranch'] }} && \ - echo "isMain is "${{ variables['isMain'] }} && \ - echo "final_storage is "${{ variables['finalStorage'] }} - - checkout: self clean: true submodules: recursive @@ -102,17 +85,6 @@ stages: inputs: ArtifactName: onnxruntime_training_cpu - - task: CmdLine@2 - condition: succeeded() - displayName: 'Upload wheel' - inputs: - script: | - files=($(Build.ArtifactStagingDirectory)/Release/dist/*.whl) && \ - echo ${files[0]} && \ - echo ${{ variables['finalStorage'] }} && \ - tools/ci_build/upload_python_package_to_azure_storage.py \ - --python_wheel_path ${files[0]} ${{ variables['finalStorage'] }} - - template: templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 74007d9b55084..21cd3a44e8924 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -495,7 +495,7 @@ stages: PackageType: 'nuget' PackagePath: '$(Build.ArtifactStagingDirectory)' PackageName: 'Microsoft.ML.OnnxRuntime.*nupkg' - PlatformsSupported: 'win-x64,win-x86,linux-x64,linux-arm64,osx.10.14-x64' + PlatformsSupported: 'win-x64,win-x86,linux-x64,linux-arm64,osx-x64' VerifyNugetSigning: false - task: PublishPipelineArtifact@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml b/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml index 76fbf55331b07..79feae8cf517c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml +++ b/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml @@ -57,7 +57,7 @@ stages: REM use a single .csv file to put the data echo os,arch,build_config,size > $(Build.BinariesDirectory)\binary_size_data.txt 7z.exe l -slt %%~ni.zip runtimes\linux-arm64\native\libonnxruntime.so | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo linux,aarch64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt - 7z.exe l -slt %%~ni.zip runtimes\osx.10.14-x64\native\libonnxruntime.dylib | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo osx,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt + 7z.exe l -slt %%~ni.zip runtimes\osx-x64\native\libonnxruntime.dylib | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo osx,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt 7z.exe l -slt %%~ni.zip runtimes\win-x64\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt 7z.exe l -slt %%~ni.zip runtimes\win-x86\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x86,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt ) diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 9dc36633a553e..3aba1d0577f9c 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -67,7 +67,7 @@ def generate_file_list_for_ep(nuget_artifacts_dir, ep, files_list, include_pdbs, is_versioned_dylib = re.match(r".*[\.\d+]+\.dylib$", child_file.name) if child_file.is_file() and child_file.suffix == ".dylib" and not is_versioned_dylib: files_list.append( - '' % cpu_arch + '' % cpu_arch ) for cpu_arch in ["x64", "aarch64"]: if child.name == get_package_name("linux", cpu_arch, ep, is_training_package):