-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/shrink memory op #5419
Feature/shrink memory op #5419
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
paddle/operators/shrink_state_op.cc
Outdated
PADDLE_ENFORCE(context->HasInput("X")); | ||
PADDLE_ENFORCE(context->HasInput("I")); | ||
PADDLE_ENFORCE(context->HasInput("RankTable")); | ||
context->SetOutputDim("Out", context->GetInputDim("X")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ShrinkStateOpInferShape
looks difficult to implement.
During the compilation time, we don't have enough information to infer shape Out
, because we don't know what will be filled in the RankTable
. During the runtime, we could use RankTable
, but it will make the code different between these two phrases infer shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RankTable
is only related with the value of Out
's first dimension. Other dimensions of Out
is just as same as X
. Inferring the other dimensions and leave the first dimension to the runtime, that is exactly InferShape
should do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Good point.
paddle/operators/shrink_state_op.cc
Outdated
auto *out_var = scope.FindVar(Output("Out")); | ||
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set"); | ||
auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>(); | ||
out_tensor.ShareDataWith(x_tensor.Slice(0, dst_num_rows)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: the shape of out_tensor
has been changed by ShareDataWith
. This violates the general design that a tensor's shape should only be modified by InferShape
. But as pointed out below, it is relatively hard to infer the shape of out_tensor
, so we consider this behavior as an exception.
Used for shrink memories state in DyRNN. The height of state could be shrinked after running a step block.
2341662
to
2dd91dd
Compare
paddle/operators/shrink_state_op.cc
Outdated
|
||
{ | ||
auto &rank_items = rank_table.items(); | ||
for (auto &rank_item : rank_items) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using binary search can make it faster when rank_table is big.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/shrink_state_op.cc
Outdated
namespace paddle { | ||
namespace operators { | ||
|
||
class ShrinkStateOp : public ArrayOp { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think ShrinkStateOp
is a good name. We don't really 'shrink' the memory block, we only do a slice on the original one. And state
is also confusing. Maybe we can call it RearrangeRnnMemoryOp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should indicate RNN
in the name since the Op will only be used in RNN. We could even put DyRNN
in the name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/shrink_state_op.cc
Outdated
PADDLE_ENFORCE(context->HasInput("X")); | ||
PADDLE_ENFORCE(context->HasInput("I")); | ||
PADDLE_ENFORCE(context->HasInput("RankTable")); | ||
context->SetOutputDim("Out", context->GetInputDim("X")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RankTable
is only related with the value of Out
's first dimension. Other dimensions of Out
is just as same as X
. Inferring the other dimensions and leave the first dimension to the runtime, that is exactly InferShape
should do.
paddle/operators/shrink_state_op.cc
Outdated
const platform::DeviceContext &dev_ctx) const override { | ||
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out"))); | ||
auto dx_name = Output(framework::GradVarName("X")); | ||
auto *dx_var = scope.FindVar(dx_name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not
auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/shrink_state_op.cc
Outdated
auto height = dout_tensor.dims()[0]; | ||
dx_tensor.Slice(0, static_cast<int>(height)) | ||
.CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx); | ||
if (height < dout_tensor.dims()[0]) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How could height < dout_tensor.dims()[0]
? In line 110: auto height = dout_tensor.dims()[0];
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -115,20 +85,21 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { | |||
public: | |||
void operator()(const framework::OpDescBind &op_desc, | |||
framework::BlockDescBind *block) const override { | |||
VLOG(10) << "I am here?"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be more meaningful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
No description provided.