Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CustomOp] Support duplicable op input and output #31535

Merged

Conversation

chenwhql
Copy link
Contributor

@chenwhql chenwhql commented Mar 10, 2021

PR types

New features

PR changes

OPs

Describe

[CustomOp] Support duplicable op input and output.

This PR add variable length tensor input and output in op, for example:

// add new param type: const std::vector<paddle::Tensor>&
std::vector<paddle::Tensor> ConcatForwardDynamicAxis(
    const std::vector<paddle::Tensor>& inputs, const paddle::Tensor& axis_t) {
  ...
}

// add new param type: std::vector<std::vector<int64_t>>
std::vector<std::vector<int64_t>> ConcatInferShapeDynamicAxis(
    std::vector<std::vector<int64_t>> input_shapes,
    std::vector<int64_t> axis_shape) {
  return {std::vector<int64_t>(input_shapes[0].size(), -1)};
}

// add new param type: std::vector<paddle::DataType
std::vector<paddle::DataType> ConcatInferDtypeDynamicAxis(
    std::vector<paddle::DataType> input_dtypes, paddle::DataType axis_dtype) {
  return {input_dtypes[0]};
}

// the input string of `const std::vector<paddle::Tensor>&` need wrapped by `paddle::Vec()`
PD_BUILD_OP(custom_concat)
    .Inputs({paddle::Vec("X"), "Axis"})
    .Outputs({"Out"})
    .SetKernelFn(PD_KERNEL(ConcatForwardDynamicAxis))
    .SetInferShapeFn(PD_INFER_SHAPE(ConcatInferShapeDynamicAxis))
    .SetInferDtypeFn(PD_INFER_DTYPE(ConcatInferDtypeDynamicAxis));

// note the warp order: paddle::Grad(paddle::Vec("X"))
PD_BUILD_GRAD_OP(custom_concat)
    .Inputs({paddle::Vec("X"), paddle::Grad("Out"), "Axis"})
    .Outputs({paddle::Grad(paddle::Vec("X"))})
    .SetKernelFn(PD_KERNEL(ConcatBackwardDynamicAxis));

@chenwhql chenwhql requested review from Aurelius84, zhwesky2010 and JiabinYang and removed request for Aurelius84 and zhwesky2010 March 11, 2021 11:57
Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

}
};

// for std::vector<Tensor> input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bad comments, describe what's for in detail

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx, I will remove it in next PR

platform::errors::NotFound(
"Input vector<tensor> (%s) is empty.", in_name));
std::vector<paddle::Tensor> custom_vec_in;
for (size_t i = 0; i < vec_x.size(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang format using for(const auto& x: vec_x) is better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if using for(const auto& x: vec_x), we cannot tell users which tensor cause error in error message.

const PreviousArgs&... pargs) {
template <int in_idx, int vec_in_idx, typename... PreviousArgs>
static Return InferShape(
std::vector<std::vector<int64_t>> input_shapes,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我们可以统一一下函数的输入参数类型,infershape、inferDatatype可以都使用const &,更符合编程习惯和直觉。可以在下个PR修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx,下个PR修改

@chenwhql chenwhql merged commit 95cceb2 into PaddlePaddle:develop Mar 12, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants