-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable hidden/cell state initialization and enhance unit testing in LSTM operator. #5429
Conversation
1. user can disable peephole connections. 2. not calculate some gradients.
int pre_h_start = static_cast<int>(batch_starts[n - 1]); | ||
int pre_h_end = pre_h_start + cur_batch_size; | ||
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end); | ||
math::matmul<Place, T>(device_ctx, pre_hidden_t, false, *weight, false, | ||
static_cast<T>(1.0), &gate_t, | ||
static_cast<T>(1.0)); | ||
} else if (hidden_t0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since H0
is optional, what will happen if hidden_t0
is nullptr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hidden_t0
is nullptr, means there is no hidden initialization, that is to say the hidden_t0
is zero. And there is no need to compute matmul
in line 122.
"whether to enable diagonal/peephole connections.") | ||
.SetDefault(true); | ||
AddAttr<bool>("isReverse", | ||
"(bool, default False) " | ||
AddAttr<bool>("is_reverse", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does LSTM operator only support one level RNN? or can have Nesting LSTM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It only supports one level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle/operators/lstm_op.cc
Outdated
|
||
auto c0_g_name = framework::GradVarName("C0"); | ||
if (ctx->HasOutput(c0_g_name)) | ||
ctx->SetOutputDim(c0_g_name, ctx->GetInputDim("C0")); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
248-267能写成(函数名称可再修改下):
auto setOutputDim = [](string name) {
auto g_name = framework::GradVarName(name);
if (ctx->HasOutput(g_name))
ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
};
setOutputDim("Input");
setOutputDim("Weight");
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -109,15 +115,23 @@ class LSTMKernel : public framework::OpKernel<T> { | |||
|
|||
int cur_batch_size = bend - bstart; | |||
|
|||
if (n != 0) { | |||
if (n > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (n)
就行了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
统一改成 n > 0
,表明第2时刻开始,意义明确。
paddle/operators/lstm_op.h
Outdated
row_shuffle(device_ctx, *hidden_t0, order, ordered_h0, true); | ||
math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false, | ||
static_cast<T>(1.0), &gate_t, | ||
static_cast<T>(1.0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
126-133行的代码,和86-92行的代码,是实现什么功能呢?
- 需要在这里简单加一下公式或注释么?
- 这两段代码,除了最后一行不一样,其他实现的功能是一样的(后面190-196行也是)。需要封装一个函数么?
- 126行和127行,需要在每个for循环里面都定义一遍么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要在这里简单加一下公式或注释么?
在代码中加了注释。
这两段代码,除了最后一行不一样,其他实现的功能是一样的(后面190-196行也是)。需要封装一个函数么?
Done.
126行和127行,需要在每个for循环里面都定义一遍么?
不会的,只有n == 0
的时候才走这个分支,只执行一边。
paddle/operators/lstm_op.h
Outdated
auto& device_ctx = ctx.device_context(); | ||
math::SetConstant<Place, T> zero; | ||
if (weight_g) { | ||
weight_g->mutable_data<T>(ctx.GetPlace()); | ||
zero(device_ctx, weight_g, static_cast<T>(0.0)); | ||
} | ||
|
||
Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些变量需要加注释说明一下各自代表什么吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
zero(device_ctx, bias_g, static_cast<T>(0.0)); | ||
} | ||
if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) { | ||
T* bias_g_data = bias_g->data<T>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的逻辑是,只要221行的三个条件中有一个是false,lstm_grad.checkXXX就为空么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是这样的。 没有bias,backward就不会bias_grad;lstm_grad.checkXXX
表示的peephole connection的weight, 是个向量,存储在bias里,如果use_peepholes = False
, lstm_grad.checkXXX
就是null。
@@ -226,9 +258,9 @@ class LSTMGradKernel : public framework::OpKernel<T> { | |||
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace()); | |||
batch_gate_g.set_lod(batch_gate->lod()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
235-259行也可以使用一个函数来缩短代码。同上。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -250,23 +282,32 @@ class LSTMGradKernel : public framework::OpKernel<T> { | |||
lstm_grad.gateGrad = gate_g.data<T>(); | |||
lstm_grad.outputGrad = out_g.data<T>(); | |||
|
|||
if (n) { | |||
if (n > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n就行了,为什么要改成n>0呢
paddle/operators/lstm_op.h
Outdated
if (c0) { | ||
lstm_value.prevStateValue = ordered_c0.data<T>(); | ||
} else { | ||
lstm_value.prevStateValue = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lstm_value.prevStateValue = c0 ? ordered_c0.data(): nullptr;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'], | ||
max_relative_error=5e-4) | ||
|
||
# In order to speed up, skip following testing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请问这几个单测很耗时么,大概需要多少时间呢
如果不测的话,这几个单测可以都删掉,下同。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几个是梯度检测,时间相对较长,具体时间没统计。
改了下单测方式,把这些从基类移到一个子类里测了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@luotao1 Thanks very much.
paddle/operators/lstm_op.cc
Outdated
|
||
auto c0_g_name = framework::GradVarName("C0"); | ||
if (ctx->HasOutput(c0_g_name)) | ||
ctx->SetOutputDim(c0_g_name, ctx->GetInputDim("C0")); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -109,15 +115,23 @@ class LSTMKernel : public framework::OpKernel<T> { | |||
|
|||
int cur_batch_size = bend - bstart; | |||
|
|||
if (n != 0) { | |||
if (n > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
统一改成 n > 0
,表明第2时刻开始,意义明确。
paddle/operators/lstm_op.h
Outdated
row_shuffle(device_ctx, *hidden_t0, order, ordered_h0, true); | ||
math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false, | ||
static_cast<T>(1.0), &gate_t, | ||
static_cast<T>(1.0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要在这里简单加一下公式或注释么?
在代码中加了注释。
这两段代码,除了最后一行不一样,其他实现的功能是一样的(后面190-196行也是)。需要封装一个函数么?
Done.
126行和127行,需要在每个for循环里面都定义一遍么?
不会的,只有n == 0
的时候才走这个分支,只执行一边。
paddle/operators/lstm_op.h
Outdated
auto& device_ctx = ctx.device_context(); | ||
math::SetConstant<Place, T> zero; | ||
if (weight_g) { | ||
weight_g->mutable_data<T>(ctx.GetPlace()); | ||
zero(device_ctx, weight_g, static_cast<T>(0.0)); | ||
} | ||
|
||
Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
zero(device_ctx, bias_g, static_cast<T>(0.0)); | ||
} | ||
if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) { | ||
T* bias_g_data = bias_g->data<T>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是这样的。 没有bias,backward就不会bias_grad;lstm_grad.checkXXX
表示的peephole connection的weight, 是个向量,存储在bias里,如果use_peepholes = False
, lstm_grad.checkXXX
就是null。
@@ -226,9 +258,9 @@ class LSTMGradKernel : public framework::OpKernel<T> { | |||
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace()); | |||
batch_gate_g.set_lod(batch_gate->lod()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.h
Outdated
if (c0) { | ||
lstm_value.prevStateValue = ordered_c0.data<T>(); | ||
} else { | ||
lstm_value.prevStateValue = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'], | ||
max_relative_error=5e-4) | ||
|
||
# In order to speed up, skip following testing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几个是梯度检测,时间相对较长,具体时间没统计。
改了下单测方式,把这些从基类移到一个子类里测了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Fix #5420
Fix #5428