Skip to content

Commit

Permalink
Add support tensor element type for register custom op shape infer fu…
Browse files Browse the repository at this point in the history
…nction (#21387)

### Description
Functionality extension for the SetOutputShape method in custom op shape inference.


### Motivation and Context
-  **SetOutputShape** Interface enhancement Actually, the shape infer function need set the tensor type and shape ,Add a parameter **type** to allow users to specify the tensor type, and set **ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT** as default value to ensure compatibility.

Co-authored-by: mingyue <[email protected]>
  • Loading branch information
mingyueliuh and mingyueliuh authored Jul 29, 2024
1 parent 94eb70d commit d888813
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2216,7 +2216,7 @@ struct ShapeInferContext {

size_t GetInputCount() const { return input_shapes_.size(); }

Status SetOutputShape(size_t indice, const Shape& shape);
Status SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);

int64_t GetAttrInt(const char* attr_name);

Expand Down
3 changes: 2 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1998,9 +1998,10 @@ inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api,
}
}

inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape) {
inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type) {
OrtTensorTypeAndShapeInfo* info = {};
ORT_CXX_RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info));
ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetTensorElementType(info, type));

using InfoPtr = std::unique_ptr<OrtTensorTypeAndShapeInfo, std::function<void(OrtTensorTypeAndShapeInfo*)>>;

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ struct OrtShapeInferContext {
}
}
ONNX_NAMESPACE::updateOutputShape(ctx_, index, shape_proto);
ONNX_NAMESPACE::updateOutputElemType(ctx_, index, info->type);
return onnxruntime::Status::OK();
}

Expand Down

0 comments on commit d888813

Please sign in to comment.