From b52a0a084375e13e30160772348a29de30583363 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 20 Oct 2024 16:04:00 +0800 Subject: [PATCH 1/5] support bmm_grad in static comp and bmm_double_grad in eager prim --- .../generator/eager_gen.py | 1 + paddle/fluid/eager/general_grad.h | 2 +- .../op_generator/vjp_interface_black_list.py | 1 + paddle/fluid/prim/api/api.yaml | 1 + .../composite_double_backward_api.h | 47 +++++++++++++ paddle/fluid/primitive/primitive.yaml | 1 + paddle/fluid/primitive/rule/vjp/details.h | 18 +++++ paddle/phi/infermeta/backward.cc | 22 ++++++ paddle/phi/infermeta/backward.h | 9 +++ paddle/phi/ops/yaml/backward.yaml | 11 +++ paddle/phi/ops/yaml/op_compat.yaml | 1 + test/legacy_test/test_bmm_op.py | 2 +- test/prim/prim/vjp/test_comp_high_grad.py | 69 +++++++++++++++++-- 13 files changed, 179 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 4d44714cd2d71e..fee893f86df719 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -79,6 +79,7 @@ "exp_double_grad", "log_double_grad", "where_double_grad", + "bmm_double_grad", ] # white ops list whose kernel can automatically do type promotion. diff --git a/paddle/fluid/eager/general_grad.h b/paddle/fluid/eager/general_grad.h index 297b09c639053e..614322021c4ae2 100644 --- a/paddle/fluid/eager/general_grad.h +++ b/paddle/fluid/eager/general_grad.h @@ -66,7 +66,7 @@ class GeneralGrad { PADDLE_ENFORCE_NOT_NULL( target_node, - common::errors::Fatal("There is no grad op for %s:[%d] or it's" + common::errors::Fatal("There is no grad op for %s:[%d] or it's " "stop_gradient=True.", msg, i)); diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py index 1020b564a2d739..0281e2c574f98b 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py @@ -32,4 +32,5 @@ 'exp_grad', 'abs_double_grad', 'where_grad', + 'bmm_grad', ] diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index 33401c213366d1..450469743fd4a7 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -50,3 +50,4 @@ - tanh - sign - sigmoid +- bmm diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index 69de0a2afd68dd..9110d411ebef5d 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -935,5 +935,52 @@ void abs_triple_grad(const Tensor& x, } } +template +void bmm_double_grad(const Tensor& x, + const Tensor& y, + const Tensor& grad_out, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + Tensor* x_grad, + Tensor* y_grad, + Tensor* grad_out_grad) { + if (x_grad) { + // dx' = bmm(dout, ddy.mT) + Tensor x_grad_tmp; + if (grad_x_grad) { + x_grad_tmp = bmm(grad_out, transpose(grad_y_grad.get(), {0, 2, 1})); + } else { + x_grad_tmp = full(common::vectorize(x.dims()), 0, x.dtype()); + } + set_output(x_grad_tmp, x_grad); + } + if (y_grad) { + // dy' = bmm(ddx.mT, dout) + Tensor y_grad_tmp; + if (grad_x_grad) { + y_grad_tmp = bmm(transpose(grad_x_grad.get(), {0, 2, 1}), grad_out); + } else { + y_grad_tmp = full(common::vectorize(y.dims()), 0, y.dtype()); + } + set_output(y_grad_tmp, y_grad); + } + if (grad_out_grad) { + // ddout = bmm(ddx, y) + bmm(x, ddy) + Tensor grad_out_grad_tmp; + if (grad_x_grad && grad_y_grad) { + grad_out_grad_tmp = + bmm(grad_x_grad.get(), y) + bmm(x, grad_y_grad.get()); + } else if (grad_x_grad) { + grad_out_grad_tmp = bmm(grad_x_grad.get(), y); + } else if (grad_y_grad) { + grad_out_grad_tmp = bmm(x, grad_y_grad.get()); + } else { + grad_out_grad_tmp = + full(common::vectorize(grad_out.dims()), 0, grad_out.dtype()); + } + set_output(grad_out_grad_tmp, grad_out_grad); + } +} + } // namespace prim } // namespace paddle diff --git a/paddle/fluid/primitive/primitive.yaml b/paddle/fluid/primitive/primitive.yaml index fd08b18e3a2145..a9df5090834ecf 100644 --- a/paddle/fluid/primitive/primitive.yaml +++ b/paddle/fluid/primitive/primitive.yaml @@ -124,3 +124,4 @@ - unique_consecutive - sigmoid - reduce_as +- bmm diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index b04a01a273d784..51175e7dae5d4d 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -2867,6 +2867,24 @@ void trunc_grad(const Tensor& out_grad, Tensor* x_grad) { } } +template +void bmm_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + Tensor* x_grad, + Tensor* y_grad) { + if (x_grad) { + // dx = bmm(dout, y.mT) + auto x_grad_tmp = bmm(out_grad, transpose(y, {0, 2, 1})); + set_output(x_grad_tmp, x_grad); + } + if (y_grad) { + // dy = bmm(x.mT, dout) + auto y_grad_tmp = bmm(transpose(y, {0, 2, 1}), out_grad); + set_output(y_grad_tmp, y_grad); + } +} + } // namespace details } // namespace primitive } // namespace paddle diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 43e0ef455ac26a..e0748ed015f05a 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -117,6 +117,28 @@ void BmmGradInferMeta(const MetaTensor& x, } } +void BmmDoubleGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& out_grad, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + MetaTensor* x_grad, + MetaTensor* y_grad, + MetaTensor* grad_out_grad) { + if (x_grad) { + x_grad->set_dims(x.dims()); + x_grad->set_dtype(x.dtype()); + } + if (y_grad) { + y_grad->set_dims(y.dims()); + y_grad->set_dtype(y.dtype()); + } + if (grad_out_grad) { + grad_out_grad->set_dims(out_grad.dims()); + grad_out_grad->set_dtype(out_grad.dtype()); + } +} + void ChannelShuffleGradInferMeta(const MetaTensor& out_grad, int groups, const std::string& data_format, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index d8570c1b899638..6ef26b28620dac 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -59,6 +59,15 @@ void BmmGradInferMeta(const MetaTensor& x, MetaTensor* x_grad, MetaTensor* y_grad); +void BmmDoubleGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& out_grad, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + MetaTensor* x_grad, + MetaTensor* y_grad, + MetaTensor* grad_out_grad); + void ChannelShuffleGradInferMeta(const MetaTensor& out_grad, int groups, const std::string& data_format, diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 42d06f5f15d529..22c55ed92add75 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -293,6 +293,7 @@ kernel : func : bmm_grad data_type : out_grad + backward : bmm_double_grad - backward_op : broadcast_tensors_grad forward : broadcast_tensors (Tensor[] input) -> Tensor[](out) @@ -3595,6 +3596,16 @@ func : yolo_loss_grad optional : gt_score +- backward_op: bmm_double_grad + forward: bmm_grad (Tensor x, Tensor y, Tensor grad_out) -> Tensor(grad_x), Tensor(grad_y) + args: (Tensor x, Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad) + output: Tensor(x_grad), Tensor(y_grad), Tensor(grad_out_grad) + infer_meta : + func : BmmDoubleGradInferMeta + param : [x, y, grad_out, grad_x_grad, grad_y_grad] + composite: bmm_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, x_grad, y_grad, grad_out_grad) + optional: grad_x_grad, grad_y_grad + - backward_op: disable_check_model_nan_inf_grad forward: disable_check_model_nan_inf (Tensor x, int flag=0) -> Tensor(out) args: (Tensor out_grad, int unsetflag = 1) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index 899a43d6e8287f..1f66719d229fef 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -486,6 +486,7 @@ {out : Out} - op : bmm + backward : bmm_grad, bmm_double_grad inputs : {x : X, y : Y} outputs : diff --git a/test/legacy_test/test_bmm_op.py b/test/legacy_test/test_bmm_op.py index 132a227c0ac443..eddcc83664891e 100644 --- a/test/legacy_test/test_bmm_op.py +++ b/test/legacy_test/test_bmm_op.py @@ -38,7 +38,7 @@ def test_check_output(self): self.check_output(check_pir=True, check_prim_pir=True) def test_checkout_grad(self): - self.check_grad(['X', 'Y'], 'Out', check_pir=True) + self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim_pir=True) class TestBmmFP16Op(OpTest): diff --git a/test/prim/prim/vjp/test_comp_high_grad.py b/test/prim/prim/vjp/test_comp_high_grad.py index 2ded44baa8f507..25c60e78f8c177 100644 --- a/test/prim/prim/vjp/test_comp_high_grad.py +++ b/test/prim/prim/vjp/test_comp_high_grad.py @@ -983,10 +983,10 @@ def test_high_grad(self): @param.parameterized_class( ('shape1', 'shape2'), [ - ([2], [2], True), - ([2, 3], [2, 3], True), - ([2, 3, 4], [2, 3, 4], True), - ([2, 3, 3, 4], [2, 3, 3, 4], True), + ([2], [2]), + ([2, 3], [2, 3]), + ([2, 3, 4], [2, 3, 4]), + ([2, 3, 3, 4], [2, 3, 3, 4]), ], ) class TestMaximumHighGradCheck2(unittest.TestCase): @@ -1043,5 +1043,66 @@ def test_high_grad(self): self.func_triple(p, x_stop, y_stop) +@param.parameterized_class( + ('shape1', 'shape2'), + [ + ([1, 2, 3], [1, 3, 4]), + ([5, 6, 7], [5, 7, 8]), + ([512, 16, 13], [512, 13, 9]), + ], +) +class TestBmmHighGradCheck2(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.shape1 = cls.shape1 + cls.shape2 = cls.shape2 + + def _grad(self, y, x, order): + u = y + dx = paddle.ones_like(x) + for _ in range(order): + dx = paddle.grad(u, x, create_graph=True)[0] + u = dx + return dx + + def func_double(self, place, x_stop, y_stop): + x = paddle.randn(self.shape1).astype("float32") + x.stop_gradient = x_stop + y = paddle.randn(self.shape2).astype("float32") + y.stop_gradient = y_stop + + # wraping with tanh to enable high order gradient + z = paddle.bmm(paddle.tanh(x), paddle.tanh(y)) + + if not x.stop_gradient: + dzdx = self._grad(z, x, 2) + if not y.stop_gradient: + dzdy = self._grad(z, y, 2) + + def func_triple(self, place, x_stop, y_stop): + x = paddle.randn(self.shape1).astype("float32") + x.stop_gradient = x_stop + y = paddle.randn(self.shape2).astype("float32") + y.stop_gradient = y_stop + + z = paddle.bmm(x, y) + + if not x.stop_gradient: + dzdx = self._grad(z, x, 3) + if not y.stop_gradient: + dzdy = self._grad(z, y, 3) + + def test_high_grad(self): + places = [base.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + for p in places: + for x_stop in [False, True]: + for y_stop in [False, True]: + with dygraph_guard(): + self.func_double(p, x_stop, y_stop) + self.func_triple(p, x_stop, y_stop) + + if __name__ == '__main__': unittest.main() From 475045c4cf0a4b2bd62062c9fe3facbcdc3d7f21 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 21 Oct 2024 16:13:50 +0800 Subject: [PATCH 2/5] fix bug --- .../composite_double_backward_api.h | 2 +- test/prim/prim/vjp/test_comp_high_grad.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index 9110d411ebef5d..0f05a3c13114cd 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -947,7 +947,7 @@ void bmm_double_grad(const Tensor& x, if (x_grad) { // dx' = bmm(dout, ddy.mT) Tensor x_grad_tmp; - if (grad_x_grad) { + if (grad_y_grad) { x_grad_tmp = bmm(grad_out, transpose(grad_y_grad.get(), {0, 2, 1})); } else { x_grad_tmp = full(common::vectorize(x.dims()), 0, x.dtype()); diff --git a/test/prim/prim/vjp/test_comp_high_grad.py b/test/prim/prim/vjp/test_comp_high_grad.py index 25c60e78f8c177..5ee3413817d4af 100644 --- a/test/prim/prim/vjp/test_comp_high_grad.py +++ b/test/prim/prim/vjp/test_comp_high_grad.py @@ -1066,9 +1066,9 @@ def _grad(self, y, x, order): return dx def func_double(self, place, x_stop, y_stop): - x = paddle.randn(self.shape1).astype("float32") + x = paddle.randn(self.shape1).astype("float32").to(device=place) x.stop_gradient = x_stop - y = paddle.randn(self.shape2).astype("float32") + y = paddle.randn(self.shape2).astype("float32").to(device=place) y.stop_gradient = y_stop # wraping with tanh to enable high order gradient @@ -1085,7 +1085,7 @@ def func_triple(self, place, x_stop, y_stop): y = paddle.randn(self.shape2).astype("float32") y.stop_gradient = y_stop - z = paddle.bmm(x, y) + z = paddle.bmm(paddle.tanh(x), paddle.tanh(y)) if not x.stop_gradient: dzdx = self._grad(z, x, 3) @@ -1093,7 +1093,13 @@ def func_triple(self, place, x_stop, y_stop): dzdy = self._grad(z, y, 3) def test_high_grad(self): - places = [base.CPUPlace()] + places = [] + if ( + os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() + in ['1', 'true', 'on'] + or not core.is_compiled_with_cuda() + ): + places.append(base.CPUPlace()) if core.is_compiled_with_cuda(): places.append(base.CUDAPlace(0)) for p in places: From e6bf3becc647d60204189fde0e75bde8bf382909 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 22 Oct 2024 11:08:40 +0800 Subject: [PATCH 3/5] remove bmm_grad in details.h --- paddle/fluid/primitive/primitive.yaml | 1 - paddle/fluid/primitive/rule/vjp/details.h | 18 ------------------ test/legacy_test/test_bmm_op.py | 2 +- 3 files changed, 1 insertion(+), 20 deletions(-) diff --git a/paddle/fluid/primitive/primitive.yaml b/paddle/fluid/primitive/primitive.yaml index a9df5090834ecf..fd08b18e3a2145 100644 --- a/paddle/fluid/primitive/primitive.yaml +++ b/paddle/fluid/primitive/primitive.yaml @@ -124,4 +124,3 @@ - unique_consecutive - sigmoid - reduce_as -- bmm diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 51175e7dae5d4d..b04a01a273d784 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -2867,24 +2867,6 @@ void trunc_grad(const Tensor& out_grad, Tensor* x_grad) { } } -template -void bmm_grad(const Tensor& x, - const Tensor& y, - const Tensor& out_grad, - Tensor* x_grad, - Tensor* y_grad) { - if (x_grad) { - // dx = bmm(dout, y.mT) - auto x_grad_tmp = bmm(out_grad, transpose(y, {0, 2, 1})); - set_output(x_grad_tmp, x_grad); - } - if (y_grad) { - // dy = bmm(x.mT, dout) - auto y_grad_tmp = bmm(transpose(y, {0, 2, 1}), out_grad); - set_output(y_grad_tmp, y_grad); - } -} - } // namespace details } // namespace primitive } // namespace paddle diff --git a/test/legacy_test/test_bmm_op.py b/test/legacy_test/test_bmm_op.py index eddcc83664891e..2911d5c248f54c 100644 --- a/test/legacy_test/test_bmm_op.py +++ b/test/legacy_test/test_bmm_op.py @@ -38,7 +38,7 @@ def test_check_output(self): self.check_output(check_pir=True, check_prim_pir=True) def test_checkout_grad(self): - self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim_pir=True) + self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim=True) class TestBmmFP16Op(OpTest): From cabf870a075900398367165d6a1c1e05b77e8256 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 22 Oct 2024 12:50:17 +0800 Subject: [PATCH 4/5] disable check_prim=True --- test/legacy_test/test_bmm_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_bmm_op.py b/test/legacy_test/test_bmm_op.py index 2911d5c248f54c..132a227c0ac443 100644 --- a/test/legacy_test/test_bmm_op.py +++ b/test/legacy_test/test_bmm_op.py @@ -38,7 +38,7 @@ def test_check_output(self): self.check_output(check_pir=True, check_prim_pir=True) def test_checkout_grad(self): - self.check_grad(['X', 'Y'], 'Out', check_pir=True, check_prim=True) + self.check_grad(['X', 'Y'], 'Out', check_pir=True) class TestBmmFP16Op(OpTest): From ce47b2a4f25925d6f64999e1d85781cbe49d6395 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 23 Oct 2024 18:54:39 +0800 Subject: [PATCH 5/5] simplify bmmdouble_grad impl --- paddle/fluid/prim/api/api.yaml | 1 - .../composite_double_backward_api.h | 12 +++++----- paddle/phi/infermeta/backward.cc | 22 ------------------- paddle/phi/infermeta/backward.h | 9 -------- paddle/phi/ops/yaml/backward.yaml | 4 ++-- 5 files changed, 9 insertions(+), 39 deletions(-) diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index 450469743fd4a7..33401c213366d1 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -50,4 +50,3 @@ - tanh - sign - sigmoid -- bmm diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index 0f05a3c13114cd..d40d637097fe71 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -948,7 +948,8 @@ void bmm_double_grad(const Tensor& x, // dx' = bmm(dout, ddy.mT) Tensor x_grad_tmp; if (grad_y_grad) { - x_grad_tmp = bmm(grad_out, transpose(grad_y_grad.get(), {0, 2, 1})); + x_grad_tmp = + matmul(grad_out, transpose(grad_y_grad.get(), {0, 2, 1})); } else { x_grad_tmp = full(common::vectorize(x.dims()), 0, x.dtype()); } @@ -958,7 +959,8 @@ void bmm_double_grad(const Tensor& x, // dy' = bmm(ddx.mT, dout) Tensor y_grad_tmp; if (grad_x_grad) { - y_grad_tmp = bmm(transpose(grad_x_grad.get(), {0, 2, 1}), grad_out); + y_grad_tmp = + matmul(transpose(grad_x_grad.get(), {0, 2, 1}), grad_out); } else { y_grad_tmp = full(common::vectorize(y.dims()), 0, y.dtype()); } @@ -969,11 +971,11 @@ void bmm_double_grad(const Tensor& x, Tensor grad_out_grad_tmp; if (grad_x_grad && grad_y_grad) { grad_out_grad_tmp = - bmm(grad_x_grad.get(), y) + bmm(x, grad_y_grad.get()); + matmul(grad_x_grad.get(), y) + matmul(x, grad_y_grad.get()); } else if (grad_x_grad) { - grad_out_grad_tmp = bmm(grad_x_grad.get(), y); + grad_out_grad_tmp = matmul(grad_x_grad.get(), y); } else if (grad_y_grad) { - grad_out_grad_tmp = bmm(x, grad_y_grad.get()); + grad_out_grad_tmp = matmul(x, grad_y_grad.get()); } else { grad_out_grad_tmp = full(common::vectorize(grad_out.dims()), 0, grad_out.dtype()); diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index e0748ed015f05a..43e0ef455ac26a 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -117,28 +117,6 @@ void BmmGradInferMeta(const MetaTensor& x, } } -void BmmDoubleGradInferMeta(const MetaTensor& x, - const MetaTensor& y, - const MetaTensor& out_grad, - const paddle::optional& grad_x_grad, - const paddle::optional& grad_y_grad, - MetaTensor* x_grad, - MetaTensor* y_grad, - MetaTensor* grad_out_grad) { - if (x_grad) { - x_grad->set_dims(x.dims()); - x_grad->set_dtype(x.dtype()); - } - if (y_grad) { - y_grad->set_dims(y.dims()); - y_grad->set_dtype(y.dtype()); - } - if (grad_out_grad) { - grad_out_grad->set_dims(out_grad.dims()); - grad_out_grad->set_dtype(out_grad.dtype()); - } -} - void ChannelShuffleGradInferMeta(const MetaTensor& out_grad, int groups, const std::string& data_format, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 6ef26b28620dac..d8570c1b899638 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -59,15 +59,6 @@ void BmmGradInferMeta(const MetaTensor& x, MetaTensor* x_grad, MetaTensor* y_grad); -void BmmDoubleGradInferMeta(const MetaTensor& x, - const MetaTensor& y, - const MetaTensor& out_grad, - const paddle::optional& grad_x_grad, - const paddle::optional& grad_y_grad, - MetaTensor* x_grad, - MetaTensor* y_grad, - MetaTensor* grad_out_grad); - void ChannelShuffleGradInferMeta(const MetaTensor& out_grad, int groups, const std::string& data_format, diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 22c55ed92add75..60c8f75c7d2354 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -3601,8 +3601,8 @@ args: (Tensor x, Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad) output: Tensor(x_grad), Tensor(y_grad), Tensor(grad_out_grad) infer_meta : - func : BmmDoubleGradInferMeta - param : [x, y, grad_out, grad_x_grad, grad_y_grad] + func : GeneralTernaryGradInferMeta + param : [x, y, grad_out] composite: bmm_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, x_grad, y_grad, grad_out_grad) optional: grad_x_grad, grad_y_grad