Skip to content

Commit

Permalink
dash
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Apr 16, 2018
1 parent 21084a7 commit 2455fd3
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 20 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_ha
device_context broadcast_op_handle)
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context gather_op_handle)
cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context reduce_op_handle)

#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle)
1 change: 0 additions & 1 deletion paddle/fluid/framework/details/gather_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ void GatherOpHandle::RunImpl() {
local_scopes_[out_scope_idx]->FindVar(out_var_handles[0]->name_);

auto out = out_var->GetMutable<framework::SelectedRows>();

out->set_height(pre_in.height());
out->set_rows(out_rows);
size_t rows = out_rows.size();
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/framework/details/reduce_and_gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@ struct GatherSelectedRows {
void operator()(
const std::vector<SelectedRows> &src_selecte_rows_,
const std::vector<platform::Place> &in_places,
const std::map<platform::Place, platform::DeviceContext> &dev_ctxes,
const std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash> &dev_ctxes,
SelectedRows *dst_selecte_rows) const {
PADDLE_ENFORCE(!src_selecte_rows_.empty());

std::vector<Tensor> in_tensors;
std::vector<std::vector<int64_t>> out_rows;
std::vector<int64_t> out_rows;

for (auto &in_sr : src_selecte_rows_) {
in_tensors.emplace_back(in_sr.value());
Expand Down
33 changes: 20 additions & 13 deletions paddle/fluid/framework/details/reduce_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,20 @@ namespace paddle {
namespace framework {
namespace details {

#ifdef PADDLE_WITH_CUDA
ReduceOpHandle::ReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap &ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {}
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
for (auto p_ctx : ctxs.contexts_) {
dev_ctxes_[p_ctx.first] = p_ctx.second.ctx_.get();
}
}
#endif

ReduceOpHandle::ReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}

void ReduceOpHandle::RunImpl() {
// the input may have dummy var.
Expand All @@ -51,11 +61,11 @@ void ReduceOpHandle::RunImpl() {
"The number of output should be one.");

// Wait input done, this Wait is asynchronous operation
auto &in_place = in_var_handle[0]->place_;
if (in_var_handle[0]->generated_op_) {
auto &in_place = in_var_handles[0]->place_;
if (in_var_handles[0]->generated_op_) {
for (auto *in : in_var_handles) {
auto &in_p = in->place_;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[in_p]);
in_var_handles[0]->generated_op_->Wait(dev_ctxes_[in_p]);
}
}

Expand All @@ -64,13 +74,11 @@ void ReduceOpHandle::RunImpl() {
auto pre_place = in_0_handle->place_;

std::vector<platform::Place> in_places;
std::map<platform::Place, platform::DeviceContext> dev_ctxes;
for (auto *in_handle : in_var_handles) {
auto in_p = in_handle->place_;
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"Places must be all on CPU or all on CUDA.");
in_places.emplace_back(in_p);
dev_ctxes[in_p] = nccl_ctxs_[in_p];
}

auto out_var = local_scopes_[out_var_handles[0]->scope_idx_]->FindVar(
Expand All @@ -82,6 +90,7 @@ void ReduceOpHandle::RunImpl() {
if (pre_in_var->IsType<framework::SelectedRows>()) {
GatherSelectedRows gather;
std::vector<SelectedRows> in_selected_rows;
auto pre_in = pre_in_var->Get<framework::SelectedRows>();

for (auto *in_handle : in_var_handles) {
auto in_var =
Expand All @@ -93,9 +102,8 @@ void ReduceOpHandle::RunImpl() {

in_selected_rows.emplace_back(in_sr);
}
auto &trg = out_var->GetMutable<framework::SelectedRows>();
gather(in_selected_rows, in_places, dev_ctxes, &trg);

auto trg = out_var->GetMutable<framework::SelectedRows>();
gather(in_selected_rows, in_places, dev_ctxes_, trg);
} else {
// reduce tensor
auto pre_in = pre_in_var->Get<framework::LoDTensor>();
Expand All @@ -110,15 +118,15 @@ void ReduceOpHandle::RunImpl() {
PADDLE_ENFORCE_EQ(in_sr.type(), pre_in.type(),
"The type of input is not consistent.");

lod_tensors.emplace_back(in_sr.value());
lod_tensors.emplace_back(in_sr);
}

auto &trg = out_var->GetMutable<framework::LoDTensor>();
auto trg = out_var->GetMutable<framework::LoDTensor>();
trg->Resize(pre_in.dims());
trg->mutable_data(out_var_handles[0]->place_, pre_in.type());

if (paddle::platform::is_cpu_place(pre_place)) {
ReduceLoDTensor func(lod_tensors, &trg);
ReduceLoDTensor func(lod_tensors, trg);
VisitDataType(ToDataType(lod_tensors[0].type()), func);

} else if (paddle::platform::is_gpu_place(pre_place)) {
Expand Down Expand Up @@ -162,7 +170,6 @@ void ReduceOpHandle::RunImpl() {
#else
PADDLE_THROW("CUDA is not support.");
#endif

} else {
PADDLE_THROW("Error");
}
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/framework/details/reduce_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/nccl_helper.h"

namespace paddle {
namespace framework {
Expand All @@ -31,11 +32,16 @@ namespace details {
struct ReduceOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
const platform::NCCLContextMap &nccl_ctxs_;

#ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap &nccl_ctxs_;
ReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap &nccl_ctxs_);
#endif

ReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);

std::string Name() const override;

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/reduce_op_handle_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ TEST(ReduceTester, TestCPUReduceTestSelectedRows) {
TEST(ReduceTester, TestGPUReduceTestSelectedRows) {
TestReduceOpHandle test_op;
size_t input_scope_idx = 0;
test_op.InitCtxOnGpu(false);
test_op.InitCtxOnGpu(true);
test_op.InitReduceOp(input_scope_idx);
test_op.TestReduceSelectedRows(input_scope_idx);
}
Expand Down

0 comments on commit 2455fd3

Please sign in to comment.