Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable CPU on Parallel executor #11306

Merged
merged 13 commits into from
Jun 11, 2018
8 changes: 4 additions & 4 deletions paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)

if(WITH_GPU)
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)

else()
set(multi_devices_graph_builder_deps)
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
endif()
Expand All @@ -29,7 +29,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)

cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle)


cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,33 @@
// limitations under the License.
#include <algorithm>

#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"

namespace paddle {
namespace framework {
namespace details {
NCCLAllReduceOpHandle::NCCLAllReduceOpHandle(
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap &ctxs)

#ifdef PADDLE_WITH_CUDA
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
for (auto &p : places_) {
this->dev_ctxes_[p] = nccl_ctxs_.DevCtx(p);
if (nccl_ctxs_) {
for (auto &p : places_) {
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
}
}
}
#else
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
#endif

void NCCLAllReduceOpHandle::RunImpl() {
void AllReduceOpHandle::RunImpl() {
if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1;
} else {
Expand All @@ -58,6 +66,8 @@ void NCCLAllReduceOpHandle::RunImpl() {
}

if (platform::is_gpu_place(lod_tensors[0]->place())) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1;
size_t numel = 0;
std::vector<std::function<void()>> all_reduce_calls;
Expand All @@ -75,7 +85,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
}

int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_.at(dev_id);
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] {
Expand All @@ -90,22 +100,25 @@ void NCCLAllReduceOpHandle::RunImpl() {
call();
}
});
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
} else { // Special handle CPU only Operator's gradient. Like CRF
auto &trg = *this->local_scopes_[0]
->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->Var()
->FindVar(out_var_handles[0]->name_)
->GetMutable<framework::LoDTensor>();

// Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func);

for (size_t i = 0; i < local_scopes_.size(); ++i) {
for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope =
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &p = places_[i];
auto *var = scope.FindVar(in_var_handles[i]->name_);
auto *var = scope.FindVar(out_var_handles[i]->name_);
auto *dev_ctx = dev_ctxes_[p];

RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
Expand All @@ -118,7 +131,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
}
}

std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; }
std::string AllReduceOpHandle::Name() const { return "all_reduce"; }
} // namespace details
} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,23 @@
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif

namespace paddle {
namespace framework {
namespace details {

struct NCCLAllReduceOpHandle : public OpHandleBase {
NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap &ctxs);

struct AllReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
#else
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
#endif
std::string Name() const override;

// Delay and buffer nccl_all_reduce together can significantly increase
Expand All @@ -43,7 +49,9 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
private:
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
const platform::NCCLContextMap &nccl_ctxs_;
#ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_;
#endif
};

} // namespace details
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/execution_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace details {

struct ExecutionStrategy {
size_t num_threads_{0};
bool use_event_{true};
bool use_cuda_{true};
bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{100};
};
Expand Down
51 changes: 29 additions & 22 deletions paddle/fluid/framework/details/multi_devices_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <utility>
#include <vector>

#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
Expand All @@ -26,10 +27,6 @@
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"

#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif

namespace paddle {
namespace framework {
namespace details {
Expand Down Expand Up @@ -243,7 +240,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0);
} else {
InsertNCCLAllReduceOp(&result, g_name);
InsertAllReduceOp(&result, g_name);
}
break;
}
Expand Down Expand Up @@ -286,6 +283,19 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(
return false;
}

void MultiDevSSAGraphBuilder::SetCommunicationContext(
OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
}
#else
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}

void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
const std::string &p_name,
size_t src_dev_id) const {
Expand All @@ -300,15 +310,12 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
op_handle->AddInput(in);

for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_.at(i).at(p_name);
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_.at(i).at(p_name);
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
vars.emplace_back(out_var);
op_handle->AddOutput(out_var);
#ifndef ADDLE_WITH_CUDA
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}
}

Expand All @@ -320,15 +327,19 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
CreateOpHandleIOs(result, op, dev_id);
}

void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const {
void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
const std::string &og) const {
#ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back(
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_));
#endif
auto *op_handle = result->ops_.back().get();

for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
Expand All @@ -338,9 +349,6 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
vars.emplace_back(var);
op_handle->AddOutput(var);
}
#else
PADDLE_ENFORCE("Not implemented");
#endif
}

bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
Expand Down Expand Up @@ -379,7 +387,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
auto *communication_dev_ctx = nccl_ctxs_->DevCtx(places_[i]);
auto *communication_dev_ctx =
nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i])
: platform::DeviceContextPool::Instance().Get(places_[i]);
#else
auto *communication_dev_ctx =
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
Expand Down Expand Up @@ -424,12 +434,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
auto *op_handle = result->ops_.back().get();

for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_[i][og];
#ifndef PADDLE_WITH_CUDA
auto &p = places_[i];
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/details/multi_devices_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;

void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;

void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const;
Expand All @@ -111,6 +111,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {

private:
BuildStrategy strategy_;

void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
};
} // namespace details
} // namespace framework
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/details/op_handle_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ OpHandleBase::~OpHandleBase() {
#endif
}

void OpHandleBase::Run(bool use_event) {
void OpHandleBase::Run(bool use_cuda) {
#ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_event) {
if (events_.empty() && use_cuda) {
for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
PADDLE_ENFORCE(cudaSetDevice(dev_id));
Expand All @@ -50,7 +50,7 @@ void OpHandleBase::Run(bool use_event) {
}
}
#else
PADDLE_ENFORCE(!use_event);
PADDLE_ENFORCE(!use_cuda);
#endif

RunImpl();
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/op_handle_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class OpHandleBase {

virtual std::string Name() const = 0;

void Run(bool use_event);
void Run(bool use_cuda);

virtual void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx);

Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/details/reduce_and_gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ struct ReduceLoDTensor {
PADDLE_ENFORCE_NE(t0.numel(), 0);
dst_tensor_.Resize(t0.dims());
T *dst = dst_tensor_.mutable_data<T>(platform::CPUPlace());
std::copy(t0.data<T>(), t0.data<T>() + t0.numel(), dst);
if (dst != t0.data<T>()) {
std::copy(t0.data<T>(), t0.data<T>() + t0.numel(), dst);
}

for (size_t i = 1; i < src_tensors_.size(); ++i) {
auto &t = *src_tensors_[i];
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/framework/details/ssa_graph_builder_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ class SSAGraphBuilderFactory {
loss_var_name_(loss_var_name),
param_names_(param_names),
local_scopes_(local_scopes),
strategy_(strategy) {}
strategy_(strategy) {
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_ = nullptr;
#endif
}

#ifdef PADDLE_WITH_CUDA
void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ void ThreadedSSAGraphExecutor::RunOp(
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
op->Run(strategy_.use_event_);
op->Run(strategy_.use_cuda_);
VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--;
ready_var_q->Extend(op->Outputs());
Expand Down
Loading