Skip to content

Commit

Permalink
add fuse var op handle
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Jun 6, 2018
1 parent 9dc3ed4 commit a584bc8
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 13 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_ro

if(WITH_GPU)
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda)
dynload_cuda variable_visitor)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
Expand All @@ -24,6 +24,7 @@ else()
endif()

cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)

cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
Expand Down
51 changes: 51 additions & 0 deletions paddle/fluid/framework/details/fuse_vars_op_handle.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/fuse_vars_op_handle.h"

namespace paddle {
namespace framework {
namespace details {

void FuseVarsOpHandle::RunImpl() {
WaitInputVarGenerated(place_);

auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(in_var_handles.size(), 0);
PADDLE_ENFORCE_EQ(out_var_handles.size() - 1, inputs_numel_.size(), "");

auto scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();

auto out_var_handle = out_var_handles[0];
auto out_var = scope->Var(out_var_handle->name_);

auto out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->Resize({total_numel_}).mutable_data(this->place_, type_);

int64_t s = 0;
for (size_t i = 1; i < out_var_handles.size(); ++i) {
auto out_name = out_var_handles[i]->name_;
auto out_t = scope->Var(out_name)->GetMutable<LoDTensor>();
auto numel = this->inputs_numel_.at(out_name);
out_t->ShareDataWith(out_tensor->Slice(s, s + numel));
s += numel;
}
this->RunAndRecordEvent([this] {});
}

std::string FuseVarsOpHandle::Name() const { return "fuse vars"; }
} // namespace details
} // namespace framework
} // namespace paddle
63 changes: 63 additions & 0 deletions paddle/fluid/framework/details/fuse_vars_op_handle.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <map>
#include <string>
#include <vector>

#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace framework {
namespace details {

struct FuseVarsOpHandle : public OpHandleBase {
public:
FuseVarsOpHandle(Scope *local_scope, const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type)
: local_scope_(local_scope),
place_(place),
inputs_numel_(inputs_numel),
type_(var_type) {
total_numel_ = 0;
for (auto in_numel : inputs_numel) {
PADDLE_ENFORCE_GT(in_numel.second, 0);
total_numel_ += in_numel.second;
}
}

std::string Name() const override;

bool IsMultiDeviceTransfer() override { return false; };

protected:
void RunImpl() override;

private:
Scope *local_scope_;
const platform::Place place_;
const std::unordered_map<std::string, int64_t> inputs_numel_;
const std::type_index type_;
int64_t total_numel_;
};
} // namespace details
} // namespace framework
} // namespace paddle
29 changes: 19 additions & 10 deletions paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>

#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include <algorithm>
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"

namespace paddle {
namespace framework {
Expand All @@ -30,27 +32,34 @@ NCCLAllReduceOpHandle::NCCLAllReduceOpHandle(
}

void NCCLAllReduceOpHandle::RunImpl() {
if (inputs_.size() == 1) {
if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1;
} else {
// Wait input done
WaitInputVarGenerated();

auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
int dtype = -1;
size_t numel = 0;
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");

std::vector<const LoDTensor *> lod_tensors;

for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto *s = local_scopes_[i];
auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get<Scope *>();

auto &lod_tensor = local_scope.FindVar(var_name)->Get<LoDTensor>();
auto &lod_tensor =
local_scope.FindVar(in_var_handles[i]->name_)->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor);
PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_,
"The name of input and output should be equal.");
}

if (platform::is_gpu_place(lod_tensors[0]->place())) {
int dtype = -1;
size_t numel = 0;
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i];
Expand Down Expand Up @@ -96,7 +105,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
auto &scope =
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &p = places_[i];
auto *var = scope.FindVar(var_name);
auto *var = scope.FindVar(in_var_handles[i]->name_);
auto *dev_ctx = dev_ctxes_[p];

RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
void RunImpl() override;

private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
const std::vector<Scope *> local_scopes_;
const std::vector<platform::Place> places_;
const platform::NCCLContextMap &nccl_ctxs_;
};

Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/framework/details/op_handle_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
}
}

size_t OpHandleBase::NoDummyInputSize() const {
size_t cnt = 0;
for (auto *in : inputs_) {
if (dynamic_cast<DummyVarHandle *>(in) == nullptr) {
++cnt;
}
}
return cnt;
}

bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
return in_var && in_var->generated_op_;
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/op_handle_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class OpHandleBase {

const std::vector<VarHandleBase *> &Outputs() const { return outputs_; }

size_t NoDummyInputSize() const;

protected:
void RunAndRecordEvent(const std::function<void()> &callback);

Expand Down

0 comments on commit a584bc8

Please sign in to comment.