diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
index b374371446a90..86b44a6784817 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
@@ -743,7 +743,7 @@ internal static OrtValue CreateFromTensorObject(TensorBase value, out TensorElem
///
/// Creates an OrtValue that contains a string tensor of specified shape, and
/// containing empty strings. String tensors are always on CPU.
- /// Use FillStringTensorElement to assign individual elements values.
+ /// Use StringTensorSetElementAt to assign individual elements values.
///
///
/// disposable OrtValue
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
index c52ca4d1a4631..ac790242409e3 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
@@ -15,6 +15,7 @@ public struct OrtTrainingApi
public IntPtr LoadCheckpoint;
public IntPtr SaveCheckpoint;
public IntPtr CreateTrainingSession;
+ public IntPtr CreateTrainingSessionFromBuffer;
public IntPtr TrainingSessionGetTrainingModelOutputCount;
public IntPtr TrainingSessionGetEvalModelOutputCount;
public IntPtr TrainingSessionGetTrainingModelOutputName;
diff --git a/include/onnxruntime/core/framework/ort_value.h b/include/onnxruntime/core/framework/ort_value.h
index 48c4e4320dfd7..a071f3182faad 100644
--- a/include/onnxruntime/core/framework/ort_value.h
+++ b/include/onnxruntime/core/framework/ort_value.h
@@ -68,11 +68,7 @@ struct OrtValue {
}
bool IsSparseTensor() const {
-#if !defined(DISABLE_SPARSE_TENSORS)
return (type_ != nullptr && type_->IsSparseTensorType());
-#else
- ORT_THROW("Sparse tensor is not supported in this build.");
-#endif
}
onnxruntime::MLDataType Type() const {
diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md
index 4a1109b9ec5dc..e33854819c5db 100644
--- a/js/web/docs/webgpu-operators.md
+++ b/js/web/docs/webgpu-operators.md
@@ -38,7 +38,7 @@ Do not modify directly.*
| Floor | ai.onnx(6-12,13+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
| Gelu | com.microsoft(1+) | |
-| Gemm | ai.onnx(7-8,9-10,11+) | |
+| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts
index 54f493422816f..f5b8a7e3b0ef9 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts
@@ -23,10 +23,12 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
const createSplitAttributesFromInputs =
(inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => {
const splitSizes: number[] = [];
+ let numOutputs: number = attributes.numOutputs;
if (inputs[1].dims[0] > 0) {
inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v)));
+ numOutputs = splitSizes.length;
}
- return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes});
+ return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes});
};
const calculateOutputIndexImpl = (numberOfTensors: number): string => `
@@ -114,7 +116,7 @@ const createSplitProgramInfoLoader =
const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes);
const metadata:
ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey};
- return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)};
+ return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)};
};
export const split = (context: ComputeContext, attributes: SplitAttributes): void => {
diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc
index 300db24a986f4..0bf27fdf5e5dc 100644
--- a/onnxruntime/core/framework/allocation_planner.cc
+++ b/onnxruntime/core/framework/allocation_planner.cc
@@ -1715,31 +1715,39 @@ class PlannerImpl {
void PartitionIntoStreams(const logging::Logger& /*logger*/,
const ExecutionProviders& /*execution_providers*/,
const PathString& /*partition_config_file*/) {
- stream_nodes_.push_back({});
- node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1);
- for (auto node_index : graph_viewer_.GetNodesInTopologicalOrder()) {
- stream_nodes_[0].push_back(node_index);
- node_stream_map_[node_index] = 0;
+ if (graph_viewer_.NumberOfNodes() > 0) {
+ stream_nodes_.push_back({});
+ node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1);
+ for (auto node_index : graph_viewer_.GetNodesInTopologicalOrder()) {
+ stream_nodes_[0].push_back(node_index);
+ node_stream_map_[node_index] = 0;
+ }
+ num_logic_streams_ = 1;
}
- num_logic_streams_ = 1;
}
Status BuildExecutionPlan(const ExecutionProviders& execution_providers) {
// 1. create logic stream instance
auto& execution_plan = plan_.execution_plan;
- ORT_ENFORCE(num_logic_streams_ == 1 && !stream_nodes_[0].empty());
- execution_plan.reserve(1);
- auto first_node_index = stream_nodes_[0][0];
- auto* node = graph_viewer_.GetNode(first_node_index);
- onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType();
- const IExecutionProvider* ep = execution_providers.Get(exec_provider_name);
- ORT_ENFORCE(ep);
- auto node_device_mem_location = ep->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault);
- execution_plan.emplace_back(std::make_unique(node_device_mem_location));
- // 2. add steps to the execution plan
- for (auto node_index : stream_nodes_[0]) {
- execution_plan[0]->steps_.emplace_back(std::make_unique(node_index));
+
+ if (graph_viewer_.NumberOfNodes() > 0) {
+ ORT_ENFORCE(num_logic_streams_ == 1 && !stream_nodes_[0].empty());
+ execution_plan.reserve(1);
+ auto first_node_index = stream_nodes_[0][0];
+ auto* node = graph_viewer_.GetNode(first_node_index);
+ onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType();
+ const IExecutionProvider* ep = execution_providers.Get(exec_provider_name);
+ ORT_ENFORCE(ep);
+ auto node_device_mem_location = ep->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault);
+ execution_plan.emplace_back(std::make_unique(node_device_mem_location));
+ // 2. add steps to the execution plan
+ for (auto node_index : stream_nodes_[0]) {
+ execution_plan[0]->steps_.emplace_back(std::make_unique(node_index));
+ }
+ } else {
+ // graph with no nodes. e.g. subgraph of If might return the input as-is or a constant value from an initializer
}
+
return Status::OK();
}
diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc
index fc2f14263f7a7..df3a7afebc176 100644
--- a/onnxruntime/core/framework/session_state_utils.cc
+++ b/onnxruntime/core/framework/session_state_utils.cc
@@ -254,10 +254,11 @@ common::Status SaveInitializedTensors(
auto initialized_tensors_to_allocate = id_to_initialized_tensor;
for (int ort_value_index : initializer_allocation_order) {
const auto entry = initialized_tensors_to_allocate.find(ort_value_index);
+ ORT_ENFORCE(entry != initialized_tensors_to_allocate.end(),
+ "OrtValue index: ", ort_value_index, " from initializer_allocation_order not found among initialized tensors");
if (!(utils::HasExternalData(*entry->second) && exec_plan.GetLocation(ort_value_index).Type() == OrtDevice::CPU)) {
// can not trace string tensor
- ORT_ENFORCE(entry != initialized_tensors_to_allocate.end() &&
- entry->second->data_type() != ONNX_NAMESPACE::TensorProto_DataType_STRING);
+ ORT_ENFORCE(entry->second->data_type() != ONNX_NAMESPACE::TensorProto_DataType_STRING, "Can not trace string tensor");
ORT_RETURN_IF_ERROR(planner.Trace(entry->first, entry->second));
}
initialized_tensors_to_allocate.erase(entry);
diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h
index 56f41154b719c..ea6a629f87cb8 100644
--- a/onnxruntime/core/framework/utils.h
+++ b/onnxruntime/core/framework/utils.h
@@ -223,7 +223,7 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType
int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType);
-#ifdef ENABLE_TRAINING_CORE
+#ifdef ENABLE_TRAINING
common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context);
#endif
diff --git a/onnxruntime/core/graph/function_utils.cc b/onnxruntime/core/graph/function_utils.cc
index 4b7900194488d..aa0727e3750b0 100644
--- a/onnxruntime/core/graph/function_utils.cc
+++ b/onnxruntime/core/graph/function_utils.cc
@@ -344,10 +344,15 @@ std::unique_ptr CreateSchema(const std::string& functi
std::unordered_map map_copy(model_local_functions.begin(),
model_local_functions.end());
std::unordered_map empty_map;
- ONNX_NAMESPACE::shape_inference::SymbolTableImpl symbolTable;
+
+ // https://github.com/microsoft/onnxruntime/issues/17061
+ // We are passing a nullptr for the symbol table, because symbol table must be global
+ // for all the shape inferencing to work correctly. Otherwise, unrelated shapes get
+ // the same symbolic shapes and are marked for memory re-use. This is a Temp fix.
+ constexpr ONNX_NAMESPACE::shape_inference::SymbolTableImpl* symbolTable = nullptr;
ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(*onnx_func_proto, func_domain_to_version,
schema_registry, ctx, options, map_copy,
- &symbolTable, &empty_map);
+ symbolTable, &empty_map);
});
op_schema->Finalize();
diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc
index a3ac4312053aa..dd38ee9b07ee6 100644
--- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc
+++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc
@@ -462,6 +462,27 @@ bool LayerNormalizationGatherActor::PreCheck(const Graph& /* graph */,
return true;
}
+bool LayerNormalizationGatherActor::PostProcess(Graph& /*graph*/, Node& current_node,
+ const SliceInfo& info_without_node,
+ const logging::Logger& /*logger*/,
+ const std::unordered_map& /*propagate_input_indices*/,
+ const std::unordered_map>&
+ /*all_input_cmp_rets*/,
+ const std::unordered_map& /*new_gather_infos*/) {
+ // Update LayerNormalization's axis attribute if it is scalar slice.
+ if (info_without_node.is_scalar_slice) {
+ auto axis = static_cast(current_node.GetAttributes().at("axis").i());
+ auto original_ln_input_rank = info_without_node.input_rank;
+ axis = axis < 0 ? axis + original_ln_input_rank : axis;
+ auto new_axis = axis - 1;
+
+ auto& attributes = current_node.GetMutableAttributes();
+ attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis));
+ }
+
+ return true;
+}
+
bool SoftmaxGatherActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info,
const logging::Logger& logger,
std::unordered_map& propagate_input_indices,
@@ -479,6 +500,28 @@ bool SoftmaxGatherActor::PreCheck(const Graph& graph, const Node& current_node,
propagate_input_indices, all_input_cmp_rets, shape_update_func);
}
+bool SoftmaxGatherActor::PostProcess(Graph& graph, Node& current_node, const SliceInfo& info_without_node,
+ const logging::Logger& logger,
+ const std::unordered_map& propagate_input_indices,
+ const std::unordered_map>& all_input_cmp_rets,
+ const std::unordered_map& new_gather_infos) {
+ SimplePointwiseGatherActor::PostProcess(graph, current_node, info_without_node, logger,
+ propagate_input_indices, all_input_cmp_rets, new_gather_infos);
+
+ // Update Softmax's axis attribute if it is scalar slice.
+ if (info_without_node.is_scalar_slice) {
+ auto axis = static_cast(current_node.GetAttributes().at("axis").i());
+ auto original_ln_input_rank = info_without_node.input_rank;
+ axis = axis < 0 ? axis + original_ln_input_rank : axis;
+ auto new_axis = axis - 1;
+
+ auto& attributes = current_node.GetMutableAttributes();
+ attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis));
+ }
+
+ return true;
+}
+
bool ReshapeGatherActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info,
const logging::Logger& logger,
std::unordered_map& propagate_input_indices,
@@ -566,6 +609,11 @@ bool ReshapeGatherActor::PreCheck(const Graph& graph, const Node& current_node,
return true;
}
+ LOG_DEBUG_INFO(logger, "Skip handle the Reshape, new_shape_const_values[info.non_negative_axis]:" +
+ std::to_string(new_shape_const_values[info.non_negative_axis]) +
+ ", info.output_dim_on_axis.has_dim_value(): " +
+ std::to_string(info.output_dim_on_axis.has_dim_value()) + ".");
+
return false;
}
@@ -604,11 +652,12 @@ bool ReshapeGatherActor::PostProcess(
return true;
}
- // If it selected shape is a dim value, we can update the shape tensor directory.
+ // If the selected shape is a dim value, we can update the shape tensor directory.
if (info_without_node.output_dim_on_axis.has_dim_value()) {
new_shape_const_values[slice_axis] = info_without_node.output_dim_on_axis.dim_value();
auto new_shape_arg =
- CreateInitializerFromVector(graph, {static_cast(new_shape_const_values.size())}, new_shape_const_values,
+ CreateInitializerFromVector(graph, {static_cast(new_shape_const_values.size())},
+ new_shape_const_values,
graph.GenerateNodeArgName(current_node.MutableInputDefs()[1]->Name()));
graph_utils::ReplaceNodeInput(current_node, 1, *new_shape_arg);
return true;
diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h
index f6715e4bb1f32..0c21be1397636 100644
--- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h
+++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h
@@ -189,7 +189,7 @@ class LayerNormalizationGatherActor : public UpStreamGatherOperatorActorBase {
const logging::Logger& /* logger */,
const std::unordered_map& /* propagate_input_indices */,
const std::unordered_map>& /* all_input_cmp_rets */,
- const std::unordered_map& /* new_gather_infos */) override { return true; }
+ const std::unordered_map& /* new_gather_infos */) override;
};
class SoftmaxGatherActor : public SimplePointwiseGatherActor {
@@ -202,6 +202,12 @@ class SoftmaxGatherActor : public SimplePointwiseGatherActor {
std::unordered_map& propagate_input_indices,
std::unordered_map>& all_input_cmp_rets,
std::function& shape_update_func) override;
+
+ bool PostProcess(Graph& /* graph */, Node& /* current_node */, const SliceInfo& /* info_without_node */,
+ const logging::Logger& /* logger */,
+ const std::unordered_map& /* propagate_input_indices */,
+ const std::unordered_map>& /* all_input_cmp_rets */,
+ const std::unordered_map& /* new_gather_infos */) override;
};
class ReshapeGatherActor : public UpStreamGatherOperatorActorBase {
diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc
index d9c6126d4bf36..3fe1980141ca5 100644
--- a/onnxruntime/core/providers/js/js_execution_provider.cc
+++ b/onnxruntime/core/providers/js/js_execution_provider.cc
@@ -231,7 +231,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Gemm);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul);
@@ -464,7 +465,8 @@ std::unique_ptr RegisterKernels() {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/js/operators/gemm.cc b/onnxruntime/core/providers/js/operators/gemm.cc
index f579d62bdfb5f..04700d0f54705 100644
--- a/onnxruntime/core/providers/js/operators/gemm.cc
+++ b/onnxruntime/core/providers/js/operators/gemm.cc
@@ -12,7 +12,15 @@ namespace js {
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
- 11, \
+ 13, \
+ T, \
+ kJsExecutionProvider, \
+ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ Gemm); \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ Gemm, \
+ kOnnxDomain, \
+ 11, 12, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h
index 691af48711a56..cfacc1aa6a363 100644
--- a/onnxruntime/core/providers/js/operators/split.h
+++ b/onnxruntime/core/providers/js/operators/split.h
@@ -25,8 +25,9 @@ class Split : public JsKernel, public SplitBase {
if (num_outputs_ < 0) {
num_outputs_ = split_sizes.size();
}
- } else if (split_sizes_.size() == 0) {
- // Compute split_sizes from input shape and num_outputs
+ } else if (split_sizes_.size() == 0 && info.GetInputCount() < 2) {
+ // Compute split_sizes from input shape and num_outputs.
+ // TODO: Shape might not be known at this point, better to handle this in javascript
auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast(axis_)).dim_value();
int64_t split_size_sum = 0;
if (num_outputs_ < 0) {
@@ -44,6 +45,7 @@ class Split : public JsKernel, public SplitBase {
ORT_ENFORCE(split_size_sum == total_split_size,
"Sum of split sizes (", split_size_sum, ") does not match input size (", total_split_size, ")");
}
+ // else: let javascript handle all other cases, ie. split_sizes come as input[1]
JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1,
"numOutputs" : $2,
diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc
index 4df11c2224e27..b5f45b15a5992 100644
--- a/onnxruntime/core/providers/vitisai/imp/graph.cc
+++ b/onnxruntime/core/providers/vitisai/imp/graph.cc
@@ -138,6 +138,7 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri
model_proto = model.ToProto();
} else {
model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat,
+ ToPathString(filename),
initializer_size_threshold);
}
auto& metadata = model.MetaData();
diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc
index f8da4e895913a..e2cb82e47f32b 100644
--- a/onnxruntime/test/framework/ort_model_only_test.cc
+++ b/onnxruntime/test/framework/ort_model_only_test.cc
@@ -4,6 +4,7 @@
#include "core/flatbuffers/schema/ort.fbs.h"
#include "core/framework/data_types.h"
#include "core/framework/tensorprotoutils.h"
+#include "core/framework/TensorSeq.h"
#include "core/graph/model.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/onnxruntime_cxx_api.h"
@@ -556,6 +557,41 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBufferNoCopyInitializersUseBuffer)
RunOrtModel(test_info);
}
+// regression test for 2 issues covered by PR #17000 (internally reported issue).
+// 1) allocation planner broke in minimal build when subgraph had no nodes.
+// 2) usage of a sequence data type caused an exception due to IsSparseTensor() throwing
+// instead of allowing the calling code to have #ifdef'd code to handle when IsSparseTensor
+// returned true and sparse tensors were disabled.
+TEST(OrtModelOnlyTests, GithubIssue17000) {
+ // need to run the model to
+ auto model_uri = ORT_TSTR("testdata/ort_github_issue_17000.ort");
+
+ auto allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0];
+
+ OrtValue item0, item1;
+ CreateMLValue(allocator, {1}, {1.f}, &item0);
+ CreateMLValue(allocator, {2}, {2.f, 3.f}, &item1);
+
+ auto elem_type = DataTypeImpl::GetType();
+ auto tensor_seq = std::make_unique(elem_type);
+ tensor_seq->SetElements({item0, item1});
+
+ auto mltype = DataTypeImpl::GetType();
+ OrtValue value(tensor_seq.release(), mltype, mltype->GetDeleteFunc());
+
+ OrtModelTestInfo test_info;
+ test_info.model_filename = model_uri;
+ test_info.inputs.insert(std::make_pair("seq_in", value));
+ test_info.output_names = {"still_has_elements"};
+ test_info.output_verifier = [](const std::vector& fetches) {
+ const auto& output = fetches[0].Get();
+ ASSERT_EQ(output.Shape().Size(), 1);
+ ASSERT_EQ(output.Data()[0], true); // removed one item from seq so should still have elements
+ };
+
+ RunOrtModel(test_info);
+}
+
#if !defined(DISABLE_ML_OPS)
// test that we can deserialize and run a previously saved ORT format model
// for a model with sequence and map outputs
diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc
index 01016774288e4..a03d0da2538d4 100644
--- a/onnxruntime/test/optimizer/compute_optimizer_test.cc
+++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc
@@ -638,7 +638,8 @@ TEST(ComputeOptimizerTests, GatherMatMul_ScalarSlicingOnSecondLastDim) {
std::map op_to_count = CountOpsInGraph(graph);
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
- ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1));
+ ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(),
+ TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger));
GraphViewer graph_viewer(graph);
@@ -737,7 +738,8 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnSecondLastDim) {
std::map op_to_count = CountOpsInGraph(graph);
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
- ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1));
+ ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(),
+ TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger));
GraphViewer graph_viewer(graph);
@@ -826,6 +828,345 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnSecondLastDim) {
}
}
+/*
+Test graph includes multiple equivalent subgraphs as below.
+ graph input [2, 32, 256] (float)
+ |
+ LayerNormalization[axis=-1 (as example)]
+ |
+ [2, 32, 256]
+ |
+ | 0 (scalar)
+ | /
+ Gather[axis=1]
+ |
+ Identity
+ |
+ graph output [2, 256] (float)
+
+Add an Identity node because currently, we don't allow Gather generates graph output.
+*/
+TEST(ComputeOptimizerTests, GatherLayerNormalization) {
+ std::vector> test_config_pairs{
+ // {
+ // is_scalar_slice,
+ // ln_axis_before_propagation,
+ // expected_ln_axis_after_propagation,
+ // expected to propagate
+ // }
+ {true, 0, 0, false},
+ {true, 1, 1, false},
+ {true, 2, 1, true},
+ {true, -3, -3, false},
+ {true, -2, -2, false},
+ {true, -1, 1, true},
+ {false, 0, 0, false},
+ {false, 1, 1, false},
+ {false, 2, 2, true},
+ {false, -3, -3, false},
+ {false, -2, -2, false},
+ {false, -1, -1, true},
+ };
+
+ constexpr static int64_t gather_axis = 1;
+ constexpr static int64_t slice_data_value = 0;
+
+ for (auto p : test_config_pairs) {
+ bool is_scalar_slice = std::get<0>(p);
+ int64_t ln_axis_before = std::get<1>(p);
+ int64_t ln_axis_after = std::get<2>(p);
+ bool expected_to_propagate = std::get<3>(p);
+
+ const logging::Logger* logger = &logging::LoggingManager::DefaultLogger();
+
+ InlinedVector indices;
+ auto pre_graph_checker = [&indices](Graph& graph) -> Status {
+ auto op_count_pre = CountOpsInGraph(graph);
+ TEST_RETURN_IF_NOT(op_count_pre.size() == 3U);
+ TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1);
+ TEST_RETURN_IF_NOT(op_count_pre["Gather"] == 1);
+ TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1);
+
+ for (Node& node : graph.Nodes()) {
+ if (node.OpType() == "Gather") {
+ TEST_RETURN_IF_NOT(indices.empty());
+ constexpr bool require_constant = true;
+ NodeArg* initializer_node_arg = graph.GetNodeArg(node.InputDefs()[1]->Name());
+ TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg,
+ indices, require_constant));
+ }
+ }
+ return Status::OK();
+ };
+
+ auto post_graph_checker = [is_scalar_slice, ln_axis_after,
+ &indices, expected_to_propagate](Graph& graph) {
+ auto op_count_post = CountOpsInGraph(graph);
+
+ TEST_RETURN_IF_NOT(op_count_post.size() == 3U);
+ TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1);
+ TEST_RETURN_IF_NOT(op_count_post["Gather"] == 1);
+ TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1);
+
+ for (Node& node : graph.Nodes()) {
+ if (node.OpType() == "LayerNormalization") {
+ const auto& input_defs = node.InputDefs();
+
+ auto producer_node = graph.GetProducerNode(input_defs[0]->Name());
+ if (expected_to_propagate) {
+ TEST_RETURN_IF_NOT(producer_node != nullptr);
+ TEST_RETURN_IF_NOT(producer_node->OpType() == "Gather");
+
+ InlinedVector values;
+ constexpr bool require_constant = true;
+ NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name());
+ TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg,
+ values, require_constant));
+ for (size_t i = 0; i < values.size(); i++) {
+ TEST_RETURN_IF_NOT(values[i] == indices[i]);
+ }
+
+ const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape();
+ TEST_RETURN_IF_NOT(slice_out_shape != nullptr);
+
+ auto& attrs = node.GetAttributes();
+ TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
+
+ auto& axis_attr = attrs.at("axis");
+ auto axis_value = (int)axis_attr.i();
+ TEST_RETURN_IF_NOT(axis_value == ln_axis_after);
+
+ if (is_scalar_slice) {
+ TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 2);
+ TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) &&
+ slice_out_shape->dim(0).dim_value() == 2);
+ TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) &&
+ slice_out_shape->dim(1).dim_value() == 256);
+ } else {
+ TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 3);
+ TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) &&
+ slice_out_shape->dim(0).dim_value() == 2);
+ TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) &&
+ slice_out_shape->dim(1).dim_value() == 1);
+ TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) &&
+ slice_out_shape->dim(2).dim_value() == 256);
+ }
+
+ } else {
+ TEST_RETURN_IF_NOT(producer_node == nullptr);
+ }
+ }
+ }
+
+ return Status::OK();
+ };
+
+ auto build_test_case = [is_scalar_slice, ln_axis_before](ModelTestBuilder& builder) {
+ auto* input1_arg = builder.MakeInput({{2, 32, 256}});
+ auto* input2_arg = builder.MakeInput({{256}});
+ auto* input3_arg = builder.MakeInput({{256}});
+ auto* ln_out = builder.MakeIntermediate();
+ builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out})
+ .AddAttribute("axis", ln_axis_before);
+
+ std::vector slice_inputs;
+ NodeArg* indices_initializer = nullptr;
+
+ if (is_scalar_slice) {
+ indices_initializer = builder.MakeScalarInitializer(slice_data_value);
+ } else {
+ indices_initializer = builder.MakeInitializer({1}, {slice_data_value});
+ }
+
+ slice_inputs = {ln_out, indices_initializer};
+
+ auto* gather_out = builder.MakeIntermediate();
+ builder.AddNode("Gather", slice_inputs,
+ {gather_out})
+ .AddAttribute("axis", gather_axis);
+
+ auto* identity_out = builder.MakeOutput();
+ builder.AddNode("Identity", {gather_out}, {identity_out});
+ };
+
+ std::unique_ptr transformer = std::make_unique();
+ ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger, std::move(transformer),
+ TransformerLevel::Level1,
+ 1, pre_graph_checker, post_graph_checker));
+ }
+}
+
+/*
+Test graph includes multiple equivalent subgraphs as below.
+ graph input [2, 4, 32, 256] (float)
+ |
+ Softmax[axis=3 (as example)]
+ |
+ [2, 4, 32, 256]
+ |
+ | 0 (scalar)
+ | /
+ Gather[axis=1]
+ |
+ Identity
+ |
+ graph output [2, 32, 256] (float)
+
+Add an Identity node because currently, we don't allow Gather generates graph output.
+*/
+TEST(ComputeOptimizerTests, GatherSoftmax) {
+ std::vector> test_config_pairs{
+ // {is_scalar_slice, softmax_axis_before_propagation,
+ // expected_softmax_axis_after_propagation, expected to propagate}
+ {true, 0, 0, false},
+ {true, 1, 1, false},
+ {true, 2, 1, true},
+ {true, 3, 2, true},
+ {true, -4, -4, false},
+ {true, -3, -3, false},
+ {true, -2, 1, true},
+ {true, -1, 2, true},
+ {false, 0, 0, false},
+ {false, 1, 1, false},
+ {false, 2, 2, true},
+ {false, 3, 3, true},
+ {false, -4, -4, false},
+ {false, -3, -3, false},
+ {false, -2, -2, true},
+ {false, -1, -1, true},
+ };
+
+ constexpr static int64_t gather_axis = 1;
+ constexpr static int64_t slice_data_value = 0;
+
+ for (auto p : test_config_pairs) {
+ bool is_scalar_slice = std::get<0>(p);
+ int64_t softmax_axis_before = std::get<1>(p);
+ int64_t softmax_axis_after = std::get<2>(p);
+ bool expected_to_propagate = std::get<3>(p);
+
+ const logging::Logger* logger = &logging::LoggingManager::DefaultLogger();
+
+ InlinedVector indices;
+ auto pre_graph_checker = [&indices](Graph& graph) -> Status {
+ auto op_count_pre = CountOpsInGraph(graph);
+ TEST_RETURN_IF_NOT(op_count_pre.size() == 3U);
+ TEST_RETURN_IF_NOT(op_count_pre["Softmax"] == 1);
+ TEST_RETURN_IF_NOT(op_count_pre["Gather"] == 1);
+ TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1);
+
+ for (Node& node : graph.Nodes()) {
+ if (node.OpType() == "Gather") {
+ TEST_RETURN_IF_NOT(indices.empty());
+ constexpr bool require_constant = true;
+ NodeArg* initializer_node_arg = graph.GetNodeArg(node.InputDefs()[1]->Name());
+ TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg,
+ indices, require_constant));
+ }
+ }
+ return Status::OK();
+ };
+
+ auto post_graph_checker = [is_scalar_slice, softmax_axis_after,
+ &indices, expected_to_propagate](Graph& graph) {
+ auto op_count_post = CountOpsInGraph(graph);
+
+ TEST_RETURN_IF_NOT(op_count_post.size() == 3U);
+ TEST_RETURN_IF_NOT(op_count_post["Softmax"] == 1);
+ TEST_RETURN_IF_NOT(op_count_post["Gather"] == 1);
+ TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1);
+
+ for (Node& node : graph.Nodes()) {
+ if (node.OpType() == "Softmax") {
+ const auto& input_defs = node.InputDefs();
+
+ auto producer_node = graph.GetProducerNode(input_defs[0]->Name());
+ if (expected_to_propagate) {
+ TEST_RETURN_IF_NOT(producer_node != nullptr);
+ TEST_RETURN_IF_NOT(producer_node->OpType() == "Gather");
+
+ InlinedVector values;
+ constexpr bool require_constant = true;
+ NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name());
+ TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values,
+ require_constant));
+ for (size_t i = 0; i < values.size(); i++) {
+ TEST_RETURN_IF_NOT(values[i] == indices[i]);
+ }
+
+ const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape();
+ TEST_RETURN_IF_NOT(slice_out_shape != nullptr);
+
+ auto& attrs = node.GetAttributes();
+ TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
+
+ auto& axis_attr = attrs.at("axis");
+ auto axis_value = (int)axis_attr.i();
+ TEST_RETURN_IF_NOT(axis_value == softmax_axis_after);
+
+ if (is_scalar_slice) {
+ TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 3);
+ TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) &&
+ slice_out_shape->dim(0).dim_value() == 2);
+ TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) &&
+ slice_out_shape->dim(1).dim_value() == 32);
+ TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) &&
+ slice_out_shape->dim(2).dim_value() == 256);
+ } else {
+ TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 4);
+ TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) &&
+ slice_out_shape->dim(0).dim_value() == 2);
+ TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) &&
+ slice_out_shape->dim(1).dim_value() == 1);
+ TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) &&
+ slice_out_shape->dim(2).dim_value() == 32);
+ TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(3)) &&
+ slice_out_shape->dim(3).dim_value() == 256);
+ }
+
+ } else {
+ TEST_RETURN_IF_NOT(producer_node == nullptr);
+ }
+ }
+ }
+
+ return Status::OK();
+ };
+
+ auto build_test_case = [is_scalar_slice, softmax_axis_before](ModelTestBuilder& builder) {
+ auto* input1_arg = builder.MakeInput({{2, 4, 32, 256}});
+ auto* softmax_out = builder.MakeIntermediate();
+ builder.AddNode("Softmax", {input1_arg}, {softmax_out})
+ .AddAttribute("axis", softmax_axis_before);
+
+ std::vector slice_inputs;
+
+ NodeArg* indices_initializer = nullptr;
+
+ if (is_scalar_slice) {
+ indices_initializer = builder.MakeScalarInitializer(slice_data_value);
+ } else {
+ indices_initializer = builder.MakeInitializer({1}, {slice_data_value});
+ }
+
+ slice_inputs = {softmax_out, indices_initializer};
+
+ auto* gather_out = builder.MakeIntermediate();
+ builder.AddNode("Gather", slice_inputs,
+ {gather_out})
+ .AddAttribute("axis", gather_axis);
+
+ auto* identity_out = builder.MakeOutput();
+ builder.AddNode("Identity", {gather_out}, {identity_out});
+ };
+
+ std::unique_ptr transformer = std::make_unique();
+ ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger, std::move(transformer),
+ TransformerLevel::Level1,
+ 1, pre_graph_checker, post_graph_checker));
+ }
+}
+
TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnBatchDim) {
const logging::Logger* logger = &logging::LoggingManager::DefaultLogger();
auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_reshape_scalar_batch_dim.onnx";
@@ -835,7 +1176,8 @@ TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnBatchDim) {
std::map op_to_count = CountOpsInGraph(graph);
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
- ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1));
+ ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(),
+ TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger));
GraphViewer graph_viewer(graph);
@@ -928,7 +1270,8 @@ TEST(ComputeOptimizerTests, GatherReshape_SlicingOnBatchDim) {
std::map op_to_count = CountOpsInGraph(graph);
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
- ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1));
+ ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(),
+ TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger));
GraphViewer graph_viewer(graph);
diff --git a/onnxruntime/test/testdata/ort_github_issue_17000.onnx b/onnxruntime/test/testdata/ort_github_issue_17000.onnx
new file mode 100644
index 0000000000000..8320c19cb6de4
Binary files /dev/null and b/onnxruntime/test/testdata/ort_github_issue_17000.onnx differ
diff --git a/onnxruntime/test/testdata/ort_github_issue_17000.ort b/onnxruntime/test/testdata/ort_github_issue_17000.ort
new file mode 100644
index 0000000000000..08d9826dd5346
Binary files /dev/null and b/onnxruntime/test/testdata/ort_github_issue_17000.ort differ
diff --git a/onnxruntime/test/testdata/ort_github_issue_17000.py b/onnxruntime/test/testdata/ort_github_issue_17000.py
new file mode 100644
index 0000000000000..43c10f5590212
--- /dev/null
+++ b/onnxruntime/test/testdata/ort_github_issue_17000.py
@@ -0,0 +1,77 @@
+import numpy as np
+import onnx
+from onnx import TensorProto, helper, numpy_helper
+
+
+def order_repeated_field(repeated_proto, key_name, order):
+ order = list(order)
+ repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name)))
+
+
+def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs):
+ node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs)
+ if doc_string == "":
+ node.doc_string = ""
+ order_repeated_field(node.attribute, "name", kwargs.keys())
+ return node
+
+
+def make_graph(*args, doc_string=None, **kwargs):
+ graph = helper.make_graph(*args, doc_string=doc_string, **kwargs)
+ if doc_string == "":
+ graph.doc_string = ""
+ return graph
+
+
+test_graph = make_graph(
+ name="test_graph",
+ # model input of a sequence type to test IsSparseTensor issue
+ inputs=[
+ helper.make_tensor_sequence_value_info("seq_in", TensorProto.FLOAT, shape=None),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("still_has_elements", TensorProto.BOOL, shape=[]),
+ ],
+ initializer=[
+ numpy_helper.from_array(np.array(0, dtype="int64"), name="i0"),
+ ],
+ nodes=[
+ make_node("SequenceLength", inputs=["seq_in"], outputs=["seq_len"], name="get_seq_len"),
+ make_node("Greater", inputs=["seq_len", "i0"], outputs=["has_elements"], name="get_has_elements"),
+ # If node with one branch that has no nodes to test the allocation planner issue
+ # if sequence has elements:
+ # remove one
+ # output bool of whether it still has elements
+ # else:
+ # output false (gives us branch with no nodes)
+ make_node(
+ "If",
+ name="test_if",
+ inputs=["has_elements"],
+ outputs=["still_has_elements"],
+ then_branch=make_graph(
+ name="then",
+ inputs=[],
+ outputs=[helper.make_tensor_value_info("then_bool_out", TensorProto.BOOL, shape=[])],
+ nodes=[
+ make_node("SequenceErase", inputs=["seq_in", "i0"], outputs=["seq_less_one"]),
+ make_node("SequenceLength", inputs=["seq_less_one"], outputs=["new_seq_len"]),
+ make_node("Greater", inputs=["new_seq_len", "i0"], outputs=["then_bool_out"]),
+ ],
+ ),
+ else_branch=make_graph(
+ name="else",
+ initializer=[numpy_helper.from_array(np.array(False, dtype="bool"), name="else_bool_out")],
+ inputs=[],
+ outputs=[helper.make_tensor_value_info("else_bool_out", TensorProto.BOOL, shape=[])],
+ nodes=[],
+ ),
+ ),
+ ],
+)
+
+# Graph with Sequence operations and an If node that has a subgraph with no nodes
+model = helper.make_model(opset_imports=[helper.make_operatorsetid("ai.onnx", 14)], ir_version=7, graph=test_graph)
+
+onnx.shape_inference.infer_shapes(model, strict_mode=True)
+onnx.save(model, "ort_github_issue_17000.onnx")
diff --git a/onnxruntime/test/testdata/required_ops.config b/onnxruntime/test/testdata/required_ops.config
index ac9d46666e1b6..e70362bab4017 100644
--- a/onnxruntime/test/testdata/required_ops.config
+++ b/onnxruntime/test/testdata/required_ops.config
@@ -3,9 +3,9 @@ ai.onnx;7;Abs,Add,And,BatchNormalization,Concat,Conv,Dropout,Flatten,Foo,Gather,
ai.onnx;8;Add,Conv,Flatten,Gemm,MatMul,MaxPool,Mul,Relu,Reshape
ai.onnx;9;Abs,Add,BatchNormalization,Cast,Clip,Concat,Constant,ConstantOfShape,Conv,Div,Equal,Gather,Gemm,Identity,If,LayerNormalization,LeakyRelu,Loop,MatMul,Mul,Pow,ReduceMean,Relu,Reshape,Scan,Shape,Sigmoid,Slice,Softmax,Softsign,Sqrt,Squeeze,Sub,Tanh,Transpose,Unsqueeze
ai.onnx;10;Add,Cast,Concat,ConstantOfShape,Div,Dropout,Erf,Expand,Gather,Greater,Identity,If,LayerNormalization,Loop,MatMul,Mul,Neg,NonZero,Pow,ReduceMean,ReduceSum,Shape,Sqrt,Squeeze,Sub,Tanh,Transpose,Unsqueeze
-ai.onnx;11;Abs,Add,ArgMax,BatchNormalization,Cast,Clip,Concat,Constant,ConstantOfShape,Conv,Div,Equal,Exp,Expand,Flatten,Gather,Gemm,Identity,If,LayerNormalization,Log,Loop,MatMul,MatMulInteger,Max,Min,Mul,Neg,Pow,RandomUniform,Range,ReduceMean,ReduceSum,ReduceSumSquare,Relu,Reshape,Scan,SequenceConstruct,SequenceInsert,SequenceLength,Shape,Sigmoid,Slice,Softmax,Split,Sqrt,Squeeze,Sub,Sum,Tanh,Transpose,Unsqueeze,Where
+ai.onnx;11;Abs,Add,ArgMax,BatchNormalization,Cast,Clip,Concat,Constant,ConstantOfShape,Conv,Div,Equal,Exp,Expand,Flatten,Gather,Gemm,Identity,If,LayerNormalization,Log,Loop,MatMul,MatMulInteger,Max,Min,Mul,Neg,Pow,RandomUniform,Range,ReduceMean,ReduceSum,ReduceSumSquare,Relu,Reshape,Scan,SequenceConstruct,SequenceErase,SequenceInsert,SequenceLength,Shape,Sigmoid,Slice,Softmax,Split,Sqrt,Squeeze,Sub,Sum,Tanh,Transpose,Unsqueeze,Where
ai.onnx;12;Add,And,Cast,Concat,Constant,ConstantOfShape,Conv,CumSum,Div,Dropout,DynamicQuantizeLinear,Equal,Erf,Expand,Flatten,Gather,GatherND,Gemm,GlobalAveragePool,Greater,Identity,If,IsInf,LayerNormalization,Less,Loop,MatMul,MatMulInteger,Min,Mul,Not,Pad,Pow,RandomNormalLike,RandomUniform,ReduceMean,ReduceSum,Relu,Reshape,Shape,Slice,Softmax,SoftmaxCrossEntropyLoss,SparseSoftmaxCrossEntropy,Split,Sqrt,Squeeze,Sub,Tanh,Transpose,Unsqueeze,Where
-ai.onnx;13;Abs,Add,Cast,Concat,ConstantOfShape,Conv,DequantizeLinear,DynamicQuantizeLinear,Equal,Expand,FooBar,FooBar_Attr,Gather,Identity,LayerNormalization,MatMul,MatMulInteger,Mul,Pad,Pow,QuantizeLinear,Range,ReduceSum,Reshape,Shape,Tanh,Transpose,Unsqueeze,Where
+ai.onnx;13;Abs,Add,Cast,Concat,ConstantOfShape,Conv,DequantizeLinear,DynamicQuantizeLinear,Equal,Expand,FooBar,FooBar_Attr,Gather,Greater,Identity,If,LayerNormalization,MatMul,MatMulInteger,Mul,Pad,Pow,QuantizeLinear,Range,ReduceSum,Reshape,Shape,Tanh,Transpose,Unsqueeze,Where
ai.onnx;14;Add,ArgMax,Cast,Conv,Identity,Relu,Sigmoid,Sub
ai.onnx;314159;Add
ai.onnx.contrib;1;StringLower
diff --git a/onnxruntime/test/testdata/required_ops_and_types.config b/onnxruntime/test/testdata/required_ops_and_types.config
index 17687906d7250..41f374214747b 100644
--- a/onnxruntime/test/testdata/required_ops_and_types.config
+++ b/onnxruntime/test/testdata/required_ops_and_types.config
@@ -1,9 +1,12 @@
# required ops and types for ORT format models in testdata
-ai.onnx;1;Conv{"inputs": {"0": ["float"]}},Foo,Identity
+ai.onnx;1;Conv{"inputs": {"0": ["float"]}}
ai.onnx;5;Reshape
ai.onnx;6;Relu{"inputs": {"0": ["float"]}}
ai.onnx;7;Add{"inputs": {"0": ["float"]}},Gemm{"inputs": {"0": ["float"]}},Mul{"inputs": {"0": ["float"]}}
ai.onnx;8;MaxPool{"inputs": {"0": ["float"]}},Sum{"inputs": {"0": ["float"]}}
ai.onnx;9;Cast{"inputs": {"0": ["float"]}, "outputs": {"0": ["bool"]}}
-ai.onnx;11;ArgMax{"inputs": {"0": ["float"]}},If,Loop
+ai.onnx;10;QLinearConv{"inputs": {"0": ["uint8_t"]}}
+ai.onnx;11;ArgMax{"inputs": {"0": ["float"]}},Clip{"inputs": {"0": ["float"]}},Conv{"inputs": {"0": ["float"]}},If,Loop,SequenceErase,SequenceLength
+ai.onnx;13;DequantizeLinear{"inputs": {"0": ["int32_t", "uint8_t"]}},Greater{"inputs": {"0": ["int64_t"]}},If,QuantizeLinear{"outputs": {"0": ["uint8_t"]}}
ai.onnx.ml;1;ArrayFeatureExtractor,LinearClassifier,Normalizer,ZipMap
+test;1;Foo
diff --git a/onnxruntime/test/testdata/training_api/ort_format/checkpoint b/onnxruntime/test/testdata/training_api/ort_format/checkpoint
new file mode 100644
index 0000000000000..ab35c9ad5acde
Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/checkpoint differ
diff --git a/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort b/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort
new file mode 100644
index 0000000000000..69b2c7e029de0
Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort differ
diff --git a/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort b/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort
new file mode 100644
index 0000000000000..88f192462362d
Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort differ
diff --git a/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py b/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py
new file mode 100644
index 0000000000000..70e8c4ac011a9
--- /dev/null
+++ b/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py
@@ -0,0 +1,70 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""This file is used to generate test data for ort format model tests in
+ orttraining/orttraining/test/training_api/core/training_capi_tests.cc."""
+
+import onnx
+import torch
+import torch.nn as nn
+
+from onnxruntime.training import artifacts
+
+
+class SimpleNet(nn.Module):
+ def __init__(self, input_size, hidden_size, output_size):
+ super().__init__()
+ self.fc1 = nn.Linear(input_size, hidden_size)
+ self.relu = nn.ReLU()
+ self.fc2 = nn.Linear(hidden_size, output_size)
+
+ def forward(self, x):
+ out = self.fc1(x)
+ out = self.relu(out)
+ out = self.fc2(out)
+ return out
+
+
+def model_export(pt_model, model_path, input_size):
+ # Generate random input data
+ input_data = torch.randn(32, input_size)
+ torch.onnx.export(
+ pt_model,
+ input_data,
+ model_path,
+ input_names=["input"],
+ output_names=["output"],
+ dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
+ )
+
+
+def main():
+ # Set the dimensions for input, hidden, and output layers
+ input_size = 10
+ hidden_size = 20
+ output_size = 5
+
+ # Create an instance of the neural network
+ pt_model = SimpleNet(input_size, hidden_size, output_size)
+
+ train_model_path = "simplenet_training.onnx"
+ model_export(pt_model, train_model_path, input_size)
+
+ onnx_model = onnx.load(train_model_path)
+
+ requires_grad = ["fc2.weight", "fc2.bias"]
+ frozen_params = [param.name for param in onnx_model.graph.initializer if param.name not in requires_grad]
+
+ # Generate the training artifacts.
+ artifacts.generate_artifacts(
+ onnx_model,
+ requires_grad=requires_grad,
+ frozen_params=frozen_params,
+ loss=artifacts.LossType.CrossEntropyLoss,
+ optimizer=artifacts.OptimType.AdamW,
+ ort_format=True,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/onnxruntime/test/testdata/training_api/ort_format/training_model.ort b/onnxruntime/test/testdata/training_api/ort_format/training_model.ort
new file mode 100644
index 0000000000000..94bda328a9f9f
Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/training_model.ort differ
diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc
index eac17f3d4d2e8..3f3aa396e6ca0 100644
--- a/orttraining/orttraining/python/orttraining_pybind_state.cc
+++ b/orttraining/orttraining/python/orttraining_pybind_state.cc
@@ -174,10 +174,11 @@ struct PyOptimizer {
PyOptimizer(const std::string optimizer_model_uri, onnxruntime::training::api::CheckpointState* state,
std::vector> providers, PySessionOptions* session_options)
: optimizer_() {
+ auto model_identifiers = onnxruntime::training::api::ModelIdentifiers("", std::nullopt, optimizer_model_uri);
auto env = GetTrainingEnv().GetORTEnv();
// XXX: We hope that env will be around when optimizer needs it.
optimizer_ = std::make_shared(
- optimizer_model_uri, state, session_options->value, *env, providers, session_options->custom_op_domains_);
+ model_identifiers, state, session_options->value, *env, providers, session_options->custom_op_domains_);
}
std::shared_ptr optimizer_;
@@ -941,9 +942,10 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
OrtDevice device, PySessionOptions* session_options) {
std::vector> provider = GetExecutionProvidersForTrainingApis(device);
auto env = GetTrainingEnv().GetORTEnv();
- return std::make_unique(
- model_uri, state, session_options->value, *env, provider, eval_model_uri,
- session_options->custom_op_domains_);
+ auto model_identifiers = onnxruntime::training::api::ModelIdentifiers(model_uri, eval_model_uri, std::nullopt);
+ return std::make_unique(model_identifiers,
+ state, session_options->value, *env, provider,
+ session_options->custom_op_domains_);
}))
.def("train_step",
[](onnxruntime::training::api::Module* model,
diff --git a/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py b/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py
index 7b24bb400b162..1213342004d48 100644
--- a/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py
+++ b/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py
@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import copy
+import logging
import os
from typing import List, Optional, Set, Tuple, Union
@@ -70,13 +71,16 @@ def _move_initializers_to_inputs(model: onnx.ModelProto, initializer_names: Opti
def _gradient_model_for(
model: onnx.ModelProto,
requires_grad: Set[str],
- output_names: List[str],
loss_name: str,
options: Optional[SessionOptions] = None,
) -> onnx.ModelProto:
"""Builds the gradient graph on top of the given input forward only graph."""
- builder = GradientGraphBuilder(model.SerializeToString(), set(output_names), requires_grad, loss_name, options)
+ logging.debug(
+ "The loss output is %s. The gradient graph will be built starting from %s_grad.", loss_name, loss_name
+ )
+
+ builder = GradientGraphBuilder(model.SerializeToString(), {loss_name}, requires_grad, loss_name, options)
builder.build()
return onnx.load_from_string(builder.get_model())
@@ -123,7 +127,7 @@ def build_gradient_graph(
optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options))
# Assumption is that the first graph output is the loss output
- gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names, output_names[0], options)
+ gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names[0], options)
_reorder_outputs(gradient_model, output_names, requires_grad)
diff --git a/orttraining/orttraining/python/training/onnxblock/onnxblock.py b/orttraining/orttraining/python/training/onnxblock/onnxblock.py
index 9f90a5a0c30cd..a2922353ac70e 100644
--- a/orttraining/orttraining/python/training/onnxblock/onnxblock.py
+++ b/orttraining/orttraining/python/training/onnxblock/onnxblock.py
@@ -205,6 +205,8 @@ def __call__(self, *args, **kwargs):
model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY
)
+ logging.debug("Adding gradient accumulation nodes for training block %s", self.__class__.__name__)
+
_training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad)
accessor._GLOBAL_ACCESSOR.model.CopyFrom(self._training_model)
diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc
index 4fa3844717ef9..1369c9c69865a 100644
--- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc
+++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc
@@ -331,9 +331,12 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad) {
#if defined(USE_CUDA)
providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider());
#endif
- auto model = std::make_unique(model_uri, &state, session_option,
+ auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
+ std::nullopt,
+ std::optional(onnxruntime::ToUTF8String(optim_uri)));
+ auto model = std::make_unique(model_identifier, &state, session_option,
*env, providers);
- auto optimizer = std::make_unique(optim_uri, &state, session_option,
+ auto optimizer = std::make_unique(model_identifier, &state, session_option,
*env, providers);
// Remove the temporary directory if it already exists.
diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc
index ec0c7a1968ba4..2170f7957e6a6 100644
--- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc
+++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc
@@ -76,9 +76,12 @@ void TestModuleExport(const std::vector>& pr
std::unique_ptr env;
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
+ auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(training_model_uri),
+ std::optional(onnxruntime::ToUTF8String(eval_model_uri)),
+ std::nullopt);
auto model = std::make_unique(
- ToUTF8String(training_model_uri), &state, onnxruntime::SessionOptions(),
- *env, providers, ToUTF8String(eval_model_uri));
+ model_identifier, &state, onnxruntime::SessionOptions(),
+ *env, providers);
auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir");
if (Env::Default().FolderExists(test_dir)) {
@@ -141,7 +144,9 @@ TEST(TrainingApiTest, ModuleParametersSize) {
onnxruntime::SessionOptions session_option;
std::unique_ptr env;
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
- auto model = std::make_unique(ToUTF8String(model_uri),
+ auto model_identifiers = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
+ std::nullopt, std::nullopt);
+ auto model = std::make_unique(model_identifiers,
&state, session_option,
*env, std::vector>());
size_t params_size = 0;
@@ -164,7 +169,10 @@ TEST(TrainingApiTest, ModuleCopyBufferToParameters) {
onnxruntime::SessionOptions session_option;
std::unique_ptr env;
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
- auto model = std::make_unique(ToUTF8String(model_uri),
+ auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
+ std::nullopt,
+ std::nullopt);
+ auto model = std::make_unique(model_identifier,
&state, session_option,
*env, std::vector>());
int64_t params_size = static_cast(model->GetParametersSize());
@@ -202,7 +210,10 @@ TEST(TrainingApiTest, ModuleTrainStep) {
onnxruntime::SessionOptions session_option;
std::unique_ptr env;
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
- auto model = std::make_unique(ToUTF8String(model_uri),
+ auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
+ std::nullopt,
+ std::nullopt);
+ auto model = std::make_unique(model_identifier,
&state, session_option,
*env, std::vector>());
ASSERT_EQ(model->GetTrainingModelOutputCount(), 1);
@@ -274,8 +285,12 @@ TEST(TrainingApiTest, OptimizerCreatedWithOptimizerCheckpointState) {
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
+ auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
+ std::nullopt,
+ std::optional(onnxruntime::ToUTF8String(optim_uri)));
+
std::shared_ptr model = std::make_shared(
- ToUTF8String(model_uri), &state, session_option,
+ model_identifier, &state, session_option,
*env, providers);
// Load state dict from faked optimizer checkpoint state.
@@ -285,7 +300,7 @@ TEST(TrainingApiTest, OptimizerCreatedWithOptimizerCheckpointState) {
{"momentum0", "momentum1"},
external_optimizer_checkpoint_state));
std::shared_ptr optim = std::make_shared(
- ToUTF8String(optim_uri), &new_state, session_option, *env, providers);
+ model_identifier, &new_state, session_option, *env, providers);
ASSERT_TRUE(optim.get() != nullptr);
}
@@ -320,8 +335,12 @@ void TestLRSchduler(const std::basic_string& test_file_name,
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
+ auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
+ std::nullopt,
+ std::optional(onnxruntime::ToUTF8String(optim_uri)));
+
std::shared_ptr model = std::make_shared(
- ToUTF8String(model_uri), &state, session_option,
+ model_identifier, &state, session_option,
*env, providers);
OrtValue input, target;
@@ -351,7 +370,7 @@ void TestLRSchduler(const std::basic_string& test_file_name,
}
std::shared_ptr optim = std::make_shared(
- ToUTF8String(optim_uri), &state, session_option,
+ model_identifier, &state, session_option,
*env, providers);
// KNOWN ISSUE: LinearLRScheduler by default use optim's states to calculate the first step's learning rate.
@@ -445,11 +464,15 @@ TEST(TrainingApiTest, OptimStep) {
providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider());
#endif
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
+
+ auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri),
+ std::nullopt,
+ std::optional(onnxruntime::ToUTF8String(optim_uri)));
auto model = std::make_unique(
- ToUTF8String(model_uri), &state, session_option,
+ model_identifier, &state, session_option,
*env, providers);
auto optim = std::make_unique(
- ToUTF8String(optim_uri), &state, session_option,
+ model_identifier, &state, session_option,
*env, providers);
OrtValue input, target;
diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc
index e864f3b8632de..d734be8e3474b 100644
--- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc
+++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "gtest/gtest.h"
+#include "gmock/gmock.h"
#include "onnxruntime_c_api.h"
#include "onnxruntime_training_c_api.h"
@@ -16,6 +17,7 @@
namespace onnxruntime::training::test {
#define MODEL_FOLDER ORT_TSTR("testdata/training_api/")
+#define ORT_FORMAT_MODEL_FOLDER ORT_TSTR("testdata/training_api/ort_format/")
TEST(TrainingCApiTest, SaveCheckpoint) {
auto model_uri = MODEL_FOLDER "training_model.onnx";
@@ -220,4 +222,100 @@ TEST(TrainingCApiTest, RegisterCustomOps) {
ASSERT_TRUE(loss.front().IsTensor());
}
+TEST(TrainingCApiTest, LoadModelsAndCreateSession) {
+ auto model_path = MODEL_FOLDER "training_model.onnx";
+
+ Ort::Env env;
+ Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
+ Ort::TrainingSession training_session = Ort::TrainingSession(env,
+ Ort::SessionOptions(),
+ checkpoint_state,
+ model_path);
+}
+
+TEST(TrainingCApiTest, LoadModelsAndCreateSession_ORTFormat) {
+ auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort";
+ auto eval_train_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort";
+ auto optimizer_model_path = ORT_FORMAT_MODEL_FOLDER "optimizer_model.ort";
+
+ Ort::Env env;
+ Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(ORT_FORMAT_MODEL_FOLDER "checkpoint");
+ Ort::TrainingSession training_session = Ort::TrainingSession(env,
+ Ort::SessionOptions(),
+ checkpoint_state,
+ train_model_path,
+ eval_train_model_path,
+ optimizer_model_path);
+}
+
+TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) {
+ auto model_path = MODEL_FOLDER "training_model.onnx";
+ size_t model_data_len = 0;
+ ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len));
+ std::vector train_model_data(model_data_len);
+ std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary);
+ bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len);
+ ASSERT_TRUE(train_model_data.size() == model_data_len);
+
+ Ort::Env env;
+ Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
+ Ort::TrainingSession training_session = Ort::TrainingSession(env,
+ Ort::SessionOptions(),
+ checkpoint_state,
+ train_model_data);
+}
+
+TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) {
+ auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort";
+ auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort";
+ auto optimizer_model_path = ORT_FORMAT_MODEL_FOLDER "optimizer_model.ort";
+ size_t model_data_len = 0;
+ ASSERT_STATUS_OK(Env::Default().GetFileLength(train_model_path, model_data_len));
+ std::vector train_model_data(model_data_len);
+ {
+ std::ifstream bytes_stream(train_model_path, std::ifstream::in | std::ifstream::binary);
+ bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len);
+ ASSERT_TRUE(train_model_data.size() == model_data_len);
+ }
+
+ model_data_len = 0;
+ ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, model_data_len));
+ std::vector eval_model_data(model_data_len);
+ {
+ std::ifstream bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary);
+ bytes_stream.read(reinterpret_cast(eval_model_data.data()), model_data_len);
+ ASSERT_TRUE(eval_model_data.size() == model_data_len);
+ }
+
+ model_data_len = 0;
+ ASSERT_STATUS_OK(Env::Default().GetFileLength(optimizer_model_path, model_data_len));
+ std::vector optimizer_model_data(model_data_len);
+ {
+ std::ifstream bytes_stream(optimizer_model_path, std::ifstream::in | std::ifstream::binary);
+ bytes_stream.read(reinterpret_cast(optimizer_model_data.data()), model_data_len);
+ ASSERT_TRUE(optimizer_model_data.size() == model_data_len);
+ }
+
+ Ort::Env env;
+ Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(ORT_FORMAT_MODEL_FOLDER "checkpoint");
+ Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(),
+ checkpoint_state, train_model_data,
+ eval_model_data, optimizer_model_data);
+}
+
+TEST(TrainingCApiTest, LoadModelsFromBufferThrows) {
+ Ort::Env env;
+ Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
+
+ try {
+ std::vector train_model_data;
+ Ort::TrainingSession training_session = Ort::TrainingSession(env,
+ Ort::SessionOptions(),
+ checkpoint_state,
+ train_model_data);
+ } catch (const std::exception& ex) {
+ ASSERT_THAT(ex.what(),
+ testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL."));
+ }
+}
} // namespace onnxruntime::training::test
diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h
index b3042c449a50b..0af737074964d 100644
--- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h
+++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h
@@ -190,7 +190,29 @@ struct OrtTrainingApi {
ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
- _Outptr_ OrtTrainingSession** out);
+ _Outptr_result_maybenull_ OrtTrainingSession** out);
+
+ /** \brief Create a training session that can be used to begin or resume training.
+ * This api provides a way to load all the training artifacts from buffers instead of files.
+ *
+ * \param[in] env Environment to be used for the training session.
+ * \param[in] options Session options that the user can customize for this training session.
+ * \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
+ * \param[in] train_model_data Buffer containing the model data to be used to perform training
+ * \param[in] train_data_length Length of the buffer containing train_model_data
+ * \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation
+ * \param[in] eval_data_length Length of the buffer containing eval_model_data
+ * \param[in] optim_model_data Buffer containing the model data to be used to perform weight update
+ * \param[in] optim_data_length Length of the buffer containing optim_model_data
+ * \param[out] out Created training session.
+ *
+ */
+ ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env,
+ _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state,
+ _In_ const void* train_model_data, size_t train_data_length,
+ _In_ const void* eval_model_data, size_t eval_data_length,
+ _In_ const void* optim_model_data, size_t optim_data_length,
+ _Outptr_result_maybenull_ OrtTrainingSession** out);
/// @}
diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h
index 5bfdfcc74e817..0edef20ba6da8 100644
--- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h
+++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h
@@ -176,6 +176,20 @@ class TrainingSession : public detail::Base {
const std::optional>& eval_model_path = std::nullopt,
const std::optional>& optimizer_model_path = std::nullopt);
+ /** \brief Create a training session that can be used to begin or resume training.
+ * This constructor allows the users to load the models from buffers instead of files.
+ *
+ * \param[in] env Env to be used for the training session.
+ * \param[in] session_options SessionOptions that the user can customize for this training session.
+ * \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
+ * \param[in] train_model_data Buffer containing training model data.
+ * \param[in] eval_model_data Buffer containing evaluation model data.
+ * \param[in] optim_model_data Buffer containing optimizer model (used for performing weight/parameter update).
+ *
+ */
+ TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
+ const std::vector& train_model_data, const std::vector& eval_model_data = {},
+ const std::vector& optim_model_data = {});
/// @}
/// \name Implementing The Training Loop
diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h
index 393e5b01f7f85..066147708863f 100644
--- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h
+++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h
@@ -24,6 +24,23 @@ inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& se
ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_));
}
+inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options,
+ CheckpointState& checkpoint_state,
+ const std::vector& train_model_data,
+ const std::vector& eval_model_data,
+ const std::vector& optim_model_data) {
+ ThrowOnError(GetTrainingApi().CreateTrainingSessionFromBuffer(
+ env, session_options, checkpoint_state,
+ train_model_data.data(), train_model_data.size(),
+ eval_model_data.data(), eval_model_data.size(),
+ optim_model_data.data(), optim_model_data.size(),
+ &p_));
+
+ ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_));
+
+ ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_));
+}
+
inline std::vector TrainingSession::TrainStep(const std::vector& input_values) {
std::vector output_values;
output_values.reserve(training_model_output_count_);
diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc
index 29300bbb7e8ec..d1775e358163c 100644
--- a/orttraining/orttraining/training_api/module.cc
+++ b/orttraining/orttraining/training_api/module.cc
@@ -12,7 +12,6 @@
#include "core/graph/graph_utils.h"
#include "orttraining/training_api/checkpoint.h"
-#include "orttraining/training_api/utils.h"
using namespace onnxruntime;
@@ -150,12 +149,11 @@ Status Parameter::ResetGrad() {
return Status::OK();
}
-Module::Module(const std::string& train_model_path_or_bytes,
+Module::Module(const ModelIdentifiers& model_identifiers,
CheckpointState* state,
const onnxruntime::SessionOptions& session_options,
const Environment& env,
const std::vector>& providers,
- const std::optional& eval_model_path_or_bytes,
[[maybe_unused]] gsl::span op_domains)
: state_{state} {
// Enforce weight prepacking is disabled
@@ -176,7 +174,12 @@ Module::Module(const std::string& train_model_path_or_bytes,
}
#endif
- ORT_THROW_IF_ERROR(train_sess_->Load(train_model_path_or_bytes));
+ // Load the training model
+ ORT_THROW_IF_ERROR(std::holds_alternative(model_identifiers.train_model)
+ ? train_sess_->Load(std::get(model_identifiers.train_model))
+ : train_sess_->Load(std::get>(model_identifiers.train_model).data(),
+ static_cast(std::get>(model_identifiers.train_model).size())));
+
for (const auto& provider : providers) {
ORT_THROW_IF_ERROR(train_sess_->RegisterExecutionProvider(provider));
}
@@ -239,7 +242,6 @@ Module::Module(const std::string& train_model_path_or_bytes,
// Copy ortvalue buffer from CPU to target_device for this "param_name" (based on graph partitioning)
// Only copies data if the target device is not the same as the current device the buffer is placed on
-
OrtValue& param_data = params_iter->second->Data();
ORT_ENFORCE(param_data.IsTensor());
const Tensor& param_data_tensor = param_data.Get();
@@ -278,47 +280,57 @@ Module::Module(const std::string& train_model_path_or_bytes,
}
}
- if (eval_model_path_or_bytes.has_value()) {
+ if (model_identifiers.IsEvalModelAvailable()) {
eval_sess_ = std::make_unique(session_options, env);
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
if (!op_domains.empty()) {
ORT_THROW_IF_ERROR(eval_sess_->AddCustomOpDomains(op_domains));
}
#endif
-
- ORT_THROW_IF_ERROR(eval_sess_->Load(eval_model_path_or_bytes.value()));
- for (const auto& provider : providers) {
- ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider));
- }
- ORT_THROW_IF_ERROR(eval_sess_->Initialize());
- utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_);
-
- // Eval model validation
- // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval
- // graphs, and all the weights present in both graphs match.
- // TODO: Add the checks instead of making assumptions??
- InlinedVector eval_user_input_names, eval_param_input_names;
- for (const auto& input_name : eval_input_names_) {
- if (state_->module_checkpoint_state.named_parameters.find(input_name) !=
- state_->module_checkpoint_state.named_parameters.end()) {
- // it is a parameter
- eval_param_input_names.emplace_back(input_name);
- continue;
- } else {
- // It is user input. We handle user inputs separately in the eval
- // because the eval graph might have different user inputs.
- // Eg if loss is not a part of the eval graph, it won't have
- // certain inputs like targets
- eval_user_input_names.emplace_back(input_name);
- }
+ if (std::holds_alternative>(model_identifiers.eval_model)) {
+ ORT_THROW_IF_ERROR(eval_sess_->Load(std::get>(model_identifiers.eval_model).value()));
+ } else {
+ auto model_data = std::get>(model_identifiers.eval_model);
+ ORT_THROW_IF_ERROR(eval_sess_->Load(model_data.data(), static_cast(model_data.size())));
}
- eval_input_names_ = eval_user_input_names;
- eval_user_input_count_ = eval_user_input_names.size();
- eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end());
+ } else {
+ return;
+ }
- // Keep a copy of the eval model path to be able to later export the model for inferencing.
- // The inference model will be reconstructed from the eval model.
- eval_model_path_ = eval_model_path_or_bytes.value();
+ for (const auto& provider : providers) {
+ ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider));
+ }
+ ORT_THROW_IF_ERROR(eval_sess_->Initialize());
+ utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_);
+
+ // Eval model validation
+ // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval
+ // graphs, and all the weights present in both graphs match.
+ // TODO(askhade): Add the checks instead of making assumptions??
+ InlinedVector eval_user_input_names, eval_param_input_names;
+ for (const auto& input_name : eval_input_names_) {
+ if (state_->module_checkpoint_state.named_parameters.find(input_name) !=
+ state_->module_checkpoint_state.named_parameters.end()) {
+ // it is a parameter
+ eval_param_input_names.emplace_back(input_name);
+ continue;
+ } else {
+ // It is user input. We handle user inputs separately in the eval
+ // because the eval graph might have different user inputs.
+ // Eg if loss is not a part of the eval graph, it won't have
+ // certain inputs like targets
+ eval_user_input_names.emplace_back(input_name);
+ }
+ }
+ eval_input_names_ = eval_user_input_names;
+ eval_user_input_count_ = eval_user_input_names.size();
+ eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end());
+
+ // Keep a copy of the eval model path to be able to later export the model for inferencing.
+ // The inference model will be reconstructed from the eval model.
+ // TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer.
+ if (std::holds_alternative>(model_identifiers.eval_model)) {
+ eval_model_path_ = std::get>(model_identifiers.eval_model);
}
}
@@ -486,14 +498,14 @@ Status Module::EvalStep(const std::vector& inputs, std::vector graph_output_names) const {
- ORT_RETURN_IF(!eval_sess_ || eval_model_path_.empty(),
+ ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(),
"Eval model was not provided. Cannot export a model for inferencing.");
ONNX_NAMESPACE::ModelProto eval_model;
- ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_), eval_model));
+ ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model));
// Clone the eval mode into an inference onnxruntime::Model.
std::shared_ptr inference_model;
diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h
index 9013ab22c124f..adb633343263e 100644
--- a/orttraining/orttraining/training_api/module.h
+++ b/orttraining/orttraining/training_api/module.h
@@ -3,7 +3,9 @@
#pragma once
+#include
#include "core/session/inference_session.h"
+#include "orttraining/training_api/utils.h"
namespace onnxruntime {
namespace training {
@@ -73,12 +75,12 @@ struct Module {
public:
// Initialize a module from an ORT inference session with loaded
// training ONNX model and load parameters
- Module(const std::string& train_model_path_or_bytes,
+ // The model and checkpoint state can be provided as a file path or a byte array
+ Module(const ModelIdentifiers& model_identifiers,
CheckpointState* state,
const onnxruntime::SessionOptions& session_options,
const Environment& env,
const std::vector>& providers,
- const std::optional& eval_model_path_or_bytes = std::nullopt,
gsl::span op_domains = gsl::span());
// Return the trainable/nontrainable parameters
@@ -159,7 +161,7 @@ struct Module {
CheckpointState* state_; // Non owning pointer to the state.
bool accumulate_gradient_ = false;
- std::string eval_model_path_;
+ std::optional eval_model_path_;
size_t train_user_input_count_{0U};
size_t eval_user_input_count_{0U};
};
diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc
index b84009e7f3591..6693bba348648 100644
--- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc
+++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc
@@ -13,6 +13,8 @@
#include "orttraining/training_api/ort_training_apis.h"
#include "orttraining/training_api/training_session.h"
+using namespace onnxruntime::training::api;
+
namespace {
std::vector> CreateProviders(
@@ -26,44 +28,85 @@ std::vector> CreateProviders(
return execution_providers;
}
+static OrtStatus* CreateSessionAndLoadModel(_In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
+ _Inout_ OrtCheckpointState* checkpoint_state,
+ const ModelIdentifiers& model_identifiers,
+ std::unique_ptr& train_sess) {
+ auto chkpt_state = reinterpret_cast(checkpoint_state);
+
+ using ProvidersType = std::vector>;
+ train_sess = std::make_unique(env->GetEnvironment(),
+ options == nullptr ? onnxruntime::SessionOptions() : options->value,
+ options == nullptr
+ ? ProvidersType()
+ : CreateProviders(options->provider_factories),
+ chkpt_state,
+ model_identifiers,
+ options == nullptr
+ ? gsl::span()
+ : options->custom_op_domains_);
+
+ return nullptr;
+}
+
} // namespace
ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSession, _In_ const OrtEnv* env,
_In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state,
_In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path,
- _In_ const ORTCHAR_T* optimizer_model_path, _Outptr_ OrtTrainingSession** out) {
+ _In_ const ORTCHAR_T* optimizer_model_path, _Outptr_result_maybenull_ OrtTrainingSession** out) {
API_IMPL_BEGIN
if (options != nullptr && options->value.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseEnvAllocators, "0") == "1") {
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Use Env Allocators is not supported for on device training.");
}
std::unique_ptr train_sess;
- auto chkpt_state = reinterpret_cast(checkpoint_state);
OrtStatus* status = nullptr;
*out = nullptr;
- ORT_TRY {
- using ProvidersType = std::vector>;
- train_sess = std::make_unique(
- env->GetEnvironment(),
- options == nullptr ? onnxruntime::SessionOptions() : options->value,
- options == nullptr ? ProvidersType() : CreateProviders(options->provider_factories),
- chkpt_state,
- onnxruntime::training::api::ModelIdentifiers(
- onnxruntime::ToUTF8String(train_model_path),
- eval_model_path ? std::optional(onnxruntime::ToUTF8String(eval_model_path))
- : std::nullopt,
- optimizer_model_path ? std::optional(onnxruntime::ToUTF8String(optimizer_model_path))
- : std::nullopt),
- options == nullptr ? gsl::span() : options->custom_op_domains_);
-
- *out = reinterpret_cast(train_sess.release());
- }
- ORT_CATCH(const std::exception& e) {
- ORT_HANDLE_EXCEPTION([&]() {
- status = OrtApis::CreateStatus(ORT_FAIL, e.what());
- });
- }
+ ORT_ENFORCE(train_model_path != nullptr,
+ "Train model path is required to create TrainingSession, it cannot be empty.");
+
+ auto model_identifiers = onnxruntime::training::api::ModelIdentifiers(
+ onnxruntime::ToUTF8String(train_model_path),
+ eval_model_path ? std::optional(onnxruntime::ToUTF8String(eval_model_path))
+ : std::nullopt,
+ optimizer_model_path ? std::optional(onnxruntime::ToUTF8String(optimizer_model_path))
+ : std::nullopt);
+
+ ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(env, options, checkpoint_state, model_identifiers, train_sess));
+ *out = reinterpret_cast(train_sess.release());
+
+ return status;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env,
+ _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state,
+ _In_ const void* train_model_data, size_t train_data_length,
+ _In_ const void* eval_model_data, size_t eval_data_length,
+ _In_ const void* optim_model_data, size_t optim_data_length,
+ _Outptr_result_maybenull_ OrtTrainingSession** out) {
+ API_IMPL_BEGIN
+ std::unique_ptr train_sess;
+ OrtStatus* status = nullptr;
+ *out = nullptr;
+ ORT_ENFORCE(train_model_data != nullptr && train_data_length != 0,
+ "Training Session Creation failed. Train model data cannot be NULL.");
+
+ auto model_identifiers = ModelIdentifiers(gsl::make_span(reinterpret_cast(train_model_data),
+ train_data_length),
+ eval_data_length == 0 || eval_model_data == nullptr
+ ? gsl::span()
+ : gsl::make_span(reinterpret_cast(eval_model_data),
+ eval_data_length),
+ optim_data_length == 0 || optim_model_data == nullptr
+ ? gsl::span()
+ : gsl::make_span(reinterpret_cast(optim_model_data),
+ optim_data_length));
+
+ ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(env, options, checkpoint_state, model_identifiers, train_sess));
+ *out = reinterpret_cast(train_sess.release());
return status;
API_IMPL_END
}
@@ -523,6 +566,7 @@ static constexpr OrtTrainingApi ort_training_api = {
&OrtTrainingApis::LoadCheckpoint,
&OrtTrainingApis::SaveCheckpoint,
&OrtTrainingApis::CreateTrainingSession,
+ &OrtTrainingApis::CreateTrainingSessionFromBuffer,
&OrtTrainingApis::TrainingSessionGetTrainingModelOutputCount,
&OrtTrainingApis::TrainingSessionGetEvalModelOutputCount,
&OrtTrainingApis::TrainingSessionGetTrainingModelOutputName,
diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc
index a6b82f1d50fc0..7f583ce8f6e76 100644
--- a/orttraining/orttraining/training_api/optimizer.cc
+++ b/orttraining/orttraining/training_api/optimizer.cc
@@ -61,19 +61,10 @@ Status GraphInputsAreExpected(gsl::span actual_graph_inputs,
} // namespace
std::unique_ptr OptimizerAlorithmFactory::CreateInstance(
- const std::string& optim_path, int32_t& group_count) {
+ std::shared_ptr model, int32_t& group_count) {
std::map, int32_t> opt_type_to_freq_map;
#if !defined(ORT_MINIMAL_BUILD)
- if (const auto optim_path_str = ToPathString(optim_path);
- fbs::utils::IsOrtFormatModel(optim_path_str)) {
- // TODO (baijumeswani): Figure out the best way to extract the optimizer type
- // from an ort format model.
- opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1;
- } else {
- std::shared_ptr model;
- ORT_ENFORCE(Model::Load(optim_path_str, model, nullptr,
- logging::LoggingManager::DefaultLogger())
- .IsOK());
+ if (model != nullptr) {
Graph& graph = model->MainGraph();
for (auto& node : graph.Nodes()) {
if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
@@ -85,33 +76,71 @@ std::unique_ptr OptimizerAlorithmFactory::CreateInstance
opt_type_to_freq_map[domain_type_pair] += 1;
}
}
- }
+ } else {
#else
- // TODO (baijumeswani): Figure out the best way to extract the optimizer type
- // from the model (either onnx model or ort format model) or from the checkpoint.
- // For now, assume that the optimizer type is AdamWOptimizer in a minimal build.
- ORT_UNUSED_PARAMETER(optim_path);
-
- opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1;
+ ORT_UNUSED_PARAMETER(model);
+#endif
+ // TODO(baijumeswani): Figure out the best way to extract the optimizer type
+ // from the model (either onnx model or ort format model) or from the checkpoint.
+ // For now, assume that the optimizer type is AdamWOptimizer when using ort format models.
+ opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1;
+#if !defined(ORT_MINIMAL_BUILD)
+ }
#endif
ORT_ENFORCE(opt_type_to_freq_map.size() == 1U, "Only support one type of optimizer algorithm, but got: " +
std::to_string(opt_type_to_freq_map.size()));
auto opt_it = opt_type_to_freq_map.begin();
+ auto& op_type = opt_it->first.second;
group_count = opt_it->second;
- auto& domain = opt_it->first.first;
- auto& type = opt_it->first.second;
+ ORT_ENFORCE(group_count == 1, "Group count can only be 1, but got: " + std::to_string(group_count));
// TODO: to support multiple groups, need to create a mapping between each group to its parameter list.
- if (domain == kMSDomain && type == "AdamWOptimizer") {
+ if (op_type == "AdamWOptimizer") {
return std::make_unique();
- } else if (domain == kMSDomain && type == "SGDOptimizerV2") {
+ } else if (op_type == "SGDOptimizerV2") {
return std::make_unique();
} else {
ORT_NOT_IMPLEMENTED("Not implemented for optimizer algo: " + opt_it->first.second);
}
}
+std::unique_ptr OptimizerAlorithmFactory::CreateInstance(
+ const PathString& optim_path, int32_t& group_count) {
+ std::shared_ptr model = nullptr;
+#if !defined(ORT_MINIMAL_BUILD)
+ if (!fbs::utils::IsOrtFormatModel(optim_path)) {
+ ORT_ENFORCE(Model::Load(optim_path, model, nullptr,
+ logging::LoggingManager::DefaultLogger())
+ .IsOK());
+ }
+#else
+ ORT_UNUSED_PARAMETER(optim_path);
+#endif
+ return CreateInstance(model, group_count);
+}
+
+std::unique_ptr OptimizerAlorithmFactory::CreateInstance(
+ const uint8_t* optim_model_data, size_t optim_model_data_len, int32_t& group_count) {
+ std::shared_ptr model = nullptr;
+#if !defined(ORT_MINIMAL_BUILD)
+ if (!fbs::utils::IsOrtFormatModelBytes(optim_model_data, static_cast(optim_model_data_len))) {
+ ONNX_NAMESPACE::ModelProto model_proto;
+ ORT_ENFORCE(model_proto.ParseFromArray(optim_model_data, static_cast(optim_model_data_len)) == true,
+ "Failed to load model because protobuf parsing failed.");
+
+ ORT_ENFORCE(Model::Load(std::move(model_proto), model, nullptr,
+ logging::LoggingManager::DefaultLogger(), ModelOptions(true, true))
+ .IsOK());
+ }
+#else
+ ORT_UNUSED_PARAMETER(optim_model_data);
+ ORT_UNUSED_PARAMETER(optim_model_data_len);
+#endif
+
+ return CreateInstance(model, group_count);
+}
+
Status Optimizer::GenerateMomentumNamedStates(OptimizerCheckpointState& optimizer_checkpoint_states) {
auto group_optimizer_state_it =
optimizer_checkpoint_states.group_named_optimizer_states.find(GROUP_ZERO_NAME);
@@ -200,14 +229,14 @@ Status Optimizer::ConstructInputs() {
return Status::OK();
} // namespace api
-Optimizer::Optimizer(const std::string& optim_path_or_bytes,
+Optimizer::Optimizer(const ModelIdentifiers& model_identifiers,
CheckpointState* state,
const onnxruntime::SessionOptions& session_options,
const Environment& env,
const std::vector>& providers,
gsl::span op_domains)
: optim_sess_(std::make_unique(session_options, env)), state_(state) {
- Initialize(optim_path_or_bytes, providers, op_domains);
+ Initialize(model_identifiers, providers, op_domains);
ORT_ENFORCE(state != nullptr, "Checkpoint state cannot be null.");
auto g_it = state_->optimizer_checkpoint_state.group_named_optimizer_states.find(GROUP_ZERO_NAME);
@@ -223,7 +252,7 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes,
}
}
-void Optimizer::Initialize(const std::string& optim_path_or_bytes,
+void Optimizer::Initialize(const ModelIdentifiers& model_identifiers,
const std::vector>& providers,
[[maybe_unused]] gsl::span op_domains) {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
@@ -236,7 +265,22 @@ void Optimizer::Initialize(const std::string& optim_path_or_bytes,
ORT_THROW_IF_ERROR(optim_sess_->RegisterExecutionProvider(execution_provider));
}
- ORT_THROW_IF_ERROR(optim_sess_->Load(optim_path_or_bytes));
+ ORT_ENFORCE(model_identifiers.IsOptimizerModelAvailable(), "Optimizer model is not available.");
+
+ if (std::holds_alternative>(model_identifiers.optim_model)) {
+ auto optimizer_model = std::get>(model_identifiers.optim_model);
+ // The above call to IsOptimizerModelAvailable() ensures that optimizer_model is not nullopt
+ ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.value()));
+ optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(ToWideString(optimizer_model.value()), group_count_);
+ } else {
+ auto optimizer_model = std::get>(model_identifiers.optim_model);
+ ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.data(),
+ static_cast(optimizer_model.size())));
+ optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optimizer_model.data(),
+ optimizer_model.size(),
+ group_count_);
+ }
+
ORT_THROW_IF_ERROR(optim_sess_->Initialize());
// Make sure that the checkpoint state can copy tensors
@@ -244,10 +288,6 @@ void Optimizer::Initialize(const std::string& optim_path_or_bytes,
utils::GetGraphInputOutputNames(optim_sess_, input_names_, output_names_);
- optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optim_path_or_bytes, group_count_);
- ORT_ENFORCE(group_count_ == 1, "Group count can only be 1, but got: " + std::to_string(group_count_));
- ORT_ENFORCE(optimizer_algo_ptr_, "optimizer_algo_ptr_ should not be nullptr.");
-
InlinedVector all_input_names;
all_input_names.reserve(CommonOptimizerInputs.size() + optimizer_algo_ptr_->optimizer_states_inputs.size());
all_input_names.insert(all_input_names.end(), CommonOptimizerInputs.begin(),
diff --git a/orttraining/orttraining/training_api/optimizer.h b/orttraining/orttraining/training_api/optimizer.h
index 36ce3297fe3c4..d9bc4870bb7ed 100644
--- a/orttraining/orttraining/training_api/optimizer.h
+++ b/orttraining/orttraining/training_api/optimizer.h
@@ -64,8 +64,11 @@ struct SGDOptimizerV2Algorithm : public OptimizerAlgorithmBase {
};
struct OptimizerAlorithmFactory {
- static std::unique_ptr CreateInstance(const std::string& optim_path_or_bytes,
+ static std::unique_ptr CreateInstance(const PathString& optim_path,
int32_t& group_count);
+ static std::unique_ptr CreateInstance(const uint8_t* optim_model_data,
+ size_t optim_model_data_len, int32_t& group_count);
+ static std::unique_ptr CreateInstance(std::shared_ptr model, int32_t& group_count);
};
struct CheckpointState;
@@ -96,7 +99,7 @@ struct Optimizer {
// Initialize an optimizer module from an ORT inference session with loaded
// training ONNX model For each parameter, initialize the OptimizerState based
// on the graph input's ValueInfoProto if the parameter doesn't have it already.
- Optimizer(const std::string& optim_path_or_bytes,
+ Optimizer(const ModelIdentifiers& model_identifiers,
CheckpointState* state,
const onnxruntime::SessionOptions& session_options,
const Environment& env,
@@ -121,7 +124,7 @@ struct Optimizer {
}
private:
- void Initialize(const std::string& optim_path_or_bytes,
+ void Initialize(const ModelIdentifiers& model_identifiers,
const std::vector>& providers,
gsl::span op_domains);
diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h
index 2b383f3b9782a..c87108957c975 100644
--- a/orttraining/orttraining/training_api/ort_training_apis.h
+++ b/orttraining/orttraining/training_api/ort_training_apis.h
@@ -8,7 +8,14 @@ ORT_API(const OrtTrainingApi*, GetTrainingApi, uint32_t version);
ORT_API_STATUS_IMPL(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
- _Outptr_ OrtTrainingSession** out);
+ _Outptr_result_maybenull_ OrtTrainingSession** out);
+
+ORT_API_STATUS_IMPL(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env,
+ _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state,
+ _In_ const void* train_model_data, size_t train_data_length,
+ _In_ const void* eval_model_data, size_t eval_data_length,
+ _In_ const void* optim_model_data, size_t optim_data_length,
+ _Outptr_result_maybenull_ OrtTrainingSession** out);
ORT_API_STATUS_IMPL(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc
index 6915193a8ff7c..45f0f0ddcf7f4 100644
--- a/orttraining/orttraining/training_api/training_session.cc
+++ b/orttraining/orttraining/training_api/training_session.cc
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "orttraining/training_api/training_session.h"
+#include "orttraining/training_api/utils.h"
namespace onnxruntime::training::api {
@@ -12,13 +13,12 @@ TrainingSession::TrainingSession(const Environment& session_env,
const ModelIdentifiers& model_identifiers,
gsl::span custom_op_domains)
: state_{state},
- module_{std::make_unique(model_identifiers.train_model, state_,
- session_options, session_env, providers,
- model_identifiers.eval_model, custom_op_domains)},
- optimizer_{model_identifiers.optim_model.has_value()
+ module_{std::make_unique(model_identifiers, state_,
+ session_options, session_env, providers, custom_op_domains)},
+ optimizer_{model_identifiers.IsOptimizerModelAvailable()
? std::make_unique(
- model_identifiers.optim_model.value(), state_,
- session_options, session_env, providers, custom_op_domains)
+ model_identifiers, state_,
+ session_options, session_env, providers)
: std::unique_ptr()} {}
Status TrainingSession::RegisterScheduler(
diff --git a/orttraining/orttraining/training_api/training_session.h b/orttraining/orttraining/training_api/training_session.h
index 1a16acd5115f0..13b0ae79093de 100644
--- a/orttraining/orttraining/training_api/training_session.h
+++ b/orttraining/orttraining/training_api/training_session.h
@@ -3,25 +3,17 @@
#pragma once
#include "core/common/common.h"
-#include "module.h"
-#include "optimizer.h"
-#include "lr_scheduler.h"
-#include "checkpoint.h"
+#include "orttraining/training_api/module.h"
+#include "orttraining/training_api/optimizer.h"
+#include "orttraining/training_api/lr_scheduler.h"
+#include "orttraining/training_api/checkpoint.h"
+#include "orttraining/training_api/utils.h"
namespace onnxruntime {
namespace training {
namespace api {
using namespace common;
-struct ModelIdentifiers {
- const std::string train_model;
- const std::optional eval_model, optim_model;
- ModelIdentifiers(const std::string& train_model_uri,
- const std::optional& eval_model_uri,
- const std::optional& optim_model_uri)
- : train_model(train_model_uri), eval_model(eval_model_uri), optim_model(optim_model_uri) {}
-};
-
// Wrapper on top of module and optimizer classes and is the only class exposed via capis
class TrainingSession {
public:
diff --git a/orttraining/orttraining/training_api/utils.h b/orttraining/orttraining/training_api/utils.h
index e856554c971ec..f16f0f947fbd5 100644
--- a/orttraining/orttraining/training_api/utils.h
+++ b/orttraining/orttraining/training_api/utils.h
@@ -10,6 +10,40 @@
namespace onnxruntime {
namespace training {
namespace api {
+
+struct ModelIdentifiers {
+ // ModelIdentifiers struct enables an easy way to store and identify the models used for training, evaluation
+ // and model updates(optimizer model).
+ // The model can be specified by a path to the model file or by a span of bytes containing the model data.
+ // Training model is required, evaluation and optimizer models are optional.
+ std::variant> train_model;
+ std::variant, gsl::span> eval_model;
+ std::variant, gsl::span> optim_model;
+
+ ModelIdentifiers(std::variant> training_model,
+ std::variant, gsl::span> evaluation_model,
+ std::variant, gsl::span