Skip to content

Commit

Permalink
Implement bidirectional LSTM (#1035)
Browse files Browse the repository at this point in the history
* resolve conflict

* move `gate_product` to `GateController`

* BiLstm needs to use its own initializer when init

* resolve conflicts

* add some comments

* improve doc

* correct the description of GateController

* fix fmt

* add `LstmState`

* add test for state

* set batch 2 in bilstm test

* resolve conflict

* fix

* fix doc

* change the batch size back to 1

* change the batch size back to 1

* modify docstring; delete dead comment
  • Loading branch information
wcshds authored Apr 26, 2024
1 parent 6ae3926 commit b387829
Show file tree
Hide file tree
Showing 3 changed files with 461 additions and 137 deletions.
2 changes: 1 addition & 1 deletion burn-book/src/building-blocks/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ Burn comes with built-in modules that you can use to build your own modules.
| Burn API | PyTorch Equivalent |
|------------------|------------------------|
| `Gru` | `nn.GRU` |
| `Lstm` | `nn.LSTM` |
| `Lstm`/`BiLstm` | `nn.LSTM` |
| `GateController` | _No direct equivalent_ |

### Transformer
Expand Down
26 changes: 19 additions & 7 deletions crates/burn-core/src/nn/rnn/gate_controller.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
use crate as burn;

use crate::module::Module;
use crate::nn::Initializer;
use crate::nn::Linear;
use crate::nn::LinearConfig;
use burn_tensor::backend::Backend;
use crate::nn::{Initializer, Linear, LinearConfig};
use burn_tensor::{backend::Backend, Tensor};

/// A GateController represents a gate in an LSTM cell. An
/// LSTM cell generally contains three gates: an input gate,
/// forget gate, and cell gate.
/// forget gate, and output gate. Additionally, cell gate
/// is just used to compute the cell state.
///
/// An Lstm gate is modeled as two linear transformations.
/// The results of these transformations are used to calculate
/// the gate's output.
#[derive(Module, Debug)]
pub struct GateController<B: Backend> {
/// Represents the affine transformation applied to input vector
pub(crate) input_transform: Linear<B>,
pub input_transform: Linear<B>,
/// Represents the affine transformation applied to the hidden state
pub(crate) hidden_transform: Linear<B>,
pub hidden_transform: Linear<B>,
}

impl<B: Backend> GateController<B> {
Expand Down Expand Up @@ -48,6 +47,19 @@ impl<B: Backend> GateController<B> {
}
}

/// Helper function for performing weighted matrix product for a gate and adds
/// bias, if any.
///
/// Mathematically, performs `Wx*X + Wh*H + b`, where:
/// Wx = weight matrix for the connection to input vector X
/// Wh = weight matrix for the connection to hidden state H
/// X = input vector
/// H = hidden state
/// b = bias terms
pub fn gate_product(&self, input: Tensor<B, 2>, hidden: Tensor<B, 2>) -> Tensor<B, 2> {
self.input_transform.forward(input) + self.hidden_transform.forward(hidden)
}

/// Used to initialize a gate controller with known weight layers,
/// allowing for predictable behavior. Used only for testing in
/// lstm.
Expand Down
Loading

0 comments on commit b387829

Please sign in to comment.