-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
expose h0 in dynamic_lstm #11391
expose h0 in dynamic_lstm #11391
Conversation
python/paddle/fluid/layers/nn.py
Outdated
@@ -387,12 +395,19 @@ def dynamic_lstm(input, | |||
cell = helper.create_tmp_variable(dtype) | |||
batch_gate = helper.create_tmp_variable(dtype) | |||
batch_cell_pre_act = helper.create_tmp_variable(dtype) | |||
inputs = {'Input': input, 'Weight': weight, 'Bias': bias} | |||
if h_0: | |||
assert h_0.shape == (-1, size), \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first dimension represents placeholder of batch size and may not be -1, input.shape[0]
might be better than -1
thus we needn't assume users' batch size setting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, if the shape has been checked in the InferShape
, maybe the Python doesn't need to check it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Infershape
doesn't check the shape, we can check it in Python code, it's more clear for the user.
python/paddle/fluid/layers/nn.py
Outdated
name=None): | ||
name=None, | ||
h_0=None, | ||
c_0=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put the name at the end?
Remove the TODO in line 264.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
python/paddle/fluid/layers/nn.py
Outdated
@@ -387,12 +395,19 @@ def dynamic_lstm(input, | |||
cell = helper.create_tmp_variable(dtype) | |||
batch_gate = helper.create_tmp_variable(dtype) | |||
batch_cell_pre_act = helper.create_tmp_variable(dtype) | |||
inputs = {'Input': input, 'Weight': weight, 'Bias': bias} | |||
if h_0: | |||
assert h_0.shape == (-1, size), \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, if the shape has been checked in the InferShape
, maybe the Python doesn't need to check it.
python/paddle/fluid/layers/nn.py
Outdated
batch size and D is the hidden size. | ||
c_0(Variable): The initial cell state is an optional | ||
input. This is a tensor with shape (N x D), where N is the | ||
batch size. `h_0` and `c_0` can be NULL but only at the same time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tell the users if not set, the default is zero.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
python/paddle/fluid/layers/nn.py
Outdated
inputs = {'Input': input, 'Weight': weight, 'Bias': bias} | ||
if h_0 != None: | ||
assert h_0.shape == ( | ||
size, size), 'The shape of h0 should be(%d, %d)' % (size, size) | ||
batch_size, size | ||
), 'The shape of h0 should be(batch_size, %d)' % size | ||
inputs['h0'] = h_0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also should be inputs['H0'] = h_0
like in LSTM. Sorry for leaving out this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for the fix.
Fixed #11340
Fixed #11339
Related issue: #11335