Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable hidden/cell state initialization and enhance unit testing in LSTM operator. #5429

Merged
merged 7 commits into from
Nov 9, 2017
Merged
76 changes: 45 additions & 31 deletions paddle/operators/lstm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class LSTMOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTM should not be null.");

PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
Expand Down Expand Up @@ -59,11 +64,13 @@ class LSTMOp : public framework::OperatorWithKernel {
"The second dimension of Input(Weight) "
"should be 4 * %d.",
frame_size);

auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
if (ctx->Attrs().Get<bool>("usePeepholes")) {

if (ctx->Attrs().Get<bool>("use_peepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection",
Expand All @@ -74,6 +81,7 @@ class LSTMOp : public framework::OperatorWithKernel {
"4 * %d if disable peepholes connection",
frame_size);
}

framework::DDim out_dims({in_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
Expand Down Expand Up @@ -118,14 +126,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Bias",
"(Tensor) the learnable weights, which contains two parts: "
"input-hidden bias weight and peephole connections weight if "
"setting `usePeepholes` True. "
"1. `usePeepholes = False` "
"setting `use_peepholes` True. "
"1. `use_peepholes = False` "
" - The shape is (1 x 4D). "
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `usePeepholes = True` "
"2. `use_peepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.")
.AsDispensable();
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
AddOutput("Hidden",
"(LoDTensor) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
Expand All @@ -145,29 +152,32 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) This LoDTensor is obtained in the forward and used "
"in the backward.")
.AsIntermediate();
AddAttr<bool>("usePeepholes",
"(bool, default True) "
AddAttr<bool>("use_peepholes",
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
.SetDefault(true);
AddAttr<bool>("isReverse",
"(bool, default False) "
AddAttr<bool>("is_reverse",
Copy link
Member

@QiJune QiJune Nov 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does LSTM operator only support one level RNN? or can have Nesting LSTM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It only supports one level.

"(bool, defalut: False) "
"whether to compute reversed LSTM.")
.SetDefault(false);
AddAttr<std::string>(
"gateActivation",
"(string, default sigmoid)"
"gate_activation",
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by default.")
.SetDefault("sigmoid");
AddAttr<std::string>("cellActivation",
"(string, default tanh)"
.SetDefault("sigmoid")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("cell_activation",
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh");
AddAttr<std::string>("candidateActivation",
"(string, default tanh)"
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("candidate_activation",
"(string, default: tanh)"
"The activation for candidate hidden state, "
"`tanh` by default.")
.SetDefault("tanh");
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddComment(R"DOC(
Long-Short Term Memory (LSTM) Operator.

Expand Down Expand Up @@ -203,7 +213,7 @@ are the cell input and cell output activation functions and `tanh` is usually
used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.

Set usePeepholes False to disable peephole connection
Set `use_peepholes` False to disable peephole connection
(http://www.bioinf.jku.at/publications/older/2604.pdf). The formula
is omitted here.

Expand All @@ -226,23 +236,27 @@ class LSTMGradOp : public framework::OperatorWithKernel {
"Input(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cell"),
"Input(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTM should not be null.");

PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
"Input(BatchGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTM should not be null.");

auto in_g_name = framework::GradVarName("Input");
if (ctx->HasOutput(in_g_name))
ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input"));

auto w_g_name = framework::GradVarName("Weight");
if (ctx->HasOutput(w_g_name))
ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight"));

auto b_g_name = framework::GradVarName("Bias");
if (ctx->HasOutput(b_g_name))
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias"));
auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name);
if (ctx->HasOutput(g_name))
ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
};

SetOutGradDim("Input");
SetOutGradDim("Weight");
SetOutGradDim("Bias");
SetOutGradDim("H0");
SetOutGradDim("C0");
}

protected:
Expand Down
Loading