Skip to content

Commit

Permalink
#1105 use approximated step size for D and c calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 12, 2020
1 parent b2488ce commit 3559529
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _bdf_init(fun, jac, t0, y0, h0, rtol, atol):
state['n_equal_steps'] = 0
D = np.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype)
D = jax.ops.index_update(D, jax.ops.index[0, :], y0)
D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * h0)
D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * state['h'])
state['D'] = D
state['y0'] = None
state['scale_y0'] = None
Expand All @@ -177,7 +177,7 @@ def _bdf_init(fun, jac, t0, y0, h0, rtol, atol):
kappa = np.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0])
gamma = np.hstack((0, np.cumsum(1 / np.arange(1, MAX_ORDER + 1))))
alpha = 1.0 / ((1 - kappa) * gamma)
c = h0 * alpha[order]
c = state['h'] * alpha[order]
error_const = kappa * gamma + 1 / np.arange(1, MAX_ORDER + 2)

state['kappa'] = kappa
Expand Down

0 comments on commit 3559529

Please sign in to comment.