diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index b2a4beaf93b6..890c9024d872 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -298,7 +298,8 @@ def pull(self, key, out=None, priority=0): def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): """ Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \ - from the store with specified row_ids. + from the store with specified row_ids. When there is only one row_id, KVStoreRowSparsePull \ + is invoked just once and the result is broadcast to all the rest of outputs. `row_sparse_pull` is executed asynchronously after all previous `pull`/`row_sparse_pull` calls and the last `push` call for the @@ -349,7 +350,17 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): """ assert(out is not None) assert(row_ids is not None) - ckeys, cvals, use_str_keys = _ctype_key_value(key, out) + if isinstance(row_ids, NDArray): + row_ids = [row_ids] + assert(isinstance(row_ids, list)), \ + "row_ids should be NDArray or list of NDArray" + first_out = out + # whether row_ids are the same + single_rowid = False + if len(row_ids) == 1 and isinstance(out, list): + single_rowid = True + first_out = [out[0]] + ckeys, cvals, use_str_keys = _ctype_key_value(key, first_out) _, crow_ids, _ = _ctype_key_value(key, row_ids) assert(len(crow_ids) == len(cvals)), \ "the number of row_ids doesn't match the number of values" @@ -359,6 +370,11 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): else: check_call(_LIB.MXKVStorePullRowSparse( self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority))) + # the result can be copied to other devices without invoking row_sparse_pull + # if the indices are the same + if single_rowid: + for out_i in out[1:]: + out[0].copyto(out_i) def set_gradient_compression(self, compression_params): """ Specifies type of low-bit quantization for gradient compression \ diff --git a/src/common/utils.cc b/src/common/utils.cc index 784fcf8651ae..9fe46d94d036 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -24,6 +24,7 @@ #include "./utils.h" #include "../operator/tensor/cast_storage-inl.h" +#include "../operator/tensor/sparse_retain-inl.h" namespace mxnet { namespace common { @@ -34,6 +35,15 @@ void CheckFormatWrapper(const RunContext &rctx, const NDArray &input, CheckFormatImpl(rctx, input, err_cpu, full_check); } +template<> +void SparseRetainOpForwardRspWrapper(mshadow::Stream *s, + const NDArray& input_nd, + const TBlob& idx_data, + const OpReqType req, + NDArray* output_nd) { + mxnet::op::SparseRetainOpForwardRspImpl(s, input_nd, idx_data, req, output_nd); +} + template<> void CastStorageDispatch(const OpContext& ctx, const NDArray& input, diff --git a/src/common/utils.cu b/src/common/utils.cu index c6e2bf813876..0937d7aa5145 100644 --- a/src/common/utils.cu +++ b/src/common/utils.cu @@ -24,6 +24,7 @@ #include "./utils.h" #include "../operator/tensor/cast_storage-inl.h" +#include "../operator/tensor/sparse_retain-inl.h" namespace mxnet { namespace common { @@ -34,6 +35,15 @@ void CheckFormatWrapper(const RunContext &rctx, const NDArray &input, CheckFormatImpl(rctx, input, err_cpu, full_check); } +template<> +void SparseRetainOpForwardRspWrapper(mshadow::Stream *s, + const NDArray& input_nd, + const TBlob& idx_data, + const OpReqType req, + NDArray* output_nd) { + mxnet::op::SparseRetainOpForwardRspImpl(s, input_nd, idx_data, req, output_nd); +} + template<> void CastStorageDispatch(const OpContext& ctx, const NDArray& input, diff --git a/src/common/utils.h b/src/common/utils.h index 038ab2a04721..535dc75200f8 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -214,7 +214,18 @@ void CheckFormatImpl(const RunContext &rctx, const NDArray &input, } } +/*! \brief Pick rows specified by user input index array from a row sparse ndarray + * and save them in the output sparse ndarray. + */ +template +void SparseRetainOpForwardRspWrapper(mshadow::Stream *s, + const NDArray& input_nd, + const TBlob& idx_data, + const OpReqType req, + NDArray* output_nd); +/* \brief Casts tensor storage type to the new type. + */ template void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output); diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 5429df70b173..d41fa64cf538 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -34,6 +34,7 @@ #include "gradient_compression.h" #include "../ndarray/ndarray_function.h" #include "../operator/tensor/sparse_retain-inl.h" +#include "./utils.h" namespace mxnet { namespace kvstore { /** @@ -176,17 +177,17 @@ class CommCPU : public Comm { reduce[i] = buf.copy_buf[i]; const_vars[i] = reduce[i].var(); } - auto result = buf.merged; + NDArray result = buf.merged; + Resource rsc = ResourceManager::Get()->Request(result.ctx(), + ResourceRequest(ResourceRequest::kTempSpace)); Engine::Get()->PushAsync( - [reduce, result, this](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [reduce, result, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) { NDArray out = result; - Resource rsc = ResourceManager::Get()->Request(rctx.ctx, - ResourceRequest(ResourceRequest::kTempSpace)); is_serial_push_? ReduceSumCPUExSerial(reduce, &out) : mxnet::ndarray::ElementwiseSum(rctx.get_stream(), rsc, reduce, &out); on_complete(); - }, Context::CPU(), const_vars, {result.var()}, + }, Context::CPU(), const_vars, {result.var(), rsc.var}, FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); } @@ -491,11 +492,7 @@ class CommDevice : public Comm { void Init(int key, const NDArrayStorageType stype, const TShape& shape, int dtype = mshadow::kFloat32) override { - if (stype == kDefaultStorage) { - sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype)); - } else { - LOG(FATAL) << "storage type " << stype << " not implemented for device yet"; - } + sorted_key_attrs_.emplace_back(key, shape, dtype, stype); } void InitBuffersAndComm(const std::vector& src) { @@ -528,26 +525,42 @@ class CommDevice : public Comm { InitBuffersAndComm(src); auto& buf = merge_buf_[key]; std::vector reduce(src.size()); - CopyFromTo(src[0], &(buf.merged), priority); - reduce[0] = buf.merged; - if (buf.copy_buf.empty()) { - // TODO(mli) this results in large device memory usage for huge ndarray, - // such as the largest fullc in VGG. consider to do segment reduce with - // NDArray.Slice or gpu direct memory access. for the latter, we need to - // remove some ctx check, and also it reduces 20% perf - buf.copy_buf.resize(src.size()-1); + const NDArrayStorageType stype = buf.merged.storage_type(); + if (stype == kDefaultStorage) { + CopyFromTo(src[0], &(buf.merged), priority); + reduce[0] = buf.merged; + + if (buf.copy_buf.empty()) { + // TODO(mli) this results in large device memory usage for huge ndarray, + // such as the largest fullc in VGG. consider to do segment reduce with + // NDArray.Slice or gpu direct memory access. for the latter, we need to + // remove some ctx check, and also it reduces 20% perf + buf.copy_buf.resize(src.size()-1); + for (size_t i = 0; i < src.size()-1; ++i) { + buf.copy_buf[i] = NDArray( + buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype()); + } + } for (size_t i = 0; i < src.size()-1; ++i) { - buf.copy_buf[i] = NDArray( - buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype()); + CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority); + reduce[i+1] = buf.copy_buf[i]; + } + } else { + if (buf.copy_buf.empty()) { + buf.copy_buf.resize(src.size()); + for (size_t j = 0; j < src.size(); ++j) { + buf.copy_buf[j] = NDArray( + buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(), + true, buf.merged.dtype()); + } + } + for (size_t i = 0; i < src.size(); ++i) { + CopyFromTo(src[i], &(buf.copy_buf[i]), priority); + reduce[i] = buf.copy_buf[i]; } } - for (size_t i = 0; i < src.size()-1; ++i) { - CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority); - reduce[i+1] = buf.copy_buf[i]; - } - - ElementwiseSum(reduce, &buf.merged); + ElementwiseSum(reduce, &buf.merged, priority); return buf.merged; } @@ -621,7 +634,53 @@ class CommDevice : public Comm { const std::vector>& dst, const bool use_copy, const int priority) override { - LOG(FATAL) << "Not implemented yet"; + CHECK_EQ(src.storage_type(), kRowSparseStorage) + << "BroadcastRowSparse expects row-sparse src NDArray"; + + for (size_t i = 0; i < dst.size(); ++i) { + NDArray* out = dst[i].first; + NDArray row_id = dst[i].second; + if (use_copy) { + CopyFromTo(src, out, priority); + } else { + CHECK_EQ(out->storage_type(), kRowSparseStorage) + << "BroadcastRowSparse expects row_sparse dst NDArray"; + + const bool is_diff_ctx = out->ctx() != src.ctx(); + NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(), + src.ctx(), true, out->dtype(), out->aux_types()) : *out; + + CHECK_EQ(row_id.ctx(), src.ctx()) + << "row_id and src are expected to be on the same context"; + + Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) { + NDArray temp = out_gpu; + const TBlob& indices = row_id.data(); + switch (temp.ctx().dev_mask()) { + case cpu::kDevMask: { + mxnet::common::SparseRetainOpForwardRspWrapper(rctx.get_stream(), + src, indices, kWriteTo, &temp); + break; + } +#if MXNET_USE_CUDA + case gpu::kDevMask: { + mxnet::common::SparseRetainOpForwardRspWrapper(rctx.get_stream(), + src, indices, kWriteTo, &temp); + // wait for GPU operations to complete + rctx.get_stream()->Wait(); + break; + } +#endif + default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; + } + on_complete(); + }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()}, + FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain")); + if (is_diff_ctx) { + CopyFromTo(out_gpu, out, priority); + } + } + } } private: @@ -667,7 +726,7 @@ class CommDevice : public Comm { #endif } - using KeyAttrs = std::tuple; + using KeyAttrs = std::tuple; // try to allocate buff on device evenly void InitMergeBuffer(const std::vector& devs) { std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), []( @@ -680,9 +739,10 @@ class CommDevice : public Comm { ctx_info[d.dev_id] = std::make_pair(d, 0); } for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { - int key = std::get<0>(sorted_key_attrs_[i]); - TShape s = std::get<1>(sorted_key_attrs_[i]); - int type = std::get<2>(sorted_key_attrs_[i]); + const int key = std::get<0>(sorted_key_attrs_[i]); + const TShape& shape = std::get<1>(sorted_key_attrs_[i]); + const int type = std::get<2>(sorted_key_attrs_[i]); + const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]); auto& buf = merge_buf_[key]; Context ctx; size_t min_size = std::numeric_limits::max(); @@ -693,8 +753,12 @@ class CommDevice : public Comm { min_size = size; } } - buf.merged = NDArray(s, ctx, false, type); - ctx_info[ctx.dev_id].second += s.Size(); + if (stype == kDefaultStorage) { + buf.merged = NDArray(shape, ctx, false, type); + } else { + buf.merged = NDArray(stype, shape, ctx, true, type); + } + ctx_info[ctx.dev_id].second += shape.Size(); } inited_ = true; } diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 1bb84fdc1114..78b6c8f231b3 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -34,6 +34,7 @@ #include #include #include "./comm.h" +#include "./utils.h" namespace mxnet { namespace kvstore { @@ -223,12 +224,12 @@ class KVStoreLocal : public KVStore { << "PullRowSparse expects row_sparse src NDArray"; auto &target_val_rowids = grouped_val_rowids[i]; const size_t num_vals = target_val_rowids.size(); - for (size_t i = 0; i < num_vals; i++) { - auto &row_id = target_val_rowids[i].second; - NDArray indices(row_id.shape(), pinned_ctx_, false, mshadow::kInt64); + for (size_t j = 0; j < num_vals; j++) { + auto &row_id = target_val_rowids[j].second; + NDArray indices(row_id.shape(), local.ctx(), false, mshadow::kInt64); CopyFromTo(row_id, &indices, 0); Unique(&indices, priority); - target_val_rowids[i].second = indices; + target_val_rowids[j].second = indices; } comm_->BroadcastRowSparse(key, local, grouped_val_rowids[i], false, priority); } @@ -354,29 +355,41 @@ class KVStoreLocal : public KVStore { } /** - * \brief sort and get unique values. Output is expected to be on cpu_pinned context + * \brief sort and get unique values. */ - void Unique(NDArray *out, int priority = 0) { - CHECK_EQ(out->ctx().dev_mask(), pinned_ctx_.dev_mask()) - << "Unique expects input with `pinned_ctx_`"; + void Unique(NDArray *out, int priority) { + Resource rsc = ResourceManager::Get()->Request(out->ctx(), + ResourceRequest(ResourceRequest::kTempSpace)); Engine::Get()->PushAsync( - [out](RunContext rctx, Engine::CallbackOnComplete on_complete) { + [rsc, out](RunContext rctx, Engine::CallbackOnComplete on_complete) { NDArray *output = out; CHECK_EQ(out->shape().ndim(), 1) << "Unique expects 1D inputs"; - const auto size = out->shape()[0]; - auto out_data = output->data(); - MSHADOW_IDX_TYPE_SWITCH(out_data.type_flag_, IType, { - auto dptr = output->data().dptr(); - common::ParallelSort(dptr, dptr + size, omp_get_max_threads()); - auto num_unique_idx = std::unique(dptr, dptr + size) - dptr; - *output = output->Reshape(mshadow::Shape1(num_unique_idx)); - }); + nnvm::dim_t size = out->shape()[0]; + switch (out->ctx().dev_mask()) { + case cpu::kDevMask: { + mshadow::Stream *s = rctx.get_stream(); + UniqueImpl(rsc, s, output, size); + break; + } + #if MXNET_USE_CUDA + case gpu::kDevMask: { + mshadow::Stream *s = rctx.get_stream(); + UniqueImpl(rsc, s, output, size); + // wait for GPU operations to complete + s->Wait(); + break; + } + #endif + default: + LOG(FATAL) << "GPU not enabled."; + } on_complete(); - }, pinned_ctx_, {}, {out->var()}, - FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreUnique")); + }, out->ctx(), {}, {out->var(), rsc.var}, + FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreUnique")); out->WaitToRead(); } + /// reducer and broadcaster Comm* comm_; /// pinned context diff --git a/src/kvstore/utils.cc b/src/kvstore/utils.cc new file mode 100644 index 000000000000..c22553f3b6ac --- /dev/null +++ b/src/kvstore/utils.cc @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file utils.cc + * \brief cpu implementation of util functions + */ + +#include "./utils.h" +#include "../common/utils.h" + +namespace mxnet { +namespace kvstore { + + +template<> +void UniqueImpl(const Resource& rsc, mshadow::Stream *s, + NDArray *out, nnvm::dim_t size) { + MSHADOW_IDX_TYPE_SWITCH(out->data().type_flag_, IType, { + IType *dptr = out->data().dptr(); + common::ParallelSort(dptr, dptr + size, omp_get_max_threads()); + size_t num_unique_idx = std::unique(dptr, dptr + size) - dptr; + *out = out->Reshape(mshadow::Shape1(num_unique_idx)); + }); +} + + +} // namespace kvstore +} // namespace mxnet diff --git a/src/kvstore/utils.cu b/src/kvstore/utils.cu new file mode 100644 index 000000000000..088a49efc808 --- /dev/null +++ b/src/kvstore/utils.cu @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file utils.cu + * \brief gpu implementation of util functions + */ +#if defined(_MSC_VER) && __CUDACC_VER_MAJOR__ == 8 && __CUDACC_VER_BUILD__ != 44 +// Many CUDA 8 compilers other than V8.0.44 crash on Windows +#pragma warning("Potential crash on CUDA compiler detected. Switching sorting from CUB to Thrust") +#define SORT_WITH_THRUST +#include +#include +#include +#else +#undef SORT_WITH_THRUST +#endif +#include "./utils.h" +#include "../common/utils.h" +#include +#include + +namespace mxnet { +namespace kvstore { + + +template +size_t UniqueImplGPU(const Resource& rsc, mshadow::Stream *s, + IType *dptr, nnvm::dim_t size) { +#ifndef SORT_WITH_THRUST + size_t sort_temp_bytes = 0; + cub::DeviceRadixSort::SortKeys(NULL, sort_temp_bytes, + dptr, dptr, size, 0, sizeof(IType)*8, mshadow::Stream::GetStream(s)); + mshadow::Tensor sort_space = rsc + .get_space_typed( + mshadow::Shape1(sort_temp_bytes), s); + void *sort_temp_storage = static_cast(sort_space.dptr_); + cub::DeviceRadixSort::SortKeys(sort_temp_storage, sort_temp_bytes, + dptr, dptr, size, 0, sizeof(IType)*8, mshadow::Stream::GetStream(s)); +#else + thrust::sort(thrust::cuda::par.on(mshadow::Stream::GetStream(s)), + dptr, dptr + size, thrust::greater()); +#endif + size_t unique_temp_bytes = 0; + mshadow::Tensor dummy_space = rsc + .get_space_typed( + mshadow::Shape1(sizeof(size_t)), s); + size_t *dummy_ptr = reinterpret_cast(dummy_space.dptr_); + cub::DeviceSelect::Unique(NULL, unique_temp_bytes, dptr, dptr, + dummy_ptr, size, mshadow::Stream::GetStream(s)); + + mshadow::Tensor unique_space = rsc + .get_space_typed( + mshadow::Shape1((unique_temp_bytes + sizeof(size_t) + 7) / 8 * 8), s); + + void *unique_temp_storage = static_cast( + unique_space.dptr_); + size_t *d_num_selected_out = reinterpret_cast( + unique_space.dptr_ + (unique_temp_bytes + 7) / 8 * 8); + + cub::DeviceSelect::Unique(unique_temp_storage, unique_temp_bytes, dptr, dptr, + d_num_selected_out, size, mshadow::Stream::GetStream(s)); + + size_t num_selected_out = 0; + CUDA_CALL(cudaMemcpy(&num_selected_out, d_num_selected_out, sizeof(size_t), + cudaMemcpyDeviceToHost)); + return num_selected_out; +} + +/*! + * \brief sort and get unique values. + */ +template<> +void UniqueImpl(const Resource& rsc, mshadow::Stream *s, + NDArray *out, nnvm::dim_t size) { + MSHADOW_IDX_TYPE_SWITCH(out->data().type_flag_, IType, { + IType *dptr = out->data().dptr(); + size_t num_selected_out = UniqueImplGPU(rsc, s, dptr, size); + *out = out->Reshape(mshadow::Shape1(num_selected_out)); + }); +} + + +} // namespace kvstore +} // namespace mxnet diff --git a/src/kvstore/utils.h b/src/kvstore/utils.h new file mode 100644 index 000000000000..75473452ce00 --- /dev/null +++ b/src/kvstore/utils.h @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file utils.h + * \brief Basic utilility functions. + */ +#ifndef MXNET_KVSTORE_UTILS_H_ +#define MXNET_KVSTORE_UTILS_H_ + +#include +#include +#include +#include +#include + +namespace mxnet { +namespace kvstore { + + +/*! + * \brief sort and get unique values. + */ +template +void UniqueImpl(const Resource& rsc, mshadow::Stream *s, + NDArray *out, nnvm::dim_t size); + +} // namespace kvstore +} // namespace mxnet + +#endif // MXNET_KVSTORE_UTILS_H_ diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index f09f168977ab..4196a7d5155c 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -606,36 +606,66 @@ void ElementwiseSum(const std::vector &source, NDArray *out, int priori // important: callback must always capture by value NDArray ret = *out; - switch (out->ctx().dev_mask()) { - case cpu::kDevMask: { - Engine::Get()->PushSync([source, ret](RunContext ctx) { - std::vector source_tblob(source.size()); - for (size_t i = 0; i < source.size(); ++i) { - source_tblob[i] = source[i].data(); - } - TBlob tmp = ret.data(); - ndarray::ElementwiseSum(source_tblob, &tmp, ctx); - }, out->ctx(), const_vars, {ret.var()}, - FnProperty::kNormal, priority, PROFILER_MESSAGE_FUNCNAME); - break; + const NDArrayStorageType stype = ret.storage_type(); + + if (stype == kDefaultStorage) { + switch (out->ctx().dev_mask()) { + case cpu::kDevMask: { + Engine::Get()->PushSync([source, ret](RunContext ctx) { + std::vector source_tblob(source.size()); + for (size_t i = 0; i < source.size(); ++i) { + source_tblob[i] = source[i].data(); + } + TBlob tmp = ret.data(); + ndarray::ElementwiseSum(source_tblob, &tmp, ctx); + }, out->ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, priority, PROFILER_MESSAGE_FUNCNAME); + break; + } +#if MXNET_USE_CUDA + case gpu::kDevMask: { + Engine::Get()->PushSync([source, ret](RunContext ctx) { + std::vector source_tblob(source.size()); + for (size_t i = 0; i < source.size(); ++i) { + source_tblob[i] = source[i].data(); + } + TBlob tmp = ret.data(); + ndarray::ElementwiseSum(source_tblob, &tmp, ctx); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, out->ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, priority, PROFILER_MESSAGE("DenseElementwiseSum")); + break; + } +#endif + default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } + } else if (stype == kRowSparseStorage) { + Resource rsc = ResourceManager::Get()->Request(ret.ctx(), + ResourceRequest(ResourceRequest::kTempSpace)); + + Engine::Get()->PushSync( + [source, ret, rsc](RunContext rctx) { + NDArray result = ret; + switch (ret.ctx().dev_mask()) { + case cpu::kDevMask: { + mxnet::ndarray::ElementwiseSum(rctx.get_stream(), rsc, source, &result); + break; + } #if MXNET_USE_CUDA - case gpu::kDevMask: { - Engine::Get()->PushSync([source, ret](RunContext ctx) { - std::vector source_tblob(source.size()); - for (size_t i = 0; i < source.size(); ++i) { - source_tblob[i] = source[i].data(); + case gpu::kDevMask: { + mxnet::ndarray::ElementwiseSum(rctx.get_stream(), rsc, source, &result); + // wait for GPU operations to complete + rctx.get_stream()->Wait(); + break; } - TBlob tmp = ret.data(); - ndarray::ElementwiseSum(source_tblob, &tmp, ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); - }, out->ctx(), const_vars, {ret.var()}, - FnProperty::kNormal, priority, PROFILER_MESSAGE_FUNCNAME); - break; - } #endif - default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; + default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; + } + }, ret.ctx(), const_vars, {ret.var(), rsc.var}, + FnProperty::kNormal, priority, PROFILER_MESSAGE("RowSparseElementwiseSum")); + } else { + LOG(FATAL) << "Not implemented for storage_type " << common::stype_string(stype); } } diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py index 20528be664fa..3249c98b2d8b 100644 --- a/tests/python/gpu/test_kvstore_gpu.py +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -26,9 +26,9 @@ str_keys = ['b', 'c', 'd'] -def init_kv_with_str(stype='default'): +def init_kv_with_str(stype='default', kv_type='local'): """init kv """ - kv = mx.kv.create() + kv = mx.kv.create(kv_type) # single kv.init('a', mx.nd.zeros(shape, stype=stype)) # list @@ -36,34 +36,54 @@ def init_kv_with_str(stype='default'): return kv -def test_row_sparse_pull(): - kv = init_kv_with_str('row_sparse') - kv.init('e', mx.nd.ones(shape).tostype('row_sparse')) +def test_rsp_push_pull(): + def check_rsp_push_pull(kv_type, is_push_cpu=True): + kv = init_kv_with_str('row_sparse', kv_type) + kv.init('e', mx.nd.ones(shape).tostype('row_sparse')) + push_ctxs = [mx.cpu(i) if is_push_cpu else mx.gpu(i) for i in range(2)] + kv.push('e', [mx.nd.ones(shape, ctx=context).tostype('row_sparse') for context in push_ctxs]) - def check_row_sparse_pull(kv, count, ctx=default_context()): - num_rows = shape[0] - vals = [] - row_ids = [] - all_row_ids = np.arange(num_rows) - for i in range(count): - vals.append(mx.nd.zeros(shape, ctx=ctx).tostype('row_sparse')) - row_id = np.random.randint(num_rows, size=num_rows) - row_ids.append(mx.nd.array(row_id, dtype='int64')) - row_ids_to_pull = row_ids[0] if len(row_ids) == 1 else row_ids - vals_to_pull = vals[0] if len(vals) == 1 else vals + def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False): + num_rows = shape[0] + row_ids = [] + all_row_ids = np.arange(num_rows) + vals = [mx.nd.sparse.zeros(shape=shape, ctx=ctxs[i], stype='row_sparse') for i in range(count)] + if is_same_rowid: + row_id = np.random.randint(num_rows, size=num_rows) + row_ids = [mx.nd.array(row_id, dtype='int64')] * count + elif use_slice: + total_row_ids = mx.nd.array(np.random.randint(num_rows, size=count*num_rows), dtype='int64') + row_ids = [total_row_ids[i*num_rows : (i+1)*num_rows] for i in range(count)] + else: + for i in range(count): + row_id = np.random.randint(num_rows, size=num_rows) + row_ids.append(mx.nd.array(row_id, dtype='int64')) + row_ids_to_pull = row_ids[0] if (len(row_ids) == 1 or is_same_rowid) else row_ids + vals_to_pull = vals[0] if len(vals) == 1 else vals - kv.row_sparse_pull('e', out=vals_to_pull, row_ids=row_ids_to_pull) - for val, row_id in zip(vals, row_ids): - retained = val.asnumpy() - excluded_row_ids = np.setdiff1d(all_row_ids, row_id.asnumpy()) - for row in range(num_rows): - expected_val = np.zeros_like(retained[row]) - expected_val += 0 if row in excluded_row_ids else 1 - assert_almost_equal(retained[row], expected_val) + kv.row_sparse_pull('e', out=vals_to_pull, row_ids=row_ids_to_pull) + for val, row_id in zip(vals, row_ids): + retained = val.asnumpy() + excluded_row_ids = np.setdiff1d(all_row_ids, row_id.asnumpy()) + for row in range(num_rows): + expected_val = np.zeros_like(retained[row]) + expected_val += 0 if row in excluded_row_ids else 2 + assert_almost_equal(retained[row], expected_val) - check_row_sparse_pull(kv, 1, mx.gpu(0)) - check_row_sparse_pull(kv, 4, mx.gpu(0)) + check_rsp_pull(kv, 1, [mx.gpu(0)]) + check_rsp_pull(kv, 1, [mx.cpu(0)]) + check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)]) + check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], is_same_rowid=True) + check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)]) + check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], is_same_rowid=True) + check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], use_slice=True) + check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], use_slice=True) + + # test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/9384 + # check_rsp_push_pull('local') + check_rsp_push_pull('device') + check_rsp_push_pull('device', is_push_cpu=False) if __name__ == '__main__': - test_row_sparse_pull() + test_rsp_push_pull()