Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CoreML: Add ML Program Concat #21423

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,14 @@ void AddOperationInput(MILSpec::Operation& op, std::string_view input_name, std:
(*op.mutable_inputs())[input_name] = std::move(arg);
}

void AddOperationInputs(MILSpec::Operation& op, std::string_view input_name,
const std::vector<std::string_view>& value_names) {
MILSpec::Argument& arg = (*op.mutable_inputs())[input_name];
vraspar marked this conversation as resolved.
Show resolved Hide resolved
for (const auto& value : value_names) {
arg.mutable_arguments()->Add()->set_name(std::string(value));
}
}

void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output) {
auto& outputs = *op.mutable_outputs();
auto& output_arg = *outputs.Add();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@
void AddOperationInput(COREML_SPEC::MILSpec::Operation& op,
std::string_view input_name, std::string_view value_name);

/// <summary>
/// Add a variadic input argument to a MILSpec::Operation
/// </summary>
/// <param name="op">Operation to update.</param>
/// <param name="input name">The input name defined by the spec for the operation. </param>
/// <param name="value_names">The input value names.</param>
void AddOperationInputs(COREML_SPEC::MILSpec::Operation& op, std::string_view input_name,
const std::vector<std::string_view>& value_names);

Check warning on line 139 in onnxruntime/core/providers/coreml/builders/impl/builder_utils.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/builder_utils.h:139: Add #include <vector> for vector<> [build/include_what_you_use] [4]

/// <summary>
/// Add an output to a MILSpec::Operation. Name, data type and shape are used from the NodeArg.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/providers/common.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/providers/coreml/builders/impl/base_op_builder.h"
#include "core/providers/coreml/builders/impl/builder_utils.h"
#include "core/providers/coreml/builders/model_builder.h"
#include "core/providers/coreml/builders/op_builder_factory.h"
#include "core/providers/coreml/shape_utils.h"
Expand All @@ -18,27 +19,52 @@

bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;

bool SupportsMLProgram() const override { return true; }
};

Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);

layer->mutable_concat()->set_sequenceconcat(false);

for (const auto* input : node.InputDefs()) {
LOGS(logger, VERBOSE) << "input name " << input->Name();
*layer->mutable_input()->Add() = input->Name();
#if defined(COREML_ENABLE_MLPROGRAM)
if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec; // NOLINT

NodeAttrHelper helper(node);
const auto axis = helper.GetInt64("axis"); // required
const auto interleave = false;

std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "concat");
std::vector<std::string_view> input_names;
for (const auto* input : node.InputDefs()) {
input_names.emplace_back(input->Name());
}
AddOperationInputs(*op, "values", input_names);
AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", *axis));
AddOperationInput(*op, "interleave", model_builder.AddScalarConstant(op->type(), "interleave", interleave));
AddOperationOutput(*op, *node.OutputDefs()[0]);
model_builder.AddOperation(std::move(op));

Check warning on line 47 in onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3] Raw Output: onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc:47: Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3]
} else

Check warning on line 48 in onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 If an else has a brace on one side, it should have it on both [readability/braces] [5] Raw Output: onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc:48: If an else has a brace on one side, it should have it on both [readability/braces] [5]
#endif // defined(COREML_ENABLE_MLPROGRAM)
{
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);

layer->mutable_concat()->set_sequenceconcat(false);

for (const auto* input : node.InputDefs()) {
LOGS(logger, VERBOSE) << "input name " << input->Name();
*layer->mutable_input()->Add() = input->Name();
}

*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();

model_builder.AddLayer(std::move(layer));
}

*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();

model_builder.AddLayer(std::move(layer));
return Status::OK();
}

bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */,
bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
if (input_defs.size() < 2) {
Expand All @@ -50,23 +76,25 @@
if (!GetShape(*input_defs[0], input_shape, logger))
return false;

auto rank = input_shape.size();
if (rank != 4) {
// For some reason, the concat in CoreML running on 3d tensor will concat on wrong axis
// Instead of concat on axis 0, it will concat on axis 1
// Disable Concat support for 3d tensor for now
// TODO, add ExpandDims and Squeeze, 3d -ExpandDims-> 4d -> Concat -Squeeze-> 3d
LOGS(logger, VERBOSE) << "Concat only support 4d shape for now, input is "
<< rank << "d shape";
return false;
}

NodeAttrHelper helper(node);
auto axis = static_cast<size_t>(HandleNegativeAxis(helper.Get("axis", 1), rank));
if (rank != axis + 3) {
LOGS(logger, VERBOSE) << "Concat only support axis to be -3, actual axis: " << axis
<< ", actual rank: " << rank;
return false;
if (!input_params.create_mlprogram) {
auto rank = input_shape.size();
if (rank != 4) {
// For some reason, the concat in CoreML running on 3d tensor will concat on wrong axis
// Instead of concat on axis 0, it will concat on axis 1
// Disable Concat support for 3d tensor for now
// TODO, add ExpandDims and Squeeze, 3d -ExpandDims-> 4d -> Concat -Squeeze-> 3d

Check warning on line 85 in onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc:85: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]

Check warning on line 85 in onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 TODO(my_username) should be followed by a space [whitespace/todo] [2] Raw Output: onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc:85: TODO(my_username) should be followed by a space [whitespace/todo] [2]
LOGS(logger, VERBOSE) << "Concat only support 4d shape for now, input is "
<< rank << "d shape";
return false;
}

NodeAttrHelper helper(node);
auto axis = static_cast<size_t>(HandleNegativeAxis(helper.Get("axis", 1), rank));
if (rank != axis + 3) {
LOGS(logger, VERBOSE) << "Concat only support axis to be -3, actual axis: " << axis
<< ", actual rank: " << rank;
return false;
}
}

return true;
Expand Down
Loading
Loading