-
Notifications
You must be signed in to change notification settings - Fork 6.8k
rsp push and rsp pull for comm device, used in kvstore('device') #8732
Changes from 27 commits
fb0077e
5288bc6
865a117
296f122
8b8c14f
4fb29ae
c37ee41
96c7a2f
0990c69
32f25c8
079981f
0e4a1c6
6a18d83
31bfad8
910d4fa
9dca449
20b28eb
4084fc2
75c0656
690ec92
0c833ed
1723594
b0d53ad
9e06a08
d84bf47
5f55545
f16faa1
72e752d
5695b52
66ae47d
8179fab
134b98f
c96b158
1b09d09
e0a68c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
#include "gradient_compression.h" | ||
#include "../ndarray/ndarray_function.h" | ||
#include "../operator/tensor/sparse_retain-inl.h" | ||
#include "./utils.h" | ||
namespace mxnet { | ||
namespace kvstore { | ||
/** | ||
|
@@ -176,17 +177,17 @@ class CommCPU : public Comm { | |
reduce[i] = buf.copy_buf[i]; | ||
const_vars[i] = reduce[i].var(); | ||
} | ||
auto result = buf.merged; | ||
NDArray result = buf.merged; | ||
Resource rsc = ResourceManager::Get()->Request(result.ctx(), | ||
ResourceRequest(ResourceRequest::kTempSpace)); | ||
Engine::Get()->PushAsync( | ||
[reduce, result, this](RunContext rctx, Engine::CallbackOnComplete on_complete) { | ||
[reduce, result, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) { | ||
NDArray out = result; | ||
Resource rsc = ResourceManager::Get()->Request(rctx.ctx, | ||
ResourceRequest(ResourceRequest::kTempSpace)); | ||
is_serial_push_? | ||
ReduceSumCPUExSerial(reduce, &out) | ||
: mxnet::ndarray::ElementwiseSum(rctx.get_stream<cpu>(), rsc, reduce, &out); | ||
on_complete(); | ||
}, Context::CPU(), const_vars, {result.var()}, | ||
}, Context::CPU(), const_vars, {result.var(), rsc.var}, | ||
FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); | ||
} | ||
|
||
|
@@ -491,11 +492,7 @@ class CommDevice : public Comm { | |
|
||
void Init(int key, const NDArrayStorageType stype, const TShape& shape, | ||
int dtype = mshadow::kFloat32) override { | ||
if (stype == kDefaultStorage) { | ||
sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype)); | ||
} else { | ||
LOG(FATAL) << "storage type " << stype << " not implemented for device yet"; | ||
} | ||
sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype, stype)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||
} | ||
|
||
void InitBuffersAndComm(const std::vector<NDArray>& src) { | ||
|
@@ -528,26 +525,42 @@ class CommDevice : public Comm { | |
InitBuffersAndComm(src); | ||
auto& buf = merge_buf_[key]; | ||
std::vector<NDArray> reduce(src.size()); | ||
CopyFromTo(src[0], &(buf.merged), priority); | ||
reduce[0] = buf.merged; | ||
|
||
if (buf.copy_buf.empty()) { | ||
// TODO(mli) this results in large device memory usage for huge ndarray, | ||
// such as the largest fullc in VGG. consider to do segment reduce with | ||
// NDArray.Slice or gpu direct memory access. for the latter, we need to | ||
// remove some ctx check, and also it reduces 20% perf | ||
buf.copy_buf.resize(src.size()-1); | ||
const NDArrayStorageType stype = buf.merged.storage_type(); | ||
if (stype == kDefaultStorage) { | ||
CopyFromTo(src[0], &(buf.merged), priority); | ||
reduce[0] = buf.merged; | ||
|
||
if (buf.copy_buf.empty()) { | ||
// TODO(mli) this results in large device memory usage for huge ndarray, | ||
// such as the largest fullc in VGG. consider to do segment reduce with | ||
// NDArray.Slice or gpu direct memory access. for the latter, we need to | ||
// remove some ctx check, and also it reduces 20% perf | ||
buf.copy_buf.resize(src.size()-1); | ||
for (size_t i = 0; i < src.size()-1; ++i) { | ||
buf.copy_buf[i] = NDArray( | ||
buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype()); | ||
} | ||
} | ||
for (size_t i = 0; i < src.size()-1; ++i) { | ||
buf.copy_buf[i] = NDArray( | ||
buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype()); | ||
CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority); | ||
reduce[i+1] = buf.copy_buf[i]; | ||
} | ||
} else { | ||
if (buf.copy_buf.empty()) { | ||
buf.copy_buf.resize(src.size()); | ||
for (size_t j = 0; j < src.size(); ++j) { | ||
buf.copy_buf[j] = NDArray( | ||
buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(), | ||
true, buf.merged.dtype()); | ||
} | ||
} | ||
for (size_t i = 0; i < src.size(); ++i) { | ||
CopyFromTo(src[i], &(buf.copy_buf[i]), priority); | ||
reduce[i] = buf.copy_buf[i]; | ||
} | ||
} | ||
for (size_t i = 0; i < src.size()-1; ++i) { | ||
CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority); | ||
reduce[i+1] = buf.copy_buf[i]; | ||
} | ||
|
||
ElementwiseSum(reduce, &buf.merged); | ||
ElementwiseSum(reduce, &buf.merged, priority); | ||
return buf.merged; | ||
} | ||
|
||
|
@@ -621,7 +634,62 @@ class CommDevice : public Comm { | |
const std::vector<std::pair<NDArray*, NDArray>>& dst, | ||
const bool use_copy, | ||
const int priority) override { | ||
LOG(FATAL) << "Not implemented yet"; | ||
CHECK_EQ(src.storage_type(), kRowSparseStorage) | ||
<< "BroadcastRowSparse expects row-sparse src NDArray"; | ||
|
||
// whether the indices are the same | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is code is duplicated in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved |
||
const bool is_same_rowid = CheckSameRowid(dst); | ||
for (size_t i = 0; i < dst.size(); ++i) { | ||
// the result can be copied to other devices without invoking sparse retain operator | ||
// if the indices are the same | ||
if (is_same_rowid && i != 0) { | ||
CopyFromTo(*dst[0].first, dst[i].first, priority); | ||
continue; | ||
} | ||
|
||
NDArray* out = dst[i].first; | ||
NDArray row_id = dst[i].second; | ||
if (use_copy) { | ||
CopyFromTo(src, out, priority); | ||
} else { | ||
CHECK_EQ(out->storage_type(), kRowSparseStorage) | ||
<< "BroadcastRowSparse expects row_sparse dst NDArray"; | ||
|
||
const bool is_diff_ctx = out->ctx() != src.ctx(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we assuming src is always on GPU? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. src is not assumed to be on gpu. Actually src is always on cpu. As you can see in https://github.com/apache/incubator-mxnet/blob/master/src/kvstore/kvstore_local.h#L233, src is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's true at the beginning. But as soon as you push some gradients on GPU, it copies the weight from pinned_ctx to GPU. See There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nonetheless, I think performing sparse retain before the copy makes more sense since the source array is usually very large. |
||
NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(), | ||
src.ctx(), true, out->dtype(), out->aux_types()) : *out; | ||
|
||
CHECK_EQ(row_id.ctx(), src.ctx()) | ||
<< "row_id and src are expected to be on the same context"; | ||
|
||
Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) { | ||
NDArray temp = out_gpu; | ||
const TBlob& indices = row_id.data(); | ||
switch (temp.ctx().dev_mask()) { | ||
case cpu::kDevMask: { | ||
mxnet::common::SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(), | ||
src, indices, kWriteTo, &temp); | ||
break; | ||
} | ||
#if MXNET_USE_CUDA | ||
case gpu::kDevMask: { | ||
mxnet::common::SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(), | ||
src, indices, kWriteTo, &temp); | ||
// wait for GPU operations to complete | ||
rctx.get_stream<gpu>()->Wait(); | ||
break; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is |
||
#endif | ||
default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; | ||
} | ||
on_complete(); | ||
}, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()}, | ||
FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain")); | ||
if (is_diff_ctx) { | ||
CopyFromTo(out_gpu, out, priority); | ||
} | ||
} | ||
} | ||
} | ||
|
||
private: | ||
|
@@ -667,7 +735,7 @@ class CommDevice : public Comm { | |
#endif | ||
} | ||
|
||
using KeyAttrs = std::tuple<int, TShape, int>; | ||
using KeyAttrs = std::tuple<int, TShape, int, NDArrayStorageType>; | ||
// try to allocate buff on device evenly | ||
void InitMergeBuffer(const std::vector<Context>& devs) { | ||
std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), []( | ||
|
@@ -681,8 +749,9 @@ class CommDevice : public Comm { | |
} | ||
for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { | ||
int key = std::get<0>(sorted_key_attrs_[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const int |
||
TShape s = std::get<1>(sorted_key_attrs_[i]); | ||
TShape shape = std::get<1>(sorted_key_attrs_[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const TShape& |
||
int type = std::get<2>(sorted_key_attrs_[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const int |
||
const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]); | ||
auto& buf = merge_buf_[key]; | ||
Context ctx; | ||
size_t min_size = std::numeric_limits<size_t>::max(); | ||
|
@@ -693,8 +762,12 @@ class CommDevice : public Comm { | |
min_size = size; | ||
} | ||
} | ||
buf.merged = NDArray(s, ctx, false, type); | ||
ctx_info[ctx.dev_id].second += s.Size(); | ||
if (stype == kDefaultStorage) { | ||
buf.merged = NDArray(shape, ctx, false, type); | ||
} else { | ||
buf.merged = NDArray(stype, shape, ctx, true, type); | ||
} | ||
ctx_info[ctx.dev_id].second += shape.Size(); | ||
} | ||
inited_ = true; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure functions in
.h
are documented. Should add some description forCastStorageDispatch
too...