From 3559529d6f5aa55c10bc9f20c14b0df4017c280a Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Sun, 12 Jul 2020 09:27:19 +0100 Subject: [PATCH] #1105 use approximated step size for D and c calculation --- pybamm/solvers/jax_bdf_solver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index 079307407a..8161d3ea3a 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -165,7 +165,7 @@ def _bdf_init(fun, jac, t0, y0, h0, rtol, atol): state['n_equal_steps'] = 0 D = np.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype) D = jax.ops.index_update(D, jax.ops.index[0, :], y0) - D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * h0) + D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * state['h']) state['D'] = D state['y0'] = None state['scale_y0'] = None @@ -177,7 +177,7 @@ def _bdf_init(fun, jac, t0, y0, h0, rtol, atol): kappa = np.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0]) gamma = np.hstack((0, np.cumsum(1 / np.arange(1, MAX_ORDER + 1)))) alpha = 1.0 / ((1 - kappa) * gamma) - c = h0 * alpha[order] + c = state['h'] * alpha[order] error_const = kappa * gamma + 1 / np.arange(1, MAX_ORDER + 2) state['kappa'] = kappa