-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement bidirectional lstm #1035
Conversation
burn-core/src/nn/rnn/lstm.rs
Outdated
if self.bidirectional { | ||
(input_gate_bw, forget_gate_bw, output_gate_bw, cell_gate_bw) = ( | ||
Some(new_gate()), | ||
Some(new_gate()), | ||
Some(new_gate()), | ||
Some(new_gate()), | ||
); | ||
} else { | ||
(input_gate_bw, forget_gate_bw, output_gate_bw, cell_gate_bw) = | ||
(None, None, None, None); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if it's a boolean condition, I feel like a match might be easier to read, but using new_gate
is a big win.
burn-core/src/nn/rnn/lstm.rs
Outdated
@@ -110,7 +124,7 @@ impl<B: Backend> Lstm<B> { | |||
/// Parameters: | |||
/// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. | |||
/// state: An optional tuple of tensors representing the initial cell state and hidden state. | |||
/// Each state tensor has shape [batch_size, hidden_size]. | |||
/// Each state tensor has shape [num_directions, batch_size, hidden_size]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like it would be better to have a BiDirectionalLstm
module instead of one module that does both. The bi-directional module might just have two Lstm modules instead of duplicating all the gates. We could also make it more Rusty by using Kind types:
pub trait LstmKind<B: Backend> {
fn forward(...) -> ...;
}
#[derive(Module)]
pub struct UniDirectional{
input_gate: GateController<B>,
forget_gate: GateController<B>,
output_gate: GateController<B>,
cell_gate: GateController<B>,
};
#[derive(Module)]
pub struct BiDirectional<B: Backend> {
forward: Lstm<B, UniDirectional>,
backward: Lstm<B, UniFirectional>,
}
pub struct Lstm<B: Backend, K: LstmKind<B> = UniDirectional<B>> {
state: K,
}
impl<B: Backend, K: LstmKind<B>> Lstm<B, K> {
pub fn forward(...) -> ... {
self.kind.forward(...) // static dispatch to the right forward pass depending on the kind.
}
}
impl<B: Backend> LstmKind<B> for UniDirectional<B> {
fn forward(...) -> ... {
// uni directional forward pass
}
}
impl<B: Backend> LstmKind<B> for BiDirectional<B> {
fn forward(...) -> ... {
// bi directional forward pass
}
}
This is equivalent to having two different modules for the uni directional and bi directional lstm but with syntax sugar so users can also use Lstm<B>
for uni directional or Lstm<B, BiDirectional>
when they want both directions. This is the same pattern that we use for the Tensor API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would also make it backward compatible in term of code, but not in term of state (Recorder).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@agelas what are your thoughts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nathanielsimard I agree that this pattern fits better with the overall approach that is usually taken. I think composing LSTM modules like this is a bit more elegant than having various toggles in the forward
pass, and we can reuse most of the prior implementation that way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nathanielsimard The problem I have is that the forward()
for UniDirectional
and BiDirectional
have different function signatures due to the inconsistency in the shape of the state
. This seems not easy to resolve, so it's probably best to keep the LSTM implementation unchanged and let users implement bidirectional LSTM or multi-layer LSTM on their own.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how often bidirectional LSTM is used, but could we have an implementation totally separated from unidirectional LSTM, and just offer both? Because I do agree that messing with function signatures might get a bit bothersome
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pattern I'm proposing creates two different types, so there are no problems with the type signature. It's just a way for both types to use the LSTM nomenclature. It's probably easier to start with two different types, then we can add the pattern afterward if it makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nathanielsimard This pattern is a bit tricky for me, so I think it's best for me to create a new type for bidirectional LSTM. I sincerely hope you can improve it if possible. Thank you very much!
burn-core/src/nn/rnn/lstm.rs
Outdated
Some((cell_state, hidden_state)) => (cell_state, hidden_state), | ||
let [batch_size, seq_length, _] = batched_input.shape().dims; | ||
let mut batched_cell_state = | ||
Tensor::zeros_device([batch_size, seq_length, self.d_hidden * num_directions], &device); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nathanielsimard Is this necessary? The implementation is already tied to a device vis-à-vis the backend right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my understanding, Module::fork
can only fork the tensor in the module struct field to the given device. So, when creating a new Tensor in the forward()
, it is necessary to specify the device?
By the way, now I feel that bidirectional LSTM may not be a common requirement for everyone. If someone needs bidirectional LSTM or multi-layer LSTM, it's probably best for them to implement it themselves. Additionally, keeping the implementation of Lstm unchanged is essential to avoid disrupting compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The device is necessary when creating a new tensor; otherwise, the new tensor will be created on the default device, but not necessarily on the same device as the module.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1035 +/- ##
==========================================
+ Coverage 86.41% 86.51% +0.09%
==========================================
Files 696 696
Lines 81131 81499 +368
==========================================
+ Hits 70112 70508 +396
+ Misses 11019 10991 -28 ☔ View full report in Codecov by Sentry. |
burn-core/src/nn/rnn/lstm.rs
Outdated
input_gate: GateController<B>, | ||
forget_gate: GateController<B>, | ||
output_gate: GateController<B>, | ||
cell_gate: GateController<B>, | ||
input_gate_reverse: GateController<B>, | ||
forget_gate_reverse: GateController<B>, | ||
output_gate_reverse: GateController<B>, | ||
cell_gate_reverse: GateController<B>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it wasn't possible to use two Lstm modules here?
3579493
to
1fd07fc
Compare
I think the current implementation of bidirectional lstm can work during inference, but the implementation cannot update the parameters of the model in backward propagation due to #1098. Using |
Closing this ticket and linking to this ticket: #1537. So someone else can pick up. |
Wgpu test failed, but I don't know why... |
What was the error in wgpu? |
Hm ok well on ubuntu-22.04 everything seems fine when I run the tests locally. Given that ubuntu and windows work, maybe it hints at possible compatibility issues or differences in how Metal handles resource management compared to Vulkan (ie Linux) or DirectX (Windows)? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it! I have a few comments, but nothing major. Would also like @laggui to review this.
let new_gate = || { | ||
GateController::new( | ||
self.d_input, | ||
d_output, | ||
self.bias, | ||
self.initializer.clone(), | ||
device, | ||
) | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yesss!
crates/burn-core/src/nn/rnn/lstm.rs
Outdated
/// ## Returns: | ||
/// A tuple of tensors, where the first tensor represents the cell states and | ||
/// the second tensor represents the hidden states for each sequence element. | ||
/// Both output tensors have the shape `[batch_size, sequence_length, hidden_size]`. | ||
pub fn forward( | ||
&self, | ||
batched_input: Tensor<B, 3>, | ||
state: Option<(Tensor<B, 2>, Tensor<B, 2>)>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventhough that's a breaking change, I think it might be beneficial to create a type for the state.
pub struct LstmState {
pub cell: Tensor<B, 2>,
pub hidden: Tensor<B, 2>,
}
We can remove the optional and implement Default
instead. It also gives us a space to document what each element of the state is used for.
We could also do the same for the return type, where it's easy to make mistake (chosing the wrong returned tensor).
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I quite like that idea, especially for the return type!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you agree to the modifications according to the PyTorch outputs, batched_cell_state
is unnecessary. For the state, we only need to return the hidden state and cell state of the last time step.
crates/burn-core/src/nn/rnn/lstm.rs
Outdated
let mut batched_cell_state = Tensor::zeros([batch_size, seq_length, self.d_hidden], device); | ||
let mut batched_hidden_state = | ||
Tensor::zeros([batch_size, seq_length, self.d_hidden], device); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not linked to the current change, but is zero necessary or empty would be enough here?
crates/burn-core/src/nn/rnn/lstm.rs
Outdated
input_biases: [f32; D1], | ||
hidden_weights: [[f32; D1]; D1], | ||
hidden_biases: [f32; D1], | ||
device: &<TestBackend as Backend>::Device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a detail, but we can use &burn_tensor::Device<TestBackend>
instead of this notation, just a bit prettier, especially when importing burn_tensor::Device
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great rewrite to add the BiLstm
! I don't see any issues 🙂 Just a minor comment regarding the state as pointed out by Nath.
crates/burn-core/src/nn/rnn/lstm.rs
Outdated
/// ## Returns: | ||
/// A tuple of tensors, where the first tensor represents the cell states and | ||
/// the second tensor represents the hidden states for each sequence element. | ||
/// Both output tensors have the shape `[batch_size, sequence_length, hidden_size]`. | ||
pub fn forward( | ||
&self, | ||
batched_input: Tensor<B, 3>, | ||
state: Option<(Tensor<B, 2>, Tensor<B, 2>)>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I quite like that idea, especially for the return type!
The outputs of both The script I used to generate tests for Bidirectional LSTM can be found here. The tests for the wgpu backend still failed, possibly due to some data race issues in the wgpu backend? |
@nathanielsimard @louisfd I changed the batch size to 2, and the test passed. This seems to be a bug with the Wgpu backend, so I've opened an issue #1656. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No more comments from my side, great job! We will investigate the wgpu test problem, but it's quite unreleated.
We will merge it, once we have an approval from @laggui |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job! 😄 I like the addition of LstmState
and the returned output.
Only a few minor changes, and then we can merge 🎉
crates/burn-core/src/nn/rnn/lstm.rs
Outdated
/// returns hidden state for each element in a sequence (i.e., across `seq_length`) and a final state, | ||
/// producing 3-dimensional tensors where the dimensions represent `[batch_size, sequence_length, hidden_size]`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we document the return types, I think we can simplify the docstring here to:
Applies the forward pass on the input tensor. This LSTM implementation returns the state for each element in a sequence (i.e., across seq_length
) and a final state.
(also removed the ambiguity with "returns hidden state", the state contains both the hidden and cell state).
crates/burn-core/src/nn/rnn/lstm.rs
Outdated
/// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation | ||
/// returns hidden state for each element in a sequence (i.e., across `seq_length`) and a final state, | ||
/// producing 3-dimensional tensors where the dimensions represent `[batch_size, sequence_length, hidden_size * 2]`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as with the Lstm
forward docstring. We can simplify to:
Applies the forward pass on the input tensor. This Bidirectional LSTM implementation returns the state for each element in a sequence (i.e., across seq_length
) and a final state.
crates/burn-core/src/nn/rnn/lstm.rs
Outdated
.select(0, Tensor::arange(0..1, &device)) | ||
.squeeze(0); | ||
cell_state | ||
// let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dead comment, we can remove 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀 🚀
I need bidirectional lstm in CRNN model.
Checklist
run-checks all
script has been executed.