-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Add operator for dot(dns, csr) = csr #8938
Changes from 3 commits
cffb161
73a9f3d
362af57
a1b3db0
3fad020
2521627
4afbd3a
7077c53
0747ada
fb20f45
44789d2
1525405
d5d3216
d6f13a3
9a81b78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -231,6 +231,12 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, | |
dispatched = storage_type_assign(&out_stype, kDefaultStorage, | ||
dispatch_mode, DispatchMode::kFComputeEx); | ||
} | ||
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the implementation only available on CPU? No fallback on GPU ctx? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added a check for CPU. will fallback to default storage for gpu |
||
!param.transpose_a && !param.transpose_b) { | ||
// dns, csr -> csr | ||
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, | ||
DispatchMode::kFComputeEx); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm. we should log storage fallback as long as dispatch mode is dispatch_fallback: Maybe I should move this logic to the common path instead of letting developers specify that in operators There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes moving the logic to common path would be nice. I see multiple places where we don't have this check. For example, https://github.com/apache/incubator-mxnet/blob/d2a856a3a2abb4e72edc301b8b821f0b75f30722/src/operator/tensor/elemwise_unary_op_basic.cc#L65 and https://github.com/apache/incubator-mxnet/blob/d2a856a3a2abb4e72edc301b8b821f0b75f30722/src/operator/tensor/elemwise_binary_scalar_op_basic.cc#L68 . These also need to be fixed right ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. We can fix that in a separate PR. |
||
if (!dispatched) { | ||
dispatch_fallback(out_attrs, dispatch_mode); | ||
LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs); | ||
|
@@ -527,6 +533,78 @@ struct DotCsrTransRspRspByRowBlocks { | |
} | ||
}; | ||
|
||
/*! | ||
* \brief CPU Kernel of PopulateCsrForNNC | ||
* Parallelization by individual rows | ||
*/ | ||
struct PopulateCsrForNNC { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add brief description on what this kernel is for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added. |
||
/*! | ||
* \brief | ||
* \param i the i-th thread | ||
* \param nnc_idx all non zero column indexes | ||
* \param indptr_out indptr array for output | ||
* \param col_idx_out column indices for output | ||
* \param nnc number of non zero columns in the output | ||
* \param num_rows_l number of rows in lhs | ||
*/ | ||
template <typename IType, typename CType> | ||
MSHADOW_CINLINE static void Map(int i, const CType* nnc_idx, | ||
IType* indptr_out, CType* col_idx_out, | ||
const nnvm::dim_t nnc, | ||
const nnvm::dim_t num_rows_l) { | ||
const CType start_idx = i * nnc; | ||
nnvm::dim_t cur = 0; | ||
indptr_out[i] = start_idx; | ||
if (i == static_cast<int>(num_rows_l - 1)) indptr_out[i + 1] = indptr_out[i] + nnc; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we are adding large array support in the future, it's more appropriate to cast i up to |
||
for (IType idx = start_idx; idx < (start_idx + nnc); idx++) { | ||
col_idx_out[idx] = nnc_idx[cur++]; | ||
} | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief CPU Impl of dot(dns, csr) = csr | ||
*/ | ||
struct DotDnsCsrCsrByRowBlocks { | ||
/*! | ||
* \brief | ||
* \param i the i-th thread | ||
* \param num_rows_r number of rows in rhs | ||
* \param num_rows_l number of rows in lhs | ||
* \param num_cols number of columns in output | ||
* \param nnc number of non zero columns | ||
*/ | ||
|
||
template <typename DType, typename IType, typename CType> | ||
MSHADOW_CINLINE static void Map( | ||
int i, DType* out, const DType* data_l, const IType* indptr_r, | ||
const CType* col_idx_r, const DType* data_r, const nnvm::dim_t seg_len, | ||
const IType num_rows_r, const IType num_rows_l, | ||
const nnvm::dim_t num_cols, const nnvm::dim_t nnc, | ||
const CType* prefix_sum) { | ||
using nnvm::dim_t; | ||
const dim_t seg_start = i * seg_len; | ||
if (seg_start >= num_rows_l) return; | ||
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l); | ||
|
||
for (dim_t j = seg_start; j < seg_end; j++) { | ||
for (dim_t k = 0; k < num_rows_r; k++) { | ||
const dim_t working_idx = j * num_rows_r + k; | ||
const DType val = data_l[working_idx]; | ||
if (indptr_r[k] == indptr_r[k + 1]) continue; | ||
const dim_t row_start = j * nnc; | ||
for (dim_t cur = indptr_r[k]; cur < indptr_r[k + 1]; cur++) { | ||
dim_t cur_col_idx_r = col_idx_r[cur]; | ||
const dim_t out_idx = row_start + prefix_sum[cur_col_idx_r] - 1; | ||
out[out_idx] += val * data_r[cur]; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
|
||
|
||
/*! | ||
* \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2 | ||
*/ | ||
|
@@ -811,6 +889,94 @@ inline void DotCsrRspRspImpl(const OpContext& ctx, | |
}); | ||
} | ||
|
||
/* | ||
* \brief CPU Impl of dot(dns, csr) = csr | ||
*/ | ||
inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, | ||
const TBlob& lhs, const NDArray& rhs, | ||
const OpReqType req, NDArray* ret) { | ||
if (kNullOp == req) return; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing this. Fixed. |
||
CHECK_EQ(rhs.storage_type(), kCSRStorage); | ||
if (!rhs.storage_initialized()) return; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we set the result to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed! |
||
|
||
using namespace mshadow; | ||
using namespace mshadow::expr; | ||
using nnvm::dim_t; | ||
|
||
/*Initialize data structures*/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: space after /* |
||
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>(); | ||
const NDArray& out = *ret; | ||
const TBlob data_l = lhs; | ||
const TBlob data_r = rhs.data(); | ||
const TBlob indptr_r = rhs.aux_data(csr::kIndPtr); | ||
const TBlob col_idx_r = rhs.aux_data(csr::kIdx); | ||
|
||
MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type | ||
MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type | ||
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // colidx type | ||
/* Allocate workspace */ | ||
CType num_cols_out = out.shape()[1]; | ||
CType rhs_data_size = static_cast<CType>(col_idx_r.shape_.Size()); | ||
size_t workspace_size = 2 * num_cols_out * sizeof(CType); | ||
Tensor<cpu, 1, char> workspace = | ||
ctx.requested[0].get_space_typed<cpu, 1, char>( | ||
Shape1(workspace_size), s); | ||
CType* col_flg = reinterpret_cast<dim_t*>(workspace.dptr_); | ||
|
||
CType* prefix_sum = col_flg; | ||
CType* nnc_idx = prefix_sum + num_cols_out; | ||
|
||
/* Set the column flags for nnz columns */ | ||
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, num_cols_out, | ||
col_flg); | ||
mxnet_op::Kernel<MarkRowFlgKernel, cpu>::Launch( | ||
s, rhs_data_size, col_flg, col_idx_r.dptr<CType>()); | ||
|
||
/* 1. Calculate prefix sum from col flgs | ||
* 2. Storage all non zero column indexes in nnc_idx | ||
*/ | ||
CType cur = 0; | ||
prefix_sum[0] = col_flg[0]; | ||
if (prefix_sum[0]) nnc_idx[cur++] = 0; | ||
for (CType i = 1; i < num_cols_out; i++) { | ||
prefix_sum[i] = prefix_sum[i - 1] + col_flg[i]; | ||
if (prefix_sum[i] > prefix_sum[i - 1]) nnc_idx[cur++] = i; | ||
} | ||
|
||
/* Allocate aux data for out */ | ||
IType num_rows_l = lhs.shape_[0]; | ||
dim_t nnc = prefix_sum[num_cols_out - 1]; | ||
dim_t nnz = nnc * num_rows_l; | ||
out.CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows_l + 1)); | ||
out.CheckAndAllocAuxData(csr::kIdx, Shape1(nnz)); | ||
out.CheckAndAllocData(Shape1(nnz)); | ||
|
||
/* Set csr indptr and index according to nnc_idx*/ | ||
IType* indptr_out = out.aux_data(csr::kIndPtr).dptr<IType>(); | ||
CType* col_idx_out = out.aux_data(csr::kIdx).dptr<CType>(); | ||
DType* data_out = out.data().dptr<DType>(); | ||
mxnet_op::Kernel<PopulateCsrForNNC, cpu>::Launch( | ||
s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l); | ||
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, nnz, data_out); | ||
|
||
if (nnc == 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why should nnc never be 0 ? This is possible when number of non zero columns are zero(matrix with all zeros) in the rhs. In this case we return the output correctly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because you already checked There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have removed the if and also added some documentation for storage_initialized |
||
return; | ||
} | ||
|
||
dim_t num_threads = mxnet_op::get_num_threads<cpu>(num_rows_l); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: const for both |
||
dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads; | ||
|
||
IType num_rows_r = rhs.shape()[0]; | ||
mxnet_op::Kernel<DotDnsCsrCsrByRowBlocks, cpu>::Launch( | ||
s, num_threads, data_out, data_l.dptr<DType>(), | ||
indptr_r.dptr<IType>(), col_idx_r.dptr<CType>(), | ||
data_r.dptr<DType>(), seg_len, num_rows_r, num_rows_l, num_cols_out, | ||
nnc, prefix_sum); | ||
}); | ||
}); | ||
}); | ||
} | ||
|
||
inline bool DotShape(const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape> *in_attrs, | ||
std::vector<TShape> *out_attrs) { | ||
|
@@ -886,6 +1052,11 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, | |
&& out_stype == kRowSparseStorage && !param.transpose_b) { | ||
NDArray ret = outputs[0]; | ||
DotCsrRspRspImpl(ctx, xpu(), inputs[0], inputs[1], req[0], param.transpose_a, &ret); | ||
} else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && | ||
out_stype == kCSRStorage && | ||
!(param.transpose_a || param.transpose_b)) { | ||
NDArray ret = outputs[0]; | ||
DotDnsCsrCsrImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret); | ||
} else { | ||
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1223,6 +1223,31 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, rhs_de | |
grad_req={'lhs': 'null', 'rhs': 'write'}, | ||
rtol=1e-3, atol=1e-4) | ||
|
||
def test_dot_dns_csr(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_lhs=False, trans_rhs=False): | ||
lhs_nd = rand_ndarray(lhs_shape, stype='default', density=lhs_density) | ||
rhs_nd = rand_ndarray(rhs_shape, stype='csr', density=rhs_density) | ||
rhs_dns = rhs_nd.tostype('default') | ||
|
||
out = mx.nd.sparse.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs, transpose_b=trans_rhs) | ||
out_dns = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs, transpose_b=trans_rhs) | ||
out_np = out_dns.asnumpy() | ||
assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5) | ||
|
||
# test symbolic forward | ||
lhs = mx.symbol.Variable('lhs', stype='default') | ||
rhs = mx.symbol.Variable('rhs', stype='csr') | ||
out = mx.symbol.sparse.dot(lhs, rhs, transpose_a=trans_lhs, transpose_b=trans_rhs) | ||
location = {'lhs': lhs_nd, 'rhs': rhs_nd} | ||
check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4) | ||
|
||
# test symbolic backward | ||
backward_trans = not trans_lhs | ||
rhs_backward_grad = mx.nd.dot(lhs_nd, out_dns, transpose_a=backward_trans).asnumpy() | ||
expected = {'rhs': rhs_backward_grad} | ||
check_symbolic_backward(out, location, [out_np], expected, | ||
grad_req={'lhs': 'null', 'rhs': 'write'}, | ||
rtol=1e-3, atol=1e-4) | ||
|
||
def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): | ||
"""Test for nnr_out = 0. Before the fix, the test would fail.""" | ||
lhs = mx.nd.zeros(lhs_shape) | ||
|
@@ -1248,10 +1273,12 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): | |
test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True, lhs_d, rhs_d) # (vector kernel) | ||
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d) # test gpu SpMM | ||
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d) # (scalar kernel) | ||
test_dot_dns_csr(lhs_shape, (lhs_shape[1], rnd.randint(500, 1000)), lhs_d, lhs_d) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. randint(50,200) is large (and slow) enough for testing. No need to increase the dim to 1000. |
||
for rhs_d in density: | ||
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, lhs_d, rhs_d) | ||
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, lhs_d, rhs_d) | ||
|
||
|
||
test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40) | ||
test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check and the error statement don't seem to match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Fixed.