Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Support N_D(N>=3) batch_dot #16586

Merged
merged 5 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 60 additions & 117 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <algorithm>
#include <utility>
#include <type_traits>

#include "./util/tensor_util-inl.h"
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
Expand Down Expand Up @@ -1353,6 +1354,7 @@ void BatchDotForward_(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
if (req[0] == kNullOp) return;
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_)
<< "Binary function only support input/output with the same type";
Expand All @@ -1362,115 +1364,46 @@ void BatchDotForward_(const nnvm::NodeAttrs& attrs,
(outputs[0].type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
<< "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU";
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
mshadow::Tensor<xpu, 3, DType> out = outputs[0].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 3, DType> mlhs = inputs[0].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 3, DType> mrhs = inputs[1].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 1, DType*> workspace =
ctx.requested[0].get_space_typed<xpu, 1, DType*>(mshadow::Shape1(3 * out.size(0)), s);
if (kNullOp != req[0]) {
if (param.transpose_a && param.transpose_b) {
mshadow::BatchGEMM<true, true>(out, mlhs, mrhs, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
workspace);
} else if (!param.transpose_a && param.transpose_b) {
mshadow::BatchGEMM<false, true>(out, mlhs, mrhs, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
workspace);
} else if (param.transpose_a && !param.transpose_b) {
mshadow::BatchGEMM<true, false>(out, mlhs, mrhs, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
workspace);
} else {
mshadow::BatchGEMM<false, false>(out, mlhs, mrhs, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
workspace);
int ndim = outputs[0].ndim();
if (outputs[0].shape_.Size() == 0 || inputs[0].shape_.Size() == 0
|| inputs[1].shape_.Size() == 0) {
if (outputs[0].shape_.Size() != 0 && req[0] != kAddTo) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(s, outputs[0].shape_.Size(),
outputs[0].dptr<DType>());
}
return;
}
});
}

template<typename xpu>
void BatchDotBackward_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
CHECK_NE(req[1], kWriteInplace);
CHECK_NE(req[0], kWriteInplace);
CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64 ||
(outputs[0].type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
<< "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU";
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
mshadow::Tensor<xpu, 3, DType> mout_grad = inputs[0].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 3, DType> mlhs_data = inputs[1].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 3, DType> mrhs_data = inputs[2].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 3, DType> mlhs_grad = outputs[0].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 3, DType> mrhs_grad = outputs[1].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 2, DType*> workspace =
ctx.requested[0].get_space_typed<xpu, 2, DType*>(
mshadow::Shape2(2, 3 * mout_grad.size(0)), s);
mshadow::Tensor<xpu, 1, DType*> rhs_workspace = workspace[0];
mshadow::Tensor<xpu, 1, DType*> lhs_workspace = workspace[1];
size_t batch_size = outputs[0].shape_.ProdShape(0, ndim - 2);
mshadow::Tensor<xpu, 3, DType> out =
outputs[0].get_with_shape<xpu, 3, DType>(Shape3(batch_size,
outputs[0].shape_[ndim - 2],
outputs[0].shape_[ndim - 1]), s);
mshadow::Tensor<xpu, 3, DType> mlhs =
inputs[0].get_with_shape<xpu, 3, DType>(Shape3(batch_size,
inputs[0].shape_[ndim - 2],
inputs[0].shape_[ndim - 1]), s);
mshadow::Tensor<xpu, 3, DType> mrhs =
inputs[1].get_with_shape<xpu, 3, DType>(Shape3(batch_size,
inputs[1].shape_[ndim - 2],
inputs[1].shape_[ndim - 1]), s);
mshadow::Tensor<xpu, 1, DType*> workspace =
ctx.requested[0].get_space_typed<xpu, 1, DType*>(mshadow::Shape1(3 * out.size(0)), s);
if (param.transpose_a && param.transpose_b) {
// Gradient of z = dot(x.T, y.T)
// dy = dot(x, dz).T = dot(dz.T, x.T)
// dx = dot(dz, y).T = dot(y.T, dz.T)
if (kNullOp != req[1]) {
mshadow::BatchGEMM<true, true>(mrhs_grad, mout_grad, mlhs_data, (DType)1.0f,
(kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f,
rhs_workspace);
}
if (kNullOp != req[0]) {
mshadow::BatchGEMM<true, true>(mlhs_grad, mrhs_data, mout_grad, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
lhs_workspace);
}
mshadow::BatchGEMM<true, true>(out, mlhs, mrhs, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
workspace);
} else if (!param.transpose_a && param.transpose_b) {
// Gradient of z = dot(x, y.T)
// dy = dot(x.T, dz).T = dot(dz.T, x)
// dx = dot(dz, y)
if (kNullOp != req[1]) {
mshadow::BatchGEMM<true, false>(mrhs_grad, mout_grad, mlhs_data, (DType)1.0f,
(kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f,
rhs_workspace);
}
if (kNullOp != req[0]) {
mshadow::BatchGEMM<false, false>(mlhs_grad, mout_grad, mrhs_data, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
lhs_workspace);
}
mshadow::BatchGEMM<false, true>(out, mlhs, mrhs, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
workspace);
} else if (param.transpose_a && !param.transpose_b) {
// Gradient of z = dot(x.T, y)
// dy = dot(x, dz)
// dx = dot(dz, y.T).T = dot(y, dz.T)
if (kNullOp != req[1]) {
mshadow::BatchGEMM<false, false>(mrhs_grad, mlhs_data, mout_grad, (DType)1.0f,
(kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f,
rhs_workspace);
}
if (kNullOp != req[0]) {
mshadow::BatchGEMM<false, true>(mlhs_grad, mrhs_data, mout_grad, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
lhs_workspace);
}
mshadow::BatchGEMM<true, false>(out, mlhs, mrhs, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
workspace);
} else {
// Gradient of z = dot(x, y)
// dy = dot(x.T, dz)
// dx = dot(dz, y.T)
if (kNullOp != req[1]) {
mshadow::BatchGEMM<true, false>(mrhs_grad, mlhs_data, mout_grad, (DType)1.0f,
(kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f,
rhs_workspace);
}
if (kNullOp != req[0]) {
mshadow::BatchGEMM<false, true>(mlhs_grad, mout_grad, mrhs_data, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
lhs_workspace);
}
mshadow::BatchGEMM<false, false>(out, mlhs, mrhs, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
workspace);
}
});
}
Expand All @@ -1485,24 +1418,34 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape& rshape = (*in_attrs)[1];
// return false if lhs and rhs both have fully unknown shape
if (!ndim_is_known(lshape) || !ndim_is_known(rshape)) return false;
if (lshape.ndim() == 3 && rshape.ndim() == 3) {
if (lshape.ndim() >= 3 && rshape.ndim() >= 3 && lshape.ndim() == rshape.ndim()) {
int ndim = lshape.ndim();
// only partially infer shape if last dim of lhs and second dim of rhs is known
bool last_dim_known = dim_size_is_known(lshape, 2);
bool second_dim_known = dim_size_is_known(rshape, 1);
bool last_dim_known = dim_size_is_known(lshape, ndim - 1);
bool second_dim_known = dim_size_is_known(rshape, ndim - 2);
if ( !last_dim_known || !second_dim_known) return false;
CHECK(lshape[0] == rshape[0])
<< "batch_dot shape error(batch_size must be equal): " << lshape << " X " << rshape
<< " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b;
index_t out_m = param.transpose_a ? lshape[2] : lshape[1];
index_t lshape_k = param.transpose_a ? lshape[1] : lshape[2];
index_t out_n = param.transpose_b ? rshape[1] : rshape[2];
index_t rshape_k = param.transpose_b ? rshape[2] : rshape[1];
CHECK(lshape_k == rshape_k)
<< "batch_dot shape error(shape mismatch): " << lshape << " X " << rshape
for (int i = 0; i < ndim - 2; i++) {
CHECK_EQ(lshape[i], rshape[i])
<< "batch_dot shape error (the leading batch dimensions must be equal): "
<< lshape << " X " << rshape
<< " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b;
}
dim_t out_m = param.transpose_a ? lshape[ndim - 1] : lshape[ndim - 2];
dim_t lshape_k = param.transpose_a ? lshape[ndim - 2] : lshape[ndim - 1];
dim_t out_n = param.transpose_b ? rshape[ndim - 2] : rshape[ndim - 1];
dim_t rshape_k = param.transpose_b ? rshape[ndim - 1] : rshape[ndim - 2];
CHECK_EQ(lshape_k, rshape_k)
<< "batch_dot shape error (shape mismatch): " << lshape << " X " << rshape
<< " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape3(lshape[0], out_m, out_n));
std::vector<dim_t> out_shape_vec;
for (int i = 0; i < ndim - 2; i++) {
out_shape_vec.push_back(lshape[i]);
}
out_shape_vec.push_back(out_m);
out_shape_vec.push_back(out_n);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(out_shape_vec));
} else {
LOG(FATAL) << "batch_dot currently only support 3D*3D array"
LOG(FATAL) << "batch_dot currently only support N-D*N-D array (N >= 3)"
<< lshape << " v.s. " << rshape;
}
// return true if output shape is fully inferred
Expand Down
84 changes: 68 additions & 16 deletions src/operator/tensor/dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ NNVM_REGISTER_OP(batch_dot)
.describe(R"doc(Batchwise dot product.

``batch_dot`` is used to compute dot product of ``x`` and ``y`` when ``x`` and
``y`` are data in batch, namely 3D arrays in shape of `(batch_size, :, :)`.
``y`` are data in batch, namely N-D (N >= 3) arrays in shape of `(B0, ..., B_i, :, :)`.

For example, given ``x`` with shape `(batch_size, n, m)` and ``y`` with shape
`(batch_size, m, k)`, the result array will have shape `(batch_size, n, k)`,
For example, given ``x`` with shape `(B_0, ..., B_i, N, M)` and ``y`` with shape
`(B_0, ..., B_i, M, K)`, the result array will have shape `(B_0, ..., B_i, N, K)`,
which is computed by::

batch_dot(x,y)[i,:,:] = dot(x[i,:,:], y[i,:,:])
batch_dot(x,y)[b_0, ..., b_i, :, :] = dot(x[b_0, ..., b_i, :, :], y[b_0, ..., b_i, :, :])

)doc" ADD_FILELINE)
.set_num_inputs(2)
Expand All @@ -138,21 +138,73 @@ which is computed by::
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", BatchDotForward_<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_batch_dot"})
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
const DotParam& param = nnvm::get<DotParam>(n->attrs.parsed);
nnvm::NodePtr lhs_grad;
nnvm::NodePtr rhs_grad;
std::string lhs_gnode_name = n->attrs.name + "_backward_lhs";
std::string rhs_gnode_name = n->attrs.name + "_backward_rhs";
if (param.transpose_a && param.transpose_b) {
// Gradient of z = dot(x.T, y.T)
// dx = dot(dz, y).T = dot(y.T, dz.T)
// dy = dot(x, dz).T = dot(dz.T, x.T)
lhs_grad = MakeNode("batch_dot", lhs_gnode_name,
{n->inputs[1], ograds[0]}, &(n->attrs.dict), &n);
rhs_grad = MakeNode("batch_dot", rhs_gnode_name,
{ograds[0], n->inputs[0]}, &(n->attrs.dict), &n);
} else if (!param.transpose_a && param.transpose_b) {
// Gradient of z = dot(x, y.T)
// dx = dot(dz, y)
// dy = dot(x.T, dz).T = dot(dz.T, x)
auto lhs_attrs_dict = n->attrs.dict;
auto rhs_attrs_dict = n->attrs.dict;
lhs_attrs_dict["transpose_a"] = "false";
lhs_attrs_dict["transpose_b"] = "false";
rhs_attrs_dict["transpose_a"] = "true";
rhs_attrs_dict["transpose_b"] = "false";
lhs_grad = MakeNode("batch_dot", lhs_gnode_name,
{ograds[0], n->inputs[1]}, &lhs_attrs_dict, &n);
rhs_grad = MakeNode("batch_dot", rhs_gnode_name,
{ograds[0], n->inputs[0]}, &rhs_attrs_dict, &n);
} else if (param.transpose_a && !param.transpose_b) {
// Gradient of z = dot(x.T, y)
// dx = dot(dz, y.T).T = dot(y, dz.T)
// dy = dot(x, dz)
auto lhs_attrs_dict = n->attrs.dict;
auto rhs_attrs_dict = n->attrs.dict;
lhs_attrs_dict["transpose_a"] = "false";
lhs_attrs_dict["transpose_b"] = "true";
rhs_attrs_dict["transpose_a"] = "false";
rhs_attrs_dict["transpose_b"] = "false";
lhs_grad = MakeNode("batch_dot", lhs_gnode_name,
{n->inputs[1], ograds[0]}, &lhs_attrs_dict, &n);
rhs_grad = MakeNode("batch_dot", rhs_gnode_name,
{n->inputs[0], ograds[0]}, &rhs_attrs_dict, &n);
} else {
// Gradient of z = dot(x, y)
// dx = dot(dz, y.T)
// dy = dot(x.T, dz)
auto lhs_attrs_dict = n->attrs.dict;
auto rhs_attrs_dict = n->attrs.dict;
lhs_attrs_dict["transpose_a"] = "false";
lhs_attrs_dict["transpose_b"] = "true";
rhs_attrs_dict["transpose_a"] = "true";
rhs_attrs_dict["transpose_b"] = "false";
lhs_grad = MakeNode("batch_dot", lhs_gnode_name,
{ograds[0], n->inputs[1]}, &lhs_attrs_dict, &n);
rhs_grad = MakeNode("batch_dot", rhs_gnode_name,
{n->inputs[0], ograds[0]}, &rhs_attrs_dict, &n);
}
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(nnvm::NodeEntry{lhs_grad, 0, 0});
ret.emplace_back(nnvm::NodeEntry{rhs_grad, 0, 0});
return ret;
})
.add_argument("lhs", "NDArray-or-Symbol", "The first input")
.add_argument("rhs", "NDArray-or-Symbol", "The second input")
.add_arguments(DotParam::__FIELDS__());

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any performance difference with the new backward_batch_dot implementation ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eric-haibin-lin Here is the benchmark script for getting backward performance: https://gist.github.com/haojin2/c1a2bd1373530f4686bdefd2eafbee84
Results:
lhs: (32, 128, 768) rhs: (32, 128, 768) transpose_b: True 0.212037ms -> 0.213933ms
lhs: (32, 1, 768) rhs: (32, 128, 768) transpose_b: True 0.119977ms -> 0.124208ms
There's no obvious regression in performance.

NNVM_REGISTER_OP(_backward_batch_dot)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr_parser(ParamParser<DotParam>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", BatchDotBackward_<cpu>);

} // namespace op
} // namespace mxnet
3 changes: 0 additions & 3 deletions src/operator/tensor/dot.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,5 @@ NNVM_REGISTER_OP(_backward_dot)
NNVM_REGISTER_OP(batch_dot)
.set_attr<FCompute>("FCompute<gpu>", BatchDotForward_<gpu>);

NNVM_REGISTER_OP(_backward_batch_dot)
.set_attr<FCompute>("FCompute<gpu>", BatchDotBackward_<gpu>);

} // namespace op
} // namespace mxnet
Loading