Skip to content

Commit

Permalink
【auto parallel】custom op spmd infer add args check (#60633)
Browse files Browse the repository at this point in the history
* add bound check

* add bound check
  • Loading branch information
liuzhenhai93 authored Jan 10, 2024
1 parent 35d445b commit 4dcb045
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions paddle/phi/api/ext/spmd_infer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ using CustomSpmdInferAttrArg = paddle::any;
template <typename T>
struct SpmdInferHelperTypeEnd {};

#define PD_INFER_SPMD_CHECK_INPUTS_SIZE_GT(inputs, size_bound) \
PD_CHECK(inputs.size() > size_bound, \
#inputs, \
" size must be great than ", \
size_bound)

#define PD_INFER_SPMD_CHECK_INPUTS_SIZE_EQ(inputs, size_bound) \
PD_CHECK( \
inputs.size() == size_bound, #inputs, " size must be eq ", size_bound)

#define PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(attr_type) \
template <typename... Tail> \
struct SpmdInferHelper<attr_type, Tail...> { \
Expand All @@ -35,6 +45,7 @@ struct SpmdInferHelperTypeEnd {};
const std::vector<CustomSpmdInferAttrArg>& attrs, \
const PreviousArgs&... pargs) { \
try { \
PD_INFER_SPMD_CHECK_INPUTS_SIZE_GT(attrs, attr_idx); \
attr_type arg = paddle::any_cast<attr_type>(attrs[attr_idx]); \
return SpmdInferHelper<Tail...>::template InferSpmd<in_idx, \
attr_idx + 1>( \
Expand Down Expand Up @@ -78,6 +89,8 @@ struct SpmdInferImpl<phi::distributed::SpmdInfo (*)(Args...), impl_fn> {
const std::vector<CustomSpmdInferTensorArg>& inputs,
const std::vector<CustomSpmdInferAttrArg>& attrs,
PreviousArgs&... pargs) {
static_assert(attr_idx == 0, "attributes must come after tensor inputs");
PD_INFER_SPMD_CHECK_INPUTS_SIZE_GT(inputs, in_idx);
auto& arg =
PADDLE_GET_CONST(phi::distributed::DistMetaTensor, inputs[in_idx]);
return SpmdInferHelper<Tail...>::template InferSpmd<in_idx + 1, attr_idx>(
Expand All @@ -94,6 +107,8 @@ struct SpmdInferImpl<phi::distributed::SpmdInfo (*)(Args...), impl_fn> {
const std::vector<CustomSpmdInferTensorArg>& inputs,
const std::vector<CustomSpmdInferAttrArg>& attrs,
PreviousArgs&... pargs) {
static_assert(attr_idx == 0, "attributes must come after tensor inputs");
PD_INFER_SPMD_CHECK_INPUTS_SIZE_GT(inputs, in_idx);
auto& arg = PADDLE_GET_CONST(
std::vector<phi::distributed::DistMetaTensor>, inputs[in_idx]);
return SpmdInferHelper<Tail...>::template InferSpmd<in_idx + 1, attr_idx>(
Expand Down Expand Up @@ -129,6 +144,8 @@ struct SpmdInferImpl<phi::distributed::SpmdInfo (*)(Args...), impl_fn> {
const std::vector<CustomSpmdInferTensorArg>& inputs,
const std::vector<CustomSpmdInferAttrArg>& attrs,
PreviousArgs&... pargs) {
PD_INFER_SPMD_CHECK_INPUTS_SIZE_EQ(inputs, in_idx);
PD_INFER_SPMD_CHECK_INPUTS_SIZE_EQ(attrs, attr_idx);
return impl_fn(pargs...);
}
};
Expand Down

0 comments on commit 4dcb045

Please sign in to comment.