-
-
Notifications
You must be signed in to change notification settings - Fork 599
/
Copy pathjax_bdf_solver.py
1026 lines (847 loc) · 36.4 KB
/
jax_bdf_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import collections
import operator as op
from functools import partial
import numpy as onp
import pybamm
if pybamm.have_jax():
import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax import linear_util as lu
from jax.api_util import flatten_fun_nokwargs
from jax.config import config
from jax.flatten_util import ravel_pytree
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax.util import cache, safe_map, split_list
config.update("jax_enable_x64", True)
MAX_ORDER = 5
NEWTON_MAXITER = 4
ROOT_SOLVE_MAXITER = 15
MIN_FACTOR = 0.2
MAX_FACTOR = 10
# https://github.com/google/jax/issues/4572#issuecomment-709809897
def some_hash_function(x):
return hash(str(x))
class HashableArrayWrapper:
"""wrapper for a numpy array to make it hashable"""
def __init__(self, val):
self.val = val
def __hash__(self):
return some_hash_function(self.val)
def __eq__(self, other):
return isinstance(other, HashableArrayWrapper) and onp.all(
onp.equal(self.val, other.val)
)
def gnool_jit(fun, static_array_argnums=(), static_argnums=()):
"""redefinition of jax jit to allow static array args"""
@partial(jax.jit, static_argnums=static_array_argnums)
def callee(*args):
args = list(args)
for i in static_array_argnums:
args[i] = args[i].val
return fun(*args)
def caller(*args):
args = list(args)
for i in static_array_argnums:
args[i] = HashableArrayWrapper(args[i])
return callee(*args)
return caller
@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3))
def _bdf_odeint(fun, mass, rtol, atol, y0, t_eval, *args):
"""
Implements a Backward Difference formula (BDF) implicit multistep integrator.
The basic algorithm is derived in :footcite:t:`byrne1975polyalgorithm`. This
particular implementation follows that implemented in the Matlab routine ode15s
described in :footcite:t:`shamphine1997matlab` and the SciPy implementation
:footcite:t:`Virtanen2020`, which features the NDF formulas for improved
stability with associated differences in the error constants, and calculates
the jacobian at J(t_{n+1}, y^0_{n+1}). This implementation was based on that
implemented in the SciPy library :footcite:t:`Virtanen2020`, which also mainly
follows :footcite:t:`shamphine1997matlab` but uses the more standard Jacobian
update.
Parameters
----------
func: callable
function to evaluate the time derivative of the solution `y` at time
`t` as `func(y, t, *args)`, producing the same shape/structure as `y0`.
mass: ndarray
diagonal of the mass matrix with shape (n,)
y0: ndarray
initial state vector, has shape (n,)
t_eval: ndarray
time points to evaluate the solution, has shape (m,)
args: (optional)
tuple of additional arguments for `fun`, which must be arrays
scalars, or (nested) standard Python containers (tuples, lists, dicts,
namedtuples, i.e. pytrees) of those types.
rtol: (optional) float
relative tolerance for the solver
atol: (optional) float
absolute tolerance for the solver
Returns
-------
y: ndarray with shape (n, m)
calculated state vector at each of the m time points
"""
def fun_bind_inputs(y, t):
return fun(y, t, *args)
jac_bind_inputs = jax.jacfwd(fun_bind_inputs, argnums=0)
t0 = t_eval[0]
h0 = t_eval[1] - t0
stepper = _bdf_init(
fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol
)
i = 0
y_out = jnp.empty((len(t_eval), len(y0)), dtype=y0.dtype)
init_state = [stepper, t_eval, i, y_out]
def cond_fun(state):
_, t_eval, i, _ = state
return i < len(t_eval)
def body_fun(state):
stepper, t_eval, i, y_out = state
stepper = _bdf_step(stepper, fun_bind_inputs, jac_bind_inputs)
index = jnp.searchsorted(t_eval, stepper.t)
index = index.astype(
"int" + t_eval.dtype.name[-2:]
) # Coerce index to correct type
def for_body(j, y_out):
t = t_eval[j]
y_out = y_out.at[jnp.index_exp[j, :]].set(_bdf_interpolate(stepper, t))
return y_out
y_out = jax.lax.fori_loop(i, index, for_body, y_out)
return [stepper, t_eval, index, y_out]
stepper, t_eval, i, y_out = jax.lax.while_loop(cond_fun, body_fun, init_state)
return y_out
BDFInternalStates = [
"t",
"atol",
"rtol",
"M",
"newton_tol",
"order",
"h",
"n_equal_steps",
"D",
"y0",
"scale_y0",
"kappa",
"gamma",
"alpha",
"c",
"error_const",
"J",
"LU",
"U",
"psi",
"n_function_evals",
"n_jacobian_evals",
"n_lu_decompositions",
"n_steps",
"consistent_y0_failed",
]
BDFState = collections.namedtuple("BDFState", BDFInternalStates)
jax.tree_util.register_pytree_node(
BDFState, lambda xs: (tuple(xs), None), lambda _, xs: BDFState(*xs)
)
def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
"""
Initiation routine for Backward Difference formula (BDF) implicit multistep
integrator.
See _bdf_odeint function above for details, this function returns a dict with
the initial state of the solver
Parameters
----------
fun: callable
function with signature (y, t), where t is a scalar time and y is a ndarray
with shape (n,), returns the rhs of the system of ODE equations as an nd
array with shape (n,)
jac: callable
function with signature (y, t), where t is a scalar time and y is a ndarray
with shape (n,), returns the jacobian matrix of fun as an ndarray with
shape (n,n)
mass: ndarray
diagonal of the mass matrix with shape (n,)
t0: float
initial time
y0: ndarray
initial state vector with shape (n,)
h0: float
initial step size
rtol: (optional) float
relative tolerance for the solver
atol: (optional) float
absolute tolerance for the solver
"""
state = {}
state["t"] = t0
state["atol"] = atol
state["rtol"] = rtol
state["M"] = mass
EPS = jnp.finfo(y0.dtype).eps
state["newton_tol"] = jnp.maximum(
10 * EPS / rtol, jnp.minimum(0.03, rtol**0.5)
)
scale_y0 = atol + rtol * jnp.abs(y0)
y0, not_converged = _select_initial_conditions(
fun, mass, t0, y0, state["newton_tol"], scale_y0
)
state["consistent_y0_failed"] = not_converged
f0 = fun(y0, t0)
order = 1
state["order"] = order
state["h"] = _select_initial_step(atol, rtol, fun, t0, y0, f0, h0)
state["n_equal_steps"] = 0
D = jnp.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype)
D = D.at[jnp.index_exp[0, :]].set(y0)
D = D.at[jnp.index_exp[1, :]].set(f0 * state["h"])
state["D"] = D
state["y0"] = y0
state["scale_y0"] = scale_y0
# kappa values for difference orders, taken from Table 1 of [1]
kappa = jnp.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0])
gamma = jnp.hstack((0, jnp.cumsum(1 / jnp.arange(1, MAX_ORDER + 1))))
alpha = 1.0 / ((1 - kappa) * gamma)
c = state["h"] * alpha[order]
error_const = kappa * gamma + 1 / jnp.arange(1, MAX_ORDER + 2)
state["kappa"] = kappa
state["gamma"] = gamma
state["alpha"] = alpha
state["c"] = c
state["error_const"] = error_const
J = jac(y0, t0)
state["J"] = J
state["LU"] = jax.scipy.linalg.lu_factor(state["M"] - c * J)
state["U"] = _compute_R(order, 1)
state["psi"] = None
state["n_function_evals"] = 2
state["n_jacobian_evals"] = 1
state["n_lu_decompositions"] = 1
state["n_steps"] = 0
tuple_state = BDFState(*[state[k] for k in BDFInternalStates])
y0, scale_y0 = _predict(tuple_state, D)
psi = _update_psi(tuple_state, D)
return tuple_state._replace(y0=y0, scale_y0=scale_y0, psi=psi)
def _compute_R(order, factor):
"""
computes the R matrix with entries
given by the first equation on page 8 of [1]
This is used to update the differences matrix when step size h is varied
according to factor = h_{n+1} / h_n
Note that the U matrix also defined in the same section can be also be
found using factor = 1, which corresponds to R with a constant step size
"""
I = jnp.arange(1, MAX_ORDER + 1).reshape(-1, 1)
J = jnp.arange(1, MAX_ORDER + 1)
M = jnp.empty((MAX_ORDER + 1, MAX_ORDER + 1))
M = M.at[jnp.index_exp[1:, 1:]].set((I - 1 - factor * J) / I)
M = M.at[jnp.index_exp[0]].set(1)
R = jnp.cumprod(M, axis=0)
return R
def _select_initial_conditions(fun, M, t0, y0, tol, scale_y0):
# identify algebraic variables as zeros on diagonal
algebraic_variables = onp.diag(M) == 0.0
# if all differentiable variables then return y0 (can use normal python if
# since M is static)
if not onp.any(algebraic_variables):
return y0, False
# calculate consistent initial conditions via a newton on -J_a @ delta = f_a
# This follows this reference:
#
# Shampine, L. F., Reichelt, M. W., & Kierzenka, J. A. (1999).
# Solving index-1 DAEs in MATLAB and Simulink. SIAM review, 41(3), 538-552.
# calculate fun_a, function of algebraic variables
def fun_a(y_a):
y_full = y0.at[algebraic_variables].set(y_a)
return fun(y_full, t0)[algebraic_variables]
y0_a = y0[algebraic_variables]
scale_y0_a = scale_y0[algebraic_variables]
d = jnp.zeros(y0_a.shape[0], dtype=y0.dtype)
y_a = jnp.array(y0_a, copy=True)
# calculate neg jacobian of fun_a
J_a = jax.jacfwd(fun_a)(y_a)
LU = jax.scipy.linalg.lu_factor(-J_a)
converged = False
dy_norm_old = -1.0
k = 0
while_state = [k, converged, dy_norm_old, d, y_a]
def while_cond(while_state):
k, converged, _, _, _ = while_state
return (converged == False) * (k < ROOT_SOLVE_MAXITER) # noqa: E712
def while_body(while_state):
k, converged, dy_norm_old, d, y_a = while_state
f_eval = fun_a(y_a)
dy = jax.scipy.linalg.lu_solve(LU, f_eval)
dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0_a) ** 2))
rate = dy_norm / dy_norm_old
d += dy
y_a = y0_a + d
# if converged then break out of iteration early
pred = dy_norm_old >= 0.0
pred *= rate / (1 - rate) * dy_norm < tol
converged = (dy_norm == 0.0) + pred
dy_norm_old = dy_norm
return [k + 1, converged, dy_norm_old, d, y_a]
k, converged, dy_norm_old, d, y_a = jax.lax.while_loop(
while_cond, while_body, while_state
)
y_tilde = y0.at[algebraic_variables].set(y_a)
return y_tilde, converged
def _select_initial_step(atol, rtol, fun, t0, y0, f0, h0):
"""
Select a good initial step by stepping forward one step of forward euler, and
comparing the predicted state against that using the provided function.
Optimal step size based on the selected order is obtained using formula (4.12)
in :footcite:t:`hairer1993solving`.
"""
scale = atol + jnp.abs(y0) * rtol
y1 = y0 + h0 * f0
f1 = fun(y1, t0 + h0)
d2 = jnp.sqrt(jnp.mean(((f1 - f0) / scale) ** 2))
order = 1
h1 = h0 * d2 ** (-1 / (order + 1))
return jnp.minimum(100 * h0, h1)
def _predict(state, D):
"""
predict forward to new step (eq 2 in [1])
"""
n = len(state.y0)
order = state.order
orders = jnp.repeat(jnp.arange(MAX_ORDER + 1).reshape(-1, 1), n, axis=1)
subD = jnp.where(orders <= order, D, 0)
y0 = jnp.sum(subD, axis=0)
scale_y0 = state.atol + state.rtol * jnp.abs(state.y0)
return y0, scale_y0
def _update_psi(state, D):
"""
update psi term as defined in second equation on page 9 of [1]
"""
order = state.order
n = len(state.y0)
orders = jnp.arange(MAX_ORDER + 1)
subGamma = jnp.where(orders > 0, jnp.where(orders <= order, state.gamma, 0), 0)
orders = jnp.repeat(orders.reshape(-1, 1), n, axis=1)
subD = jnp.where(orders > 0, jnp.where(orders <= order, D, 0), 0)
psi = jnp.dot(subD.T, subGamma) * state.alpha[order]
return psi
def _update_difference_for_next_step(state, d):
"""
update of difference equations can be done efficiently
by reusing d and D.
From first equation on page 4 of [1]:
d = y_n - y^0_n = D^{k + 1} y_n
Standard backwards difference gives
D^{j + 1} y_n = D^{j} y_n - D^{j} y_{n - 1}
Combining these gives the following algorithm
"""
order = state.order
D = state.D
D = D.at[jnp.index_exp[order + 2]].set(d - D[order + 1])
D = D.at[jnp.index_exp[order + 1]].set(d)
i = order
while_state = [i, D]
def while_cond(while_state):
i, _ = while_state
return i >= 0
def while_body(while_state):
i, D = while_state
D = D.at[jnp.index_exp[i]].add(D[i + 1])
i -= 1
return [i, D]
i, D = jax.lax.while_loop(while_cond, while_body, while_state)
return D
def _update_step_size_and_lu(state, factor):
state = _update_step_size(state, factor)
# redo lu (c has changed)
LU = jax.scipy.linalg.lu_factor(state.M - state.c * state.J)
n_lu_decompositions = state.n_lu_decompositions + 1
return state._replace(LU=LU, n_lu_decompositions=n_lu_decompositions)
def _update_step_size(state, factor):
"""
If step size h is changed then also need to update the terms in
the first equation of page 9 of [1]:
- constant c = h / (1-kappa) gamma_k term
- lu factorisation of (M - c * J) used in newton iteration (same equation)
- psi term
"""
order = state.order
h = state.h * factor
n_equal_steps = 0
c = h * state.alpha[order]
# update D using equations in section 3.2 of [1]
RU = _compute_R(order, factor).dot(state.U)
I = jnp.arange(0, MAX_ORDER + 1).reshape(-1, 1)
J = jnp.arange(0, MAX_ORDER + 1)
# only update order+1, order+1 entries of D
RU = jnp.where(
jnp.logical_and(I <= order, J <= order), RU, jnp.identity(MAX_ORDER + 1)
)
D = state.D
D = jnp.dot(RU.T, D)
# D = jax.ops.index_update(D, jax.ops.index[:order + 1],
# jnp.dot(RU.T, D[:order + 1]))
# update psi (D has changed)
psi = _update_psi(state, D)
# update y0 (D has changed)
y0, scale_y0 = _predict(state, D)
return state._replace(
n_equal_steps=n_equal_steps,
h=h,
c=c,
D=D,
psi=psi,
y0=y0,
scale_y0=scale_y0,
)
def _update_jacobian(state, jac):
"""
we update the jacobian using J(t_{n+1}, y^0_{n+1})
following the scipy bdf implementation rather than J(t_n, y_n) as per [1]
"""
J = jac(state.y0, state.t + state.h)
n_jacobian_evals = state.n_jacobian_evals + 1
LU = jax.scipy.linalg.lu_factor(state.M - state.c * J)
n_lu_decompositions = state.n_lu_decompositions + 1
return state._replace(
J=J,
n_jacobian_evals=n_jacobian_evals,
LU=LU,
n_lu_decompositions=n_lu_decompositions,
)
def _newton_iteration(state, fun):
tol = state.newton_tol
c = state.c
psi = state.psi
y0 = state.y0
LU = state.LU
M = state.M
scale_y0 = state.scale_y0
t = state.t + state.h
d = jnp.zeros(y0.shape, dtype=y0.dtype)
y = jnp.array(y0, copy=True)
n_function_evals = state.n_function_evals
converged = False
dy_norm_old = -1.0
k = 0
while_state = [k, converged, dy_norm_old, d, y, n_function_evals]
def while_cond(while_state):
k, converged, _, _, _, _ = while_state
return (converged == False) * (k < NEWTON_MAXITER) # noqa: E712
def while_body(while_state):
k, converged, dy_norm_old, d, y, n_function_evals = while_state
f_eval = fun(y, t)
n_function_evals += 1
b = c * f_eval - M @ (psi + d)
dy = jax.scipy.linalg.lu_solve(LU, b)
dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0) ** 2))
rate = dy_norm / dy_norm_old
# if iteration is not going to converge in NEWTON_MAXITER
# (assuming the current rate), then abort
pred = rate >= 1
pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol
pred *= dy_norm_old >= 0
k += pred * (NEWTON_MAXITER - k - 1)
d += dy
y = y0 + d
# if converged then break out of iteration early
pred = dy_norm_old >= 0.0
pred *= rate / (1 - rate) * dy_norm < tol
converged = (dy_norm == 0.0) + pred
dy_norm_old = dy_norm
return [k + 1, converged, dy_norm_old, d, y, n_function_evals]
k, converged, dy_norm_old, d, y, n_function_evals = jax.lax.while_loop(
while_cond, while_body, while_state
)
return converged, k, y, d, state._replace(n_function_evals=n_function_evals)
def rms_norm(arg):
return jnp.sqrt(jnp.mean(arg**2))
def _prepare_next_step(state, d):
D = _update_difference_for_next_step(state, d)
psi = _update_psi(state, D)
y0, scale_y0 = _predict(state, D)
return state._replace(D=D, psi=psi, y0=y0, scale_y0=scale_y0)
def _prepare_next_step_order_change(state, d, y, n_iter):
order = state.order
D = _update_difference_for_next_step(state, d)
# Note: we are recalculating these from the while loop above, could re-use?
scale_y = state.atol + state.rtol * jnp.abs(y)
error = state.error_const[order] * d
error_norm = rms_norm(error / scale_y)
safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter)
# similar to the optimal step size factor we calculated above for the current
# order k, we need to calculate the optimal step size factors for orders
# k-1 and k+1. To do this, we note that the error = C_k * D^{k+1} y_n
error_m_norm = jnp.where(
order > 1,
rms_norm(state.error_const[order - 1] * D[order] / scale_y),
jnp.inf,
)
error_p_norm = jnp.where(
order < MAX_ORDER,
rms_norm(state.error_const[order + 1] * D[order + 2] / scale_y),
jnp.inf,
)
error_norms = jnp.array([error_m_norm, error_norm, error_p_norm])
factors = error_norms ** (-1 / (jnp.arange(3) + order))
# now we have the three factors for orders k-1, k and k+1, pick the maximum in
# order to maximise the resultant step size
max_index = jnp.argmax(factors)
order += max_index - 1
factor = jnp.minimum(MAX_FACTOR, safety * factors[max_index])
new_state = _update_step_size_and_lu(state._replace(D=D, order=order), factor)
return new_state
def _bdf_step(state, fun, jac):
# print('bdf_step', state.t, state.h)
# we will try and use the old jacobian unless convergence of newton iteration
# fails
updated_jacobian = False
# initialise step size and try to make the step,
# iterate, reducing step size until error is in bounds
step_accepted = False
y = jnp.empty_like(state.y0)
d = jnp.empty_like(state.y0)
n_iter = -1
# loop until step is accepted
while_state = [state, step_accepted, updated_jacobian, y, d, n_iter]
def while_cond(while_state):
_, step_accepted, _, _, _, _ = while_state
return step_accepted == False # noqa: E712
def while_body(while_state):
state, step_accepted, updated_jacobian, y, d, n_iter = while_state
# solve BDF equation using y0 as starting point
converged, n_iter, y, d, state = _newton_iteration(state, fun)
not_converged = converged == False # noqa: E712
# newton iteration did not converge, but jacobian has already been
# evaluated so reduce step size by 0.3 (as per [1]) and try again
state = tree_map(
partial(jnp.where, not_converged * updated_jacobian),
_update_step_size_and_lu(state, 0.3),
state,
)
# if not_converged * updated_jacobian:
# print('not converged, update step size by 0.3')
# if not_converged * (updated_jacobian == False):
# print('not converged, update jacobian')
# if not converged and jacobian not updated, then update the jacobian and
# try again
(state, updated_jacobian) = tree_map(
partial(
jnp.where, not_converged * (updated_jacobian == False) # noqa: E712
),
(_update_jacobian(state, jac), True),
(state, False + updated_jacobian),
)
safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter)
scale_y = state.atol + state.rtol * jnp.abs(y)
# combine eq 3, 4 and 6 from [1] to obtain error
# Note that error = C_k * h^{k+1} y^{k+1}
# and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1}
error = state.error_const[state.order] * d
error_norm = rms_norm(error / scale_y)
# calculate optimal step size factor as per eq 2.46 of [2]
factor = jnp.maximum(
MIN_FACTOR, safety * error_norm ** (-1 / (state.order + 1))
)
# if converged * (error_norm > 1):
# print(
# "converged, but error is too large",
# error_norm,
# factor,
# d,
# scale_y,
# )
(state, step_accepted) = tree_map(
partial(jnp.where, converged * (error_norm > 1)), # noqa: E712
(_update_step_size_and_lu(state, factor), False),
(state, converged),
)
return [state, step_accepted, updated_jacobian, y, d, n_iter]
state, step_accepted, updated_jacobian, y, d, n_iter = jax.lax.while_loop(
while_cond, while_body, while_state
)
# take the accepted step
n_steps = state.n_steps + 1
t = state.t + state.h
# a change in order is only done after running at order k for k + 1 steps
# (see page 83 of [2])
n_equal_steps = state.n_equal_steps + 1
state = state._replace(n_equal_steps=n_equal_steps, t=t, n_steps=n_steps)
state = tree_map(
partial(jnp.where, n_equal_steps < state.order + 1),
_prepare_next_step(state, d),
_prepare_next_step_order_change(state, d, y, n_iter),
)
return state
def _bdf_interpolate(state, t_eval):
"""
interpolate solution at time values t* where t-h < t* < t
definition of the interpolating polynomial can be found on page 7 of [1]
"""
order = state.order
t = state.t
h = state.h
D = state.D
j = 0
time_factor = 1.0
order_summation = D[0]
while_state = [j, time_factor, order_summation]
def while_cond(while_state):
j, _, _ = while_state
return j < order
def while_body(while_state):
j, time_factor, order_summation = while_state
time_factor *= (t_eval - (t - h * j)) / (h * (1 + j))
order_summation += D[j + 1] * time_factor
j += 1
return [j, time_factor, order_summation]
j, time_factor, order_summation = jax.lax.while_loop(
while_cond, while_body, while_state
)
return order_summation
def block_diag(lst):
def block_fun(i, j, Ai, Aj):
if i == j:
return Ai
else:
return onp.zeros(
(
Ai.shape[0] if Ai.ndim > 1 else 1,
Aj.shape[1] if Aj.ndim > 1 else 1,
),
dtype=Ai.dtype,
)
blocks = [
[block_fun(i, j, Ai, Aj) for j, Aj in enumerate(lst)]
for i, Ai in enumerate(lst)
]
return onp.block(blocks)
# NOTE: the code below (except the docstring on jax_bdf_integrate and other minor
# edits), has been modified from the JAX library at https://github.com/google/jax.
# The main difference is the addition of support for semi-explicit dae index 1
# problems via the addition of a mass matrix.
# This is under an Apache license, a short form of which is given here:
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use
# this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover
"""
for debugging purposes, use this instead of jax.lax.while_loop
"""
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover
"""
for debugging purposes, use this instead of jax.lax.fori_loop
"""
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val
def flax_scan(f, init, xs, length=None): # pragma: no cover
"""
for debugging purposes, use this instead of jax.lax.scan
"""
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, onp.stack(ys)
@partial(gnool_jit, static_array_argnums=(0, 1, 2, 3))
def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args):
y0, unravel = ravel_pytree(y0)
func = ravel_first_arg(func, unravel)
out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args)
return jax.vmap(unravel)(out)
def _bdf_odeint_fwd(func, mass, rtol, atol, y0, ts, *args):
ys = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args)
return ys, (ys, ts, args)
def _bdf_odeint_rev(func, mass, rtol, atol, res, g):
ys, ts, args = res
def aug_dynamics(augmented_state, t, *args):
"""Original system augmented with vjp_y, vjp_t and vjp_args."""
y, y_bar, *_ = augmented_state
# `t` here is negative time, so we need to negate again to get back to
# normal time. See the `odeint` invocation in `scan_fun` below.
y_dot, vjpfun = jax.vjp(func, y, -t, *args)
# Adjoint equations for semi-explicit dae index 1 system from
#
# [1] Cao, Y., Li, S., Petzold, L., & Serban, R. (2003). Adjoint sensitivity
# analysis for differential-algebraic equations: The adjoint DAE system and
# its numerical solution.
# SIAM journal on scientific computing, 24(3), 1076-1089.
#
# y_bar_dot_d = -J_dd^T y_bar_d - J_ad^T y_bar_a
# 0 = J_da^T y_bar_d + J_aa^T y_bar_d
y_bar_dot, *rest = vjpfun(y_bar)
return (-y_dot, y_bar_dot, *rest)
algebraic_variables = onp.diag(mass) == 0.0
differentiable_variables = algebraic_variables == False # noqa: E712
mass_is_I = onp.array_equal(mass, onp.eye(mass.shape[0]))
is_dae = onp.any(algebraic_variables)
if not mass_is_I:
M_dd = mass[onp.ix_(differentiable_variables, differentiable_variables)]
LU_invM_dd = jax.scipy.linalg.lu_factor(M_dd)
def initialise(g0, y0, t0):
# [1] gives init conditions for y_bar_a = g_d - J_ad^T (J_aa^T)^-1 g_a
if mass_is_I:
y_bar = g0
elif is_dae:
J = jax.jacfwd(func)(y0, t0, *args)
# boolean arguments not implemented in jnp.ix_
J_aa = J[onp.ix_(algebraic_variables, algebraic_variables)]
J_ad = J[onp.ix_(algebraic_variables, differentiable_variables)]
LU = jax.scipy.linalg.lu_factor(J_aa)
g0_a = g0[algebraic_variables]
invJ_aa = jax.scipy.linalg.lu_solve(LU, g0_a)
y_bar = g0.at[differentiable_variables].set(
jax.scipy.linalg.lu_solve(LU_invM_dd, g0_a - J_ad @ invJ_aa)
)
else:
y_bar = jax.scipy.linalg.lu_solve(LU_invM_dd, g0)
return y_bar
y_bar = initialise(g[-1], ys[-1], ts[-1])
ts_bar = []
t0_bar = 0.0
def arg_to_identity(arg):
return onp.identity(arg.shape[0] if arg.ndim > 0 else 1, dtype=arg.dtype)
def arg_dicts_to_values(args):
"""
Note:JAX puts in empty arrays into args for some reason, we remove them here
"""
return sum((tuple(b.values()) for b in args if isinstance(b, dict)), ())
aug_mass = (mass, mass, onp.array(1.0)) + arg_dicts_to_values(
tree_map(arg_to_identity, args)
)
def scan_fun(carry, i):
y_bar, t0_bar, args_bar = carry
# Compute effect of moving measurement time
t_bar = jnp.dot(func(ys[i], ts[i], *args), g[i])
t0_bar = t0_bar - t_bar
# Run augmented system backwards to previous observation
_, y_bar, t0_bar, args_bar = jax_bdf_integrate(
aug_dynamics,
(ys[i], y_bar, t0_bar, args_bar),
jnp.array([-ts[i], -ts[i - 1]]),
*args,
mass=aug_mass,
rtol=rtol,
atol=atol,
)
y_bar, t0_bar, args_bar = tree_map(
op.itemgetter(1), (y_bar, t0_bar, args_bar)
)
# Add gradient from current output
y_bar = y_bar + initialise(g[i - 1], ys[i - 1], ts[i - 1])
return (y_bar, t0_bar, args_bar), t_bar
init_carry = (y_bar, t0_bar, tree_map(jnp.zeros_like, args))
(y_bar, t0_bar, args_bar), rev_ts_bar = jax.lax.scan(
scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1)
)
ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]])
return (y_bar, ts_bar, *args_bar)
_bdf_odeint.defvjp(_bdf_odeint_fwd, _bdf_odeint_rev)
@cache()
def closure_convert(fun, in_tree, in_avals):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
out_tree = out_tree()
# We only want to closure convert for constants with respect to which we're
# differentiating. As a proxy for that, we hoist consts with float dtype.
# TODO(mattjj): revise this approach
def is_float(c):
return dtypes.issubdtype(type(c), jnp.inexact)
(closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
num_consts = len(hoisted_consts)
def converted_fun(y, t, *hconsts_args):
hoisted_consts, args = split_list(hconsts_args, [num_consts])
consts = merge(closure_consts, hoisted_consts)
all_args, _ = tree_flatten((y, t, *args))
out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
return tree_unflatten(out_tree, out_flat)
return converted_fun, hoisted_consts
def partition_list(choice, lst):
out = [], []
which = [out[choice(elt)].append(elt) or choice(elt) for elt in lst]
def merge(l1, l2):
i1, i2 = iter(l1), iter(l2)
return [next(i2 if snd else i1) for snd in which]
return out, merge
def abstractify(x):
return core.raise_to_shaped(core.get_aval(x))
def ravel_first_arg(f, unravel):
return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
@lu.transformation
def ravel_first_arg_(unravel, y_flat, *args):
y = unravel(y_flat)
ans = yield (y,) + args, {}
ans_flat, _ = ravel_pytree(ans)
yield ans_flat
def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None):
"""
Backward Difference formula (BDF) implicit multistep integrator. The basic algorithm
is derived in :footcite:t:`byrne1975polyalgorithm`. This particular implementation
follows that implemented in the Matlab routine ode15s described in
:footcite:t:`shampine1997matlab` and the SciPy implementation
:footcite:t:`Virtanen2020` which features the NDF formulas for improved stability,
with associated differences in the error constants, and calculates the jacobian at
J(t_{n+1}, y^0_{n+1}). This implementation was based on that implemented in the
SciPy library :footcite:t:`Virtanen2020`, which also mainly follows
:footcite:t:`shampine1997matlab` but uses the more standard jacobian update.
Parameters
----------
func: callable
function to evaluate the time derivative of the solution `y` at time
`t` as `func(y, t, *args)`, producing the same shape/structure as `y0`.
y0: ndarray
initial state vector
t_eval: ndarray
time points to evaluate the solution, has shape (m,)
args: (optional)
tuple of additional arguments for `fun`, which must be arrays
scalars, or (nested) standard Python containers (tuples, lists, dicts,
namedtuples, i.e. pytrees) of those types.
rtol: (optional) float
relative tolerance for the solver
atol: (optional) float
absolute tolerance for the solver
mass: (optional) ndarray
diagonal of the mass matrix with shape (n,)
Returns