Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose h0 in dynamic_lstm #11391

Merged
merged 4 commits into from
Jun 13, 2018
Merged

Conversation

Yancey1989
Copy link
Contributor

@Yancey1989 Yancey1989 commented Jun 12, 2018

Fixed #11340
Fixed #11339
Related issue: #11335

@@ -387,12 +395,19 @@ def dynamic_lstm(input,
cell = helper.create_tmp_variable(dtype)
batch_gate = helper.create_tmp_variable(dtype)
batch_cell_pre_act = helper.create_tmp_variable(dtype)
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
if h_0:
assert h_0.shape == (-1, size), \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first dimension represents placeholder of batch size and may not be -1, input.shape[0] might be better than -1 thus we needn't assume users' batch size setting.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if the shape has been checked in the InferShape, maybe the Python doesn't need to check it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Infershape doesn't check the shape, we can check it in Python code, it's more clear for the user.

name=None):
name=None,
h_0=None,
c_0=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put the name at the end?

Remove the TODO in line 264.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -387,12 +395,19 @@ def dynamic_lstm(input,
cell = helper.create_tmp_variable(dtype)
batch_gate = helper.create_tmp_variable(dtype)
batch_cell_pre_act = helper.create_tmp_variable(dtype)
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
if h_0:
assert h_0.shape == (-1, size), \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if the shape has been checked in the InferShape, maybe the Python doesn't need to check it.

batch size and D is the hidden size.
c_0(Variable): The initial cell state is an optional
input. This is a tensor with shape (N x D), where N is the
batch size. `h_0` and `c_0` can be NULL but only at the same time.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tell the users if not set, the default is zero.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
if h_0 != None:
assert h_0.shape == (
size, size), 'The shape of h0 should be(%d, %d)' % (size, size)
batch_size, size
), 'The shape of h0 should be(batch_size, %d)' % size
inputs['h0'] = h_0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also should be inputs['H0'] = h_0 like in LSTM. Sorry for leaving out this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Done.

Copy link
Contributor

@guoshengCS guoshengCS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the fix.

@Yancey1989 Yancey1989 merged commit 14e8337 into PaddlePaddle:develop Jun 13, 2018
@Yancey1989 Yancey1989 deleted the expose_h0 branch June 13, 2018 03:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants