Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable hidden/cell state initialization and enhance unit testing in LSTM operator. #5429

Merged
merged 7 commits into from
Nov 9, 2017

Conversation

qingqing01
Copy link
Contributor

@qingqing01 qingqing01 commented Nov 7, 2017

Fix #5420
Fix #5428

  • Enable hidden/cell state initialization in LSTM operator.
  • Users can disable peephole connections.
  • Enhance unit testing.
  • Also, fix the attribute names

@qingqing01 qingqing01 changed the title Enable initial hidden/cell state and enhance unit testing in LSTM Operator. Enable hidden/cell state initialization and enhance unit testing in LSTM Operator. Nov 7, 2017
@qingqing01 qingqing01 changed the title Enable hidden/cell state initialization and enhance unit testing in LSTM Operator. Enable hidden/cell state initialization and enhance unit testing in LSTM operator. Nov 7, 2017
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(device_ctx, pre_hidden_t, false, *weight, false,
static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
} else if (hidden_t0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since H0 is optional, what will happen if hidden_t0 is nullptr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hidden_t0 is nullptr, means there is no hidden initialization, that is to say the hidden_t0 is zero. And there is no need to compute matmul in line 122.

"whether to enable diagonal/peephole connections.")
.SetDefault(true);
AddAttr<bool>("isReverse",
"(bool, default False) "
AddAttr<bool>("is_reverse",
Copy link
Member

@QiJune QiJune Nov 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does LSTM operator only support one level RNN? or can have Nesting LSTM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It only supports one level.

Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#5420#5428 两个issue内容能补充一下么


auto c0_g_name = framework::GradVarName("C0");
if (ctx->HasOutput(c0_g_name))
ctx->SetOutputDim(c0_g_name, ctx->GetInputDim("C0"));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

248-267能写成(函数名称可再修改下):

auto setOutputDim = [](string name) {
  auto g_name = framework::GradVarName(name);
  if (ctx->HasOutput(g_name))
     ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
};
setOutputDim("Input");
setOutputDim("Weight");
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -109,15 +115,23 @@ class LSTMKernel : public framework::OpKernel<T> {

int cur_batch_size = bend - bstart;

if (n != 0) {
if (n > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (n)就行了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

统一改成 n > 0,表明第2时刻开始,意义明确。

row_shuffle(device_ctx, *hidden_t0, order, ordered_h0, true);
math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false,
static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

126-133行的代码,和86-92行的代码,是实现什么功能呢?

  • 需要在这里简单加一下公式或注释么?
  • 这两段代码,除了最后一行不一样,其他实现的功能是一样的(后面190-196行也是)。需要封装一个函数么?
  • 126行和127行,需要在每个for循环里面都定义一遍么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要在这里简单加一下公式或注释么?

在代码中加了注释。

这两段代码,除了最后一行不一样,其他实现的功能是一样的(后面190-196行也是)。需要封装一个函数么?

Done.

126行和127行,需要在每个for循环里面都定义一遍么?

不会的,只有n == 0的时候才走这个分支,只执行一边。

auto& device_ctx = ctx.device_context();
math::SetConstant<Place, T> zero;
if (weight_g) {
weight_g->mutable_data<T>(ctx.GetPlace());
zero(device_ctx, weight_g, static_cast<T>(0.0));
}

Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些变量需要加注释说明一下各自代表什么吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

zero(device_ctx, bias_g, static_cast<T>(0.0));
}
if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) {
T* bias_g_data = bias_g->data<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的逻辑是,只要221行的三个条件中有一个是false,lstm_grad.checkXXX就为空么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是这样的。 没有bias,backward就不会bias_grad;lstm_grad.checkXXX 表示的peephole connection的weight, 是个向量,存储在bias里,如果use_peepholes = False, lstm_grad.checkXXX 就是null。

@@ -226,9 +258,9 @@ class LSTMGradKernel : public framework::OpKernel<T> {
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

235-259行也可以使用一个函数来缩短代码。同上。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -250,23 +282,32 @@ class LSTMGradKernel : public framework::OpKernel<T> {
lstm_grad.gateGrad = gate_g.data<T>();
lstm_grad.outputGrad = out_g.data<T>();

if (n) {
if (n > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n就行了,为什么要改成n>0呢

if (c0) {
lstm_value.prevStateValue = ordered_c0.data<T>();
} else {
lstm_value.prevStateValue = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lstm_value.prevStateValue = c0 ? ordered_c0.data(): nullptr;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'],
max_relative_error=5e-4)

# In order to speed up, skip following testing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问这几个单测很耗时么,大概需要多少时间呢
如果不测的话,这几个单测可以都删掉,下同。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个是梯度检测,时间相对较长,具体时间没统计。
改了下单测方式,把这些从基类移到一个子类里测了。

Copy link
Contributor Author

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luotao1 Thanks very much.


auto c0_g_name = framework::GradVarName("C0");
if (ctx->HasOutput(c0_g_name))
ctx->SetOutputDim(c0_g_name, ctx->GetInputDim("C0"));
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -109,15 +115,23 @@ class LSTMKernel : public framework::OpKernel<T> {

int cur_batch_size = bend - bstart;

if (n != 0) {
if (n > 0) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

统一改成 n > 0,表明第2时刻开始,意义明确。

row_shuffle(device_ctx, *hidden_t0, order, ordered_h0, true);
math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false,
static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要在这里简单加一下公式或注释么?

在代码中加了注释。

这两段代码,除了最后一行不一样,其他实现的功能是一样的(后面190-196行也是)。需要封装一个函数么?

Done.

126行和127行,需要在每个for循环里面都定义一遍么?

不会的,只有n == 0的时候才走这个分支,只执行一边。

auto& device_ctx = ctx.device_context();
math::SetConstant<Place, T> zero;
if (weight_g) {
weight_g->mutable_data<T>(ctx.GetPlace());
zero(device_ctx, weight_g, static_cast<T>(0.0));
}

Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

zero(device_ctx, bias_g, static_cast<T>(0.0));
}
if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) {
T* bias_g_data = bias_g->data<T>();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是这样的。 没有bias,backward就不会bias_grad;lstm_grad.checkXXX 表示的peephole connection的weight, 是个向量,存储在bias里,如果use_peepholes = False, lstm_grad.checkXXX 就是null。

@@ -226,9 +258,9 @@ class LSTMGradKernel : public framework::OpKernel<T> {
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (c0) {
lstm_value.prevStateValue = ordered_c0.data<T>();
} else {
lstm_value.prevStateValue = nullptr;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'],
max_relative_error=5e-4)

# In order to speed up, skip following testing
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个是梯度检测,时间相对较长,具体时间没统计。
改了下单测方式,把这些从基类移到一个子类里测了。

Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@qingqing01 qingqing01 merged commit 41d0533 into PaddlePaddle:develop Nov 9, 2017
@qingqing01 qingqing01 deleted the lstm_fix branch March 7, 2018 12:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enhance unit testing for LSTM operator. Enable hidden state and cell state initialization in LSTM Operator.
3 participants