Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

After using slice_assign, gradient descent cannot track the model parameters #1098

Closed
wcshds opened this issue Dec 25, 2023 · 6 comments · Fixed by #1146
Closed

After using slice_assign, gradient descent cannot track the model parameters #1098

wcshds opened this issue Dec 25, 2023 · 6 comments · Fixed by #1146
Assignees
Labels
bug Something isn't working

Comments

@wcshds
Copy link
Contributor

wcshds commented Dec 25, 2023

I found that after using slice_assign in the loss function, gradient descent cannot track the model parameters. I believe this is the main reason why the loss becomes NaN after the first iteration when I apply my implementation of CTC loss to the CRNN model.

pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 1> {
    let device = input.device();
    let [d1, d2, d3] = input.dims();

    let input2 = Tensor::empty(input.shape(), &device);
    input2
        .clone()
        .slice_assign([0..d1, 0..d2, 0..d3], input.clone());

    input2.mean()
}
@antimora antimora added the bug Something isn't working label Dec 26, 2023
@wcshds
Copy link
Contributor Author

wcshds commented Dec 27, 2023

The actual reason for my implementation of CTC loss becoming NaN after iterations is that the logarithm of zero is taken, not due to slice_assign. Now, I am not certain if there is a bug in slice_assign... perhaps more investigation is needed.

@wcshds
Copy link
Contributor Author

wcshds commented Jan 10, 2024

@nathanielsimard I'm trying to train a CRNN model from scratch, but after a day of training, there's still no sign of convergence in the model. Then I noticed that only the parameters of the last layer were being updated. Here is the minimal reproducible example:

use burn::{
    backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
    module::Module,
    nn::{Linear, LinearConfig, Lstm, LstmConfig},
    optim::{AdamConfig, GradientsParams, Optimizer},
    record::{FullPrecisionSettings, PrettyJsonFileRecorder},
    tensor::{
        backend::{AutodiffBackend, Backend},
        Tensor,
    },
};

fn main() {
    run::<Autodiff<NdArray>>(NdArrayDevice::Cpu);
}

fn run<B: AutodiffBackend>(device: B::Device) {
    let mut model = Model::<B>::new(&device);
    let mut optim = AdamConfig::new().init();
    let pfr = PrettyJsonFileRecorder::<FullPrecisionSettings>::new();

    for iteration in 0..51 {
        let input = Tensor::random(
            [2, 10, 5],
            burn::tensor::Distribution::Uniform(-1.0, 1.0),
            &device,
        );
        let output = model.forward(input);
        let loss = output.mean();

        println!(
            "[Train - Iteration {}] Loss {:.5}",
            iteration,
            loss.clone().into_scalar()
        );

        let grads = loss.backward();
        let grads = GradientsParams::from_grads(grads, &model);

        model = optim.step(0.001, model, grads);

        if iteration % 10 == 0 {
            model
                .clone()
                .lstm
                .save_file(format!("./lstm-{:02}", iteration), &pfr)
                .unwrap();
            model
                .clone()
                .linear
                .save_file(format!("./linear-{:02}", iteration), &pfr)
                .unwrap();
        }
    }
}

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    lstm: Lstm<B>,
    linear: Linear<B>,
}

impl<B: Backend> Model<B> {
    pub fn new(device: &B::Device) -> Self {
        Self {
            lstm: LstmConfig::new(5, 10, true).init(device),
            linear: LinearConfig::new(10, 20).init(device),
        }
    }

    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
        let (_, x) = self.lstm.forward(input, None);
        let [batch_size, seq_length, d_hidden] = x.dims();
        let x = x.reshape([batch_size * seq_length, d_hidden]);
        let x = self.linear.forward(x);

        x
    }
}

After some investigation, I believe this is due to the use of slice_assign in the implementation of LSTM. After replacing it with Tensor::cat instead of Tensor::slice_assign, the parameters of LSTM can be updated correctly.

pub fn forward(
    &self,
    batched_input: Tensor<B, 3>,
    state: Option<(Tensor<B, 2>, Tensor<B, 2>)>,
) -> (Tensor<B, 3>, Tensor<B, 3>) {
    let [batch_size, seq_length, _] = batched_input.shape().dims;
    let device = &batched_input.device();

    let (mut cell_state, mut hidden_state) = match state {
        Some((cell_state, hidden_state)) => (cell_state, hidden_state),
        None => (
            Tensor::zeros([batch_size, self.d_hidden], device),
            Tensor::zeros([batch_size, self.d_hidden], device),
        ),
    };

    let mut batched_cell_state_vec = Vec::with_capacity(seq_length);
    let mut batched_hidden_state_vec = Vec::with_capacity(seq_length);

    for input_t in batched_input.iter_dim(1) {
        let input_t = input_t.squeeze(1);
        // f(orget)g(ate) tensors
        let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate);
        let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state

        // i(nput)g(ate) tensors
        let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate);
        let add_values = activation::sigmoid(biased_ig_input_sum);

        // o(output)g(ate) tensors
        let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate);
        let output_values = activation::sigmoid(biased_og_input_sum);

        // c(ell)g(ate) tensors
        let biased_cg_input_sum = self.gate_product(&input_t, &hidden_state, &self.cell_gate);
        let candidate_cell_values = biased_cg_input_sum.tanh();

        cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values;
        hidden_state = output_values * cell_state.clone().tanh();

        let unsqueezed_shape = [cell_state.shape().dims[0], 1, cell_state.shape().dims[1]];

        let unsqueezed_cell_state = cell_state.clone().reshape(unsqueezed_shape);
        let unsqueezed_hidden_state = hidden_state.clone().reshape(unsqueezed_shape);

        // store the state for this timestep
        batched_cell_state_vec.push(unsqueezed_cell_state);
        batched_hidden_state_vec.push(unsqueezed_hidden_state);
    }

    let batched_cell_state = Tensor::cat(batched_cell_state_vec, 1);
    let batched_hidden_state = Tensor::cat(batched_hidden_state_vec, 1);

    (batched_cell_state, batched_hidden_state)
}

@nathanielsimard
Copy link
Member

Thanks @wcshds for the example with LSTM, it will help us in fixing this.

@nathanielsimard
Copy link
Member

nathanielsimard commented Jan 16, 2024

I found that after using slice_assign in the loss function, gradient descent cannot track the model parameters. I believe this is the main reason why the loss becomes NaN after the first iteration when I apply my implementation of CTC loss to the CRNN model.

pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 1> {
    let device = input.device();
    let [d1, d2, d3] = input.dims();

    let input2 = Tensor::empty(input.shape(), &device);
    input2
        .clone()
        .slice_assign([0..d1, 0..d2, 0..d3], input.clone());

    input2.mean()
}

I found the actual problem in this code. Slice assign doesn't actually mutate any data in input2, it returns a new tensor handle that should be used afterward:

pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 1> {
    let device = input.device();
    let [d1, d2, d3] = input.dims();

    let input2 = Tensor::empty(input.shape(), &device);
    let x = input2.slice_assign([0..d1, 0..d2, 0..d3], input);

    x.mean()
}

There are no mutable operation in the tensor API, every operation returns the result that should be used!

Though it doesn't explain the bug with LSTM.

@wcshds
Copy link
Contributor Author

wcshds commented Jan 16, 2024

There are no mutable operation in the tensor API, every operation returns the result that should be used!

Thank you very much for pointing out the actual problem! I often forget that.

@louisfd
Copy link
Member

louisfd commented Jan 17, 2024

@wcshds I added some tests comparing slice_assign and cat backwards in #1146 but I can't find a bug

@github-project-automation github-project-automation bot moved this from Todo to Done in Burn 🔥 Jan 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

4 participants