Skip to content

Commit

Permalink
Merge pull request #9043 from Xreki/core_inference_remove_clone
Browse files Browse the repository at this point in the history
Remove unnecessary clone of program in C++ Executor.Run
  • Loading branch information
luotao1 authored Mar 19, 2018
2 parents df99b16 + 371c53f commit c042137
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 31 deletions.
33 changes: 20 additions & 13 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
BlockDesc* block, std::map<std::string, const LoDTensor*>& feed_targets,
const BlockDesc& block,
std::map<std::string, const LoDTensor*>& feed_targets,
const std::string& feed_holder_name) {
size_t feed_count = 0;
for (auto* op : block->AllOps()) {
for (auto* op : block.AllOps()) {
if (op->Type() == kFeedOpType) {
feed_count++;
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
Expand All @@ -135,7 +136,7 @@ static bool has_feed_operators(
"The number of feed operators should match 'feed_targets'");

// When feed operator are present, so should be feed_holder
auto var = block->FindVar(feed_holder_name);
auto var = block.FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
feed_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
Expand All @@ -153,10 +154,10 @@ static bool has_feed_operators(
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
BlockDesc* block, std::map<std::string, LoDTensor*>& fetch_targets,
const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& fetch_holder_name) {
size_t fetch_count = 0;
for (auto* op : block->AllOps()) {
for (auto* op : block.AllOps()) {
if (op->Type() == kFetchOpType) {
fetch_count++;
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
Expand All @@ -175,7 +176,7 @@ static bool has_fetch_operators(
"The number of fetch operators should match 'fetch_targets'");

// When fetch operator are present, so should be fetch_holder
auto var = block->FindVar(fetch_holder_name);
auto var = block.FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
fetch_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
Expand All @@ -192,10 +193,19 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
const std::string& feed_holder_name,
const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId);
auto* copy_program = new ProgramDesc(program);
bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
bool has_fetch_ops =
has_fetch_operators(program.Block(0), fetch_targets, fetch_holder_name);

ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
if (!has_feed_ops || !has_fetch_ops) {
copy_program = std::unique_ptr<ProgramDesc>(new ProgramDesc(program)).get();
}

auto* global_block = copy_program->MutableBlock(0);

if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) {
if (!has_feed_ops) {
// create feed_holder variable
auto* feed_holder = global_block->Var(feed_holder_name);
feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
Expand Down Expand Up @@ -228,7 +238,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}

if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) {
if (!has_fetch_ops) {
// create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name);
fetch_holder->SetType(proto::VarType::FETCH_LIST);
Expand Down Expand Up @@ -262,8 +272,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}

delete copy_program;
}

ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
Expand Down Expand Up @@ -313,9 +321,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} // if (create_vars)

for (auto& op : ctx->ops_) {
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);

if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
Expand Down
16 changes: 8 additions & 8 deletions paddle/fluid/operators/conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,16 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {

framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
library = framework::LibraryType::kCUDNN;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
library = framework::LibraryType::kMKLDNN;
}
#endif

Expand All @@ -91,15 +91,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
"input and filter data type should be consistent");

if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used");
}

std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_);
framework::DataLayout layout = framework::StringToDataLayout(data_format);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library);
}

Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/operators/feed_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/profiler.h"

namespace paddle {
namespace operators {
Expand All @@ -28,6 +29,10 @@ class FeedOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
// get device context from pool
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);

auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name);

Expand All @@ -50,14 +55,10 @@ class FeedOp : public framework::OperatorBase {
auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>();

// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);

if (platform::is_same_place(feed_item.place(), place)) {
out_item->ShareDataWith(feed_item);
} else {
framework::TensorCopy(feed_item, place, dev_ctx, out_item);
framework::TensorCopy(feed_item, place, *dev_ctx, out_item);
}
out_item->set_lod(feed_item.lod());
}
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/operators/fetch_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"

namespace paddle {
namespace operators {
Expand All @@ -29,6 +30,9 @@ class FetchOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));

auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr,
Expand All @@ -53,7 +57,6 @@ class FetchOp : public framework::OperatorBase {

// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(src_item.place());

TensorCopy(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/operators/load_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"

namespace paddle {
namespace operators {
Expand All @@ -29,6 +30,9 @@ class LoadOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);

auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
Expand All @@ -41,9 +45,7 @@ class LoadOp : public framework::OperatorBase {

auto *tensor = out_var->GetMutable<framework::LoDTensor>();

platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
DeserializeFromStream(fin, tensor, dev_ctx);
DeserializeFromStream(fin, tensor, *dev_ctx);

if (platform::is_gpu_place(place)) {
// copy CPU to GPU
Expand All @@ -55,7 +57,7 @@ class LoadOp : public framework::OperatorBase {
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod());
TensorCopy(cpu_tensor, place, dev_ctx, tensor);
TensorCopy(cpu_tensor, place, *dev_ctx, tensor);
}
}
};
Expand Down

0 comments on commit c042137

Please sign in to comment.