diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 4d44714cd2d71..fee893f86df71 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -79,6 +79,7 @@ "exp_double_grad", "log_double_grad", "where_double_grad", + "bmm_double_grad", ] # white ops list whose kernel can automatically do type promotion. diff --git a/paddle/fluid/eager/general_grad.h b/paddle/fluid/eager/general_grad.h index 297b09c639053..614322021c4ae 100644 --- a/paddle/fluid/eager/general_grad.h +++ b/paddle/fluid/eager/general_grad.h @@ -66,7 +66,7 @@ class GeneralGrad { PADDLE_ENFORCE_NOT_NULL( target_node, - common::errors::Fatal("There is no grad op for %s:[%d] or it's" + common::errors::Fatal("There is no grad op for %s:[%d] or it's " "stop_gradient=True.", msg, i)); diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py index 1020b564a2d73..0281e2c574f98 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py @@ -32,4 +32,5 @@ 'exp_grad', 'abs_double_grad', 'where_grad', + 'bmm_grad', ] diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index 69de0a2afd68d..d40d637097fe7 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -935,5 +935,54 @@ void abs_triple_grad(const Tensor& x, } } +template +void bmm_double_grad(const Tensor& x, + const Tensor& y, + const Tensor& grad_out, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + Tensor* x_grad, + Tensor* y_grad, + Tensor* grad_out_grad) { + if (x_grad) { + // dx' = bmm(dout, ddy.mT) + Tensor x_grad_tmp; + if (grad_y_grad) { + x_grad_tmp = + matmul(grad_out, transpose(grad_y_grad.get(), {0, 2, 1})); + } else { + x_grad_tmp = full(common::vectorize(x.dims()), 0, x.dtype()); + } + set_output(x_grad_tmp, x_grad); + } + if (y_grad) { + // dy' = bmm(ddx.mT, dout) + Tensor y_grad_tmp; + if (grad_x_grad) { + y_grad_tmp = + matmul(transpose(grad_x_grad.get(), {0, 2, 1}), grad_out); + } else { + y_grad_tmp = full(common::vectorize(y.dims()), 0, y.dtype()); + } + set_output(y_grad_tmp, y_grad); + } + if (grad_out_grad) { + // ddout = bmm(ddx, y) + bmm(x, ddy) + Tensor grad_out_grad_tmp; + if (grad_x_grad && grad_y_grad) { + grad_out_grad_tmp = + matmul(grad_x_grad.get(), y) + matmul(x, grad_y_grad.get()); + } else if (grad_x_grad) { + grad_out_grad_tmp = matmul(grad_x_grad.get(), y); + } else if (grad_y_grad) { + grad_out_grad_tmp = matmul(x, grad_y_grad.get()); + } else { + grad_out_grad_tmp = + full(common::vectorize(grad_out.dims()), 0, grad_out.dtype()); + } + set_output(grad_out_grad_tmp, grad_out_grad); + } +} + } // namespace prim } // namespace paddle diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 42d06f5f15d52..60c8f75c7d235 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -293,6 +293,7 @@ kernel : func : bmm_grad data_type : out_grad + backward : bmm_double_grad - backward_op : broadcast_tensors_grad forward : broadcast_tensors (Tensor[] input) -> Tensor[](out) @@ -3595,6 +3596,16 @@ func : yolo_loss_grad optional : gt_score +- backward_op: bmm_double_grad + forward: bmm_grad (Tensor x, Tensor y, Tensor grad_out) -> Tensor(grad_x), Tensor(grad_y) + args: (Tensor x, Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad) + output: Tensor(x_grad), Tensor(y_grad), Tensor(grad_out_grad) + infer_meta : + func : GeneralTernaryGradInferMeta + param : [x, y, grad_out] + composite: bmm_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, x_grad, y_grad, grad_out_grad) + optional: grad_x_grad, grad_y_grad + - backward_op: disable_check_model_nan_inf_grad forward: disable_check_model_nan_inf (Tensor x, int flag=0) -> Tensor(out) args: (Tensor out_grad, int unsetflag = 1) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index 899a43d6e8287..1f66719d229fe 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -486,6 +486,7 @@ {out : Out} - op : bmm + backward : bmm_grad, bmm_double_grad inputs : {x : X, y : Y} outputs : diff --git a/test/prim/prim/vjp/test_comp_high_grad.py b/test/prim/prim/vjp/test_comp_high_grad.py index 2ded44baa8f50..5ee3413817d4a 100644 --- a/test/prim/prim/vjp/test_comp_high_grad.py +++ b/test/prim/prim/vjp/test_comp_high_grad.py @@ -983,10 +983,10 @@ def test_high_grad(self): @param.parameterized_class( ('shape1', 'shape2'), [ - ([2], [2], True), - ([2, 3], [2, 3], True), - ([2, 3, 4], [2, 3, 4], True), - ([2, 3, 3, 4], [2, 3, 3, 4], True), + ([2], [2]), + ([2, 3], [2, 3]), + ([2, 3, 4], [2, 3, 4]), + ([2, 3, 3, 4], [2, 3, 3, 4]), ], ) class TestMaximumHighGradCheck2(unittest.TestCase): @@ -1043,5 +1043,72 @@ def test_high_grad(self): self.func_triple(p, x_stop, y_stop) +@param.parameterized_class( + ('shape1', 'shape2'), + [ + ([1, 2, 3], [1, 3, 4]), + ([5, 6, 7], [5, 7, 8]), + ([512, 16, 13], [512, 13, 9]), + ], +) +class TestBmmHighGradCheck2(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.shape1 = cls.shape1 + cls.shape2 = cls.shape2 + + def _grad(self, y, x, order): + u = y + dx = paddle.ones_like(x) + for _ in range(order): + dx = paddle.grad(u, x, create_graph=True)[0] + u = dx + return dx + + def func_double(self, place, x_stop, y_stop): + x = paddle.randn(self.shape1).astype("float32").to(device=place) + x.stop_gradient = x_stop + y = paddle.randn(self.shape2).astype("float32").to(device=place) + y.stop_gradient = y_stop + + # wraping with tanh to enable high order gradient + z = paddle.bmm(paddle.tanh(x), paddle.tanh(y)) + + if not x.stop_gradient: + dzdx = self._grad(z, x, 2) + if not y.stop_gradient: + dzdy = self._grad(z, y, 2) + + def func_triple(self, place, x_stop, y_stop): + x = paddle.randn(self.shape1).astype("float32") + x.stop_gradient = x_stop + y = paddle.randn(self.shape2).astype("float32") + y.stop_gradient = y_stop + + z = paddle.bmm(paddle.tanh(x), paddle.tanh(y)) + + if not x.stop_gradient: + dzdx = self._grad(z, x, 3) + if not y.stop_gradient: + dzdy = self._grad(z, y, 3) + + def test_high_grad(self): + places = [] + if ( + os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() + in ['1', 'true', 'on'] + or not core.is_compiled_with_cuda() + ): + places.append(base.CPUPlace()) + if core.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + for p in places: + for x_stop in [False, True]: + for y_stop in [False, True]: + with dygraph_guard(): + self.func_double(p, x_stop, y_stop) + self.func_triple(p, x_stop, y_stop) + + if __name__ == '__main__': unittest.main()