-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
change field line integration to use diffrax #610
Conversation
The RuntimeWarning for numpy header size being different is still an issue, I think because the warning is coming from C level code (@f0uriest @unalmis any ideas on how to ignore those? seems like a |
…m netCDF4 but cannot only target that module)
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #610 +/- ##
==========================================
- Coverage 95.45% 95.45% -0.01%
==========================================
Files 95 95
Lines 23419 23429 +10
==========================================
+ Hits 22354 22363 +9
- Misses 1065 1066 +1
|
…f it is not benign, Cython would raise a ValueError, not a RuntimeWarning), and add kwrgs to integrate function
…diffrax's part), undo accidental ml_dtypes change, add ignore for benign equinox warning
… interest defined by passed-in bounds
it looks like there are older versions of diffrax that work with older versions of jax. They might not have the discrete terminating event stuff but might still be useful for being able to select different integration schemes etc. |
exp. decrease RHS outside the bounding box (exp(-r)*B) |
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_midres | -3.98 +/- 5.49 | -2.57e-02 +/- 3.55e-02 | 6.21e-01 +/- 3.1e-02 | 6.46e-01 +/- 1.7e-02 |
test_build_transform_fft_highres | -5.12 +/- 3.82 | -5.40e-02 +/- 4.03e-02 | 1.00e+00 +/- 3.8e-02 | 1.06e+00 +/- 1.4e-02 |
test_equilibrium_init_lowres | -6.77 +/- 2.85 | -2.81e-01 +/- 1.19e-01 | 3.87e+00 +/- 8.7e-02 | 4.15e+00 +/- 8.0e-02 |
test_objective_compile_atf | -0.44 +/- 3.12 | -3.49e-02 +/- 2.48e-01 | 7.93e+00 +/- 2.1e-01 | 7.96e+00 +/- 1.3e-01 |
test_objective_compute_atf | -0.53 +/- 2.12 | -5.40e-05 +/- 2.17e-04 | 1.02e-02 +/- 1.1e-04 | 1.02e-02 +/- 1.9e-04 |
test_objective_jac_atf | +0.26 +/- 1.38 | +4.96e-03 +/- 2.65e-02 | 1.92e+00 +/- 1.7e-02 | 1.92e+00 +/- 2.1e-02 |
test_perturb_1 | +2.56 +/- 5.52 | +3.16e-01 +/- 6.80e-01 | 1.26e+01 +/- 5.8e-01 | 1.23e+01 +/- 3.5e-01 |
test_proximal_jac_atf | +0.34 +/- 1.47 | +2.78e-02 +/- 1.19e-01 | 8.12e+00 +/- 6.9e-02 | 8.09e+00 +/- 9.7e-02 |
test_proximal_freeb_compute | +2.33 +/- 1.08 | +4.26e-03 +/- 1.98e-03 | 1.87e-01 +/- 1.5e-03 | 1.83e-01 +/- 1.2e-03 |
test_build_transform_fft_lowres | +0.30 +/- 6.13 | +1.60e-03 +/- 3.26e-02 | 5.33e-01 +/- 2.4e-02 | 5.31e-01 +/- 2.1e-02 |
test_equilibrium_init_medres | +1.40 +/- 5.57 | +5.74e-02 +/- 2.29e-01 | 4.17e+00 +/- 2.3e-01 | 4.11e+00 +/- 3.5e-02 |
test_equilibrium_init_highres | +1.59 +/- 2.35 | +8.64e-02 +/- 1.28e-01 | 5.53e+00 +/- 1.2e-01 | 5.45e+00 +/- 4.5e-02 |
test_objective_compile_dshape_current | +0.40 +/- 1.18 | +1.54e-02 +/- 4.49e-02 | 3.82e+00 +/- 8.8e-03 | 3.81e+00 +/- 4.4e-02 |
test_objective_compute_dshape_current | +0.63 +/- 1.60 | +2.19e-05 +/- 5.52e-05 | 3.48e-03 +/- 3.9e-05 | 3.46e-03 +/- 3.9e-05 |
test_objective_jac_dshape_current | -1.70 +/- 5.33 | -6.86e-04 +/- 2.14e-03 | 3.96e-02 +/- 1.4e-03 | 4.03e-02 +/- 1.6e-03 |
test_perturb_2 | +0.48 +/- 1.96 | +8.30e-02 +/- 3.38e-01 | 1.73e+01 +/- 1.9e-01 | 1.72e+01 +/- 2.8e-01 |
test_proximal_freeb_jac | -0.42 +/- 0.77 | -3.18e-02 +/- 5.76e-02 | 7.48e+00 +/- 3.2e-02 | 7.51e+00 +/- 4.8e-02 |
test_solve_fixed_iter | -0.31 +/- 61.51 | -1.52e-02 +/- 3.05e+00 | 4.95e+00 +/- 2.2e+00 | 4.96e+00 +/- 2.1e+00 | |
@dpanici I know you can't formally approve this but lmk if you have any comments |
r, z = field_line_integrate(r0, z0, phis, field, bounds_Z=(-np.inf, 0.05)) | ||
np.testing.assert_allclose(z[-1], 0.05, atol=3e-3) | ||
phis = [0, 2 * np.pi * 25] | ||
r, z = field_line_integrate(r0, z0, phis, field, solver=Dopri5()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we use another integrator for this test (an implicit one?)? Maybe we should specify some integrator options for the user to choose depending on the case. Or basically reference https://docs.kidger.site/diffrax/api/solvers/ode_solvers/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dopri5 and Tsit5 are both rk45 methods, just with slightly different coefficients. There's not any real reason to prefer one over the other, other than to just test that they both work.
) | ||
return jnp.array( | ||
[r * br / bp * jnp.sign(bp), jnp.sign(bp), r * bz / bp * jnp.sign(bp)] | ||
).squeeze() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
).squeeze() | |
# step along the field line | |
# Torodial component of the field is used to normalize the step size | |
return jnp.array( | |
jnp.sign(bp) * [r * br / bp, 1, r * bz / bp] | |
).squeeze() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not that important if you don't change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not that we're normalizing the step or anything, its just because we're using the toroidal angle as our "time" coordinate so that modifies the ODE slightly.
Updated comment by @f0uriest
field_line_integrate
now usesdiffrax
instead ofjax.experimental.ode.odeint
which has been soft deprecated.test_plot_poincare
which callsfield_line_integrate
under the hood went from ~65s to ~55s with these changes.Resolves #609