Skip to content

Commit

Permalink
fix while_grad_op first step loss lod problem (#7490)
Browse files Browse the repository at this point in the history
* fix while_grad_op first step loss lod problem

* optimize code
  • Loading branch information
jacquesqiao authored Jan 15, 2018
1 parent 59bc4c4 commit 9deb175
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
1 change: 1 addition & 0 deletions paddle/operators/shrink_rnn_memory_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
}
}
dx_tensor.set_lod(x_tensor.lod());
}
};

Expand Down
16 changes: 10 additions & 6 deletions paddle/operators/while_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class WhileGradOp : public framework::OperatorBase {
for (size_t i = 0; i < outside_og_names.size(); ++i) {
auto outside_og_name = outside_og_names[i];
auto inside_og_name = inside_og_names[i];
VLOG(10) << "Linking outside " << outside_og_name << " --> inside "
<< inside_og_name;
VLOG(8) << "Linking outside " << outside_og_name << " --> inside "
<< inside_og_name;
auto &og_outside =
detail::Ref(scope.FindVar(outside_og_name),
"Cannot find Outside Gradient %s", outside_og_name);
Expand All @@ -141,11 +141,11 @@ class WhileGradOp : public framework::OperatorBase {
auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
auto &inside_array =
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
VLOG(10) << outside_og_name << " size = " << outside_array.size();
VLOG(8) << outside_og_name << " size = " << outside_array.size();
inside_array.resize(outside_array.size());

for (size_t j = 0; j < inside_array.size(); ++j) {
VLOG(10) << j << " " << outside_array[j].numel();
VLOG(8) << j << " " << outside_array[j].numel();
if (outside_array[j].numel() != 0) {
inside_array[j].set_lod(outside_array[j].lod());
inside_array[j].ShareDataWith(outside_array[j]);
Expand Down Expand Up @@ -187,10 +187,14 @@ class WhileGradOp : public framework::OperatorBase {
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f;

auto var_name = pg_names[param_id];
auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", framework::VariableNameMap{},
{{"Out", {pg_names[param_id]}}}, attrs);
{{"Out", {var_name}}}, attrs);
zero_op->Run(scope, dev_place);
scope.FindVar(var_name)
->GetMutable<framework::LoDTensor>()
->set_lod(inside_tensor.lod());
}
}

Expand Down Expand Up @@ -231,7 +235,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
for (auto &each_ig : igs) {
if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
VLOG(10) << "Ignore " << each_ig;
VLOG(8) << "Ignore " << each_ig;
each_ig = framework::kEmptyVarName;
}
}
Expand Down

0 comments on commit 9deb175

Please sign in to comment.