Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CustomOp Output Type Inference function #19906

Merged
merged 3 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1837,14 +1837,28 @@ struct OrtApi {

/** \brief Used for custom operators, get an input of a kernel
*
* \see ::OrtCustomOp
* The function attempts fetches the input of the kernel. If the input is optional
* and not present, the function returns success and out is set to nullptr.
*
* \param[in] context ::OrtKernelContext instance
* \param[in] input index. See KernelContext_GetInputCount for boundaries check.
* \param[in, out] returns a ptr to OrtValue if the input is present
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index,
_Out_ const OrtValue** out);

/** \brief Used for custom operators, get an output of a kernel
*
* \see ::OrtCustomOp
* The function attempts fetches the output of the kernel. If the output is optional
* and not present, the function returns success and out is set to nullptr.
*
* \param[in] context ::OrtKernelContext instance
* \param[in] output index. See KernelContext_GetOutputCount for boundaries check.
* \param[in, out] returns a ptr to OrtValue if the output is present
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index,
_In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out);
Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2055,7 +2055,11 @@ struct KernelContext {
explicit KernelContext(OrtKernelContext* context);
size_t GetInputCount() const;
size_t GetOutputCount() const;
// If input is optional and is not present, the method returns en empty ConstValue
// which can be compared to nullptr.
ConstValue GetInput(size_t index) const;
// If outout is optional and is not present, the method returns en empty UnownedValue
// which can be compared to nullptr.
UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
void* GetGPUComputeStream() const;
Expand Down
157 changes: 109 additions & 48 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1040,59 +1040,120 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o
return Status::OK();
}

void InferOutputTypes(const InlinedVector<const KernelDef*>& kernel_defs,
ONNX_NAMESPACE::InferenceContext& infer_ctx) {
for (const auto& kernel_def : kernel_defs) {
// This function attempts to do its best for older custom ops (most of them) who do not have
// they own type and shape inference function. However, it falls short in some cases, and we leave
// those for the user to handle in their own inference function.
static void InferOutputTypes(const ONNX_NAMESPACE::OpSchema& schema, gsl::span<const KernelDef* const> kernel_defs,
ONNX_NAMESPACE::InferenceContext& infer_ctx) {
const auto& inputs = schema.inputs();
const auto node_input_num = infer_ctx.getNumInputs();

const KernelDef* def_selected = nullptr;
bool is_variadic_input = false;
bool is_homogeneous_input = false;
int32_t output_propagate{0};

for (size_t kernel_index = 0;
kernel_index < kernel_defs.size() && def_selected == nullptr;
++kernel_index) {
const auto* kernel_def = kernel_defs[kernel_index];
const auto& type_constraints = kernel_def->TypeConstraints();
auto num_inputs = infer_ctx.getNumInputs();
bool matched = true;
ONNXTensorElementDataType undef = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
// first, make sure there is a constraint for every input
for (size_t i = 0; i < num_inputs && matched; ++i) {
auto input_name = "Input" + std::to_string(i);
auto input_type = infer_ctx.getInputType(i);
if (input_type) {
auto elem_type = static_cast<ONNXTensorElementDataType>(input_type->tensor_type().elem_type());
auto tc_iter = type_constraints.find(input_name);
if (tc_iter != type_constraints.end()) {
if (tc_iter->second.size() > 1) {
undef = elem_type;
} else if (tc_iter->second.size() != 1 ||
tc_iter->second[0] != DataTypeImpl::TensorTypeFromONNXEnum(elem_type)) {
matched = false;
def_selected = kernel_def;

for (size_t i = 0; i < node_input_num; ++i) {
const auto input_type = infer_ctx.getInputType(i);

// Guard against variadic parameter index
const size_t schema_input_index = (i < inputs.size()) ? i : inputs.size() - 1;
const auto& param = inputs[schema_input_index];
const auto& input_name = param.GetName();
if (input_type == nullptr) {
if (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Optional)
continue;

ORT_THROW("[CustomOP type inferencing error]: kernel Input: ", input_name,
" is absent, but not optional. Op : ", schema.Name());
}

is_variadic_input = (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic);
is_homogeneous_input = param.GetIsHomogeneous();

if (!is_variadic_input || is_homogeneous_input) {
auto hit = type_constraints.find(input_name);
if (hit != type_constraints.end()) {
const auto& types = hit->second;
// For custom ops kernel constraints are never empty
assert(!types.empty());
if (!std::any_of(types.cbegin(), types.cend(),
[input_type](const DataTypeImpl* type) {
return type->IsCompatible(*input_type);
})) {
def_selected = nullptr;
output_propagate = 0;
break;
}

// If we have multiple types possible from the constraints,
// record the last type and use it to guess the output type if
// output may have different types. Works well for symmetric single input/outputs
// otherwise give up and let the user supply their own function
if (types.size() > 1) {
output_propagate = input_type->tensor_type().elem_type();
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
matched = false;
ORT_THROW("[CustomOP type inferencing error]: no type constraint found for input: ",
input_name, " Op: ", schema.Name());
}
} else {
matched = false;
}
} // for
// next, ensure that there is a constraint for every output
auto num_outputs = infer_ctx.getNumOutputs();
for (size_t i = 0; i < num_outputs && matched; i++) {
auto output_name = "Output" + std::to_string(i);
auto tc_iter = type_constraints.find(output_name);
if (tc_iter == type_constraints.end() || tc_iter->second.size() < 1) {
matched = false;
}
}
if (matched) {
for (size_t i = 0; i < num_outputs; i++) {
auto output_name = "Output" + std::to_string(i);
auto output_type = infer_ctx.getOutputType(i);
auto tc_iter = type_constraints.find(output_name);
if (tc_iter->second.size() > 1) {
output_type->mutable_tensor_type()->set_elem_type(undef);
} else {
output_type->mutable_tensor_type()->set_elem_type(
tc_iter->second[0]->GetTypeProto()->tensor_type().elem_type());
}
}
}

if (def_selected == nullptr) {
ORT_THROW("[CustomOP type inferencing error]: no kernel def matches node inputs for Op: ", schema.Name());
}

const auto& outputs = schema.outputs();
const auto node_output_num = infer_ctx.getNumOutputs();
const auto& selected_type_constraints = def_selected->TypeConstraints();

for (size_t i = 0; i < node_output_num; ++i) {
auto output_type = infer_ctx.getOutputType(i);
// Account for variadic outputs
const size_t schema_output_index = (i < outputs.size()) ? i : outputs.size() - 1;
const auto& param = outputs[schema_output_index];
const auto& output_name = param.GetName();

const bool is_variadic_output = (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic);
const bool is_homogeneous = param.GetIsHomogeneous();

// We give up on variadic non-homogeneous outputs
// Let the user handle it in their inference function
if (is_variadic_output && !is_homogeneous) {
break;
}

auto hit = selected_type_constraints.find(output_name);
if (hit != selected_type_constraints.end()) {
const auto& types = hit->second;
assert(!types.empty());

if (types.size() == 1) {
// Use the constraint type
output_type->mutable_tensor_type()->set_elem_type(
types[0]->GetTypeProto()->tensor_type().elem_type());
} else if (!is_variadic_input || is_homogeneous_input) {
// If not variadic or homogeneous, and there are multiple types possible, guess from the last input type
// as this works for symmetric varied single input/outputs
// otherwise give up and let the user supply their own function
output_type->mutable_tensor_type()->set_elem_type(output_propagate);
}
} else {
ORT_THROW("[CustomOP type inferencing error]: no type constraint found for output: ",
output_name, " Op: ", schema.Name());
}
}
}

#endif

common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domains,
Expand Down Expand Up @@ -1152,13 +1213,13 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai
}

std::vector<ONNX_NAMESPACE::OpSchema> schemas;
for (auto schema_iter : schema_map) {
schemas.push_back(schema_iter.second);
InlinedVector<const KernelDef*> kernel_defs = std::move(kernel_def_map[schema_iter.first]);
for (auto& [name, schema] : schema_map) {
schemas.push_back(schema);
auto infer_fn = schemas.back().GetTypeAndShapeInferenceFunction();
ONNX_NAMESPACE::InferenceFunction extended_infer_fn =
[infer_fn, kernel_defs](ONNX_NAMESPACE::InferenceContext& infer_ctx) {
InferOutputTypes(kernel_defs, infer_ctx);
[sch = schema, infer_fn = std::move(infer_fn),
kernel_defs = std::move(kernel_def_map[name])](ONNX_NAMESPACE::InferenceContext& infer_ctx) {
InferOutputTypes(sch, kernel_defs, infer_ctx);
if (infer_fn) {
infer_fn(infer_ctx);
}
Expand Down
93 changes: 91 additions & 2 deletions onnxruntime/test/framework/shape_inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
#include <unordered_map>

#include "gtest/gtest.h"
#include "core/common/span_utils.h"
#include "core/graph/model.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "test/framework/model_builder_utils.h"
#include "test/util/include/asserts.h"
#include "test/util/include/test_utils.h"
#include "test/util/include/inference_session_wrapper.h"
#include "test/test_environment.h"

using namespace ONNX_NAMESPACE;
using namespace std;

namespace onnxruntime {
namespace test {
Expand All @@ -22,7 +25,7 @@
protected:
onnxruntime::Model model_;
int node_count_;
std::unordered_map<string, std::unique_ptr<onnxruntime::NodeArg>> name_to_arg_;
std::unordered_map<std::string, std::unique_ptr<onnxruntime::NodeArg>> name_to_arg_;

public:
ShapeInferenceTest() : model_("Test", false, DefaultLoggingManager().DefaultLogger()), node_count_(0) {}
Expand Down Expand Up @@ -73,5 +76,91 @@
CheckShapeEquality(InputShape(node), OutputShape(node));
}

namespace {
struct MyCustomKernelWithOptionalInput {
MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) {

Check warning on line 81 in onnxruntime/test/framework/shape_inference_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [5] Raw Output: onnxruntime/test/framework/shape_inference_test.cc:81: Single-parameter constructors should be marked explicit. [runtime/explicit] [5]
}

OrtStatusPtr ComputeV2(OrtKernelContext* /* context */) const {
return nullptr;
}
};

struct MyCustomOpWithOptionalInput : Ort::CustomOpBase<MyCustomOpWithOptionalInput,
MyCustomKernelWithOptionalInput,
true> {
explicit MyCustomOpWithOptionalInput(const char* provider) : provider_(provider) {}

OrtStatusPtr CreateKernelV2(const OrtApi& /* api */, const OrtKernelInfo* info, void** kernel) const {
*kernel = new MyCustomKernelWithOptionalInput(info);
return nullptr;
};

Check warning on line 97 in onnxruntime/test/framework/shape_inference_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/test/framework/shape_inference_test.cc:97: You don't need a ; after a } [readability/braces] [4]

const char* GetName() const { return "FooBar"; };

Check warning on line 99 in onnxruntime/test/framework/shape_inference_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/test/framework/shape_inference_test.cc:99: You don't need a ; after a } [readability/braces] [4]
const char* GetExecutionProviderType() const { return provider_; };

Check warning on line 100 in onnxruntime/test/framework/shape_inference_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/test/framework/shape_inference_test.cc:100: You don't need a ; after a } [readability/braces] [4]

size_t GetInputTypeCount() const { return 3; };

Check warning on line 102 in onnxruntime/test/framework/shape_inference_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/test/framework/shape_inference_test.cc:102: You don't need a ; after a } [readability/braces] [4]
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };

Check warning on line 103 in onnxruntime/test/framework/shape_inference_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/test/framework/shape_inference_test.cc:103: You don't need a ; after a } [readability/braces] [4]
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const {
// The second input (index == 1) is optional
if (index == 1)
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;

return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}

size_t GetOutputTypeCount() const { return 1; };

Check warning on line 112 in onnxruntime/test/framework/shape_inference_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/test/framework/shape_inference_test.cc:112: You don't need a ; after a } [readability/braces] [4]
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };

Check warning on line 113 in onnxruntime/test/framework/shape_inference_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/test/framework/shape_inference_test.cc:113: You don't need a ; after a } [readability/braces] [4]
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}

private:
const char* provider_;
};

const ORTCHAR_T* const OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = ORT_TSTR("testdata/foo_bar_2.onnx");

} // namespace

// CustomOps Output type inference function quits if it
// encounters the an output that is optional and absent.
// It quits without any errors or logging. We want to make sure
// that inference proceeds for all of the outputs when absent optional inputs are present
TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) {
MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider};

const auto& env = GetEnvironment();

Ort::CustomOpDomain op_domain("test");
op_domain.Add(&custom_op);

std::initializer_list<OrtCustomOpDomain*> op_domains = {static_cast<OrtCustomOpDomain*>(op_domain)};

SessionOptions sess_opts;
sess_opts.inter_op_param.thread_pool_size = 1;
sess_opts.intra_op_param.thread_pool_size = 1;

InferenceSessionWrapper session{sess_opts, env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2};
ASSERT_STATUS_OK(session.AddCustomOpDomains(AsSpan(op_domains)));

ASSERT_STATUS_OK(session.Load());
ASSERT_STATUS_OK(session.Initialize());

const onnxruntime::Model& model = session.GetModel();
const auto& graph = model.MainGraph();
const auto& nodes = graph.Nodes();
for (const auto& node : nodes) {
if (node.OpType() == "FooBar") {
// check inferred shapes
const auto* node_arg = node.OutputDefs()[0];
const auto* type_proto = node_arg->TypeAsProto();
ASSERT_NE(nullptr, type_proto);
ASSERT_EQ(ONNX_NAMESPACE::TypeProto::ValueCase::kTensorType, type_proto->value_case());
ASSERT_EQ(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, type_proto->tensor_type().elem_type());
}
}
}

} // namespace test
} // namespace onnxruntime
7 changes: 4 additions & 3 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ static constexpr PATH_TYPE MODEL_WITH_CUSTOM_MODEL_METADATA = TSTR("testdata/mod
static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/VariedInputCustomOp.onnx");
static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_3.onnx");
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx");
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
static constexpr PATH_TYPE OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
static constexpr PATH_TYPE VARIADIC_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/custom_op_variadic_io.onnx");
static constexpr PATH_TYPE VARIADIC_UNDEF_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR(
"testdata/custom_op_variadic_undef_io.onnx");
Expand Down Expand Up @@ -1082,7 +1082,7 @@ TEST(CApiTest, invalid_variadic_input_homogeneity_custom_op) {
}
}

TEST(CApiTest, optional_input_output_custom_op_handler) {
TEST(CApiTest, optional_input_custom_op_handler) {
MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider};

// `MyCustomOpFooBar` defines a custom op with atmost 3 inputs and the second input is optional.
Expand Down Expand Up @@ -1147,7 +1147,7 @@ TEST(CApiTest, optional_input_output_custom_op_handler) {
{
std::vector<const char*> input_names = {"X1", "X2"};
ort_inputs.erase(ort_inputs.begin() + 2); // remove the last input in the container
Ort::Session session(*ort_env, OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2, session_options);
Ort::Session session(*ort_env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2, session_options);
auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
&output_name, 1);
ASSERT_EQ(ort_outputs.size(), 1u);
Expand All @@ -1166,6 +1166,7 @@ TEST(CApiTest, optional_input_output_custom_op_handler) {
}
}
}

TEST(CApiTest, custom_op_with_attributes_handler) {
MyCustomOpWithAttributes custom_op{onnxruntime::kCpuExecutionProvider};

Expand Down
Loading