From fc49a970c7d762cc2212ab485eddb3bbc0c63415 Mon Sep 17 00:00:00 2001 From: wcshds Date: Thu, 21 Dec 2023 11:31:28 +0800 Subject: [PATCH 01/17] resolve conflict --- burn-core/src/nn/rnn/lstm.rs | 404 ++++++++++++++++++++++++++++++----- 1 file changed, 352 insertions(+), 52 deletions(-) diff --git a/burn-core/src/nn/rnn/lstm.rs b/burn-core/src/nn/rnn/lstm.rs index 2635c02f06..48c8ff32cb 100644 --- a/burn-core/src/nn/rnn/lstm.rs +++ b/burn-core/src/nn/rnn/lstm.rs @@ -2,7 +2,6 @@ use crate as burn; use crate::config::Config; use crate::module::Module; -use crate::nn::rnn::gate_controller; use crate::nn::Initializer; use crate::nn::LinearConfig; use crate::tensor::backend::Backend; @@ -46,40 +45,21 @@ impl LstmConfig { pub fn init(&self, device: &B::Device) -> Lstm { let d_output = self.d_hidden; - let input_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - device, - ); - let forget_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - device, - ); - let output_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - device, - ); - let cell_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - device, - ); + let new_gate = || { + GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + device, + ) + }; Lstm { - input_gate, - forget_gate, - output_gate, - cell_gate, + input_gate: new_gate(), + forget_gate: new_gate(), + output_gate: new_gate(), + cell_gate: new_gate(), d_hidden: self.d_hidden, } } @@ -94,19 +74,10 @@ impl LstmConfig { }; Lstm { - input_gate: gate_controller::GateController::new_with( - &linear_config, - record.input_gate, - ), - forget_gate: gate_controller::GateController::new_with( - &linear_config, - record.forget_gate, - ), - output_gate: gate_controller::GateController::new_with( - &linear_config, - record.output_gate, - ), - cell_gate: gate_controller::GateController::new_with(&linear_config, record.cell_gate), + input_gate: GateController::new_with(&linear_config, record.input_gate), + forget_gate: GateController::new_with(&linear_config, record.forget_gate), + output_gate: GateController::new_with(&linear_config, record.output_gate), + cell_gate: GateController::new_with(&linear_config, record.cell_gate), d_hidden: self.d_hidden, } } @@ -132,8 +103,26 @@ impl Lstm { batched_input: Tensor, state: Option<(Tensor, Tensor)>, ) -> (Tensor, Tensor) { - let [batch_size, seq_length, _] = batched_input.shape().dims; - let device = &batched_input.device(); + let device = batched_input.clone().device(); + let [batch_size, seq_length, _] = batched_input.dims(); + + self.forward_iter( + batched_input.iter_dim(1).zip(0..seq_length), + state, + batch_size, + seq_length, + &device, + ) + } + + fn forward_iter, usize)>>( + &self, + input_timestep_iter: I, + state: Option<(Tensor, Tensor)>, + batch_size: usize, + seq_length: usize, + device: &B::Device, + ) -> (Tensor, Tensor) { let mut batched_cell_state = Tensor::zeros([batch_size, seq_length, self.d_hidden], device); let mut batched_hidden_state = Tensor::zeros([batch_size, seq_length, self.d_hidden], device); @@ -146,7 +135,7 @@ impl Lstm { ), }; - for (t, input_t) in batched_input.iter_dim(1).enumerate() { + for (input_t, t) in input_timestep_iter { let input_t = input_t.squeeze(1); // f(orget)g(ate) tensors let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate); @@ -226,6 +215,132 @@ impl Lstm { } } +/// The configuration for a [bidirectional lstm](BiLstm) module. +#[derive(Config)] +pub struct BiLstmConfig { + /// The size of the input features. + pub d_input: usize, + /// The size of the hidden state. + pub d_hidden: usize, + /// If a bias should be applied during the BiLstm transformation. + pub bias: bool, + /// BiLstm initializer + #[config(default = "Initializer::XavierNormal{gain:1.0}")] + pub initializer: Initializer, +} + +/// The Lstm module. This implementation is for bidirectional Lstm. +#[derive(Module, Debug)] +pub struct BiLstm { + forward: Lstm, + reverse: Lstm, + d_hidden: usize, +} + +impl BiLstmConfig { + /// Initialize a new [bidirectional LSTM](BiLstm) module on an automatically selected device. + pub fn init_devauto(&self) -> BiLstm { + let device = B::Device::default(); + self.init(&device) + } + + /// Initialize a new [bidirectional LSTM](BiLstm) module. + pub fn init(&self, device: &B::Device) -> BiLstm { + BiLstm { + forward: LstmConfig::new(self.d_input, self.d_hidden, self.bias).init(device), + reverse: LstmConfig::new(self.d_input, self.d_hidden, self.bias).init(device), + d_hidden: self.d_hidden, + } + } + + /// Initialize a new [bidirectional LSTM](BiLstm) module with a [record](BiLstmRecord). + pub fn init_with(&self, record: BiLstmRecord) -> BiLstm { + BiLstm { + forward: LstmConfig::new(self.d_input, self.d_hidden, self.bias) + .init_with(record.forward), + reverse: LstmConfig::new(self.d_input, self.d_hidden, self.bias) + .init_with(record.reverse), + d_hidden: self.d_hidden, + } + } +} + +impl BiLstm { + /// Applies the forward pass on the input tensor. This LSTM implementation + /// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`), + /// producing 3-dimensional tensors where the dimensions represent [batch_size, sequence_length, hidden_size * 2]. + /// + /// Parameters: + /// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. + /// state: An optional tuple of tensors representing the initial cell state and hidden state. + /// Each state tensor has shape [2, batch_size, hidden_size]. + /// If no initial state is provided, these tensors are initialized to zeros. + /// + /// Returns: + /// A tuple of tensors, where the first tensor represents the cell states and + /// the second tensor represents the hidden states for each sequence element. + /// Both output tensors have the shape [batch_size, sequence_length, hidden_size * 2]. + pub fn forward( + &self, + batched_input: Tensor, + state: Option<(Tensor, Tensor)>, + ) -> (Tensor, Tensor) { + let device = batched_input.clone().device(); + let [batch_size, seq_length, _] = batched_input.shape().dims; + + let (cell_state_forward, hidden_state_forward, cell_state_reverse, hidden_state_reverse) = + match state { + Some((cell_state, hidden_state)) => ( + cell_state + .clone() + .slice([0..1, 0..batch_size, 0..self.d_hidden]) + .squeeze(0), + hidden_state + .clone() + .slice([0..1, 0..batch_size, 0..self.d_hidden]) + .squeeze(0), + cell_state + .slice([1..2, 0..batch_size, 0..self.d_hidden]) + .squeeze(0), + hidden_state + .slice([1..2, 0..batch_size, 0..self.d_hidden]) + .squeeze(0), + ), + None => ( + Tensor::zeros([batch_size, self.d_hidden], &device), + Tensor::zeros([batch_size, self.d_hidden], &device), + Tensor::zeros([batch_size, self.d_hidden], &device), + Tensor::zeros([batch_size, self.d_hidden], &device), + ), + }; + + let (batched_cell_state_forward, batched_hidden_state_forward) = self.forward.forward( + batched_input.clone(), + Some((cell_state_forward, hidden_state_forward)), + ); + + // reverse direction + let (batched_cell_state_reverse, batched_hidden_state_reverse) = self.reverse.forward_iter( + batched_input.iter_dim(1).rev().zip((0..seq_length).rev()), + Some((cell_state_reverse, hidden_state_reverse)), + batch_size, + seq_length, + &device, + ); + + let batched_cell_state = Tensor::cat( + [batched_cell_state_forward, batched_cell_state_reverse].to_vec(), + 2, + ); + let batched_hidden_state = Tensor::cat( + [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(), + 2, + ); + + (batched_cell_state, batched_hidden_state) + } +} + #[cfg(test)] mod tests { use super::*; @@ -263,7 +378,7 @@ mod tests { TestBackend::seed(0); let config = LstmConfig::new(1, 1, false); let device = Default::default(); - let mut lstm = config.init::(&device); + let mut lstm = config.init_devauto::(); fn create_gate_controller( weights: f32, @@ -278,7 +393,7 @@ mod tests { weight: Param::from(Tensor::from_data(Data::from([[weights]]), device)), bias: Some(Param::from(Tensor::from_data(Data::from([biases]), device))), }; - gate_controller::GateController::create_with_weights( + GateController::create_with_weights( d_input, d_output, bias, @@ -346,7 +461,7 @@ mod tests { #[test] fn test_batched_forward_pass() { let device = Default::default(); - let lstm = LstmConfig::new(64, 1024, true).init(&device); + let lstm = LstmConfig::new(64, 1024, true).init::(&device); let batched_input = Tensor::::random([8, 10, 64], Distribution::Default, &device); @@ -355,4 +470,189 @@ mod tests { assert_eq!(cell_state.shape().dims, [8, 10, 1024]); assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); } + + #[test] + fn test_bidirectional() { + TestBackend::seed(0); + let config = BiLstmConfig::new(2, 4, true); + let device = Default::default(); + let mut lstm = config.init::(&device); + + fn create_gate_controller( + input_weights: [[f32; D1]; D2], + input_biases: [f32; D1], + hidden_weights: [[f32; D1]; D1], + hidden_biases: [f32; D1], + device: &::Device, + ) -> GateController { + let d_input = input_weights[0].len(); + let d_output = input_weights.len(); + + let input_record = LinearRecord { + weight: Param::from(Tensor::from_data(Data::from(input_weights), device)), + bias: Some(Param::from(Tensor::from_data( + Data::from(input_biases), + device, + ))), + }; + let hidden_record = LinearRecord { + weight: Param::from(Tensor::from_data(Data::from(hidden_weights), device)), + bias: Some(Param::from(Tensor::from_data( + Data::from(hidden_biases), + device, + ))), + }; + GateController::create_with_weights( + d_input, + d_output, + true, + Initializer::XavierUniform { gain: 1.0 }, + input_record, + hidden_record, + ) + } + + let input = Tensor::::from_data( + Data::from([[[-0.131, -1.591], [1.378, -1.867], [0.397, 0.047]]]), + &device, + ); + + lstm.forward.input_gate = create_gate_controller( + [[0.078, 0.234, 0.398, 0.333], [0.452, 0.124, -0.042, -0.152]], + [0.196, 0.094, -0.270, 0.008], + [ + [0.054, 0.057, 0.282, 0.021], + [0.065, -0.303, -0.499, 0.069], + [-0.007, 0.226, -0.131, -0.307], + [-0.025, 0.072, 0.197, 0.129], + ], + [0.278, -0.211, 0.435, -0.162], + &device, + ); + + lstm.forward.forget_gate = create_gate_controller( + [ + [-0.187, -0.201, 0.078, -0.314], + [0.169, 0.229, 0.218, 0.466], + ], + [0.320, -0.135, -0.301, 0.180], + [ + [0.392, -0.028, 0.470, -0.025], + [-0.284, -0.286, -0.211, -0.001], + [0.245, -0.259, 0.102, -0.379], + [-0.096, -0.462, 0.170, 0.232], + ], + [0.458, 0.039, 0.287, -0.327], + &device, + ); + + lstm.forward.cell_gate = create_gate_controller( + [ + [-0.216, 0.256, 0.369, 0.160], + [0.453, -0.238, 0.306, -0.411], + ], + [0.360, 0.001, 0.303, 0.438], + [ + [0.356, -0.185, 0.494, 0.325], + [0.111, -0.388, 0.051, -0.150], + [-0.434, 0.296, -0.185, 0.290], + [-0.010, -0.023, 0.460, 0.238], + ], + [0.268, -0.136, -0.452, 0.471], + &device, + ); + + lstm.forward.output_gate = create_gate_controller( + [[0.235, -0.132, 0.049, 0.157], [-0.280, 0.229, 0.102, 0.448]], + [0.237, -0.396, -0.134, -0.047], + [ + [-0.243, 0.196, 0.087, 0.163], + [0.138, -0.247, -0.401, -0.462], + [0.030, -0.263, 0.473, 0.259], + [-0.413, -0.173, -0.206, 0.324], + ], + [-0.364, -0.023, 0.215, -0.401], + &device, + ); + + lstm.reverse.input_gate = create_gate_controller( + [ + [0.220, -0.191, 0.062, -0.443], + [-0.112, -0.353, -0.443, 0.080], + ], + [-0.418, 0.209, 0.297, -0.429], + [ + [-0.121, -0.408, 0.132, -0.450], + [0.231, 0.154, -0.294, 0.022], + [0.378, 0.239, 0.176, -0.361], + [0.480, 0.427, -0.156, -0.137], + ], + [0.267, -0.474, -0.393, 0.190], + &device, + ); + + lstm.reverse.forget_gate = create_gate_controller( + [ + [0.151, 0.148, 0.341, -0.112], + [-0.368, -0.476, 0.003, 0.083], + ], + [-0.489, -0.361, -0.035, 0.328], + [ + [0.460, -0.124, -0.377, -0.033], + [-0.296, 0.162, 0.456, -0.271], + [0.320, 0.235, 0.383, 0.423], + [-0.167, 0.332, -0.493, 0.086], + ], + [-0.425, 0.219, 0.294, -0.075], + &device, + ); + + lstm.reverse.cell_gate = create_gate_controller( + [ + [-0.451, 0.285, 0.305, -0.344], + [-0.399, 0.344, -0.022, 0.263], + ], + [0.215, -0.028, 0.097, 0.197], + [ + [0.072, 0.106, -0.030, 0.056], + [-0.278, -0.256, -0.129, -0.252], + [-0.305, 0.219, 0.045, -0.123], + [0.224, 0.011, -0.199, -0.362], + ], + [0.086, 0.466, -0.152, 0.353], + &device, + ); + + lstm.reverse.output_gate = create_gate_controller( + [ + [0.057, -0.357, 0.031, 0.235], + [-0.492, -0.109, -0.316, -0.422], + ], + [0.233, 0.053, 0.162, -0.465], + [ + [0.240, 0.223, -0.188, -0.181], + [-0.427, -0.390, -0.176, -0.338], + [-0.158, 0.152, -0.105, 0.106], + [-0.223, -0.186, -0.059, 0.319], + ], + [0.207, 0.295, 0.361, 0.029], + &device, + ); + + let expected_result = Data::from([[ + [ + -0.01604, 0.02718, -0.14959, 0.10219, 0.34534, 0.06087, 0.07809, 0.01806, + ], + [ + -0.13098, 0.07478, -0.10684, 0.15549, 0.19981, 0.12038, 0.19815, -0.02509, + ], + [ + 0.09250, 0.03285, -0.04502, 0.24134, 0.03017, 0.11454, 0.01943, 0.06517, + ], + ]]); + + let (_, hidden_state) = lstm.forward(input, None); + + hidden_state.to_data().assert_approx_eq(&expected_result, 3) + } } From eb92f78bfa8a98089ef12fe2d3f257ce61ec2387 Mon Sep 17 00:00:00 2001 From: wcshds Date: Fri, 12 Jan 2024 18:59:07 +0800 Subject: [PATCH 02/17] move `gate_product` to `GateController` --- burn-core/src/nn/rnn/gate_controller.rs | 14 +++++ burn-core/src/nn/rnn/lstm.rs | 75 ++++++------------------- 2 files changed, 32 insertions(+), 57 deletions(-) diff --git a/burn-core/src/nn/rnn/gate_controller.rs b/burn-core/src/nn/rnn/gate_controller.rs index b5cc373bf8..1c9e2c7d44 100644 --- a/burn-core/src/nn/rnn/gate_controller.rs +++ b/burn-core/src/nn/rnn/gate_controller.rs @@ -5,6 +5,7 @@ use crate::nn::Initializer; use crate::nn::Linear; use crate::nn::LinearConfig; use burn_tensor::backend::Backend; +use burn_tensor::Tensor; /// A GateController represents a gate in an LSTM cell. An /// LSTM cell generally contains three gates: an input gate, @@ -59,6 +60,19 @@ impl GateController { } } + /// Helper function for performing weighted matrix product for a gate and adds + /// bias, if any. + /// + /// Mathematically, performs `Wx*X + Wh*H + b`, where: + /// Wx = weight matrix for the connection to input vector X + /// Wh = weight matrix for the connection to hidden state H + /// X = input vector + /// H = hidden state + /// b = bias terms + pub fn gate_product(&self, input: Tensor, hidden: Tensor) -> Tensor { + self.input_transform.forward(input) + self.hidden_transform.forward(hidden) + } + /// Used to initialize a gate controller with known weight layers, /// allowing for predictable behavior. Used only for testing in /// lstm. diff --git a/burn-core/src/nn/rnn/lstm.rs b/burn-core/src/nn/rnn/lstm.rs index 03716d9429..f5a89c5e9c 100644 --- a/burn-core/src/nn/rnn/lstm.rs +++ b/burn-core/src/nn/rnn/lstm.rs @@ -132,28 +132,34 @@ impl Lstm { for (input_t, t) in input_timestep_iter { let input_t = input_t.squeeze(1); // f(orget)g(ate) tensors - let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate); + let biased_fg_input_sum = self + .forget_gate + .gate_product(input_t.clone(), hidden_state.clone()); let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state // i(nput)g(ate) tensors - let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate); + let biased_ig_input_sum = self + .input_gate + .gate_product(input_t.clone(), hidden_state.clone()); let add_values = activation::sigmoid(biased_ig_input_sum); // o(output)g(ate) tensors - let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate); + let biased_og_input_sum = self + .output_gate + .gate_product(input_t.clone(), hidden_state.clone()); let output_values = activation::sigmoid(biased_og_input_sum); // c(ell)g(ate) tensors - let biased_cg_input_sum = self.gate_product(&input_t, &hidden_state, &self.cell_gate); + let biased_cg_input_sum = self + .cell_gate + .gate_product(input_t.clone(), hidden_state.clone()); let candidate_cell_values = biased_cg_input_sum.tanh(); cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values; hidden_state = output_values * cell_state.clone().tanh(); - let unsqueezed_shape = [cell_state.shape().dims[0], 1, cell_state.shape().dims[1]]; - - let unsqueezed_cell_state = cell_state.clone().reshape(unsqueezed_shape); - let unsqueezed_hidden_state = hidden_state.clone().reshape(unsqueezed_shape); + let unsqueezed_cell_state = cell_state.clone().unsqueeze_dim(1); + let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1); // store the state for this timestep batched_cell_state = batched_cell_state.slice_assign( @@ -168,45 +174,6 @@ impl Lstm { (batched_cell_state, batched_hidden_state) } - - /// Helper function for performing weighted matrix product for a gate and adds - /// bias, if any. - /// - /// Mathematically, performs `Wx*X + Wh*H + b`, where: - /// Wx = weight matrix for the connection to input vector X - /// Wh = weight matrix for the connection to hidden state H - /// X = input vector - /// H = hidden state - /// b = bias terms - fn gate_product( - &self, - input: &Tensor, - hidden: &Tensor, - gate: &GateController, - ) -> Tensor { - let input_product = input.clone().matmul(gate.input_transform.weight.val()); - let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val()); - - let input_bias = gate - .input_transform - .bias - .as_ref() - .map(|bias_param| bias_param.val()); - let hidden_bias = gate - .hidden_transform - .bias - .as_ref() - .map(|bias_param| bias_param.val()); - - match (input_bias, hidden_bias) { - (Some(input_bias), Some(hidden_bias)) => { - input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() - } - (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, - (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), - (None, None) => input_product + hidden_product, - } - } } /// The configuration for a [bidirectional lstm](BiLstm) module. @@ -232,12 +199,6 @@ pub struct BiLstm { } impl BiLstmConfig { - /// Initialize a new [bidirectional LSTM](BiLstm) module on an automatically selected device. - pub fn init_devauto(&self) -> BiLstm { - let device = B::Device::default(); - self.init(&device) - } - /// Initialize a new [bidirectional LSTM](BiLstm) module. pub fn init(&self, device: &B::Device) -> BiLstm { BiLstm { @@ -347,7 +308,7 @@ mod tests { let config = LstmConfig::new(5, 5, false) .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 }); - let lstm = config.init::(&Default::default()); + let lstm = config.init(&Default::default()); let gate_to_data = |gate: GateController| gate.input_transform.weight.val().to_data(); @@ -372,7 +333,7 @@ mod tests { TestBackend::seed(0); let config = LstmConfig::new(1, 1, false); let device = Default::default(); - let mut lstm = config.init_devauto::(); + let mut lstm = config.init(&device); fn create_gate_controller( weights: f32, @@ -455,7 +416,7 @@ mod tests { #[test] fn test_batched_forward_pass() { let device = Default::default(); - let lstm = LstmConfig::new(64, 1024, true).init::(&device); + let lstm = LstmConfig::new(64, 1024, true).init(&device); let batched_input = Tensor::::random([8, 10, 64], Distribution::Default, &device); @@ -470,7 +431,7 @@ mod tests { TestBackend::seed(0); let config = BiLstmConfig::new(2, 4, true); let device = Default::default(); - let mut lstm = config.init::(&device); + let mut lstm = config.init(&device); fn create_gate_controller( input_weights: [[f32; D1]; D2], From 2f8debf0c8b9a815bffa9d2efa4e4bc898532a70 Mon Sep 17 00:00:00 2001 From: wcshds Date: Sat, 13 Jan 2024 02:04:59 +0800 Subject: [PATCH 03/17] BiLstm needs to use its own initializer when init --- burn-core/src/nn/rnn/lstm.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/burn-core/src/nn/rnn/lstm.rs b/burn-core/src/nn/rnn/lstm.rs index f5a89c5e9c..94cc49b582 100644 --- a/burn-core/src/nn/rnn/lstm.rs +++ b/burn-core/src/nn/rnn/lstm.rs @@ -202,8 +202,12 @@ impl BiLstmConfig { /// Initialize a new [bidirectional LSTM](BiLstm) module. pub fn init(&self, device: &B::Device) -> BiLstm { BiLstm { - forward: LstmConfig::new(self.d_input, self.d_hidden, self.bias).init(device), - reverse: LstmConfig::new(self.d_input, self.d_hidden, self.bias).init(device), + forward: LstmConfig::new(self.d_input, self.d_hidden, self.bias) + .with_initializer(self.initializer.clone()) + .init(device), + reverse: LstmConfig::new(self.d_input, self.d_hidden, self.bias) + .with_initializer(self.initializer.clone()) + .init(device), d_hidden: self.d_hidden, } } @@ -212,8 +216,10 @@ impl BiLstmConfig { pub fn init_with(&self, record: BiLstmRecord) -> BiLstm { BiLstm { forward: LstmConfig::new(self.d_input, self.d_hidden, self.bias) + .with_initializer(self.initializer.clone()) .init_with(record.forward), reverse: LstmConfig::new(self.d_input, self.d_hidden, self.bias) + .with_initializer(self.initializer.clone()) .init_with(record.reverse), d_hidden: self.d_hidden, } From 0abe2a06db283a7954e4fdc9f621d604040ffad4 Mon Sep 17 00:00:00 2001 From: wcshds Date: Tue, 16 Apr 2024 16:58:48 +0800 Subject: [PATCH 04/17] resolve conflicts --- .../burn-core/src/nn/rnn/gate_controller.rs | 19 +- crates/burn-core/src/nn/rnn/lstm.rs | 425 ++++++++++++++---- 2 files changed, 356 insertions(+), 88 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/gate_controller.rs b/crates/burn-core/src/nn/rnn/gate_controller.rs index 9d1c696e49..1655f87eae 100644 --- a/crates/burn-core/src/nn/rnn/gate_controller.rs +++ b/crates/burn-core/src/nn/rnn/gate_controller.rs @@ -1,10 +1,8 @@ use crate as burn; use crate::module::Module; -use crate::nn::Initializer; -use crate::nn::Linear; -use crate::nn::LinearConfig; -use burn_tensor::backend::Backend; +use crate::nn::{Initializer, Linear, LinearConfig}; +use burn_tensor::{backend::Backend, Tensor}; /// A GateController represents a gate in an LSTM cell. An /// LSTM cell generally contains three gates: an input gate, @@ -48,6 +46,19 @@ impl GateController { } } + /// Helper function for performing weighted matrix product for a gate and adds + /// bias, if any. + /// + /// Mathematically, performs `Wx*X + Wh*H + b`, where: + /// Wx = weight matrix for the connection to input vector X + /// Wh = weight matrix for the connection to hidden state H + /// X = input vector + /// H = hidden state + /// b = bias terms + pub fn gate_product(&self, input: Tensor, hidden: Tensor) -> Tensor { + self.input_transform.forward(input) + self.hidden_transform.forward(hidden) + } + /// Used to initialize a gate controller with known weight layers, /// allowing for predictable behavior. Used only for testing in /// lstm. diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 1c6f874131..599710c928 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -2,14 +2,12 @@ use crate as burn; use crate::config::Config; use crate::module::Module; -use crate::nn::rnn::gate_controller; +use crate::nn::rnn::gate_controller::GateController; use crate::nn::Initializer; use crate::tensor::backend::Backend; use crate::tensor::Tensor; use burn_tensor::activation; -use super::gate_controller::GateController; - /// The configuration for a [lstm](Lstm) module. #[derive(Config)] pub struct LstmConfig { @@ -27,10 +25,14 @@ pub struct LstmConfig { /// The Lstm module. This implementation is for a unidirectional, stateless, Lstm. #[derive(Module, Debug)] pub struct Lstm { - input_gate: GateController, - forget_gate: GateController, - output_gate: GateController, - cell_gate: GateController, + /// input gate + pub input_gate: GateController, + /// forget gate + pub forget_gate: GateController, + /// output gate + pub output_gate: GateController, + /// cell gate + pub cell_gate: GateController, d_hidden: usize, } @@ -39,40 +41,21 @@ impl LstmConfig { pub fn init(&self, device: &B::Device) -> Lstm { let d_output = self.d_hidden; - let input_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - device, - ); - let forget_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - device, - ); - let output_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - device, - ); - let cell_gate = gate_controller::GateController::new( - self.d_input, - d_output, - self.bias, - self.initializer.clone(), - device, - ); + let new_gate = || { + GateController::new( + self.d_input, + d_output, + self.bias, + self.initializer.clone(), + device, + ) + }; Lstm { - input_gate, - forget_gate, - output_gate, - cell_gate, + input_gate: new_gate(), + forget_gate: new_gate(), + output_gate: new_gate(), + cell_gate: new_gate(), d_hidden: self.d_hidden, } } @@ -98,8 +81,26 @@ impl Lstm { batched_input: Tensor, state: Option<(Tensor, Tensor)>, ) -> (Tensor, Tensor) { - let [batch_size, seq_length, _] = batched_input.shape().dims; - let device = &batched_input.device(); + let device = batched_input.device(); + let [batch_size, seq_length, _] = batched_input.dims(); + + self.forward_iter( + batched_input.iter_dim(1).zip(0..seq_length), + state, + batch_size, + seq_length, + &device, + ) + } + + fn forward_iter, usize)>>( + &self, + input_timestep_iter: I, + state: Option<(Tensor, Tensor)>, + batch_size: usize, + seq_length: usize, + device: &B::Device, + ) -> (Tensor, Tensor) { let mut batched_cell_state = Tensor::zeros([batch_size, seq_length, self.d_hidden], device); let mut batched_hidden_state = Tensor::zeros([batch_size, seq_length, self.d_hidden], device); @@ -112,31 +113,37 @@ impl Lstm { ), }; - for (t, input_t) in batched_input.iter_dim(1).enumerate() { + for (input_t, t) in input_timestep_iter { let input_t = input_t.squeeze(1); // f(orget)g(ate) tensors - let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate); + let biased_fg_input_sum = self + .forget_gate + .gate_product(input_t.clone(), hidden_state.clone()); let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state // i(nput)g(ate) tensors - let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate); + let biased_ig_input_sum = self + .input_gate + .gate_product(input_t.clone(), hidden_state.clone()); let add_values = activation::sigmoid(biased_ig_input_sum); // o(output)g(ate) tensors - let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate); + let biased_og_input_sum = self + .output_gate + .gate_product(input_t.clone(), hidden_state.clone()); let output_values = activation::sigmoid(biased_og_input_sum); // c(ell)g(ate) tensors - let biased_cg_input_sum = self.gate_product(&input_t, &hidden_state, &self.cell_gate); + let biased_cg_input_sum = self + .cell_gate + .gate_product(input_t.clone(), hidden_state.clone()); let candidate_cell_values = biased_cg_input_sum.tanh(); cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values; hidden_state = output_values * cell_state.clone().tanh(); - let unsqueezed_shape = [cell_state.shape().dims[0], 1, cell_state.shape().dims[1]]; - - let unsqueezed_cell_state = cell_state.clone().reshape(unsqueezed_shape); - let unsqueezed_hidden_state = hidden_state.clone().reshape(unsqueezed_shape); + let unsqueezed_cell_state = cell_state.clone().unsqueeze_dim(1); + let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1); // store the state for this timestep batched_cell_state = batched_cell_state.slice_assign( @@ -151,44 +158,115 @@ impl Lstm { (batched_cell_state, batched_hidden_state) } +} + +/// The configuration for a [Bidirectional LSTM](BiLstm) module. +#[derive(Config)] +pub struct BiLstmConfig { + /// The size of the input features. + pub d_input: usize, + /// The size of the hidden state. + pub d_hidden: usize, + /// If a bias should be applied during the BiLstm transformation. + pub bias: bool, + /// BiLstm initializer + #[config(default = "Initializer::XavierNormal{gain:1.0}")] + pub initializer: Initializer, +} + +/// The BiLstm module. This implementation is for Bidirectional LSTM. +#[derive(Module, Debug)] +pub struct BiLstm { + forward: Lstm, + reverse: Lstm, + d_hidden: usize, +} + +impl BiLstmConfig { + /// Initialize a new [Bidirectional LSTM](BiLstm) module. + pub fn init(&self, device: &B::Device) -> BiLstm { + BiLstm { + forward: LstmConfig::new(self.d_input, self.d_hidden, self.bias) + .with_initializer(self.initializer.clone()) + .init(device), + reverse: LstmConfig::new(self.d_input, self.d_hidden, self.bias) + .with_initializer(self.initializer.clone()) + .init(device), + d_hidden: self.d_hidden, + } + } +} - /// Helper function for performing weighted matrix product for a gate and adds - /// bias, if any. +impl BiLstm { + /// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation + /// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`), + /// producing 3-dimensional tensors where the dimensions represent [batch_size, sequence_length, hidden_size * 2]. + /// + /// Parameters: + /// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. + /// state: An optional tuple of tensors representing the initial cell state and hidden state. + /// Each state tensor has shape [2, batch_size, hidden_size]. + /// If no initial state is provided, these tensors are initialized to zeros. /// - /// Mathematically, performs `Wx*X + Wh*H + b`, where: - /// Wx = weight matrix for the connection to input vector X - /// Wh = weight matrix for the connection to hidden state H - /// X = input vector - /// H = hidden state - /// b = bias terms - fn gate_product( + /// Returns: + /// A tuple of tensors, where the first tensor represents the cell states and + /// the second tensor represents the hidden states for each sequence element. + /// Both output tensors have the shape [batch_size, sequence_length, hidden_size * 2]. + pub fn forward( &self, - input: &Tensor, - hidden: &Tensor, - gate: &GateController, - ) -> Tensor { - let input_product = input.clone().matmul(gate.input_transform.weight.val()); - let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val()); - - let input_bias = gate - .input_transform - .bias - .as_ref() - .map(|bias_param| bias_param.val()); - let hidden_bias = gate - .hidden_transform - .bias - .as_ref() - .map(|bias_param| bias_param.val()); + batched_input: Tensor, + state: Option<(Tensor, Tensor)>, + ) -> (Tensor, Tensor) { + let device = batched_input.clone().device(); + let [batch_size, seq_length, _] = batched_input.shape().dims; - match (input_bias, hidden_bias) { - (Some(input_bias), Some(hidden_bias)) => { - input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() + let [state_forward, state_reverse] = match state { + Some((cell_state, hidden_state)) => { + let cell_state_forward = cell_state + .clone() + .slice([0..1, 0..batch_size, 0..self.d_hidden]) + .squeeze(0); + let hidden_state_forward = hidden_state + .clone() + .slice([0..1, 0..batch_size, 0..self.d_hidden]) + .squeeze(0); + let cell_state_reverse = cell_state + .slice([1..2, 0..batch_size, 0..self.d_hidden]) + .squeeze(0); + let hidden_state_reverse = hidden_state + .slice([1..2, 0..batch_size, 0..self.d_hidden]) + .squeeze(0); + + [ + Some((cell_state_forward, hidden_state_forward)), + Some((cell_state_reverse, hidden_state_reverse)), + ] } - (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, - (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), - (None, None) => input_product + hidden_product, - } + None => [None, None], + }; + + let (batched_cell_state_forward, batched_hidden_state_forward) = + self.forward.forward(batched_input.clone(), state_forward); + + // reverse direction + let (batched_cell_state_reverse, batched_hidden_state_reverse) = self.reverse.forward_iter( + batched_input.iter_dim(1).rev().zip((0..seq_length).rev()), + state_reverse, + batch_size, + seq_length, + &device, + ); + + let batched_cell_state = Tensor::cat( + [batched_cell_state_forward, batched_cell_state_reverse].to_vec(), + 2, + ); + let batched_hidden_state = Tensor::cat( + [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(), + 2, + ); + + (batched_cell_state, batched_hidden_state) } } @@ -251,7 +329,7 @@ mod tests { weight: Param::from_data(Data::from([[weights]]), device), bias: Some(Param::from_data(Data::from([biases]), device)), }; - gate_controller::GateController::create_with_weights( + GateController::create_with_weights( d_input, d_output, bias, @@ -353,4 +431,183 @@ mod tests { // Asserts that the gradients exist and are non-zero assert!(*some_gradient.any().into_data().value.first().unwrap()); } + + #[test] + fn test_bidirectional() { + TestBackend::seed(0); + let config = BiLstmConfig::new(2, 4, true); + let device = Default::default(); + let mut lstm = config.init(&device); + + fn create_gate_controller( + input_weights: [[f32; D1]; D2], + input_biases: [f32; D1], + hidden_weights: [[f32; D1]; D1], + hidden_biases: [f32; D1], + device: &::Device, + ) -> GateController { + let d_input = input_weights[0].len(); + let d_output = input_weights.len(); + + let input_record = LinearRecord { + weight: Param::from_data(Data::from(input_weights), device), + bias: Some(Param::from_data(Data::from(input_biases), device)), + }; + let hidden_record = LinearRecord { + weight: Param::from_data(Data::from(hidden_weights), device), + bias: Some(Param::from_data(Data::from(hidden_biases), device)), + }; + GateController::create_with_weights( + d_input, + d_output, + true, + Initializer::XavierUniform { gain: 1.0 }, + input_record, + hidden_record, + ) + } + + let input = Tensor::::from_data( + Data::from([[[-0.131, -1.591], [1.378, -1.867], [0.397, 0.047]]]), + &device, + ); + + lstm.forward.input_gate = create_gate_controller( + [[0.078, 0.234, 0.398, 0.333], [0.452, 0.124, -0.042, -0.152]], + [0.196, 0.094, -0.270, 0.008], + [ + [0.054, 0.057, 0.282, 0.021], + [0.065, -0.303, -0.499, 0.069], + [-0.007, 0.226, -0.131, -0.307], + [-0.025, 0.072, 0.197, 0.129], + ], + [0.278, -0.211, 0.435, -0.162], + &device, + ); + + lstm.forward.forget_gate = create_gate_controller( + [ + [-0.187, -0.201, 0.078, -0.314], + [0.169, 0.229, 0.218, 0.466], + ], + [0.320, -0.135, -0.301, 0.180], + [ + [0.392, -0.028, 0.470, -0.025], + [-0.284, -0.286, -0.211, -0.001], + [0.245, -0.259, 0.102, -0.379], + [-0.096, -0.462, 0.170, 0.232], + ], + [0.458, 0.039, 0.287, -0.327], + &device, + ); + + lstm.forward.cell_gate = create_gate_controller( + [ + [-0.216, 0.256, 0.369, 0.160], + [0.453, -0.238, 0.306, -0.411], + ], + [0.360, 0.001, 0.303, 0.438], + [ + [0.356, -0.185, 0.494, 0.325], + [0.111, -0.388, 0.051, -0.150], + [-0.434, 0.296, -0.185, 0.290], + [-0.010, -0.023, 0.460, 0.238], + ], + [0.268, -0.136, -0.452, 0.471], + &device, + ); + + lstm.forward.output_gate = create_gate_controller( + [[0.235, -0.132, 0.049, 0.157], [-0.280, 0.229, 0.102, 0.448]], + [0.237, -0.396, -0.134, -0.047], + [ + [-0.243, 0.196, 0.087, 0.163], + [0.138, -0.247, -0.401, -0.462], + [0.030, -0.263, 0.473, 0.259], + [-0.413, -0.173, -0.206, 0.324], + ], + [-0.364, -0.023, 0.215, -0.401], + &device, + ); + + lstm.reverse.input_gate = create_gate_controller( + [ + [0.220, -0.191, 0.062, -0.443], + [-0.112, -0.353, -0.443, 0.080], + ], + [-0.418, 0.209, 0.297, -0.429], + [ + [-0.121, -0.408, 0.132, -0.450], + [0.231, 0.154, -0.294, 0.022], + [0.378, 0.239, 0.176, -0.361], + [0.480, 0.427, -0.156, -0.137], + ], + [0.267, -0.474, -0.393, 0.190], + &device, + ); + + lstm.reverse.forget_gate = create_gate_controller( + [ + [0.151, 0.148, 0.341, -0.112], + [-0.368, -0.476, 0.003, 0.083], + ], + [-0.489, -0.361, -0.035, 0.328], + [ + [0.460, -0.124, -0.377, -0.033], + [-0.296, 0.162, 0.456, -0.271], + [0.320, 0.235, 0.383, 0.423], + [-0.167, 0.332, -0.493, 0.086], + ], + [-0.425, 0.219, 0.294, -0.075], + &device, + ); + + lstm.reverse.cell_gate = create_gate_controller( + [ + [-0.451, 0.285, 0.305, -0.344], + [-0.399, 0.344, -0.022, 0.263], + ], + [0.215, -0.028, 0.097, 0.197], + [ + [0.072, 0.106, -0.030, 0.056], + [-0.278, -0.256, -0.129, -0.252], + [-0.305, 0.219, 0.045, -0.123], + [0.224, 0.011, -0.199, -0.362], + ], + [0.086, 0.466, -0.152, 0.353], + &device, + ); + + lstm.reverse.output_gate = create_gate_controller( + [ + [0.057, -0.357, 0.031, 0.235], + [-0.492, -0.109, -0.316, -0.422], + ], + [0.233, 0.053, 0.162, -0.465], + [ + [0.240, 0.223, -0.188, -0.181], + [-0.427, -0.390, -0.176, -0.338], + [-0.158, 0.152, -0.105, 0.106], + [-0.223, -0.186, -0.059, 0.319], + ], + [0.207, 0.295, 0.361, 0.029], + &device, + ); + + let expected_result = Data::from([[ + [ + -0.01604, 0.02718, -0.14959, 0.10219, 0.34534, 0.06087, 0.07809, 0.01806, + ], + [ + -0.13098, 0.07478, -0.10684, 0.15549, 0.19981, 0.12038, 0.19815, -0.02509, + ], + [ + 0.09250, 0.03285, -0.04502, 0.24134, 0.03017, 0.11454, 0.01943, 0.06517, + ], + ]]); + + let (_, hidden_state) = lstm.forward(input, None); + + hidden_state.to_data().assert_approx_eq(&expected_result, 3) + } } From 102e2a87c288dce83db3f4384db837489d530eb5 Mon Sep 17 00:00:00 2001 From: wcshds Date: Tue, 16 Apr 2024 19:48:07 +0800 Subject: [PATCH 05/17] add some comments --- crates/burn-core/src/nn/rnn/gate_controller.rs | 4 ++-- crates/burn-core/src/nn/rnn/lstm.rs | 15 +++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/gate_controller.rs b/crates/burn-core/src/nn/rnn/gate_controller.rs index 1655f87eae..3ddc04e12c 100644 --- a/crates/burn-core/src/nn/rnn/gate_controller.rs +++ b/crates/burn-core/src/nn/rnn/gate_controller.rs @@ -14,9 +14,9 @@ use burn_tensor::{backend::Backend, Tensor}; #[derive(Module, Debug)] pub struct GateController { /// Represents the affine transformation applied to input vector - pub(crate) input_transform: Linear, + pub input_transform: Linear, /// Represents the affine transformation applied to the hidden state - pub(crate) hidden_transform: Linear, + pub hidden_transform: Linear, } impl GateController { diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 599710c928..d7f745ac53 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -25,13 +25,13 @@ pub struct LstmConfig { /// The Lstm module. This implementation is for a unidirectional, stateless, Lstm. #[derive(Module, Debug)] pub struct Lstm { - /// input gate + /// The input gate regulates which information to update and store in the memory cell at each time step. pub input_gate: GateController, - /// forget gate + /// The forget gate is used to control which information to discard or keep in the memory cell at each time step. pub forget_gate: GateController, - /// output gate + /// The output gate determines which information from the memory cell to output at each time step. pub output_gate: GateController, - /// cell gate + /// The cell gate is used to compute the cell state that stores and carries information through time. pub cell_gate: GateController, d_hidden: usize, } @@ -177,8 +177,10 @@ pub struct BiLstmConfig { /// The BiLstm module. This implementation is for Bidirectional LSTM. #[derive(Module, Debug)] pub struct BiLstm { - forward: Lstm, - reverse: Lstm, + /// LSTM for the forward direction + pub forward: Lstm, + /// LSTM for the reverse direction + pub reverse: Lstm, d_hidden: usize, } @@ -245,6 +247,7 @@ impl BiLstm { None => [None, None], }; + // forward direction let (batched_cell_state_forward, batched_hidden_state_forward) = self.forward.forward(batched_input.clone(), state_forward); From e09d36003e754b810047676d5ec114b9574a82a4 Mon Sep 17 00:00:00 2001 From: wcshds Date: Tue, 16 Apr 2024 20:02:03 +0800 Subject: [PATCH 06/17] improve doc --- crates/burn-core/src/nn/rnn/lstm.rs | 41 +++++++++++++++-------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index d7f745ac53..20ffb74f6a 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -64,18 +64,18 @@ impl LstmConfig { impl Lstm { /// Applies the forward pass on the input tensor. This LSTM implementation /// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`), - /// producing 3-dimensional tensors where the dimensions represent [batch_size, sequence_length, hidden_size]. + /// producing 3-dimensional tensors where the dimensions represent `[batch_size, sequence_length, hidden_size]`. /// - /// Parameters: - /// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. - /// state: An optional tuple of tensors representing the initial cell state and hidden state. - /// Each state tensor has shape [batch_size, hidden_size]. + /// ## Parameters: + /// - batched_input: The input tensor of shape `[batch_size, sequence_length, input_size]`. + /// - state: An optional tuple of tensors representing the initial cell state and hidden state. + /// Each state tensor has shape `[batch_size, hidden_size]`. /// If no initial state is provided, these tensors are initialized to zeros. /// - /// Returns: - /// A tuple of tensors, where the first tensor represents the cell states and - /// the second tensor represents the hidden states for each sequence element. - /// Both output tensors have the shape [batch_size, sequence_length, hidden_size]. + /// ## Returns: + /// A tuple of tensors, where the first tensor represents the cell states and + /// the second tensor represents the hidden states for each sequence element. + /// Both output tensors have the shape `[batch_size, sequence_length, hidden_size]`. pub fn forward( &self, batched_input: Tensor, @@ -177,9 +177,9 @@ pub struct BiLstmConfig { /// The BiLstm module. This implementation is for Bidirectional LSTM. #[derive(Module, Debug)] pub struct BiLstm { - /// LSTM for the forward direction + /// LSTM for the forward direction. pub forward: Lstm, - /// LSTM for the reverse direction + /// LSTM for the reverse direction. pub reverse: Lstm, d_hidden: usize, } @@ -202,18 +202,19 @@ impl BiLstmConfig { impl BiLstm { /// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation /// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`), - /// producing 3-dimensional tensors where the dimensions represent [batch_size, sequence_length, hidden_size * 2]. + /// producing 3-dimensional tensors where the dimensions represent `[batch_size, sequence_length, hidden_size * 2]`. /// - /// Parameters: - /// batched_input: The input tensor of shape [batch_size, sequence_length, input_size]. - /// state: An optional tuple of tensors representing the initial cell state and hidden state. - /// Each state tensor has shape [2, batch_size, hidden_size]. + /// ## Parameters: + /// + /// - batched_input: The input tensor of shape `[batch_size, sequence_length, input_size]`. + /// - state: An optional tuple of tensors representing the initial cell state and hidden state. + /// Each state tensor has shape `[2, batch_size, hidden_size]`. /// If no initial state is provided, these tensors are initialized to zeros. /// - /// Returns: - /// A tuple of tensors, where the first tensor represents the cell states and - /// the second tensor represents the hidden states for each sequence element. - /// Both output tensors have the shape [batch_size, sequence_length, hidden_size * 2]. + /// ## Returns: + /// A tuple of tensors, where the first tensor represents the cell states and + /// the second tensor represents the hidden states for each sequence element. + /// Both output tensors have the shape `[batch_size, sequence_length, hidden_size * 2]`. pub fn forward( &self, batched_input: Tensor, From f5ed22cc5c377b9b8f51e7f9cd42c0e293283241 Mon Sep 17 00:00:00 2001 From: wcshds Date: Tue, 16 Apr 2024 20:09:02 +0800 Subject: [PATCH 07/17] correct the description of GateController --- crates/burn-core/src/nn/rnn/gate_controller.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/burn-core/src/nn/rnn/gate_controller.rs b/crates/burn-core/src/nn/rnn/gate_controller.rs index 3ddc04e12c..8cc0fe860e 100644 --- a/crates/burn-core/src/nn/rnn/gate_controller.rs +++ b/crates/burn-core/src/nn/rnn/gate_controller.rs @@ -6,7 +6,8 @@ use burn_tensor::{backend::Backend, Tensor}; /// A GateController represents a gate in an LSTM cell. An /// LSTM cell generally contains three gates: an input gate, -/// forget gate, and cell gate. +/// forget gate, and output gate. Additionally, cell gate +/// is just used to compute the cell state. /// /// An Lstm gate is modeled as two linear transformations. /// The results of these transformations are used to calculate From 847db7fa54016ed6a8c58aae8387b0cd8bc947e5 Mon Sep 17 00:00:00 2001 From: wcshds Date: Tue, 16 Apr 2024 21:44:37 +0800 Subject: [PATCH 08/17] fix fmt --- crates/burn-core/src/nn/rnn/gate_controller.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-core/src/nn/rnn/gate_controller.rs b/crates/burn-core/src/nn/rnn/gate_controller.rs index 8cc0fe860e..0c52f383f4 100644 --- a/crates/burn-core/src/nn/rnn/gate_controller.rs +++ b/crates/burn-core/src/nn/rnn/gate_controller.rs @@ -6,7 +6,7 @@ use burn_tensor::{backend::Backend, Tensor}; /// A GateController represents a gate in an LSTM cell. An /// LSTM cell generally contains three gates: an input gate, -/// forget gate, and output gate. Additionally, cell gate +/// forget gate, and output gate. Additionally, cell gate /// is just used to compute the cell state. /// /// An Lstm gate is modeled as two linear transformations. From 5a9345ed6add3264e998e82aefa2cc1b9226c13f Mon Sep 17 00:00:00 2001 From: wcshds Date: Thu, 18 Apr 2024 12:21:51 +0800 Subject: [PATCH 09/17] add `LstmState` --- burn-book/src/building-blocks/module.md | 2 +- crates/burn-core/src/nn/rnn/lstm.rs | 167 ++++++++++++++---------- 2 files changed, 98 insertions(+), 71 deletions(-) diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index 0a5d81eaae..12c3e6bdbc 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -147,7 +147,7 @@ Burn comes with built-in modules that you can use to build your own modules. | Burn API | PyTorch Equivalent | |------------------|------------------------| | `Gru` | `nn.GRU` | -| `Lstm` | `nn.LSTM` | +| `Lstm`/`BiLstm` | `nn.LSTM` | | `GateController` | _No direct equivalent_ | ### Transformer diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 20ffb74f6a..a621d49ab1 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -8,6 +8,21 @@ use crate::tensor::backend::Backend; use crate::tensor::Tensor; use burn_tensor::activation; +/// A LstmState is used to store cell state and hidden state in LSTM. +pub struct LstmState { + /// The cell state. + pub cell: Tensor, + /// The hidden state. + pub hidden: Tensor, +} + +impl LstmState { + /// Initialize a new [LSTM State](LstmState). + pub fn new(cell: Tensor, hidden: Tensor) -> Self { + Self { cell, hidden } + } +} + /// The configuration for a [lstm](Lstm) module. #[derive(Config)] pub struct LstmConfig { @@ -25,11 +40,11 @@ pub struct LstmConfig { /// The Lstm module. This implementation is for a unidirectional, stateless, Lstm. #[derive(Module, Debug)] pub struct Lstm { - /// The input gate regulates which information to update and store in the memory cell at each time step. + /// The input gate regulates which information to update and store in the cell state at each time step. pub input_gate: GateController, /// The forget gate is used to control which information to discard or keep in the memory cell at each time step. pub forget_gate: GateController, - /// The output gate determines which information from the memory cell to output at each time step. + /// The output gate determines which information from the cell state to output at each time step. pub output_gate: GateController, /// The cell gate is used to compute the cell state that stores and carries information through time. pub cell_gate: GateController, @@ -63,24 +78,24 @@ impl LstmConfig { impl Lstm { /// Applies the forward pass on the input tensor. This LSTM implementation - /// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`), + /// returns hidden state for each element in a sequence (i.e., across `seq_length`) and a final state, /// producing 3-dimensional tensors where the dimensions represent `[batch_size, sequence_length, hidden_size]`. /// /// ## Parameters: /// - batched_input: The input tensor of shape `[batch_size, sequence_length, input_size]`. - /// - state: An optional tuple of tensors representing the initial cell state and hidden state. - /// Each state tensor has shape `[batch_size, hidden_size]`. - /// If no initial state is provided, these tensors are initialized to zeros. + /// - state: An optional `LstmState` representing the initial cell state and hidden state. + /// Each state tensor has shape `[batch_size, hidden_size]`. + /// If no initial state is provided, these tensors are initialized to zeros. /// /// ## Returns: - /// A tuple of tensors, where the first tensor represents the cell states and - /// the second tensor represents the hidden states for each sequence element. - /// Both output tensors have the shape `[batch_size, sequence_length, hidden_size]`. + /// - output: A tensor represents the output features of LSTM. Shape: `[batch_size, sequence_length, hidden_size]` + /// - state: A `LstmState` represents the final forward and reverse states. Both `state.cell` and + /// `state.hidden` have the shape `[batch_size, hidden_size]`. pub fn forward( &self, batched_input: Tensor, - state: Option<(Tensor, Tensor)>, - ) -> (Tensor, Tensor) { + state: Option>, + ) -> (Tensor, LstmState) { let device = batched_input.device(); let [batch_size, seq_length, _] = batched_input.dims(); @@ -96,17 +111,16 @@ impl Lstm { fn forward_iter, usize)>>( &self, input_timestep_iter: I, - state: Option<(Tensor, Tensor)>, + state: Option>, batch_size: usize, seq_length: usize, device: &B::Device, - ) -> (Tensor, Tensor) { - let mut batched_cell_state = Tensor::zeros([batch_size, seq_length, self.d_hidden], device); + ) -> (Tensor, LstmState) { let mut batched_hidden_state = - Tensor::zeros([batch_size, seq_length, self.d_hidden], device); + Tensor::empty([batch_size, seq_length, self.d_hidden], device); let (mut cell_state, mut hidden_state) = match state { - Some((cell_state, hidden_state)) => (cell_state, hidden_state), + Some(state) => (state.cell, state.hidden), None => ( Tensor::zeros([batch_size, self.d_hidden], device), Tensor::zeros([batch_size, self.d_hidden], device), @@ -142,21 +156,19 @@ impl Lstm { cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values; hidden_state = output_values * cell_state.clone().tanh(); - let unsqueezed_cell_state = cell_state.clone().unsqueeze_dim(1); let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1); - // store the state for this timestep - batched_cell_state = batched_cell_state.slice_assign( - [0..batch_size, t..(t + 1), 0..self.d_hidden], - unsqueezed_cell_state.clone(), - ); + // store the hidden state for this timestep batched_hidden_state = batched_hidden_state.slice_assign( [0..batch_size, t..(t + 1), 0..self.d_hidden], unsqueezed_hidden_state.clone(), ); } - (batched_cell_state, batched_hidden_state) + ( + batched_hidden_state, + LstmState::new(cell_state, hidden_state), + ) } } @@ -201,76 +213,88 @@ impl BiLstmConfig { impl BiLstm { /// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation - /// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`), + /// returns hidden state for each element in a sequence (i.e., across `seq_length`) and a final state, /// producing 3-dimensional tensors where the dimensions represent `[batch_size, sequence_length, hidden_size * 2]`. /// /// ## Parameters: /// /// - batched_input: The input tensor of shape `[batch_size, sequence_length, input_size]`. - /// - state: An optional tuple of tensors representing the initial cell state and hidden state. - /// Each state tensor has shape `[2, batch_size, hidden_size]`. - /// If no initial state is provided, these tensors are initialized to zeros. + /// - state: An optional `LstmState` representing the initial cell state and hidden state. + /// Each state tensor has shape `[2, batch_size, hidden_size]`. + /// If no initial state is provided, these tensors are initialized to zeros. /// /// ## Returns: - /// A tuple of tensors, where the first tensor represents the cell states and - /// the second tensor represents the hidden states for each sequence element. - /// Both output tensors have the shape `[batch_size, sequence_length, hidden_size * 2]`. + /// - output: A tensor represents the output features of LSTM. Shape: `[batch_size, sequence_length, hidden_size * 2]` + /// - state: A `LstmState` represents the final forward and reverse states. Both `state.cell` and + /// `state.hidden` have the shape `[2, batch_size, hidden_size]`. pub fn forward( &self, batched_input: Tensor, - state: Option<(Tensor, Tensor)>, - ) -> (Tensor, Tensor) { + state: Option>, + ) -> (Tensor, LstmState) { let device = batched_input.clone().device(); let [batch_size, seq_length, _] = batched_input.shape().dims; - let [state_forward, state_reverse] = match state { - Some((cell_state, hidden_state)) => { - let cell_state_forward = cell_state + let [init_state_forward, init_state_reverse] = match state { + Some(state) => { + let cell_state_forward = state + .cell .clone() .slice([0..1, 0..batch_size, 0..self.d_hidden]) .squeeze(0); - let hidden_state_forward = hidden_state + let hidden_state_forward = state + .hidden .clone() .slice([0..1, 0..batch_size, 0..self.d_hidden]) .squeeze(0); - let cell_state_reverse = cell_state + let cell_state_reverse = state + .cell .slice([1..2, 0..batch_size, 0..self.d_hidden]) .squeeze(0); - let hidden_state_reverse = hidden_state + let hidden_state_reverse = state + .hidden .slice([1..2, 0..batch_size, 0..self.d_hidden]) .squeeze(0); [ - Some((cell_state_forward, hidden_state_forward)), - Some((cell_state_reverse, hidden_state_reverse)), + Some(LstmState::new(cell_state_forward, hidden_state_forward)), + Some(LstmState::new(cell_state_reverse, hidden_state_reverse)), ] } None => [None, None], }; // forward direction - let (batched_cell_state_forward, batched_hidden_state_forward) = - self.forward.forward(batched_input.clone(), state_forward); + let (batched_hidden_state_forward, final_state_forward) = self + .forward + .forward(batched_input.clone(), init_state_forward); // reverse direction - let (batched_cell_state_reverse, batched_hidden_state_reverse) = self.reverse.forward_iter( + let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter( batched_input.iter_dim(1).rev().zip((0..seq_length).rev()), - state_reverse, + init_state_reverse, batch_size, seq_length, &device, ); - let batched_cell_state = Tensor::cat( - [batched_cell_state_forward, batched_cell_state_reverse].to_vec(), - 2, - ); - let batched_hidden_state = Tensor::cat( + let output = Tensor::cat( [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(), 2, ); - (batched_cell_state, batched_hidden_state) + let state = LstmState::new( + Tensor::stack( + [final_state_forward.cell, final_state_reverse.cell].to_vec(), + 0, + ), + Tensor::stack( + [final_state_forward.hidden, final_state_reverse.hidden].to_vec(), + 0, + ), + ); + + (output, state) } } @@ -278,7 +302,7 @@ impl BiLstm { mod tests { use super::*; use crate::{module::Param, nn::LinearRecord, TestBackend}; - use burn_tensor::{Data, Distribution}; + use burn_tensor::{Data, Device, Distribution}; #[cfg(feature = "std")] use crate::TestAutodiffBackend; @@ -323,7 +347,7 @@ mod tests { d_output: usize, bias: bool, initializer: Initializer, - device: &::Device, + device: &Device, ) -> GateController { let record_1 = LinearRecord { weight: Param::from_data(Data::from([[weights]]), device), @@ -383,19 +407,21 @@ mod tests { // single timestep with single feature let input = Tensor::::from_data(Data::from([[[0.1]]]), &device); - let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None); - let cell_state = cell_state_batch - .select(0, Tensor::arange(0..1, &device)) - .squeeze(0); - let hidden_state = hidden_state_batch - .select(0, Tensor::arange(0..1, &device)) - .squeeze(0); - cell_state + // let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None); + let (output, state) = lstm.forward(input, None); + state + .cell .to_data() .assert_approx_eq(&Data::from([[0.046]]), 3); - hidden_state + state + .hidden + .to_data() + .assert_approx_eq(&Data::from([[0.024]]), 3); + output + .select(0, Tensor::arange(0..1, &device)) + .squeeze(0) .to_data() - .assert_approx_eq(&Data::from([[0.024]]), 3) + .assert_approx_eq(&state.hidden.to_data(), 3); } #[test] @@ -405,10 +431,11 @@ mod tests { let batched_input = Tensor::::random([8, 10, 64], Distribution::Default, &device); - let (cell_state, hidden_state) = lstm.forward(batched_input, None); + let (output, state) = lstm.forward(batched_input, None); - assert_eq!(cell_state.shape().dims, [8, 10, 1024]); - assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); + assert_eq!(output.dims(), [8, 10, 1024]); + assert_eq!(state.cell.dims(), [8, 1024]); + assert_eq!(state.hidden.dims(), [8, 1024]); } #[test] @@ -421,8 +448,8 @@ mod tests { let batched_input = Tensor::::random(shape, Distribution::Default, &device); - let (cell_state, hidden_state) = lstm.forward(batched_input.clone(), None); - let fake_loss = cell_state + hidden_state; + let (output, _) = lstm.forward(batched_input.clone(), None); + let fake_loss = output; let grads = fake_loss.backward(); let some_gradient = lstm @@ -448,7 +475,7 @@ mod tests { input_biases: [f32; D1], hidden_weights: [[f32; D1]; D1], hidden_biases: [f32; D1], - device: &::Device, + device: &Device, ) -> GateController { let d_input = input_weights[0].len(); let d_output = input_weights.len(); @@ -610,8 +637,8 @@ mod tests { ], ]]); - let (_, hidden_state) = lstm.forward(input, None); + let (output, _) = lstm.forward(input, None); - hidden_state.to_data().assert_approx_eq(&expected_result, 3) + output.to_data().assert_approx_eq(&expected_result, 3) } } From 7c0e761008baba16dcab3813a19e579a110f5823 Mon Sep 17 00:00:00 2001 From: wcshds Date: Thu, 18 Apr 2024 12:34:39 +0800 Subject: [PATCH 10/17] add test for state --- crates/burn-core/src/nn/rnn/lstm.rs | 207 ++++++++++++++++------------ 1 file changed, 117 insertions(+), 90 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index a621d49ab1..2a578207da 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -466,7 +466,7 @@ mod tests { #[test] fn test_bidirectional() { TestBackend::seed(0); - let config = BiLstmConfig::new(2, 4, true); + let config = BiLstmConfig::new(2, 3, true); let device = Default::default(); let mut lstm = config.init(&device); @@ -499,146 +499,173 @@ mod tests { } let input = Tensor::::from_data( - Data::from([[[-0.131, -1.591], [1.378, -1.867], [0.397, 0.047]]]), + Data::from([[ + [0.949, -0.861], + [0.892, 0.927], + [-0.173, -0.301], + [-0.081, 0.992], + ]]), + &device, + ); + let h0 = Tensor::::from_data( + Data::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]), + &device, + ); + let c0 = Tensor::::from_data( + Data::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]), &device, ); lstm.forward.input_gate = create_gate_controller( - [[0.078, 0.234, 0.398, 0.333], [0.452, 0.124, -0.042, -0.152]], - [0.196, 0.094, -0.270, 0.008], + [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]], + [-0.196, 0.354, 0.209], [ - [0.054, 0.057, 0.282, 0.021], - [0.065, -0.303, -0.499, 0.069], - [-0.007, 0.226, -0.131, -0.307], - [-0.025, 0.072, 0.197, 0.129], + [-0.320, 0.232, -0.165], + [0.093, -0.572, -0.315], + [-0.467, 0.325, 0.046], ], - [0.278, -0.211, 0.435, -0.162], + [0.181, -0.190, -0.245], &device, ); lstm.forward.forget_gate = create_gate_controller( + [[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]], + [0.315, -0.413, -0.041], [ - [-0.187, -0.201, 0.078, -0.314], - [0.169, 0.229, 0.218, 0.466], + [0.453, 0.063, 0.561], + [0.211, 0.149, 0.213], + [-0.499, -0.158, 0.068], ], - [0.320, -0.135, -0.301, 0.180], - [ - [0.392, -0.028, 0.470, -0.025], - [-0.284, -0.286, -0.211, -0.001], - [0.245, -0.259, 0.102, -0.379], - [-0.096, -0.462, 0.170, 0.232], - ], - [0.458, 0.039, 0.287, -0.327], + [-0.431, -0.535, 0.125], &device, ); lstm.forward.cell_gate = create_gate_controller( + [[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]], + [-0.358, 0.282, -0.078], [ - [-0.216, 0.256, 0.369, 0.160], - [0.453, -0.238, 0.306, -0.411], - ], - [0.360, 0.001, 0.303, 0.438], - [ - [0.356, -0.185, 0.494, 0.325], - [0.111, -0.388, 0.051, -0.150], - [-0.434, 0.296, -0.185, 0.290], - [-0.010, -0.023, 0.460, 0.238], + [-0.358, 0.109, 0.139], + [-0.345, 0.091, -0.368], + [-0.508, 0.221, -0.507], ], - [0.268, -0.136, -0.452, 0.471], + [0.502, -0.509, -0.247], &device, ); lstm.forward.output_gate = create_gate_controller( - [[0.235, -0.132, 0.049, 0.157], [-0.280, 0.229, 0.102, 0.448]], - [0.237, -0.396, -0.134, -0.047], + [[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]], + [-0.227, -0.274, 0.039], [ - [-0.243, 0.196, 0.087, 0.163], - [0.138, -0.247, -0.401, -0.462], - [0.030, -0.263, 0.473, 0.259], - [-0.413, -0.173, -0.206, 0.324], + [-0.383, 0.449, 0.222], + [-0.357, -0.093, 0.449], + [-0.106, 0.236, 0.360], ], - [-0.364, -0.023, 0.215, -0.401], + [-0.361, -0.209, -0.454], &device, ); lstm.reverse.input_gate = create_gate_controller( + [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]], + [0.540, -0.164, 0.033], [ - [0.220, -0.191, 0.062, -0.443], - [-0.112, -0.353, -0.443, 0.080], - ], - [-0.418, 0.209, 0.297, -0.429], - [ - [-0.121, -0.408, 0.132, -0.450], - [0.231, 0.154, -0.294, 0.022], - [0.378, 0.239, 0.176, -0.361], - [0.480, 0.427, -0.156, -0.137], + [0.159, 0.180, -0.037], + [-0.443, 0.485, -0.488], + [0.098, -0.085, -0.140], ], - [0.267, -0.474, -0.393, 0.190], + [-0.510, 0.105, 0.114], &device, ); lstm.reverse.forget_gate = create_gate_controller( + [[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]], + [0.141, 0.004, 0.055], [ - [0.151, 0.148, 0.341, -0.112], - [-0.368, -0.476, 0.003, 0.083], - ], - [-0.489, -0.361, -0.035, 0.328], - [ - [0.460, -0.124, -0.377, -0.033], - [-0.296, 0.162, 0.456, -0.271], - [0.320, 0.235, 0.383, 0.423], - [-0.167, 0.332, -0.493, 0.086], + [-0.005, -0.277, -0.515], + [-0.011, -0.101, -0.365], + [0.426, 0.379, 0.337], ], - [-0.425, 0.219, 0.294, -0.075], + [-0.382, 0.331, -0.176], &device, ); lstm.reverse.cell_gate = create_gate_controller( + [[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]], + [-0.206, -0.546, 0.462], [ - [-0.451, 0.285, 0.305, -0.344], - [-0.399, 0.344, -0.022, 0.263], - ], - [0.215, -0.028, 0.097, 0.197], - [ - [0.072, 0.106, -0.030, 0.056], - [-0.278, -0.256, -0.129, -0.252], - [-0.305, 0.219, 0.045, -0.123], - [0.224, 0.011, -0.199, -0.362], + [0.449, -0.240, 0.071], + [-0.045, 0.131, 0.124], + [0.138, -0.201, 0.191], ], - [0.086, 0.466, -0.152, 0.353], + [-0.030, 0.211, -0.352], &device, ); lstm.reverse.output_gate = create_gate_controller( + [[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]], + [-0.387, -0.250, 0.066], [ - [0.057, -0.357, 0.031, 0.235], - [-0.492, -0.109, -0.316, -0.422], + [-0.030, 0.268, 0.299], + [-0.019, -0.280, -0.314], + [0.466, -0.365, -0.248], ], - [0.233, 0.053, 0.162, -0.465], - [ - [0.240, 0.223, -0.188, -0.181], - [-0.427, -0.390, -0.176, -0.338], - [-0.158, 0.152, -0.105, 0.106], - [-0.223, -0.186, -0.059, 0.319], - ], - [0.207, 0.295, 0.361, 0.029], + [-0.398, -0.199, -0.566], &device, ); - let expected_result = Data::from([[ - [ - -0.01604, 0.02718, -0.14959, 0.10219, 0.34534, 0.06087, 0.07809, 0.01806, - ], - [ - -0.13098, 0.07478, -0.10684, 0.15549, 0.19981, 0.12038, 0.19815, -0.02509, - ], - [ - 0.09250, 0.03285, -0.04502, 0.24134, 0.03017, 0.11454, 0.01943, 0.06517, - ], + let expected_output_with_init_state = Data::from([[ + [0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798], + [0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742], + [0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012], + [-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872], ]]); - - let (output, _) = lstm.forward(input, None); - - output.to_data().assert_approx_eq(&expected_result, 3) + let expected_output_without_init_state = Data::from([[ + [0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863], + [-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142], + [0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846], + [-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550], + ]]); + let expected_hn_with_init_state = Data::from([ + [[-0.03420, 0.07774, -0.09774]], + [[-0.15635, -0.03366, -0.05798]], + ]); + let expected_cn_with_init_state = Data::from([ + [[-0.13593, 0.17125, -0.22395]], + [[-0.45425, -0.11206, -0.12908]], + ]); + let expected_hn_without_init_state = Data::from([ + [[-0.04026, 0.07178, -0.10189]], + [[-0.15969, -0.05322, -0.08863]], + ]); + let expected_cn_without_init_state = Data::from([ + [[-0.15839, 0.15923, -0.23569]], + [[-0.47407, -0.17493, -0.19643]], + ]); + + let (output_with_init_state, state_with_init_state) = + lstm.forward(input.clone(), Some(LstmState::new(c0, h0))); + let (output_without_init_state, state_without_init_state) = lstm.forward(input, None); + + output_with_init_state + .to_data() + .assert_approx_eq(&expected_output_with_init_state, 3); + output_without_init_state + .to_data() + .assert_approx_eq(&expected_output_without_init_state, 3); + state_with_init_state + .hidden + .to_data() + .assert_approx_eq(&expected_hn_with_init_state, 3); + state_with_init_state + .cell + .to_data() + .assert_approx_eq(&expected_cn_with_init_state, 3); + state_without_init_state + .hidden + .to_data() + .assert_approx_eq(&expected_hn_without_init_state, 3); + state_without_init_state + .cell + .to_data() + .assert_approx_eq(&expected_cn_without_init_state, 3); } } From bfeb74a4e30b1e11add6d69848360b7fc60b6f4a Mon Sep 17 00:00:00 2001 From: wcshds Date: Thu, 18 Apr 2024 18:09:07 +0800 Subject: [PATCH 11/17] set batch 2 in bilstm test --- crates/burn-core/src/nn/rnn/lstm.rs | 182 ++++++++++++++++------------ 1 file changed, 106 insertions(+), 76 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 2a578207da..22242a6487 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -499,146 +499,176 @@ mod tests { } let input = Tensor::::from_data( - Data::from([[ - [0.949, -0.861], - [0.892, 0.927], - [-0.173, -0.301], - [-0.081, 0.992], - ]]), + Data::from([ + [ + [1.647, -0.499], + [-1.991, 0.439], + [0.571, 0.563], + [0.149, -1.048], + ], + [ + [0.039, -0.786], + [-0.703, 1.071], + [-0.417, -1.480], + [-0.621, -0.827], + ], + ]), &device, ); let h0 = Tensor::::from_data( - Data::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]), + Data::from([ + [[0.680, -0.813, 0.760], [0.336, 0.827, -0.749]], + [[-1.736, -0.235, 0.925], [-0.048, 0.218, 0.909]], + ]), &device, ); let c0 = Tensor::::from_data( - Data::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]), + Data::from([ + [[-0.298, 0.507, -0.058], [-1.805, 0.768, 0.523]], + [[0.364, -1.398, 1.188], [0.087, -0.555, 0.500]], + ]), &device, ); lstm.forward.input_gate = create_gate_controller( - [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]], - [-0.196, 0.354, 0.209], + [[0.050, 0.292, -0.044], [-0.392, 0.409, -0.110]], + [-0.007, 0.483, 0.038], [ - [-0.320, 0.232, -0.165], - [0.093, -0.572, -0.315], - [-0.467, 0.325, 0.046], + [0.036, 0.511, -0.236], + [-0.232, 0.449, 0.146], + [-0.282, -0.365, -0.329], ], - [0.181, -0.190, -0.245], + [0.355, -0.259, -0.300], &device, ); lstm.forward.forget_gate = create_gate_controller( - [[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]], - [0.315, -0.413, -0.041], + [[0.360, -0.228, -0.036], [0.123, -0.077, -0.341]], + [-0.306, -0.335, -0.039], [ - [0.453, 0.063, 0.561], - [0.211, 0.149, 0.213], - [-0.499, -0.158, 0.068], + [0.156, -0.156, -0.360], + [-0.117, -0.429, -0.259], + [0.023, 0.226, 0.455], ], - [-0.431, -0.535, 0.125], + [0.255, -0.067, -0.125], &device, ); lstm.forward.cell_gate = create_gate_controller( - [[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]], - [-0.358, 0.282, -0.078], + [[-0.375, -0.128, 0.363], [0.041, -0.109, 0.071]], + [0.014, 0.489, 0.218], [ - [-0.358, 0.109, 0.139], - [-0.345, 0.091, -0.368], - [-0.508, 0.221, -0.507], + [0.559, -0.561, -0.426], + [0.205, -0.492, 0.010], + [0.280, -0.496, -0.220], ], - [0.502, -0.509, -0.247], + [0.239, 0.166, -0.176], &device, ); lstm.forward.output_gate = create_gate_controller( - [[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]], - [-0.227, -0.274, 0.039], + [[0.352, 0.206, 0.020], [0.343, -0.327, 0.208]], + [-0.451, 0.071, -0.232], [ - [-0.383, 0.449, 0.222], - [-0.357, -0.093, 0.449], - [-0.106, 0.236, 0.360], + [-0.257, -0.346, -0.343], + [0.490, -0.473, 0.208], + [0.457, 0.105, 0.093], ], - [-0.361, -0.209, -0.454], + [-0.531, 0.178, -0.475], &device, ); lstm.reverse.input_gate = create_gate_controller( - [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]], - [0.540, -0.164, 0.033], + [[0.098, 0.072, 0.429], [0.397, 0.479, -0.320]], + [-0.129, 0.442, -0.044], [ - [0.159, 0.180, -0.037], - [-0.443, 0.485, -0.488], - [0.098, -0.085, -0.140], + [-0.543, 0.344, -0.013], + [-0.388, 0.389, -0.480], + [-0.496, -0.193, -0.169], ], - [-0.510, 0.105, 0.114], + [-0.042, 0.576, -0.465], &device, ); lstm.reverse.forget_gate = create_gate_controller( - [[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]], - [0.141, 0.004, 0.055], + [[-0.514, -0.553, -0.569], [-0.045, 0.367, 0.521]], + [0.240, -0.500, 0.502], [ - [-0.005, -0.277, -0.515], - [-0.011, -0.101, -0.365], - [0.426, 0.379, 0.337], + [0.270, 0.027, 0.411], + [-0.123, -0.447, -0.051], + [-0.280, -0.056, 0.261], ], - [-0.382, 0.331, -0.176], + [0.189, -0.567, 0.117], &device, ); lstm.reverse.cell_gate = create_gate_controller( - [[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]], - [-0.206, -0.546, 0.462], + [[-0.488, 0.185, -0.163], [-0.243, -0.307, -0.098]], + [0.368, -0.306, -0.524], [ - [0.449, -0.240, 0.071], - [-0.045, 0.131, 0.124], - [0.138, -0.201, 0.191], + [0.572, -0.365, 0.348], + [-0.492, 0.512, -0.023], + [-0.144, 0.050, 0.098], ], - [-0.030, 0.211, -0.352], + [0.148, 0.163, -0.546], &device, ); lstm.reverse.output_gate = create_gate_controller( - [[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]], - [-0.387, -0.250, 0.066], + [[-0.069, -0.455, 0.461], [-0.274, 0.266, 0.519]], + [0.388, 0.545, -0.388], [ - [-0.030, 0.268, 0.299], - [-0.019, -0.280, -0.314], - [0.466, -0.365, -0.248], + [0.180, -0.462, 0.106], + [0.543, 0.295, -0.411], + [-0.011, -0.066, 0.470], ], - [-0.398, -0.199, -0.566], + [-0.179, -0.196, -0.067], &device, ); - let expected_output_with_init_state = Data::from([[ - [0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798], - [0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742], - [0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012], - [-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872], - ]]); - let expected_output_without_init_state = Data::from([[ - [0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863], - [-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142], - [0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846], - [-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550], - ]]); + let expected_output_with_init_state = Data::from([ + [ + [-0.05291, 0.20481, 0.00247, 0.14330, -0.01617, -0.32437], + [0.06333, 0.17128, -0.08646, 0.23891, -0.39256, -0.08803], + [0.08696, 0.21229, 0.00791, 0.02564, -0.08598, -0.15525], + [0.06240, 0.25245, 0.00132, 0.00171, 0.01233, 0.01195], + ], + [ + [-0.11673, 0.23612, 0.05902, 0.34088, -0.09401, -0.16047], + [-0.01053, 0.20343, 0.01439, 0.26776, -0.29267, -0.15661], + [0.04320, 0.28468, -0.02198, 0.24269, 0.04973, -0.04563], + [0.07891, 0.24718, -0.04706, 0.13683, -0.01629, 0.03767], + ], + ]); + let expected_output_without_init_state = Data::from([ + [ + [-0.08461, 0.18986, 0.07192, 0.18021, -0.02266, -0.35150], + [0.06048, 0.17062, -0.05256, 0.29482, -0.40167, -0.12416], + [0.08438, 0.20755, 0.02044, 0.10186, -0.08353, -0.25673], + [0.06173, 0.24971, 0.00638, 0.13258, 0.06368, -0.09722], + ], + [ + [0.02993, 0.18217, 0.00005, 0.35562, -0.09828, -0.17992], + [0.09557, 0.16621, -0.02360, 0.28457, -0.29604, -0.21862], + [0.06623, 0.26088, -0.03991, 0.27286, 0.05034, -0.08039], + [0.08877, 0.24112, -0.05770, 0.16840, -0.00154, -0.06161], + ], + ]); let expected_hn_with_init_state = Data::from([ - [[-0.03420, 0.07774, -0.09774]], - [[-0.15635, -0.03366, -0.05798]], + [[0.06240, 0.25245, 0.00132], [0.07891, 0.24718, -0.04706]], + [[0.14330, -0.01617, -0.32437], [0.34088, -0.09401, -0.16047]], ]); let expected_cn_with_init_state = Data::from([ - [[-0.13593, 0.17125, -0.22395]], - [[-0.45425, -0.11206, -0.12908]], + [[0.27726, 0.43163, 0.00460], [0.40963, 0.47434, -0.15836]], + [[0.28537, -0.05057, -0.68145], [0.67802, -0.19816, -0.55872]], ]); let expected_hn_without_init_state = Data::from([ - [[-0.04026, 0.07178, -0.10189]], - [[-0.15969, -0.05322, -0.08863]], + [[0.06173, 0.24971, 0.00638], [0.08877, 0.24112, -0.05770]], + [[0.18021, -0.02266, -0.35150], [0.35562, -0.09828, -0.17992]], ]); let expected_cn_without_init_state = Data::from([ - [[-0.15839, 0.15923, -0.23569]], - [[-0.47407, -0.17493, -0.19643]], + [[0.27319, 0.42555, 0.02218], [0.47942, 0.46064, -0.19706]], + [[0.36375, -0.07220, -0.76521], [0.71734, -0.20792, -0.66048]], ]); let (output_with_init_state, state_with_init_state) = From 02e3c56ad5a965a982c0beabd69b1bc590de139b Mon Sep 17 00:00:00 2001 From: wcshds Date: Fri, 26 Apr 2024 21:28:14 +0800 Subject: [PATCH 12/17] resolve conflict --- crates/burn-core/src/nn/rnn/lstm.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 22242a6487..1f603baebd 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -438,6 +438,20 @@ mod tests { assert_eq!(state.hidden.dims(), [8, 1024]); } + #[test] + fn test_batched_forward_pass_batch_of_one() { + let device = Default::default(); + let lstm = LstmConfig::new(64, 1024, true).init(&device); + let batched_input = + Tensor::::random([1, 2, 64], Distribution::Default, &device); + + let (output, state) = lstm.forward(batched_input, None); + + assert_eq!(output.dims(), [1, 10, 1024]); + assert_eq!(state.cell.dims(), [1, 1024]); + assert_eq!(state.hidden.dims(), [1, 1024]); + } + #[test] #[cfg(feature = "std")] fn test_batched_backward_pass() { From 744fb40194127785f7e9833ea640f964168ed6ac Mon Sep 17 00:00:00 2001 From: wcshds Date: Fri, 26 Apr 2024 21:31:07 +0800 Subject: [PATCH 13/17] fix --- crates/burn-core/src/nn/rnn/lstm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 1f603baebd..7072921c3a 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -447,7 +447,7 @@ mod tests { let (output, state) = lstm.forward(batched_input, None); - assert_eq!(output.dims(), [1, 10, 1024]); + assert_eq!(output.dims(), [1, 2, 1024]); assert_eq!(state.cell.dims(), [1, 1024]); assert_eq!(state.hidden.dims(), [1, 1024]); } From 9b3e1f904ab655bb492397d2b1b0f5ad27fc350d Mon Sep 17 00:00:00 2001 From: wcshds Date: Fri, 26 Apr 2024 21:33:50 +0800 Subject: [PATCH 14/17] fix doc --- crates/burn-core/src/nn/rnn/lstm.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 7072921c3a..29bd6e2f63 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -89,8 +89,8 @@ impl Lstm { /// /// ## Returns: /// - output: A tensor represents the output features of LSTM. Shape: `[batch_size, sequence_length, hidden_size]` - /// - state: A `LstmState` represents the final forward and reverse states. Both `state.cell` and - /// `state.hidden` have the shape `[batch_size, hidden_size]`. + /// - state: A `LstmState` represents the final states. Both `state.cell` and `state.hidden` have the shape + /// `[batch_size, hidden_size]`. pub fn forward( &self, batched_input: Tensor, From b8b5e91b63a7579d25b848d9b0d0068af236ecec Mon Sep 17 00:00:00 2001 From: wcshds Date: Fri, 26 Apr 2024 21:39:06 +0800 Subject: [PATCH 15/17] change the batch size back to 1 --- crates/burn-core/src/nn/rnn/lstm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 29bd6e2f63..cd535ded35 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -89,7 +89,7 @@ impl Lstm { /// /// ## Returns: /// - output: A tensor represents the output features of LSTM. Shape: `[batch_size, sequence_length, hidden_size]` - /// - state: A `LstmState` represents the final states. Both `state.cell` and `state.hidden` have the shape + /// - state: A `LstmState` represents the final states. Both `state.cell` and `state.hidden` have the shape /// `[batch_size, hidden_size]`. pub fn forward( &self, From 5d26ea4145d2d371c3551b6ebad592c56b69b213 Mon Sep 17 00:00:00 2001 From: wcshds Date: Fri, 26 Apr 2024 21:44:41 +0800 Subject: [PATCH 16/17] change the batch size back to 1 --- crates/burn-core/src/nn/rnn/lstm.rs | 182 ++++++++++++---------------- 1 file changed, 76 insertions(+), 106 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index cd535ded35..1ba9f53879 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -513,176 +513,146 @@ mod tests { } let input = Tensor::::from_data( - Data::from([ - [ - [1.647, -0.499], - [-1.991, 0.439], - [0.571, 0.563], - [0.149, -1.048], - ], - [ - [0.039, -0.786], - [-0.703, 1.071], - [-0.417, -1.480], - [-0.621, -0.827], - ], - ]), + Data::from([[ + [0.949, -0.861], + [0.892, 0.927], + [-0.173, -0.301], + [-0.081, 0.992], + ]]), &device, ); let h0 = Tensor::::from_data( - Data::from([ - [[0.680, -0.813, 0.760], [0.336, 0.827, -0.749]], - [[-1.736, -0.235, 0.925], [-0.048, 0.218, 0.909]], - ]), + Data::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]), &device, ); let c0 = Tensor::::from_data( - Data::from([ - [[-0.298, 0.507, -0.058], [-1.805, 0.768, 0.523]], - [[0.364, -1.398, 1.188], [0.087, -0.555, 0.500]], - ]), + Data::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]), &device, ); lstm.forward.input_gate = create_gate_controller( - [[0.050, 0.292, -0.044], [-0.392, 0.409, -0.110]], - [-0.007, 0.483, 0.038], + [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]], + [-0.196, 0.354, 0.209], [ - [0.036, 0.511, -0.236], - [-0.232, 0.449, 0.146], - [-0.282, -0.365, -0.329], + [-0.320, 0.232, -0.165], + [0.093, -0.572, -0.315], + [-0.467, 0.325, 0.046], ], - [0.355, -0.259, -0.300], + [0.181, -0.190, -0.245], &device, ); lstm.forward.forget_gate = create_gate_controller( - [[0.360, -0.228, -0.036], [0.123, -0.077, -0.341]], - [-0.306, -0.335, -0.039], + [[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]], + [0.315, -0.413, -0.041], [ - [0.156, -0.156, -0.360], - [-0.117, -0.429, -0.259], - [0.023, 0.226, 0.455], + [0.453, 0.063, 0.561], + [0.211, 0.149, 0.213], + [-0.499, -0.158, 0.068], ], - [0.255, -0.067, -0.125], + [-0.431, -0.535, 0.125], &device, ); lstm.forward.cell_gate = create_gate_controller( - [[-0.375, -0.128, 0.363], [0.041, -0.109, 0.071]], - [0.014, 0.489, 0.218], + [[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]], + [-0.358, 0.282, -0.078], [ - [0.559, -0.561, -0.426], - [0.205, -0.492, 0.010], - [0.280, -0.496, -0.220], + [-0.358, 0.109, 0.139], + [-0.345, 0.091, -0.368], + [-0.508, 0.221, -0.507], ], - [0.239, 0.166, -0.176], + [0.502, -0.509, -0.247], &device, ); lstm.forward.output_gate = create_gate_controller( - [[0.352, 0.206, 0.020], [0.343, -0.327, 0.208]], - [-0.451, 0.071, -0.232], + [[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]], + [-0.227, -0.274, 0.039], [ - [-0.257, -0.346, -0.343], - [0.490, -0.473, 0.208], - [0.457, 0.105, 0.093], + [-0.383, 0.449, 0.222], + [-0.357, -0.093, 0.449], + [-0.106, 0.236, 0.360], ], - [-0.531, 0.178, -0.475], + [-0.361, -0.209, -0.454], &device, ); lstm.reverse.input_gate = create_gate_controller( - [[0.098, 0.072, 0.429], [0.397, 0.479, -0.320]], - [-0.129, 0.442, -0.044], + [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]], + [0.540, -0.164, 0.033], [ - [-0.543, 0.344, -0.013], - [-0.388, 0.389, -0.480], - [-0.496, -0.193, -0.169], + [0.159, 0.180, -0.037], + [-0.443, 0.485, -0.488], + [0.098, -0.085, -0.140], ], - [-0.042, 0.576, -0.465], + [-0.510, 0.105, 0.114], &device, ); lstm.reverse.forget_gate = create_gate_controller( - [[-0.514, -0.553, -0.569], [-0.045, 0.367, 0.521]], - [0.240, -0.500, 0.502], + [[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]], + [0.141, 0.004, 0.055], [ - [0.270, 0.027, 0.411], - [-0.123, -0.447, -0.051], - [-0.280, -0.056, 0.261], + [-0.005, -0.277, -0.515], + [-0.011, -0.101, -0.365], + [0.426, 0.379, 0.337], ], - [0.189, -0.567, 0.117], + [-0.382, 0.331, -0.176], &device, ); lstm.reverse.cell_gate = create_gate_controller( - [[-0.488, 0.185, -0.163], [-0.243, -0.307, -0.098]], - [0.368, -0.306, -0.524], + [[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]], + [-0.206, -0.546, 0.462], [ - [0.572, -0.365, 0.348], - [-0.492, 0.512, -0.023], - [-0.144, 0.050, 0.098], + [0.449, -0.240, 0.071], + [-0.045, 0.131, 0.124], + [0.138, -0.201, 0.191], ], - [0.148, 0.163, -0.546], + [-0.030, 0.211, -0.352], &device, ); lstm.reverse.output_gate = create_gate_controller( - [[-0.069, -0.455, 0.461], [-0.274, 0.266, 0.519]], - [0.388, 0.545, -0.388], + [[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]], + [-0.387, -0.250, 0.066], [ - [0.180, -0.462, 0.106], - [0.543, 0.295, -0.411], - [-0.011, -0.066, 0.470], + [-0.030, 0.268, 0.299], + [-0.019, -0.280, -0.314], + [0.466, -0.365, -0.248], ], - [-0.179, -0.196, -0.067], + [-0.398, -0.199, -0.566], &device, ); - let expected_output_with_init_state = Data::from([ - [ - [-0.05291, 0.20481, 0.00247, 0.14330, -0.01617, -0.32437], - [0.06333, 0.17128, -0.08646, 0.23891, -0.39256, -0.08803], - [0.08696, 0.21229, 0.00791, 0.02564, -0.08598, -0.15525], - [0.06240, 0.25245, 0.00132, 0.00171, 0.01233, 0.01195], - ], - [ - [-0.11673, 0.23612, 0.05902, 0.34088, -0.09401, -0.16047], - [-0.01053, 0.20343, 0.01439, 0.26776, -0.29267, -0.15661], - [0.04320, 0.28468, -0.02198, 0.24269, 0.04973, -0.04563], - [0.07891, 0.24718, -0.04706, 0.13683, -0.01629, 0.03767], - ], - ]); - let expected_output_without_init_state = Data::from([ - [ - [-0.08461, 0.18986, 0.07192, 0.18021, -0.02266, -0.35150], - [0.06048, 0.17062, -0.05256, 0.29482, -0.40167, -0.12416], - [0.08438, 0.20755, 0.02044, 0.10186, -0.08353, -0.25673], - [0.06173, 0.24971, 0.00638, 0.13258, 0.06368, -0.09722], - ], - [ - [0.02993, 0.18217, 0.00005, 0.35562, -0.09828, -0.17992], - [0.09557, 0.16621, -0.02360, 0.28457, -0.29604, -0.21862], - [0.06623, 0.26088, -0.03991, 0.27286, 0.05034, -0.08039], - [0.08877, 0.24112, -0.05770, 0.16840, -0.00154, -0.06161], - ], - ]); + let expected_output_with_init_state = Data::from([[ + [0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798], + [0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742], + [0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012], + [-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872], + ]]); + let expected_output_without_init_state = Data::from([[ + [0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863], + [-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142], + [0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846], + [-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550], + ]]); let expected_hn_with_init_state = Data::from([ - [[0.06240, 0.25245, 0.00132], [0.07891, 0.24718, -0.04706]], - [[0.14330, -0.01617, -0.32437], [0.34088, -0.09401, -0.16047]], + [[-0.03420, 0.07774, -0.09774]], + [[-0.15635, -0.03366, -0.05798]], ]); let expected_cn_with_init_state = Data::from([ - [[0.27726, 0.43163, 0.00460], [0.40963, 0.47434, -0.15836]], - [[0.28537, -0.05057, -0.68145], [0.67802, -0.19816, -0.55872]], + [[-0.13593, 0.17125, -0.22395]], + [[-0.45425, -0.11206, -0.12908]], ]); let expected_hn_without_init_state = Data::from([ - [[0.06173, 0.24971, 0.00638], [0.08877, 0.24112, -0.05770]], - [[0.18021, -0.02266, -0.35150], [0.35562, -0.09828, -0.17992]], + [[-0.04026, 0.07178, -0.10189]], + [[-0.15969, -0.05322, -0.08863]], ]); let expected_cn_without_init_state = Data::from([ - [[0.27319, 0.42555, 0.02218], [0.47942, 0.46064, -0.19706]], - [[0.36375, -0.07220, -0.76521], [0.71734, -0.20792, -0.66048]], + [[-0.15839, 0.15923, -0.23569]], + [[-0.47407, -0.17493, -0.19643]], ]); let (output_with_init_state, state_with_init_state) = From 56f6d7cfbf32679f6d160afda43fd335b833215f Mon Sep 17 00:00:00 2001 From: wcshds Date: Sat, 27 Apr 2024 01:51:56 +0800 Subject: [PATCH 17/17] modify docstring; delete dead comment --- crates/burn-core/src/nn/rnn/lstm.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 1ba9f53879..82025c6364 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -78,8 +78,7 @@ impl LstmConfig { impl Lstm { /// Applies the forward pass on the input tensor. This LSTM implementation - /// returns hidden state for each element in a sequence (i.e., across `seq_length`) and a final state, - /// producing 3-dimensional tensors where the dimensions represent `[batch_size, sequence_length, hidden_size]`. + /// returns the state for each element in a sequence (i.e., across seq_length) and a final state. /// /// ## Parameters: /// - batched_input: The input tensor of shape `[batch_size, sequence_length, input_size]`. @@ -213,11 +212,9 @@ impl BiLstmConfig { impl BiLstm { /// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation - /// returns hidden state for each element in a sequence (i.e., across `seq_length`) and a final state, - /// producing 3-dimensional tensors where the dimensions represent `[batch_size, sequence_length, hidden_size * 2]`. + /// returns the state for each element in a sequence (i.e., across seq_length) and a final state. /// /// ## Parameters: - /// /// - batched_input: The input tensor of shape `[batch_size, sequence_length, input_size]`. /// - state: An optional `LstmState` representing the initial cell state and hidden state. /// Each state tensor has shape `[2, batch_size, hidden_size]`. @@ -407,7 +404,6 @@ mod tests { // single timestep with single feature let input = Tensor::::from_data(Data::from([[[0.1]]]), &device); - // let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None); let (output, state) = lstm.forward(input, None); state .cell