diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index 079307407a..8161d3ea3a 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -165,7 +165,7 @@ def _bdf_init(fun, jac, t0, y0, h0, rtol, atol): state['n_equal_steps'] = 0 D = np.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype) D = jax.ops.index_update(D, jax.ops.index[0, :], y0) - D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * h0) + D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * state['h']) state['D'] = D state['y0'] = None state['scale_y0'] = None @@ -177,7 +177,7 @@ def _bdf_init(fun, jac, t0, y0, h0, rtol, atol): kappa = np.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0]) gamma = np.hstack((0, np.cumsum(1 / np.arange(1, MAX_ORDER + 1)))) alpha = 1.0 / ((1 - kappa) * gamma) - c = h0 * alpha[order] + c = state['h'] * alpha[order] error_const = kappa * gamma + 1 / np.arange(1, MAX_ORDER + 2) state['kappa'] = kappa