Skip to content

Commit

Permalink
Locate and fix the memory leak issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
panyx0718 committed Jun 4, 2018
1 parent b05e173 commit ca8913f
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 33 deletions.
15 changes: 4 additions & 11 deletions paddle/fluid/framework/block_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
}

BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
ProgramDesc *prog)
ProgramDesc *prog, bool is_test)
: prog_(prog), desc_(desc) {
need_update_ = true;
for (auto &op : other.ops_) {
Expand All @@ -218,19 +218,12 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
}

void BlockDesc::ClearPBOps() {
auto ops = this->desc_->mutable_ops();
while (!ops->empty()) {
// we do not own the OpDesc, so release the ownership.
ops->ReleaseLast();
}
this->desc_->mutable_ops()->Clear();
}

void BlockDesc::ClearPBVars() {
auto vars = this->desc_->mutable_vars();
while (!vars->empty()) {
// we do not own the VarDesc, so release the ownership.
vars->ReleaseLast();
}
this->desc_->mutable_vars()->Clear();

}

void BlockDesc::SetForwardBlockID(int32_t forward_block_id) {
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/block_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class BlockDesc {
public:
BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc);

BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, ProgramDesc *prog);
BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, ProgramDesc *prog,
bool is_test=false);

~BlockDesc() {
this->ClearPBVars();
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ proto::OpDesc *OpDesc::Proto() {
return &desc_;
}

const proto::OpDesc& OpDesc::ConstProto() const {
return desc_;
}

const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", name,
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/op_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class OpDesc {

proto::OpDesc *Proto();

const proto::OpDesc& ConstProto() const;

std::string Type() const { return desc_.type(); }

void SetType(const std::string &type) { desc_.set_type(type); }
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/program_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ ProgramDesc::ProgramDesc() {
blocks_.emplace_back(new BlockDesc(this, block));
}

ProgramDesc::ProgramDesc(const ProgramDesc &o) {
ProgramDesc::ProgramDesc(const ProgramDesc &o, bool is_test) {
desc_ = o.desc_;
for (int i = 0; i < desc_.blocks_size(); ++i) {
auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this));
blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this, is_test));
}
for (auto &block : blocks_) {
for (auto *op : block->AllOps()) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/program_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ProgramDesc {

explicit ProgramDesc(const proto::ProgramDesc &desc);

ProgramDesc(const ProgramDesc &o);
ProgramDesc(const ProgramDesc &o, bool is_test=false);

explicit ProgramDesc(const std::string &binary_str);

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ if(NOT WITH_MKLDNN)
list(REMOVE_ITEM GENERAL_OPS fc_op)
endif(NOT WITH_MKLDNN)

list(REMOVE_ITEM GENERAL_OPS reduce_op)

foreach(src ${GENERAL_OPS})
op_library(${src})
endforeach()
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ static pybind11::bytes SerializeMessage(
void BindProgramDesc(pybind11::module *m) {
pybind11::class_<pd::ProgramDesc>(*m, "ProgramDesc", "")
.def(pybind11::init<>())
.def("__init__",
[](pd::ProgramDesc &self, const pd::ProgramDesc &other,
bool is_test=false) {
new (&self) pd::ProgramDesc(other, is_test);
})
.def("__init__",
[](pd::ProgramDesc &self, const pd::ProgramDesc &other) {
new (&self) pd::ProgramDesc(other);
Expand Down
22 changes: 13 additions & 9 deletions python/paddle/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import contextlib
from framework import Program, default_main_program, Variable
from . import core
import sys

__all__ = [
'Executor', 'global_scope', 'scope_guard', 'switch_scope', 'fetch_var'
Expand Down Expand Up @@ -207,7 +208,7 @@ def _add_program_cache(self, program_cache_key, program):
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
fetch_var_name):
tmp_program = program.clone()

"""
global_block = tmp_program.global_block()
if feed_var_name in global_block.vars:
Expand Down Expand Up @@ -246,7 +247,7 @@ def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})

"""
return tmp_program

def _feed_data(self, program, feed, feed_var_name, scope):
Expand Down Expand Up @@ -277,7 +278,8 @@ def run(self,
fetch_var_name='fetch',
scope=None,
return_numpy=True,
use_program_cache=False):
use_program_cache=False,
keep_create=False):
""" Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
Expand Down Expand Up @@ -329,12 +331,14 @@ def run(self,
program = cached_program
else:
self.program_caches.pop(cache_key, None)
program = self._add_feed_fetch_ops(
program=program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)
while keep_create:
program = self._add_feed_fetch_ops(
program=program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)
sys.stderr.write('created a program\n')

self._feed_data(program, feed, feed_var_name, scope)
self.executor.run(program.desc, scope, 0, True, True)
Expand Down
24 changes: 16 additions & 8 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,9 @@ def clone_variable(self, var):


class Program(object):
def __init__(self):
def __init__(self, is_test=False):
if is_test:
return
self.desc = core.ProgramDesc()
self.blocks = [Block(self, 0)]
self.current_block_idx = 0
Expand Down Expand Up @@ -1101,13 +1103,19 @@ def clone(self, for_test=False):
if for_test:
p = self.inference_optimize()
else:
p = Program()
p.desc = core.ProgramDesc(self.desc)
p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())]
p.sync_with_cpp()

p.copy_param_info_from(self)
p.copy_data_info_from(self)
p = Program(is_test=True)
p.desc = core.ProgramDesc()
p.blocks = [Block(p, 0)]
p.current_block_idx = 0
p._seed = 0
p._current_role = core.op_proto_and_checker_maker.OpRole.Forward
p._op_role_var = []
p.desc = core.ProgramDesc(self.desc, True)
# p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())]
# p.sync_with_cpp()

# p.copy_param_info_from(self)
# p.copy_data_info_from(self)
return p

def prune(self, targets):
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/fluid/tests/book/test_word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def train_loop(main_program):
for batch_id, data in enumerate(train_reader()):
avg_cost_np = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_cost])
fetch_list=[avg_cost],
use_program_cache=False,
keep_create=True)
sys.stderr.write('pass: %d, batch_id: %d, cost: %s\n' %
(pass_id, batch_id, avg_cost_np))
if math.isnan(float(avg_cost_np[0])):
Expand Down

0 comments on commit ca8913f

Please sign in to comment.