Skip to content

Commit

Permalink
Merge pull request #10671 from chengduoZH/fix_fetch_op_handle
Browse files Browse the repository at this point in the history
Refine fetch op handle
  • Loading branch information
reyoung authored May 16, 2018
2 parents 7ebb246 + 624caee commit 8b1b756
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
15 changes: 8 additions & 7 deletions paddle/fluid/framework/details/fetch_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,18 @@ void FetchOpHandle::RunImpl() {
WaitInputVarGenerated(platform::CPUPlace());

tensors_.resize(inputs_.size());
auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
auto &var_name = var_handle->name_;
platform::CPUPlace cpu;
auto &scopes = *local_scopes_;

for (size_t i = 0; i < scopes.size(); ++i) {
auto &scope = scopes[i];
auto *var =
scope->FindVar(kLocalExecScopeName)->Get<Scope *>()->FindVar(var_name);
for (size_t i = 0; i < inputs_.size(); ++i) {
auto *var_handle = static_cast<VarHandle *>(inputs_[i]);
auto &scope = scopes.at(var_handle->scope_idx_);
auto *var = scope->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope",
var_name);
var_handle->name_);

auto &t = var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(t.place())) {
#ifdef PADDLE_WITH_CUDA
Expand Down
5 changes: 2 additions & 3 deletions python/paddle/fluid/tests/unittests/test_parallel_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def test_update_sparse_parameter_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy)
is_sparse=True, build_strategy=build_strategy)

def test_update_dense_parameter_reduce(self):
build_strategy = fluid.BuildStrategy()
Expand Down Expand Up @@ -849,8 +849,7 @@ def parallel_exe(self, train_inputs, seed):
assert not math.isnan(np.sum(ret[i])) and \
not math.isinf(np.sum(ret[i]))

@unittest.skip("this test is buggy")
def test_feed(self):
def test_fetch_op(self):
tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16)
tst_reader_iter = tst_reader()

Expand Down

0 comments on commit 8b1b756

Please sign in to comment.