Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim] Support bmm_double_grad in eager mode #68834

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"exp_double_grad",
"log_double_grad",
"where_double_grad",
"bmm_double_grad",
]

# white ops list whose kernel can automatically do type promotion.
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/eager/general_grad.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class GeneralGrad {

PADDLE_ENFORCE_NOT_NULL(
target_node,
common::errors::Fatal("There is no grad op for %s:[%d] or it's"
common::errors::Fatal("There is no grad op for %s:[%d] or it's "
"stop_gradient=True.",
msg,
i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@
'exp_grad',
'abs_double_grad',
'where_grad',
'bmm_grad',
]
Original file line number Diff line number Diff line change
Expand Up @@ -935,5 +935,54 @@ void abs_triple_grad(const Tensor& x,
}
}

template <typename T>
void bmm_double_grad(const Tensor& x,
const Tensor& y,
const Tensor& grad_out,
const paddle::optional<Tensor>& grad_x_grad,
const paddle::optional<Tensor>& grad_y_grad,
Tensor* x_grad,
Tensor* y_grad,
Tensor* grad_out_grad) {
if (x_grad) {
// dx' = bmm(dout, ddy.mT)
Tensor x_grad_tmp;
if (grad_y_grad) {
x_grad_tmp =
matmul<T>(grad_out, transpose<T>(grad_y_grad.get(), {0, 2, 1}));
} else {
x_grad_tmp = full<T>(common::vectorize(x.dims()), 0, x.dtype());
}
set_output<T>(x_grad_tmp, x_grad);
}
if (y_grad) {
// dy' = bmm(ddx.mT, dout)
Tensor y_grad_tmp;
if (grad_x_grad) {
y_grad_tmp =
matmul<T>(transpose<T>(grad_x_grad.get(), {0, 2, 1}), grad_out);
} else {
y_grad_tmp = full<T>(common::vectorize(y.dims()), 0, y.dtype());
}
set_output<T>(y_grad_tmp, y_grad);
}
if (grad_out_grad) {
// ddout = bmm(ddx, y) + bmm(x, ddy)
Tensor grad_out_grad_tmp;
if (grad_x_grad && grad_y_grad) {
grad_out_grad_tmp =
matmul<T>(grad_x_grad.get(), y) + matmul<T>(x, grad_y_grad.get());
} else if (grad_x_grad) {
grad_out_grad_tmp = matmul<T>(grad_x_grad.get(), y);
} else if (grad_y_grad) {
grad_out_grad_tmp = matmul<T>(x, grad_y_grad.get());
} else {
grad_out_grad_tmp =
full<T>(common::vectorize(grad_out.dims()), 0, grad_out.dtype());
}
set_output<T>(grad_out_grad_tmp, grad_out_grad);
}
}

} // namespace prim
} // namespace paddle
11 changes: 11 additions & 0 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@
kernel :
func : bmm_grad
data_type : out_grad
backward : bmm_double_grad

- backward_op : broadcast_tensors_grad
forward : broadcast_tensors (Tensor[] input) -> Tensor[](out)
Expand Down Expand Up @@ -3595,6 +3596,16 @@
func : yolo_loss_grad
optional : gt_score

- backward_op: bmm_double_grad
forward: bmm_grad (Tensor x, Tensor y, Tensor grad_out) -> Tensor(grad_x), Tensor(grad_y)
args: (Tensor x, Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad)
output: Tensor(x_grad), Tensor(y_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, y, grad_out]
composite: bmm_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, x_grad, y_grad, grad_out_grad)
optional: grad_x_grad, grad_y_grad

- backward_op: disable_check_model_nan_inf_grad
forward: disable_check_model_nan_inf (Tensor x, int flag=0) -> Tensor(out)
args: (Tensor out_grad, int unsetflag = 1)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@
{out : Out}

- op : bmm
backward : bmm_grad, bmm_double_grad
inputs :
{x : X, y : Y}
outputs :
Expand Down
75 changes: 71 additions & 4 deletions test/prim/prim/vjp/test_comp_high_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,10 +983,10 @@ def test_high_grad(self):
@param.parameterized_class(
('shape1', 'shape2'),
[
([2], [2], True),
([2, 3], [2, 3], True),
([2, 3, 4], [2, 3, 4], True),
([2, 3, 3, 4], [2, 3, 3, 4], True),
([2], [2]),
([2, 3], [2, 3]),
([2, 3, 4], [2, 3, 4]),
([2, 3, 3, 4], [2, 3, 3, 4]),
],
)
class TestMaximumHighGradCheck2(unittest.TestCase):
Expand Down Expand Up @@ -1043,5 +1043,72 @@ def test_high_grad(self):
self.func_triple(p, x_stop, y_stop)


@param.parameterized_class(
('shape1', 'shape2'),
[
([1, 2, 3], [1, 3, 4]),
([5, 6, 7], [5, 7, 8]),
([512, 16, 13], [512, 13, 9]),
],
)
class TestBmmHighGradCheck2(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.shape1 = cls.shape1
cls.shape2 = cls.shape2

def _grad(self, y, x, order):
u = y
dx = paddle.ones_like(x)
for _ in range(order):
dx = paddle.grad(u, x, create_graph=True)[0]
u = dx
return dx

def func_double(self, place, x_stop, y_stop):
x = paddle.randn(self.shape1).astype("float32").to(device=place)
x.stop_gradient = x_stop
y = paddle.randn(self.shape2).astype("float32").to(device=place)
y.stop_gradient = y_stop

# wraping with tanh to enable high order gradient
z = paddle.bmm(paddle.tanh(x), paddle.tanh(y))

if not x.stop_gradient:
dzdx = self._grad(z, x, 2)
if not y.stop_gradient:
dzdy = self._grad(z, y, 2)

def func_triple(self, place, x_stop, y_stop):
x = paddle.randn(self.shape1).astype("float32")
x.stop_gradient = x_stop
y = paddle.randn(self.shape2).astype("float32")
y.stop_gradient = y_stop

z = paddle.bmm(paddle.tanh(x), paddle.tanh(y))

if not x.stop_gradient:
dzdx = self._grad(z, x, 3)
if not y.stop_gradient:
dzdy = self._grad(z, y, 3)

def test_high_grad(self):
places = []
if (
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
in ['1', 'true', 'on']
or not core.is_compiled_with_cuda()
):
places.append(base.CPUPlace())
if core.is_compiled_with_cuda():
places.append(base.CUDAPlace(0))
for p in places:
for x_stop in [False, True]:
for y_stop in [False, True]:
with dygraph_guard():
self.func_double(p, x_stop, y_stop)
self.func_triple(p, x_stop, y_stop)


if __name__ == '__main__':
unittest.main()