From 0101c1328479507bcbd3a3725375d5b8e35309e9 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Mon, 21 Oct 2019 17:53:54 -0700 Subject: [PATCH 1/5] Support N_D(N>=3) batch_dot --- src/operator/tensor/dot-inl.h | 178 +++++++++---------------- src/operator/tensor/dot.cc | 85 +++++++++--- src/operator/tensor/dot.cu | 3 - tests/python/unittest/test_numpy_op.py | 117 ++++++++++++++++ tests/python/unittest/test_operator.py | 4 +- 5 files changed, 249 insertions(+), 138 deletions(-) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 96c869f40d40..0166a09f3268 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -30,6 +30,7 @@ #include #include #include +#include #include "./util/tensor_util-inl.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" @@ -1353,6 +1354,7 @@ void BatchDotForward_(const nnvm::NodeAttrs& attrs, using namespace mshadow; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); + if (req[0] == kNullOp) return; const DotParam& param = nnvm::get(attrs.parsed); CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) << "Binary function only support input/output with the same type"; @@ -1362,115 +1364,47 @@ void BatchDotForward_(const nnvm::NodeAttrs& attrs, (outputs[0].type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask)) << "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU"; MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - mshadow::Tensor out = outputs[0].get(s); - mshadow::Tensor mlhs = inputs[0].get(s); - mshadow::Tensor mrhs = inputs[1].get(s); - mshadow::Tensor workspace = - ctx.requested[0].get_space_typed(mshadow::Shape1(3 * out.size(0)), s); - if (kNullOp != req[0]) { - if (param.transpose_a && param.transpose_b) { - mshadow::BatchGEMM(out, mlhs, mrhs, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - workspace); - } else if (!param.transpose_a && param.transpose_b) { - mshadow::BatchGEMM(out, mlhs, mrhs, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - workspace); - } else if (param.transpose_a && !param.transpose_b) { - mshadow::BatchGEMM(out, mlhs, mrhs, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - workspace); - } else { - mshadow::BatchGEMM(out, mlhs, mrhs, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - workspace); + int ndim = outputs[0].ndim(); + if (outputs[0].shape_.Size() == 0 || inputs[0].shape_.Size() == 0 + || inputs[1].shape_.Size() == 0) { + if (outputs[0].shape_.Size() != 0 && req[0] != kAddTo) { + + mxnet_op::Kernel::Launch(s, outputs[0].shape_.Size(), + outputs[0].dptr()); } + return; } - }); -} - -template -void BatchDotBackward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); - const DotParam& param = nnvm::get(attrs.parsed); - CHECK_NE(req[1], kWriteInplace); - CHECK_NE(req[0], kWriteInplace); - CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64 || - (outputs[0].type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask)) - << "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU"; - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - mshadow::Tensor mout_grad = inputs[0].get(s); - mshadow::Tensor mlhs_data = inputs[1].get(s); - mshadow::Tensor mrhs_data = inputs[2].get(s); - mshadow::Tensor mlhs_grad = outputs[0].get(s); - mshadow::Tensor mrhs_grad = outputs[1].get(s); - mshadow::Tensor workspace = - ctx.requested[0].get_space_typed( - mshadow::Shape2(2, 3 * mout_grad.size(0)), s); - mshadow::Tensor rhs_workspace = workspace[0]; - mshadow::Tensor lhs_workspace = workspace[1]; + size_t batch_size = outputs[0].shape_.ProdShape(0, ndim - 2); + mshadow::Tensor out = + outputs[0].get_with_shape(Shape3(batch_size, + outputs[0].shape_[ndim - 2], + outputs[0].shape_[ndim - 1]), s); + mshadow::Tensor mlhs = + inputs[0].get_with_shape(Shape3(batch_size, + inputs[0].shape_[ndim - 2], + inputs[0].shape_[ndim - 1]), s); + mshadow::Tensor mrhs = + inputs[1].get_with_shape(Shape3(batch_size, + inputs[1].shape_[ndim - 2], + inputs[1].shape_[ndim - 1]), s); + mshadow::Tensor workspace = + ctx.requested[0].get_space_typed(mshadow::Shape1(3 * out.size(0)), s); if (param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x.T, y.T) - // dy = dot(x, dz).T = dot(dz.T, x.T) - // dx = dot(dz, y).T = dot(y.T, dz.T) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, (DType)1.0f, - (kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - lhs_workspace); - } + mshadow::BatchGEMM(out, mlhs, mrhs, (DType)1.0f, + (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, + workspace); } else if (!param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x, y.T) - // dy = dot(x.T, dz).T = dot(dz.T, x) - // dx = dot(dz, y) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, (DType)1.0f, - (kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - lhs_workspace); - } + mshadow::BatchGEMM(out, mlhs, mrhs, (DType)1.0f, + (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, + workspace); } else if (param.transpose_a && !param.transpose_b) { - // Gradient of z = dot(x.T, y) - // dy = dot(x, dz) - // dx = dot(dz, y.T).T = dot(y, dz.T) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, (DType)1.0f, - (kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - lhs_workspace); - } + mshadow::BatchGEMM(out, mlhs, mrhs, (DType)1.0f, + (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, + workspace); } else { - // Gradient of z = dot(x, y) - // dy = dot(x.T, dz) - // dx = dot(dz, y.T) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, (DType)1.0f, - (kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - lhs_workspace); - } + mshadow::BatchGEMM(out, mlhs, mrhs, (DType)1.0f, + (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, + workspace); } }); } @@ -1485,24 +1419,34 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, mxnet::TShape& rshape = (*in_attrs)[1]; // return false if lhs and rhs both have fully unknown shape if (!ndim_is_known(lshape) || !ndim_is_known(rshape)) return false; - if (lshape.ndim() == 3 && rshape.ndim() == 3) { + if (lshape.ndim() >= 3 && rshape.ndim() >= 3 && lshape.ndim() == rshape.ndim()) { + int ndim = lshape.ndim(); // only partially infer shape if last dim of lhs and second dim of rhs is known - bool last_dim_known = dim_size_is_known(lshape, 2); - bool second_dim_known = dim_size_is_known(rshape, 1); + bool last_dim_known = dim_size_is_known(lshape, ndim - 1); + bool second_dim_known = dim_size_is_known(rshape, ndim - 2); if ( !last_dim_known || !second_dim_known) return false; - CHECK(lshape[0] == rshape[0]) - << "batch_dot shape error(batch_size must be equal): " << lshape << " X " << rshape - << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; - index_t out_m = param.transpose_a ? lshape[2] : lshape[1]; - index_t lshape_k = param.transpose_a ? lshape[1] : lshape[2]; - index_t out_n = param.transpose_b ? rshape[1] : rshape[2]; - index_t rshape_k = param.transpose_b ? rshape[2] : rshape[1]; - CHECK(lshape_k == rshape_k) - << "batch_dot shape error(shape mismatch): " << lshape << " X " << rshape + for (int i = 0; i < ndim - 2; i++) { + CHECK_EQ(lshape[i], rshape[i]) + << "batch_dot shape error (the leading batch dimensions must be equal): " + << lshape << " X " << rshape + << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; + } + dim_t out_m = param.transpose_a ? lshape[ndim - 1] : lshape[ndim - 2]; + dim_t lshape_k = param.transpose_a ? lshape[ndim - 2] : lshape[ndim - 1]; + dim_t out_n = param.transpose_b ? rshape[ndim - 2] : rshape[ndim - 1]; + dim_t rshape_k = param.transpose_b ? rshape[ndim - 1] : rshape[ndim - 2]; + CHECK_EQ(lshape_k, rshape_k) + << "batch_dot shape error (shape mismatch): " << lshape << " X " << rshape << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape3(lshape[0], out_m, out_n)); + std::vector out_shape_vec; + for (int i = 0; i < ndim - 2; i++) { + out_shape_vec.push_back(lshape[i]); + } + out_shape_vec.push_back(out_m); + out_shape_vec.push_back(out_n); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(out_shape_vec)); } else { - LOG(FATAL) << "batch_dot currently only support 3D*3D array" + LOG(FATAL) << "batch_dot currently only support N-D*N-D array (N >= 3)" << lshape << " v.s. " << rshape; } // return true if output shape is fully inferred diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index 11a056146e1d..5b0c9956b0cd 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -115,13 +115,13 @@ NNVM_REGISTER_OP(batch_dot) .describe(R"doc(Batchwise dot product. ``batch_dot`` is used to compute dot product of ``x`` and ``y`` when ``x`` and -``y`` are data in batch, namely 3D arrays in shape of `(batch_size, :, :)`. +``y`` are data in batch, namely N-D (N >= 3) arrays in shape of `(B0, ..., B_i, :, :)`. -For example, given ``x`` with shape `(batch_size, n, m)` and ``y`` with shape -`(batch_size, m, k)`, the result array will have shape `(batch_size, n, k)`, +For example, given ``x`` with shape `(B_0, ..., B_i, N, M)` and ``y`` with shape +`(B_0, ..., B_i, M, K)`, the result array will have shape `(B_0, ..., B_i, N, K)`, which is computed by:: - batch_dot(x,y)[i,:,:] = dot(x[i,:,:], y[i,:,:]) + batch_dot(x,y)[b_0, ..., b_i, :, :] = dot(x[b_0, ..., b_i, :, :], y[b_0, ..., b_i, :, :]) )doc" ADD_FILELINE) .set_num_inputs(2) @@ -138,21 +138,74 @@ which is computed by:: return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", BatchDotForward_) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_batch_dot"}) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, + const std::vector& ograds) { + const DotParam& param = nnvm::get(n->attrs.parsed); + nnvm::NodePtr lhs_grad; + nnvm::NodePtr rhs_grad; + std::string lhs_gnode_name = n->attrs.name + "_backward_lhs"; + std::string rhs_gnode_name = n->attrs.name + "_backward_rhs"; + if (param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x.T, y.T) + // dx = dot(dz, y).T = dot(y.T, dz.T) + // dy = dot(x, dz).T = dot(dz.T, x.T) + lhs_grad = MakeNode("batch_dot", lhs_gnode_name, + {n->inputs[1], ograds[0]}, &(n->attrs.dict), &n); + rhs_grad = MakeNode("batch_dot", rhs_gnode_name, + {ograds[0], n->inputs[0]}, &(n->attrs.dict), &n); + } else if (!param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x, y.T) + // dx = dot(dz, y) + // dy = dot(x.T, dz).T = dot(dz.T, x) + auto lhs_attrs_dict = n->attrs.dict; + auto rhs_attrs_dict = n->attrs.dict; + lhs_attrs_dict["transpose_a"] = "false"; + lhs_attrs_dict["transpose_b"] = "false"; + rhs_attrs_dict["transpose_a"] = "true"; + rhs_attrs_dict["transpose_b"] = "false"; + lhs_grad = MakeNode("batch_dot", lhs_gnode_name, + {ograds[0], n->inputs[1]}, &lhs_attrs_dict, &n); + rhs_grad = MakeNode("batch_dot", rhs_gnode_name, + {ograds[0], n->inputs[0]}, &rhs_attrs_dict, &n); + } else if (param.transpose_a && !param.transpose_b) { + // Gradient of z = dot(x.T, y) + // dx = dot(dz, y.T).T = dot(y, dz.T) + // dy = dot(x, dz) + auto lhs_attrs_dict = n->attrs.dict; + auto rhs_attrs_dict = n->attrs.dict; + lhs_attrs_dict["transpose_a"] = "false"; + lhs_attrs_dict["transpose_b"] = "true"; + rhs_attrs_dict["transpose_a"] = "false"; + rhs_attrs_dict["transpose_b"] = "false"; + lhs_grad = MakeNode("batch_dot", lhs_gnode_name, + {n->inputs[1], ograds[0]}, &lhs_attrs_dict, &n); + rhs_grad = MakeNode("batch_dot", rhs_gnode_name, + {n->inputs[0], ograds[0]}, &rhs_attrs_dict, &n); + } else { + // Gradient of z = dot(x, y) + // dx = dot(dz, y.T) + // dy = dot(x.T, dz) + auto lhs_attrs_dict = n->attrs.dict; + auto rhs_attrs_dict = n->attrs.dict; + lhs_attrs_dict["transpose_a"] = "false"; + lhs_attrs_dict["transpose_b"] = "true"; + rhs_attrs_dict["transpose_a"] = "true"; + rhs_attrs_dict["transpose_b"] = "false"; + lhs_grad = MakeNode("batch_dot", lhs_gnode_name, + {ograds[0], n->inputs[1]}, &lhs_attrs_dict, &n); + rhs_grad = MakeNode("batch_dot", rhs_gnode_name, + {n->inputs[0], ograds[0]}, &rhs_attrs_dict, &n); + } + std::vector ret; + ret.emplace_back(nnvm::NodeEntry{lhs_grad, 0, 0}); + ret.emplace_back(nnvm::NodeEntry{rhs_grad, 0, 0}); + return ret; +}) +//.set_attr("FGradient", ElemwiseGradUseIn{"_backward_batch_dot"}) .add_argument("lhs", "NDArray-or-Symbol", "The first input") .add_argument("rhs", "NDArray-or-Symbol", "The second input") .add_arguments(DotParam::__FIELDS__()); -NNVM_REGISTER_OP(_backward_batch_dot) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("TIsBackward", true) -.set_attr("FCompute", BatchDotBackward_); - } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/dot.cu b/src/operator/tensor/dot.cu index 8ee2e2832fbb..b245b1c9e5ed 100644 --- a/src/operator/tensor/dot.cu +++ b/src/operator/tensor/dot.cu @@ -38,8 +38,5 @@ NNVM_REGISTER_OP(_backward_dot) NNVM_REGISTER_OP(batch_dot) .set_attr("FCompute", BatchDotForward_); -NNVM_REGISTER_OP(_backward_batch_dot) -.set_attr("FCompute", BatchDotBackward_); - } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index b764ac73d30c..31f4870ad0cb 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -24,6 +24,7 @@ import platform import mxnet as mx import scipy.stats as ss +from nose.tools import assert_raises from mxnet import np, npx from mxnet.gluon import HybridBlock from mxnet.base import MXNetError @@ -901,6 +902,122 @@ def hybrid_forward(self, F, a): expected_grad[basic_index] = 1 assert same(a.grad.asnumpy(), expected_grad) +@with_seed() +@use_np +def test_npx_batch_dot(): + ctx = mx.context.current_context() + dtypes = ['float32', 'float64'] + if ctx.device_type == 'gpu': + dtypes += ['float16'] + class TestBatchDot(HybridBlock): + def __init__(self, transpose_a, transpose_b): + super(TestBatchDot, self).__init__() + self._transpose_a = transpose_a + self._transpose_b = transpose_b + + def hybrid_forward(self, F, lhs, rhs): + return F.npx.batch_dot(lhs, rhs, + transpose_a=self._transpose_a, + transpose_b=self._transpose_b) + + def batch_dot_numpy(lhs, rhs, transpose_a, transpose_b): + assert lhs.ndim == rhs.ndim >= 3 + if transpose_a: + lhs = lhs.swapaxes(-1, -2) + if transpose_b: + rhs = rhs.swapaxes(-1, -2) + return _np.matmul(lhs, rhs) + + def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, rhs_req, + init_lhs_grad, init_rhs_grad): + + if transpose_a and transpose_b: + # Gradient of z = dot(x.T, y.T) + # dx = dot(dz, y).T = dot(y.T, dz.T) + # dy = dot(x, dz).T = dot(dz.T, x.T) + lhs_grad = batch_dot_numpy(rhs, ograd, transpose_a=True, transpose_b=True) + rhs_grad = batch_dot_numpy(ograd, lhs, transpose_a=True, transpose_b=True) + elif not transpose_a and transpose_b: + # Gradient of z = dot(x, y.T) + # dx = dot(dz, y) + # dy = dot(x.T, dz).T = dot(dz.T, x) + lhs_grad = batch_dot_numpy(ograd, rhs, transpose_a=False, transpose_b=False) + rhs_grad = batch_dot_numpy(ograd, lhs, transpose_a=True, transpose_b=False) + elif transpose_a and not transpose_b: + # Gradient of z = dot(x.T, y) + # dx = dot(dz, y.T).T = dot(y, dz.T) + # dy = dot(x, dz) + lhs_grad = batch_dot_numpy(rhs, ograd, transpose_a=False, transpose_b=True) + rhs_grad = batch_dot_numpy(lhs, ograd, transpose_a=False, transpose_b=False) + else: + # Gradient of z = dot(x, y) + # dx = dot(dz, y.T) + # dy = dot(x.T, dz) + lhs_grad = batch_dot_numpy(ograd, rhs, transpose_a=False, transpose_b=True) + rhs_grad = batch_dot_numpy(lhs, ograd, transpose_a=True, transpose_b=False) + if lhs_req == 'add': + lhs_grad += init_lhs_grad + if rhs_req == 'add': + rhs_grad += init_rhs_grad + return lhs_grad, rhs_grad + + + configs = [ + ((2, 3, 0), (2, 4, 0), False, True), + ((2, 4, 3), (2, 4, 3), True, False), + ((0, 3, 0), (0, 0, 2), False, False), + ((3, 2, 3, 2), (3, 2, 2, 3), True, True), + ((3, 1, 5, 2), (3, 1, 2, 1), False, False) + ] + bad_configs = [ + ((5, 3, 2), (5, 1, 3), False, False), + ((2, 5, 3, 1), (2, 4, 3, 1), True, False) + ] + for hybridize in [True, False]: + for lhs_shape, rhs_shape, transpose_a, transpose_b in configs: + for dtype in dtypes: + for lhs_grad_req in ['write', 'add']: + for rhs_grad_req in ['write', 'add']: + f_batch_dot = TestBatchDot(transpose_a=transpose_a, + transpose_b=transpose_b) + if hybridize: + f_batch_dot.hybridize() + lhs_val = mx.np.array(_np.random.uniform(-1.0, 1.0, lhs_shape), dtype=dtype) + rhs_val = mx.np.array(_np.random.uniform(-1.0, 1.0, rhs_shape), dtype=dtype) + lhs_val.attach_grad(grad_req=lhs_grad_req) + rhs_val.attach_grad(grad_req=rhs_grad_req) + gt_out = batch_dot_numpy(lhs_val.asnumpy(), rhs_val.asnumpy(), + transpose_a, transpose_b) + init_lhs_grad = mx.np.random.uniform(-1.0, 1.0, lhs_shape, dtype=dtype) + init_rhs_grad = mx.np.random.uniform(-1.0, 1.0, rhs_shape, dtype=dtype) + o_grad = mx.np.random.uniform(-1.0, 1.0, gt_out.shape, dtype=dtype) + if lhs_grad_req == 'add': + lhs_val.grad[:] = init_lhs_grad + if rhs_grad_req == 'add': + rhs_val.grad[:] = init_rhs_grad + with mx.autograd.record(): + out = f_batch_dot(lhs_val, rhs_val) + out.backward(o_grad) + assert_almost_equal(out.asnumpy(), gt_out, rtol=1E-5, atol=1E-5) + gt_lhs_grad, gt_rhs_grad = gt_grad_batch_dot_numpy(lhs_val.asnumpy(), + rhs_val.asnumpy(), + o_grad.asnumpy(), + transpose_a=transpose_a, + transpose_b=transpose_b, + lhs_req=lhs_grad_req, + rhs_req=rhs_grad_req, + init_lhs_grad=init_lhs_grad.asnumpy(), + init_rhs_grad=init_rhs_grad.asnumpy()) + assert_almost_equal(lhs_val.grad.asnumpy(), gt_lhs_grad, rtol=1E-5, atol=1E-5) + assert_almost_equal(rhs_val.grad.asnumpy(), gt_rhs_grad, rtol=1E-5, atol=1E-5) + for lhs_shape, rhs_shape, transpose_a, transpose_b in bad_configs: + for dtype in dtypes: + lhs_val = mx.np.array(_np.random.uniform(-1.0, 1.0, lhs_shape), dtype=dtype) + rhs_val = mx.np.array(_np.random.uniform(-1.0, 1.0, rhs_shape), dtype=dtype) + assert_raises(MXNetError, lambda: mx.npx.batch_dot(lhs_val, rhs_val, + transpose_a=transpose_a, + transpose_b=transpose_b)) + @with_seed() @use_np diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index a16dc6c693ab..c87fa6148d3b 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3291,9 +3291,9 @@ def test_batch_dot(): agrad_npy = np.empty((batch_size, m, k), dtype=data_type) bgrad_npy = np.empty((batch_size, k, n), dtype=data_type) a_init_grad_npy = np.random.normal(size=(batch_size, m, k)) - a_init_grad_npy = a_npy.astype(data_type) + a_init_grad_npy = a_init_grad_npy.astype(data_type) b_init_grad_npy = np.random.normal(size=(batch_size, k, n)) - b_init_grad_npy = b_npy.astype(data_type) + b_init_grad_npy = b_init_grad_npy.astype(data_type) for i in range(batch_size): c_npy[i, :, :] = np.dot(a_npy[i, :, :], b_npy[i, :, :]) bgrad_npy[i, :, :] = np.dot(a_npy[i, :, :].T, ograd_npy[i, :, :]) From 604b52fd60f78da8dd5abe1518bd6bc491083195 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Tue, 22 Oct 2019 20:44:15 -0700 Subject: [PATCH 2/5] use 1E-4 --- tests/python/unittest/test_numpy_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 31f4870ad0cb..35ddbe65e8c9 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -998,7 +998,7 @@ def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, with mx.autograd.record(): out = f_batch_dot(lhs_val, rhs_val) out.backward(o_grad) - assert_almost_equal(out.asnumpy(), gt_out, rtol=1E-5, atol=1E-5) + assert_almost_equal(out.asnumpy(), gt_out, rtol=1E-4, atol=1E-4) gt_lhs_grad, gt_rhs_grad = gt_grad_batch_dot_numpy(lhs_val.asnumpy(), rhs_val.asnumpy(), o_grad.asnumpy(), @@ -1008,8 +1008,8 @@ def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, rhs_req=rhs_grad_req, init_lhs_grad=init_lhs_grad.asnumpy(), init_rhs_grad=init_rhs_grad.asnumpy()) - assert_almost_equal(lhs_val.grad.asnumpy(), gt_lhs_grad, rtol=1E-5, atol=1E-5) - assert_almost_equal(rhs_val.grad.asnumpy(), gt_rhs_grad, rtol=1E-5, atol=1E-5) + assert_almost_equal(lhs_val.grad.asnumpy(), gt_lhs_grad, rtol=1E-4, atol=1E-4) + assert_almost_equal(rhs_val.grad.asnumpy(), gt_rhs_grad, rtol=1E-4, atol=1E-4) for lhs_shape, rhs_shape, transpose_a, transpose_b in bad_configs: for dtype in dtypes: lhs_val = mx.np.array(_np.random.uniform(-1.0, 1.0, lhs_shape), dtype=dtype) From e036834989cd1a2a13080d1e3fc1c7f893338cb9 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Tue, 22 Oct 2019 20:47:28 -0700 Subject: [PATCH 3/5] fix lint --- src/operator/tensor/dot-inl.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 0166a09f3268..8405404dc627 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -30,7 +30,7 @@ #include #include #include -#include + #include "./util/tensor_util-inl.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" @@ -1368,7 +1368,6 @@ void BatchDotForward_(const nnvm::NodeAttrs& attrs, if (outputs[0].shape_.Size() == 0 || inputs[0].shape_.Size() == 0 || inputs[1].shape_.Size() == 0) { if (outputs[0].shape_.Size() != 0 && req[0] != kAddTo) { - mxnet_op::Kernel::Launch(s, outputs[0].shape_.Size(), outputs[0].dptr()); } From df3f11557f626a673cb9768d070052a662853fda Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Tue, 22 Oct 2019 21:30:00 -0700 Subject: [PATCH 4/5] remove unnecessary comment --- src/operator/tensor/dot.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index 5b0c9956b0cd..556260ed9600 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -202,7 +202,6 @@ which is computed by:: ret.emplace_back(nnvm::NodeEntry{rhs_grad, 0, 0}); return ret; }) -//.set_attr("FGradient", ElemwiseGradUseIn{"_backward_batch_dot"}) .add_argument("lhs", "NDArray-or-Symbol", "The first input") .add_argument("rhs", "NDArray-or-Symbol", "The second input") .add_arguments(DotParam::__FIELDS__()); From 87ca6a20f5ef1b6583c8cda0c29393ccad5c8871 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Tue, 22 Oct 2019 22:51:11 -0700 Subject: [PATCH 5/5] Update test_numpy_op.py --- tests/python/unittest/test_numpy_op.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 35ddbe65e8c9..ae8ad621df75 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -909,6 +909,7 @@ def test_npx_batch_dot(): dtypes = ['float32', 'float64'] if ctx.device_type == 'gpu': dtypes += ['float16'] + eps_dict = {'float32': 1E-4, 'float64': 1E-4, 'float16': 1E-3} class TestBatchDot(HybridBlock): def __init__(self, transpose_a, transpose_b): super(TestBatchDot, self).__init__() @@ -976,6 +977,7 @@ def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, for hybridize in [True, False]: for lhs_shape, rhs_shape, transpose_a, transpose_b in configs: for dtype in dtypes: + eps = eps_dict[dtype] for lhs_grad_req in ['write', 'add']: for rhs_grad_req in ['write', 'add']: f_batch_dot = TestBatchDot(transpose_a=transpose_a, @@ -998,7 +1000,7 @@ def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, with mx.autograd.record(): out = f_batch_dot(lhs_val, rhs_val) out.backward(o_grad) - assert_almost_equal(out.asnumpy(), gt_out, rtol=1E-4, atol=1E-4) + assert_almost_equal(out.asnumpy(), gt_out, rtol=eps, atol=eps) gt_lhs_grad, gt_rhs_grad = gt_grad_batch_dot_numpy(lhs_val.asnumpy(), rhs_val.asnumpy(), o_grad.asnumpy(), @@ -1008,8 +1010,8 @@ def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, rhs_req=rhs_grad_req, init_lhs_grad=init_lhs_grad.asnumpy(), init_rhs_grad=init_rhs_grad.asnumpy()) - assert_almost_equal(lhs_val.grad.asnumpy(), gt_lhs_grad, rtol=1E-4, atol=1E-4) - assert_almost_equal(rhs_val.grad.asnumpy(), gt_rhs_grad, rtol=1E-4, atol=1E-4) + assert_almost_equal(lhs_val.grad.asnumpy(), gt_lhs_grad, rtol=eps, atol=eps) + assert_almost_equal(rhs_val.grad.asnumpy(), gt_rhs_grad, rtol=eps, atol=eps) for lhs_shape, rhs_shape, transpose_a, transpose_b in bad_configs: for dtype in dtypes: lhs_val = mx.np.array(_np.random.uniform(-1.0, 1.0, lhs_shape), dtype=dtype)