Skip to content

Commit

Permalink
Add operator for dot(dns, csr) = csr (apache#8938)
Browse files Browse the repository at this point in the history
* Add operator for dot(dns, csr) = csr

* Fix whitespace

* Add comments

* Add comments and fix error message

* Fixes for dot dns csr

* Fixes

* Remove non required statements

* Add fallback for GPU

* Remove unused if

* Fix comments and casting

* Add operator to the documentation
  • Loading branch information
anirudh2290 authored and zheng-da committed Jun 28, 2018
1 parent 1808128 commit db40eb1
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 25 deletions.
51 changes: 36 additions & 15 deletions benchmark/python/sparse/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,10 @@ def bench_dot(lhs_shape, rhs_shape, lhs_stype, rhs_stype,
# Create matrix instances
lhs_nd = rand_ndarray(lhs_shape, lhs_stype, density=lhs_den, distribution=distribution)
# only uniform distribution supported for rhs
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution="uniform")
if rhs_stype == 'csr':
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution=distribution)
else:
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution="uniform")
lhs_dns = None
rhs_dns = None
dense_cost = None
Expand Down Expand Up @@ -337,27 +340,41 @@ def print_benchmark_info(lhs, rhs, lhs_trans, fw):

def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", rhs_density=1,
distribution="uniform"):
if lhs != "csr":
raise ValueError("Value other than csr for lhs not supported")

if rhs_density > 1 or rhs_density < 0:
raise ValueError("rhs_density has to be between 0 and 1")

print_benchmark_info(lhs, rhs, lhs_trans, fw)

if rhs == "csr":
lhs_stype = "default"
rhs_stype = "csr"
assert (lhs_stype == 'default'), "Only dot(default, csr) supported"
# Arrange dimensions according to use case. For below csr will have num_rows << num_cols
feature_dim_list = data_dict['batch_size']
batch_size_list = data_dict['m']
output_dim_list = data_dict['feature_dim']
density_list = data_dict['density']
default_output_index = data_dict['default_index']['feature_dim']
default_density_index = data_dict['default_index']['density']
default_feature_index = data_dict['default_index']['batch_size']
default_batch_size_index = data_dict['default_index']['output_dim']
num_repeat = data_dict['num_repeat']

lhs_stype = "csr"
rhs_stype = "row_sparse" if rhs == "rsp" else "default"
else:
lhs_stype = "csr"
rhs_stype = "row_sparse" if rhs == "rsp" else "default"

feature_dim_list = data_dict['feature_dim']
output_dim_list = data_dict['m']
batch_size_list = data_dict['batch_size']
density_list = data_dict['density']
feature_dim_list = data_dict['feature_dim']
output_dim_list = data_dict['m']
batch_size_list = data_dict['batch_size']
density_list = data_dict['density']

default_output_index = data_dict['default_index']['output_dim']
default_batch_size_index = data_dict['default_index']['batch_size']
default_feature_index = data_dict['default_index']['feature_dim']
default_density_index = data_dict['default_index']['density']
num_repeat = data_dict['num_repeat']
default_output_index = data_dict['default_index']['output_dim']
default_batch_size_index = data_dict['default_index']['batch_size']
default_feature_index = data_dict['default_index']['feature_dim']
default_density_index = data_dict['default_index']['density']
num_repeat = data_dict['num_repeat']

for output_dim in output_dim_list:
if lhs_trans:
Expand Down Expand Up @@ -403,7 +420,7 @@ def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", r
feature_dim_list[default_feature_index]),
(output_row_dim,
output_dim_list[default_output_index]),
lhs_stype, rhs_stype, density, rhs_density, lhs_trans, ctx,
lhs_stype, rhs_stype, density, density, lhs_trans, ctx,
num_repeat=num_repeat, fw=fw, distribution=distribution)

check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(ARGS.num_omp_threads)))
Expand All @@ -423,6 +440,10 @@ def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", r
rhs="rsp", lhs_trans=False,
fw="mxnet", rhs_density=0.05,
distribution=distribution)
run_benchmark(context, lhs="default",
rhs="csr", lhs_trans=False,
fw="mxnet", rhs_density=0.001,
distribution=distribution)
if not ARGS.gpu:
run_benchmark(context, lhs="csr",
rhs="default", lhs_trans=False,
Expand Down
5 changes: 4 additions & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,10 @@ class NDArray {
bool fresh_out_grad() const;
/*! \return updated grad state in entry_ */
void set_fresh_out_grad(bool state) const;
// returns true if a sparse ndarray's aux_data and storage are initialized
/*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
* Throws an exception if the indices array shape is inconsistent
* Returns false if the indices array is empty(nnz = 0) for csr/row_sparse
*/
inline bool storage_initialized() const {
if (is_none()) return false;
auto stype = storage_type();
Expand Down
201 changes: 192 additions & 9 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,22 +202,25 @@ void DotBackward_(const nnvm::NodeAttrs& attrs,
inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
// csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp
// csr has many zero columns, so the result of dot(csr.T, matrix) should be
// rsp
const auto& lhs_stype = in_attrs->at(0);
const auto& rhs_stype = in_attrs->at(1);
auto& out_stype = out_attrs->at(0);
bool dispatched = false;
bool only_lhs_transpose = param.transpose_a && !param.transpose_b;
bool rhs_rsp_or_dns = rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage;
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) {
bool rhs_rsp_or_dns =
rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage;
if (!dispatched && lhs_stype == kDefaultStorage &&
rhs_stype == kDefaultStorage) {
// dns, dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFCompute);
}
if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose &&
(rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage)) {
Expand All @@ -228,11 +231,22 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns &&
!param.transpose_a && !param.transpose_b) {
// csr, rsp/dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFComputeEx);
}
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
!param.transpose_a && !param.transpose_b) {
// dns, csr -> csr
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback
: DispatchMode::kFComputeEx;
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
dispatch_ex);
}
if (!dispatched) {
dispatch_fallback(out_attrs, dispatch_mode);
}
if (static_cast<DispatchMode>(*dispatch_mode) == DispatchMode::kFComputeFallback) {
LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs);
}
return true;
Expand Down Expand Up @@ -527,6 +541,80 @@ struct DotCsrTransRspRspByRowBlocks {
}
};

/*!
* \brief CPU Kernel of PopulateCsrForNNC
* Parallelization by individual rows
* Populates the indptr and indices array
* based on number of non zero columns
*/
struct PopulateCsrForNNC {
/*!
* \brief
* \param i the i-th thread
* \param nnc_idx all non zero column indexes
* \param indptr_out indptr array for output
* \param col_idx_out column indices for output
* \param nnc number of non zero columns in the output
* \param num_rows_l number of rows in lhs
*/
template <typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i, const CType* nnc_idx,
IType* indptr_out, CType* col_idx_out,
const nnvm::dim_t nnc,
const nnvm::dim_t num_rows_l) {
const CType start_idx = i * nnc;
nnvm::dim_t cur = 0;
indptr_out[i] = start_idx;
if (static_cast<nnvm::dim_t>(i) == (num_rows_l - 1)) indptr_out[i + 1] = indptr_out[i] + nnc;
for (IType idx = start_idx; idx < (start_idx + nnc); idx++) {
col_idx_out[idx] = nnc_idx[cur++];
}
}
};

/*!
* \brief CPU Impl of dot(dns, csr) = csr
*/
struct DotDnsCsrCsrByRowBlocks {
/*!
* \brief
* \param i the i-th thread
* \param num_rows_r number of rows in rhs
* \param num_rows_l number of rows in lhs
* \param num_cols number of columns in output
* \param nnc number of non zero columns
*/

template <typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(
int i, DType* out, const DType* data_l, const IType* indptr_r,
const CType* col_idx_r, const DType* data_r, const nnvm::dim_t seg_len,
const IType num_rows_r, const IType num_rows_l,
const nnvm::dim_t num_cols, const nnvm::dim_t nnc,
const CType* prefix_sum) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows_l) return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);

for (dim_t j = seg_start; j < seg_end; j++) {
for (dim_t k = 0; k < num_rows_r; k++) {
const dim_t working_idx = j * num_rows_r + k;
const DType val = data_l[working_idx];
if (indptr_r[k] == indptr_r[k + 1]) continue;
const dim_t row_start = j * nnc;
for (dim_t cur = indptr_r[k]; cur < indptr_r[k + 1]; cur++) {
dim_t cur_col_idx_r = col_idx_r[cur];
const dim_t out_idx = row_start + prefix_sum[cur_col_idx_r] - 1;
out[out_idx] += val * data_r[cur];
}
}
}
}
};



/*!
* \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
*/
Expand Down Expand Up @@ -811,6 +899,96 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
});
}

/*
* \brief CPU Impl of dot(dns, csr) = csr
*/
template<typename xpu>
inline void DotDnsCsrCsrImpl(const OpContext& ctx,
const TBlob& lhs, const NDArray& rhs,
const OpReqType req, NDArray* ret) {
if (kNullOp == req) return;

CHECK_EQ(req, kWriteTo);
CHECK_EQ(rhs.storage_type(), kCSRStorage);

using namespace mshadow;
using namespace mshadow::expr;
using nnvm::dim_t;

/* Initialize data structures */
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
const NDArray& out = *ret;
const TBlob data_l = lhs;
const TBlob data_r = rhs.data();
const TBlob indptr_r = rhs.aux_data(csr::kIndPtr);
const TBlob col_idx_r = rhs.aux_data(csr::kIdx);
if (!rhs.storage_initialized()) {
FillZerosCsrImpl(s, *ret);
return;
}

MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // colidx type
/* Allocate workspace */
CType num_cols_out = out.shape()[1];
CType rhs_data_size = static_cast<CType>(col_idx_r.shape_.Size());
size_t workspace_size = 2 * num_cols_out * sizeof(CType);
Tensor<cpu, 1, char> workspace =
ctx.requested[0].get_space_typed<cpu, 1, char>(
Shape1(workspace_size), s);
CType* col_flg = reinterpret_cast<dim_t*>(workspace.dptr_);

CType* prefix_sum = col_flg;
CType* nnc_idx = prefix_sum + num_cols_out;

/* Set the column flags for nnz columns */
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, num_cols_out,
col_flg);
mxnet_op::Kernel<MarkRowFlgKernel, cpu>::Launch(
s, rhs_data_size, col_flg, col_idx_r.dptr<CType>());

/* 1. Calculate prefix sum from col flgs
* 2. Storage all non zero column indexes in nnc_idx
*/
CType cur = 0;
prefix_sum[0] = col_flg[0];
if (prefix_sum[0]) nnc_idx[cur++] = 0;
for (CType i = 1; i < num_cols_out; i++) {
prefix_sum[i] = prefix_sum[i - 1] + col_flg[i];
if (prefix_sum[i] > prefix_sum[i - 1]) nnc_idx[cur++] = i;
}

/* Allocate aux data for out */
IType num_rows_l = lhs.shape_[0];
dim_t nnc = prefix_sum[num_cols_out - 1];
dim_t nnz = nnc * num_rows_l;
out.CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows_l + 1));
out.CheckAndAllocAuxData(csr::kIdx, Shape1(nnz));
out.CheckAndAllocData(Shape1(nnz));

/* Set csr indptr and index according to nnc_idx*/
IType* indptr_out = out.aux_data(csr::kIndPtr).dptr<IType>();
CType* col_idx_out = out.aux_data(csr::kIdx).dptr<CType>();
DType* data_out = out.data().dptr<DType>();
mxnet_op::Kernel<PopulateCsrForNNC, cpu>::Launch(
s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l);
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, nnz, data_out);

const dim_t num_threads = mxnet_op::get_num_threads<cpu>(num_rows_l);
const dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads;

IType num_rows_r = rhs.shape()[0];
mxnet_op::Kernel<DotDnsCsrCsrByRowBlocks, cpu>::Launch(
s, num_threads, data_out, data_l.dptr<DType>(),
indptr_r.dptr<IType>(), col_idx_r.dptr<CType>(),
data_r.dptr<DType>(), seg_len, num_rows_r, num_rows_l, num_cols_out,
nnc, prefix_sum);
});
});
});
}

inline bool DotShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
Expand Down Expand Up @@ -886,6 +1064,11 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs,
&& out_stype == kRowSparseStorage && !param.transpose_b) {
NDArray ret = outputs[0];
DotCsrRspRspImpl(ctx, xpu(), inputs[0], inputs[1], req[0], param.transpose_a, &ret);
} else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
out_stype == kCSRStorage &&
!(param.transpose_a || param.transpose_b)) {
NDArray ret = outputs[0];
DotDnsCsrCsrImpl<xpu>(ctx, inputs[0].data(), inputs[1], req[0], &ret);
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ The storage type of ``dot`` output depends on storage types of inputs and transp
- dot(csr, default) = default
- dot(csr.T, default) = row_sparse
- dot(csr, row_sparse) = default
- dot(default, csr) = csr
- otherwise, ``dot`` generates output with default storage
)doc" ADD_FILELINE)
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,31 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, rhs_de
grad_req={'lhs': 'null', 'rhs': 'write'},
rtol=1e-3, atol=1e-4)

def test_dot_dns_csr(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_lhs=False, trans_rhs=False):
lhs_nd = rand_ndarray(lhs_shape, stype='default', density=lhs_density)
rhs_nd = rand_ndarray(rhs_shape, stype='csr', density=rhs_density)
rhs_dns = rhs_nd.tostype('default')

out = mx.nd.sparse.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs, transpose_b=trans_rhs)
out_dns = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs, transpose_b=trans_rhs)
out_np = out_dns.asnumpy()
assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5)

# test symbolic forward
lhs = mx.symbol.Variable('lhs', stype='default')
rhs = mx.symbol.Variable('rhs', stype='csr')
out = mx.symbol.sparse.dot(lhs, rhs, transpose_a=trans_lhs, transpose_b=trans_rhs)
location = {'lhs': lhs_nd, 'rhs': rhs_nd}
check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4)

# test symbolic backward
backward_trans = not trans_lhs
rhs_backward_grad = mx.nd.dot(lhs_nd, out_dns, transpose_a=backward_trans).asnumpy()
expected = {'rhs': rhs_backward_grad}
check_symbolic_backward(out, location, [out_np], expected,
grad_req={'lhs': 'null', 'rhs': 'write'},
rtol=1e-3, atol=1e-4)

def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):
"""Test for nnr_out = 0. Before the fix, the test would fail."""
lhs = mx.nd.zeros(lhs_shape)
Expand All @@ -1248,10 +1273,12 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):
test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True, lhs_d, rhs_d) # (vector kernel)
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d) # test gpu SpMM
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d) # (scalar kernel)
test_dot_dns_csr(lhs_shape, (lhs_shape[1], rnd.randint(50, 200)), lhs_d, lhs_d)
for rhs_d in density:
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, lhs_d, rhs_d)
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, lhs_d, rhs_d)


test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40)
test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40)

Expand Down

0 comments on commit db40eb1

Please sign in to comment.