Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

rsp push and rsp pull for comm device, used in kvstore('device') #8732

Merged
merged 35 commits into from
Jan 15, 2018
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
fb0077e
comm device for rsp push and pull
ZiyueHuang Nov 20, 2017
5288bc6
update
ZiyueHuang Nov 20, 2017
865a117
resolve conflict
ZiyueHuang Nov 26, 2017
296f122
update test
ZiyueHuang Nov 27, 2017
8b8c14f
optimization for same row_ids
ZiyueHuang Nov 27, 2017
4fb29ae
resolve conflict
ZiyueHuang Nov 28, 2017
c37ee41
add stream->wait
ZiyueHuang Nov 28, 2017
96c7a2f
remove using space
ZiyueHuang Nov 28, 2017
0990c69
fix race of rsc and extend ElementwiseSum to rsp cases
ZiyueHuang Nov 29, 2017
32f25c8
add log fatal in ElementwiseSum
ZiyueHuang Nov 29, 2017
079981f
resolve
ZiyueHuang Dec 1, 2017
0e4a1c6
direct copy rows if full rsp and put all outputs on ctx of src
ZiyueHuang Dec 2, 2017
6a18d83
Merge remote-tracking branch 'upstream/master' into comm_device
ZiyueHuang Dec 2, 2017
31bfad8
trigger
ZiyueHuang Dec 2, 2017
910d4fa
fix
ZiyueHuang Dec 2, 2017
9dca449
simplify copy
ZiyueHuang Dec 2, 2017
20b28eb
move check same rowids to utils and add test for same rowids case
ZiyueHuang Dec 3, 2017
4084fc2
Merge remote-tracking branch 'upstream/master' into comm_device
Dec 8, 2017
75c0656
Merge remote-tracking branch 'upstream/master' into comm_device
Dec 9, 2017
690ec92
remove direct copy row by row
ZiyueHuang Dec 10, 2017
0c833ed
Merge remote-tracking branch 'upstream/master' into comm_device
ZiyueHuang Dec 10, 2017
1723594
fix checkSameRowid
ZiyueHuang Dec 12, 2017
b0d53ad
Merge remote-tracking branch 'upstream/master' into comm_device
ZiyueHuang Dec 12, 2017
9e06a08
gpu unique impl draft
ZiyueHuang Dec 17, 2017
d84bf47
unique
ZiyueHuang Dec 17, 2017
5f55545
Merge remote-tracking branch 'upstream/master' into comm_device
ZiyueHuang Dec 17, 2017
f16faa1
update
ZiyueHuang Dec 17, 2017
72e752d
Merge remote-tracking branch 'upstream/master' into comm_device
ZiyueHuang Dec 20, 2017
5695b52
fix windows build
ZiyueHuang Dec 20, 2017
66ae47d
trigger windows build
ZiyueHuang Dec 20, 2017
8179fab
Merge remote-tracking branch 'upstream/master' into comm_device
ZiyueHuang Dec 22, 2017
134b98f
support single rowid with multiple vals
ZiyueHuang Dec 22, 2017
c96b158
address comments
ZiyueHuang Jan 1, 2018
1b09d09
check same row_ids and copy in fronted
ZiyueHuang Jan 11, 2018
e0a68c4
revise names and disable test for local kvstore
ZiyueHuang Jan 12, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "./utils.h"
#include "../operator/tensor/cast_storage-inl.h"
#include "../operator/tensor/sparse_retain-inl.h"

namespace mxnet {
namespace common {
Expand All @@ -34,6 +35,15 @@ void CheckFormatWrapper<cpu>(const RunContext &rctx, const NDArray &input,
CheckFormatImpl<cpu>(rctx, input, err_cpu, full_check);
}

template<>
void SparseRetainOpForwardRspWrapper<cpu>(mshadow::Stream<cpu> *s,
const NDArray& input_nd,
const TBlob& idx_data,
const OpReqType req,
NDArray* output_nd) {
mxnet::op::SparseRetainOpForwardRspImpl<cpu>(s, input_nd, idx_data, req, output_nd);
}

template<>
void CastStorageDispatch<cpu>(const OpContext& ctx,
const NDArray& input,
Expand Down
10 changes: 10 additions & 0 deletions src/common/utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "./utils.h"
#include "../operator/tensor/cast_storage-inl.h"
#include "../operator/tensor/sparse_retain-inl.h"

namespace mxnet {
namespace common {
Expand All @@ -34,6 +35,15 @@ void CheckFormatWrapper<gpu>(const RunContext &rctx, const NDArray &input,
CheckFormatImpl<gpu>(rctx, input, err_cpu, full_check);
}

template<>
void SparseRetainOpForwardRspWrapper<gpu>(mshadow::Stream<gpu> *s,
const NDArray& input_nd,
const TBlob& idx_data,
const OpReqType req,
NDArray* output_nd) {
mxnet::op::SparseRetainOpForwardRspImpl<gpu>(s, input_nd, idx_data, req, output_nd);
}

template<>
void CastStorageDispatch<gpu>(const OpContext& ctx,
const NDArray& input,
Expand Down
11 changes: 11 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,18 @@ void CheckFormatImpl(const RunContext &rctx, const NDArray &input,
}
}

/*! \brief Pick rows specified by user input index array from a row sparse ndarray
* and save them in the output sparse ndarray.
*/
template<typename xpu>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure functions in .h are documented. Should add some description for CastStorageDispatch too...

void SparseRetainOpForwardRspWrapper(mshadow::Stream<xpu> *s,
const NDArray& input_nd,
const TBlob& idx_data,
const OpReqType req,
NDArray* output_nd);

/* \brief Casts tensor storage type to the new type.
*/
template<typename xpu>
void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output);

Expand Down
135 changes: 104 additions & 31 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "gradient_compression.h"
#include "../ndarray/ndarray_function.h"
#include "../operator/tensor/sparse_retain-inl.h"
#include "./utils.h"
namespace mxnet {
namespace kvstore {
/**
Expand Down Expand Up @@ -176,17 +177,17 @@ class CommCPU : public Comm {
reduce[i] = buf.copy_buf[i];
const_vars[i] = reduce[i].var();
}
auto result = buf.merged;
NDArray result = buf.merged;
Resource rsc = ResourceManager::Get()->Request(result.ctx(),
ResourceRequest(ResourceRequest::kTempSpace));
Engine::Get()->PushAsync(
[reduce, result, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
[reduce, result, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray out = result;
Resource rsc = ResourceManager::Get()->Request(rctx.ctx,
ResourceRequest(ResourceRequest::kTempSpace));
is_serial_push_?
ReduceSumCPUExSerial(reduce, &out)
: mxnet::ndarray::ElementwiseSum(rctx.get_stream<cpu>(), rsc, reduce, &out);
on_complete();
}, Context::CPU(), const_vars, {result.var()},
}, Context::CPU(), const_vars, {result.var(), rsc.var},
FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));
}

Expand Down Expand Up @@ -491,11 +492,7 @@ class CommDevice : public Comm {

void Init(int key, const NDArrayStorageType stype, const TShape& shape,
int dtype = mshadow::kFloat32) override {
if (stype == kDefaultStorage) {
sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype));
} else {
LOG(FATAL) << "storage type " << stype << " not implemented for device yet";
}
sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype, stype));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using emplace_back(key, shape, dtype, stype) can avoid constructing temporary tuple object.

}

void InitBuffersAndComm(const std::vector<NDArray>& src) {
Expand Down Expand Up @@ -528,26 +525,42 @@ class CommDevice : public Comm {
InitBuffersAndComm(src);
auto& buf = merge_buf_[key];
std::vector<NDArray> reduce(src.size());
CopyFromTo(src[0], &(buf.merged), priority);
reduce[0] = buf.merged;

if (buf.copy_buf.empty()) {
// TODO(mli) this results in large device memory usage for huge ndarray,
// such as the largest fullc in VGG. consider to do segment reduce with
// NDArray.Slice or gpu direct memory access. for the latter, we need to
// remove some ctx check, and also it reduces 20% perf
buf.copy_buf.resize(src.size()-1);
const NDArrayStorageType stype = buf.merged.storage_type();
if (stype == kDefaultStorage) {
CopyFromTo(src[0], &(buf.merged), priority);
reduce[0] = buf.merged;

if (buf.copy_buf.empty()) {
// TODO(mli) this results in large device memory usage for huge ndarray,
// such as the largest fullc in VGG. consider to do segment reduce with
// NDArray.Slice or gpu direct memory access. for the latter, we need to
// remove some ctx check, and also it reduces 20% perf
buf.copy_buf.resize(src.size()-1);
for (size_t i = 0; i < src.size()-1; ++i) {
buf.copy_buf[i] = NDArray(
buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
}
}
for (size_t i = 0; i < src.size()-1; ++i) {
buf.copy_buf[i] = NDArray(
buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
reduce[i+1] = buf.copy_buf[i];
}
} else {
if (buf.copy_buf.empty()) {
buf.copy_buf.resize(src.size());
for (size_t j = 0; j < src.size(); ++j) {
buf.copy_buf[j] = NDArray(
buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(),
true, buf.merged.dtype());
}
}
for (size_t i = 0; i < src.size(); ++i) {
CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
reduce[i] = buf.copy_buf[i];
}
}
for (size_t i = 0; i < src.size()-1; ++i) {
CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
reduce[i+1] = buf.copy_buf[i];
}

ElementwiseSum(reduce, &buf.merged);
ElementwiseSum(reduce, &buf.merged, priority);
return buf.merged;
}

Expand Down Expand Up @@ -621,7 +634,62 @@ class CommDevice : public Comm {
const std::vector<std::pair<NDArray*, NDArray>>& dst,
const bool use_copy,
const int priority) override {
LOG(FATAL) << "Not implemented yet";
CHECK_EQ(src.storage_type(), kRowSparseStorage)
<< "BroadcastRowSparse expects row-sparse src NDArray";

// whether the indices are the same
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is code is duplicated in comm.h and kvstore_local.h. Shall we move it to util.h?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved

const bool is_same_rowid = CheckSameRowid(dst);
for (size_t i = 0; i < dst.size(); ++i) {
// the result can be copied to other devices without invoking sparse retain operator
// if the indices are the same
if (is_same_rowid && i != 0) {
CopyFromTo(*dst[0].first, dst[i].first, priority);
continue;
}

NDArray* out = dst[i].first;
NDArray row_id = dst[i].second;
if (use_copy) {
CopyFromTo(src, out, priority);
} else {
CHECK_EQ(out->storage_type(), kRowSparseStorage)
<< "BroadcastRowSparse expects row_sparse dst NDArray";

const bool is_diff_ctx = out->ctx() != src.ctx();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we assuming src is always on GPU?
If so, should we perform retain first before copying it to other devices?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src is not assumed to be on gpu. Actually src is always on cpu. As you can see in https://github.com/apache/incubator-mxnet/blob/master/src/kvstore/kvstore_local.h#L233, src is local_[key]. And local_[key] is initialized to be on pinned_ctx_ which is always cpu, https://github.com/apache/incubator-mxnet/blob/master/src/kvstore/kvstore_local.h#L152.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true at the beginning. But as soon as you push some gradients on GPU, it copies the weight from pinned_ctx to GPU. See
https://github.com/apache/incubator-mxnet/blob/master/src/kvstore/kvstore_local.h#L173

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nonetheless, I think performing sparse retain before the copy makes more sense since the source array is usually very large.

NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
src.ctx(), true, out->dtype(), out->aux_types()) : *out;

CHECK_EQ(row_id.ctx(), src.ctx())
<< "row_id and src are expected to be on the same context";

Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray temp = out_gpu;
const TBlob& indices = row_id.data();
switch (temp.ctx().dev_mask()) {
case cpu::kDevMask: {
mxnet::common::SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
src, indices, kWriteTo, &temp);
break;
}
#if MXNET_USE_CUDA
case gpu::kDevMask: {
mxnet::common::SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(),
src, indices, kWriteTo, &temp);
// wait for GPU operations to complete
rctx.get_stream<gpu>()->Wait();
break;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is Stream->Wait() missing?

#endif
default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
}
on_complete();
}, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
if (is_diff_ctx) {
CopyFromTo(out_gpu, out, priority);
}
}
}
}

private:
Expand Down Expand Up @@ -667,7 +735,7 @@ class CommDevice : public Comm {
#endif
}

using KeyAttrs = std::tuple<int, TShape, int>;
using KeyAttrs = std::tuple<int, TShape, int, NDArrayStorageType>;
// try to allocate buff on device evenly
void InitMergeBuffer(const std::vector<Context>& devs) {
std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
Expand All @@ -681,8 +749,9 @@ class CommDevice : public Comm {
}
for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
int key = std::get<0>(sorted_key_attrs_[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const int

TShape s = std::get<1>(sorted_key_attrs_[i]);
TShape shape = std::get<1>(sorted_key_attrs_[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const TShape&

int type = std::get<2>(sorted_key_attrs_[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const int

const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]);
auto& buf = merge_buf_[key];
Context ctx;
size_t min_size = std::numeric_limits<size_t>::max();
Expand All @@ -693,8 +762,12 @@ class CommDevice : public Comm {
min_size = size;
}
}
buf.merged = NDArray(s, ctx, false, type);
ctx_info[ctx.dev_id].second += s.Size();
if (stype == kDefaultStorage) {
buf.merged = NDArray(shape, ctx, false, type);
} else {
buf.merged = NDArray(stype, shape, ctx, true, type);
}
ctx_info[ctx.dev_id].second += shape.Size();
}
inited_ = true;
}
Expand Down
57 changes: 39 additions & 18 deletions src/kvstore/kvstore_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <functional>
#include <algorithm>
#include "./comm.h"
#include "./utils.h"

namespace mxnet {
namespace kvstore {
Expand Down Expand Up @@ -223,12 +224,18 @@ class KVStoreLocal : public KVStore {
<< "PullRowSparse expects row_sparse src NDArray";
auto &target_val_rowids = grouped_val_rowids[i];
const size_t num_vals = target_val_rowids.size();
for (size_t i = 0; i < num_vals; i++) {
auto &row_id = target_val_rowids[i].second;
NDArray indices(row_id.shape(), pinned_ctx_, false, mshadow::kInt64);
CopyFromTo(row_id, &indices, 0);
Unique(&indices, priority);
target_val_rowids[i].second = indices;
// whether the indices are the same
const bool is_same_rowid = CheckSameRowid(target_val_rowids);
for (size_t j = 0; j < num_vals; j++) {
if (is_same_rowid && j != 0) {
target_val_rowids[j].second = target_val_rowids[0].second;
} else {
auto &row_id = target_val_rowids[j].second;
NDArray indices(row_id.shape(), local.ctx(), false, mshadow::kInt64);
CopyFromTo(row_id, &indices, 0);
Unique(&indices, priority);
target_val_rowids[j].second = indices;
}
}
comm_->BroadcastRowSparse(key, local, grouped_val_rowids[i], false, priority);
}
Expand Down Expand Up @@ -354,29 +361,43 @@ class KVStoreLocal : public KVStore {
}

/**
* \brief sort and get unique values. Output is expected to be on cpu_pinned context
* \brief sort and get unique values.
*/
void Unique(NDArray *out, int priority = 0) {
CHECK_EQ(out->ctx().dev_mask(), pinned_ctx_.dev_mask())
<< "Unique expects input with `pinned_ctx_`";
void Unique(NDArray *out, int priority) {
Resource rsc = ResourceManager::Get()->Request(out->ctx(),
ResourceRequest(ResourceRequest::kTempSpace));
Engine::Get()->PushAsync(
[out](RunContext rctx, Engine::CallbackOnComplete on_complete) {
[rsc, out](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray *output = out;
CHECK_EQ(out->shape().ndim(), 1) << "Unique expects 1D inputs";
const auto size = out->shape()[0];
auto size = out->shape()[0];
auto out_data = output->data();
MSHADOW_IDX_TYPE_SWITCH(out_data.type_flag_, IType, {
auto dptr = output->data().dptr<IType>();
common::ParallelSort(dptr, dptr + size, omp_get_max_threads());
auto num_unique_idx = std::unique(dptr, dptr + size) - dptr;
*output = output->Reshape(mshadow::Shape1(num_unique_idx));
IType *dptr = output->data().dptr<IType>();
switch (out->ctx().dev_mask()) {
case cpu::kDevMask: {
mshadow::Stream<cpu> *s = rctx.get_stream<cpu>();
UniqueImpl(rsc, s, output, size);
break;
}
#if MXNET_USE_CUDA
case gpu::kDevMask: {
mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
UniqueImpl(rsc, s, output, size);
break;
}
#endif
default:
LOG(FATAL) << "GPU not enabled.";
}
});
on_complete();
}, pinned_ctx_, {}, {out->var()},
FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreUnique"));
}, out->ctx(), {}, {out->var(), rsc.var},
FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreUnique"));
out->WaitToRead();
}


/// reducer and broadcaster
Comm* comm_;
/// pinned context
Expand Down
Loading