Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounce averaging #854

Merged
merged 355 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
355 commits
Select commit Hold shift + click to select a range
a34764e
Merge branch 'master' into bounce
rahulgaur104 Apr 10, 2024
b3ee158
increasing rtol because quantities that cross zero can have a huge rtol
rahulgaur104 Apr 10, 2024
6f77af3
Merge branch 'bounce' of https://github.com/PlasmaControl/DESC into b…
rahulgaur104 Apr 10, 2024
53142c9
adding pytest.warning
rahulgaur104 Apr 10, 2024
f04523d
override_grid=False fixes coefficient jumps. Should add this to the r…
rahulgaur104 Apr 10, 2024
3515abf
Merge branch 'master' into bounce
rahulgaur104 Apr 10, 2024
530bcbb
Change API for bounce integral to support integrating...
unalmis Apr 11, 2024
f2a6107
Simplify shape broadc in docstring
unalmis Apr 11, 2024
88efe51
Make sure tests_bounce_average_drifts test runs with new API change
unalmis Apr 11, 2024
e8b360a
Fix bug in test where bounce points should be computed from normalize…
unalmis Apr 11, 2024
fc2b59b
Merge branch 'master' into bounce
unalmis Apr 11, 2024
d08b360
Add quadrature test
unalmis Apr 12, 2024
cc20d0b
Add unit change of variables for clari
unalmis Apr 12, 2024
a5b1793
fixing the bounce averaged drifts test
rahulgaur104 Apr 12, 2024
fad449c
Remove test frgrid on meshgrid
unalmis Apr 12, 2024
7d001e0
Use modified tanh sinh quadrature
unalmis Apr 12, 2024
77ddaa4
Remove old root finding bounce point code
unalmis Apr 12, 2024
12ee415
Fix typo in comment
unalmis Apr 12, 2024
fa7a617
Use general automorphism for quadrature and fix change of variable bug
unalmis Apr 13, 2024
990c96a
Switch to sin automorphism to suppress singularity, add lebgauss test
unalmis Apr 13, 2024
ff328c6
Fix typos
unalmis Apr 13, 2024
28b2b19
Improve automorphism test, remove unused code, update requirements
unalmis Apr 14, 2024
5a1bdc4
Fix test failure due to comment change
unalmis Apr 14, 2024
69c5134
Choose better default knots for bounce integral
unalmis Apr 14, 2024
c39a073
Make bounce integration more modular so that custom grid can be used
unalmis Apr 14, 2024
e681c22
Make variables in bounce_integral.py private
unalmis Apr 14, 2024
df41869
Use map_coords in desc_grid_from_field_line_coords now that it uses
unalmis Apr 16, 2024
df5198b
bounce average test modified, scipy's incomplete elliptic integrals s…
rahulgaur104 Apr 18, 2024
85ec3a5
Merge branch 'master' into bounce
rahulgaur104 Apr 18, 2024
b6a0286
Do quadrature instead of using brokescipy.special functions
unalmis Apr 19, 2024
0d74884
small changes to the bounce average test
rahulgaur104 Apr 19, 2024
ae0a6f8
more changes to the bounce average test; most of the numerical values…
rahulgaur104 Apr 19, 2024
d05482c
adding responses to questions
rahulgaur104 Apr 19, 2024
b4604f6
bavg_drift_num has the same dimension as pitch now :)
rahulgaur104 Apr 19, 2024
8666648
Add helper function to compute epsilon effective
unalmis Apr 21, 2024
77c88bf
Make sure nan won't appear in gradient and improve composite linspace…
unalmis Apr 21, 2024
bf6cb73
Simplify composite_linspace function
unalmis Apr 21, 2024
a460ca1
A bug has been caught, but is not yet squashed!
unalmis Apr 21, 2024
df16d7d
Fix interpolation error induced nan propogation
unalmis Apr 22, 2024
ed8dc4f
Add methods for debugging bounce point inversion
unalmis Apr 22, 2024
37f2dd0
Think the splines are failing.
unalmis Apr 22, 2024
130f816
Squash floating point bugs that cause bounce points to not be detecte…
unalmis Apr 23, 2024
64e8b46
Fix all floating point error induced issues for detecting bounce points!
unalmis Apr 23, 2024
9d53735
Improve test_bounce_averaged_drifts
unalmis Apr 23, 2024
33c1249
Clean up bounce_average_drift test
unalmis Apr 23, 2024
81aa30a
Make it simpler to change spline method in test_bounce_averaged_drifts
unalmis Apr 24, 2024
26a9727
Move function definition to stop circular import
unalmis Apr 24, 2024
599eff6
Rename some variables to avoid confusion
unalmis Apr 24, 2024
43cad50
Remove override_grid from eq.compute in bounce_integral.
unalmis Apr 24, 2024
ab9ba6b
Reorder arguments into guard function for consistency
unalmis Apr 24, 2024
22fee12
Use same resolution in quadrature in bounce average drift test
unalmis Apr 24, 2024
cb9c55c
Fix floating point error in automorphism sin
unalmis Apr 25, 2024
9069695
floating point error resistant quadratic root
unalmis Apr 25, 2024
57e2d32
Use clip instead of shifting in arcsin automorph
unalmis Apr 25, 2024
49e87d5
API changes to be able to compute effective ripple in compute_funs
unalmis Apr 25, 2024
c7ec435
Remove override_grid=False, doesn't seem to matter anymore
unalmis Apr 25, 2024
772ed2a
fix bugs in bounce average drift test
unalmis Apr 26, 2024
9767f07
Add back testing assertions to analytic expressions
unalmis Apr 26, 2024
025f09d
Add methods to plot interpolated integrand.
unalmis Apr 27, 2024
74840d5
Merge branch 'master' into bounce
unalmis Apr 27, 2024
4cca3c1
Fix bad merge
unalmis Apr 27, 2024
f303a6e
Add back override_grid=False.
unalmis Apr 27, 2024
4a457c0
corrected errors in the analytical integrals, analytical and numerica…
rahulgaur104 Apr 27, 2024
d6deb24
correcting k2; it's incorrect in Hegna's paper
rahulgaur104 Apr 27, 2024
87f5c2d
replacing dPdrho with alpha_MHD terms; will take a look again later.
rahulgaur104 Apr 27, 2024
5431e0c
adding b dot grad theta in the analytical denominator
rahulgaur104 Apr 27, 2024
767e74a
adding gradpar_theta_analytic = b dot grad theta_PEST for the correct…
rahulgaur104 Apr 28, 2024
a449715
Fix the fourth integral in the bounce average test and add numerical …
unalmis Apr 28, 2024
3a4a384
Make changes to get_extrema to make computing eps_eff easier
unalmis Apr 28, 2024
03754ab
Make sure test_bounce_average_drifts computes things on correct grid
unalmis Apr 28, 2024
8b1d066
Make sure _compute_field_line recomputes field line quantities
unalmis Apr 28, 2024
b508b54
Working now!
unalmis Apr 29, 2024
6dc4a1b
Clean up some code and add image comparison test for drift
unalmis Apr 29, 2024
05465ef
Fix bug with recomputing quantities on incorrect grid
unalmis Apr 29, 2024
c1b6792
Regenerate baseline image for test_drift with pip matplotlib not conda
unalmis Apr 29, 2024
ab74a17
Remove extraneous details from image_comparison test plot in bounce_d…
unalmis Apr 29, 2024
427e5de
Merge branch 'master' into bounce
unalmis Apr 29, 2024
5361c3f
Merge branch 'eq_compute_bug' into bounce
unalmis Apr 29, 2024
22106cc
Add non-batched option to save memory when computing eps_eff
unalmis May 1, 2024
0f82588
Set default batched option to true
unalmis May 1, 2024
a657222
Reduce memory usage in interpolate
unalmis May 1, 2024
fe9da7d
Merge branch 'master' into bounce
rahulgaur104 May 1, 2024
cb2b57f
Simplify bounce quad looped algorithm transposing
unalmis May 1, 2024
31dbb5e
found a sneaky sqrt(2); agreement looks better now :)
rahulgaur104 May 2, 2024
ef86002
forgot to add the sqrt(2) + adding low-order drift expressions for c…
rahulgaur104 May 2, 2024
b5f3c26
removing last pitch points because the analytical integral is not clo…
rahulgaur104 May 2, 2024
c56c7e0
Reduce resolution and remove plotting test since numerical comparison…
unalmis May 2, 2024
29662c4
Merge branch 'master' into bounce
unalmis May 2, 2024
b02be9f
Merge branch 'eq_compute_bug' into bounce
unalmis May 2, 2024
889eb62
Reduce resolution, add back image test, switch default quadrature
unalmis May 2, 2024
fb9cede
Clean up orthax import
unalmis May 2, 2024
71cae5d
Generalize vmap to work with in_axes!=0
unalmis May 2, 2024
7f9eb37
Reduce resolution where possible, increase resolution until convergen…
unalmis May 3, 2024
2aaa35b
Merge branch 'master' into bounce
unalmis May 3, 2024
2c20bb5
Remove changes to grid spacing as that is done in #985
unalmis May 3, 2024
8114551
Make derivative suppression option available
unalmis May 4, 2024
5ef6e95
Tighten tolerance in test after trying out preiciosn things with mpmath
unalmis May 5, 2024
5c4da4e
Merge branch 'master' into bounce and fix stuff that broke from grid …
unalmis May 9, 2024
334562f
Merge branch 'master' into bounce
unalmis May 21, 2024
2279a04
Merge branch 'fieldline_compute' into bounce
unalmis May 21, 2024
68d855b
Merge branch 'fieldline_compute' into bounce
unalmis May 21, 2024
f4710c3
Modify create meshgrid for new grid attributes
unalmis May 21, 2024
d437a1d
Merge branch 'fieldline_compute' into bounce
unalmis May 22, 2024
4961cae
Use source grid attributes in desc_grid_from_field_line_coords
unalmis May 22, 2024
32c6bd2
Merge branch 'fieldline_compute' into bounce
unalmis May 23, 2024
f919bcc
Pass in quadrature points to bounce integral instead of function
unalmis May 26, 2024
7c7a1a6
Merge branch 'fieldline_compute' into bounce
unalmis May 26, 2024
f105ec3
Change keyword argument to batch from batched
unalmis May 28, 2024
1d8609d
Remove sort option from get_extrema
unalmis May 28, 2024
67a26c4
Merge branch 'fieldline_compute' into bounce
unalmis May 30, 2024
db33416
Move things from #1003 into #854
unalmis May 30, 2024
d5a00c8
Make sure dℓ parameterizes the distance along the field line in meter…
unalmis May 30, 2024
6bbaca5
Add g^pa magnetic axis limit
unalmis May 30, 2024
711f734
Remove confusing paranthesis
unalmis May 30, 2024
09a9d18
Merge branch 'fieldline_compute' into bounce
unalmis Jun 1, 2024
8c667bc
Update things after last merge
unalmis Jun 1, 2024
a93b88e
Add get_pitch utility method
unalmis Jun 2, 2024
a05d50c
Merge branch 'fieldline_compute' into bounce
unalmis Jun 2, 2024
9ba02bb
Fix assert statement for shape in get_pitch
unalmis Jun 2, 2024
7a85e87
Merge branch 'fieldline_compute' into bounce
unalmis Jun 3, 2024
61a8c3d
Fix use of rtz_grid() after merge with other branch
unalmis Jun 3, 2024
8518cee
Merge branch 'fieldline_compute' into bounce
unalmis Jun 3, 2024
11f8fe2
Merge branch 'fieldline_compute' into bounce
unalmis Jun 3, 2024
6711df0
Update interpax requirement to 0.3.2 for differentiable spline
unalmis Jun 4, 2024
8077d4e
Avoid complex arithmetic when computing roots to fix bug in effective…
unalmis Jun 10, 2024
b14426a
Clean up some documentation and docstrings
unalmis Jun 11, 2024
deb12d6
Fix type hinting and reorganize order of methods
unalmis Jun 12, 2024
782274d
Merge branch 'fieldline_compute' into bounce
unalmis Jun 17, 2024
4333973
Fix sign of B^zeta and B_z_ra
unalmis Jun 18, 2024
d51aa16
Partially undo previous commit
unalmis Jun 18, 2024
2a5e4b7
Merge branch 'fieldline_compute' into bounce
rahulgaur104 Jun 20, 2024
02afa2c
No more nan in effective ripple gradient
unalmis Jun 22, 2024
1d58463
Merge branch 'fieldline_compute' into bounce
unalmis Jun 22, 2024
3b5e9f9
Add test for finite nonzero derivative
unalmis Jun 22, 2024
ff46b4d
Merge branch 'fieldline_compute' into bounce
unalmis Jun 25, 2024
390e782
move changes from ripple to bounce (make some functions private)
unalmis Jun 25, 2024
8c4d28f
Merge branch 'fieldline_compute' into bounce
unalmis Jun 28, 2024
fd11816
Fix imports after merge
unalmis Jun 28, 2024
4b91a5e
Merge branch 'fieldline_compute' into bounce
unalmis Jun 28, 2024
b9de417
Remove unneeded compute funs
unalmis Jul 1, 2024
670ad66
Change label per Rory's request
unalmis Jul 2, 2024
f07cdae
Fix label
unalmis Jul 2, 2024
3322e94
Remove g^pa per review request
unalmis Jul 2, 2024
c10a59a
Remove old code
unalmis Jul 2, 2024
d1981b2
Merge branch 'fieldline_compute' into bounce
unalmis Jul 5, 2024
718ceb3
Merge branch 'fieldline_compute' into bounce
unalmis Jul 10, 2024
74ff461
Merge branch 'fieldline_compute' into bounce
unalmis Jul 11, 2024
b00ddc5
Merging fieldline_compute branch
unalmis Jul 11, 2024
ba92b47
Merge branch 'master' into bounce
unalmis Jul 20, 2024
e25b4ad
Merge branch 'clebsh_basis' into bounce
unalmis Jul 20, 2024
12219af
Merge branch 'clebsh_basis' into bounce
unalmis Jul 20, 2024
311425b
Clean up desc.backend and add eigh_tridiagonal
unalmis Jul 23, 2024
d921447
Add Guass-Lobatto quadrature for effective ripple, and speed up bounc…
unalmis Jul 23, 2024
6ed9d92
skeleton for fourier bounce integrals
unalmis Jul 24, 2024
24d75e0
Merge branch 'clebsh_basis' into bounce
unalmis Jul 24, 2024
8a20570
Merge branch 'bounce' into ku/fourier_bounce
unalmis Jul 24, 2024
30b7c1b
Move more efficient bounce points computation from fourier bounce to …
unalmis Jul 24, 2024
263930d
Fix dynamic jaxpr shape error induced from previous commit
unalmis Jul 25, 2024
dc1cc86
Merge branch 'bounce' into ku/fourier_bounce
unalmis Jul 25, 2024
95ed8a8
Add num_wells parameter to reduce size of bounce points matrix by fac…
unalmis Jul 25, 2024
1992e15
Specify num_wells expclitly in test to make sure it doesnt affect aut…
unalmis Jul 25, 2024
f316d89
Merge branch 'clebsh_basis' into bounce
unalmis Jul 25, 2024
142c86d
Merge branch 'bounce' into ku/fourier_bounce
unalmis Jul 25, 2024
e6b38bf
Clean up docstring comments
unalmis Jul 26, 2024
6acab01
Merge branch 'bounce' into ku/fourier_bounce
unalmis Jul 26, 2024
ad30aa0
Add remaining fourier bounce methods
unalmis Jul 31, 2024
8349019
Merge branch 'ku/map_coordinates_clebsch' into ku/fourier_bounce
unalmis Jul 31, 2024
689d3be
Merge branch 'ku/map_coordinates_clebsch' into ku/fourier_bounce
unalmis Aug 6, 2024
e5b535b
Merge branch 'clebsh_basis' into bounce
unalmis Aug 7, 2024
8c9c531
Downstream changes needed to implement Nemov's Gamma_c from Gamma_c b…
unalmis Aug 7, 2024
4ef040d
WIP: Commit before merge to save progress
unalmis Aug 7, 2024
289aba2
Merge branch 'bounce' into ku/fourier_bounce
unalmis Aug 7, 2024
e582143
Merge branch 'master' into bounce
unalmis Aug 9, 2024
89a03ee
Merge branch 'bounce' into ku/fourier_bounce
unalmis Aug 9, 2024
239db74
Merge branch 'master' into bounce
f0uriest Aug 9, 2024
a277a19
Merge branch 'master' into bounce
f0uriest Aug 9, 2024
16f5791
Merge branch 'master' into bounce
unalmis Aug 10, 2024
04eb8ab
Merge branch 'bounce' into ku/fourier_bounce
unalmis Aug 10, 2024
f7469e8
merging master test_compute_everyting fails! @unalmis
rahulgaur104 Aug 14, 2024
a8f8e82
Merge remote-tracking branch 'origin' into bounce
unalmis Aug 14, 2024
c914728
Merge commit 'f7469e8' into bounce
unalmis Aug 14, 2024
62219b3
Adding tests part 1
unalmis Aug 15, 2024
6571b8e
Merge branch 'master' into bounce
unalmis Aug 15, 2024
672b163
Making progress on tests
unalmis Aug 15, 2024
ff991cc
Replace einsum with vandermode matrix
unalmis Aug 15, 2024
744540a
Adding tests part 2
unalmis Aug 16, 2024
07eb550
Merge branch 'master' into bounce
rahulgaur104 Aug 18, 2024
ade0a5e
Merge branch 'bounce' into ku/fourier_bounce
unalmis Aug 18, 2024
4658415
Merge branch 'master' into bounce
dpanici Aug 20, 2024
8197f71
Force push with lease to avoid diverging branch with remote due to co…
unalmis Aug 20, 2024
d3f8f6b
Merge branch 'bounce' into ku/fourier_bounce
unalmis Aug 20, 2024
714a8f0
Merge branch 'master' into bounce
unalmis Aug 20, 2024
b6ee838
Merge branch 'bounce' into ku/fourier_bounce
unalmis Aug 21, 2024
8b64e0d
Move integration algorithms to integrals subfolder
unalmis Aug 21, 2024
7f993bf
Merge branch 'integrals' into ku/fourier_bounce
unalmis Aug 21, 2024
14ca329
Merge branch 'integrals' into ku/fourier_bounce
unalmis Aug 21, 2024
7c2d7c2
Simplify some broadcasting add short comments explaining theory
unalmis Aug 21, 2024
ae28995
Merge branch 'integrals' into ku/fourier_bounce
unalmis Aug 22, 2024
3e8a7bf
Merge branch 'integrals' into ku/fourier_bounce
unalmis Aug 22, 2024
c8bb170
Merge branch 'integrals' into ku/fourier_bounce
unalmis Aug 22, 2024
8094209
Make compatible with new meshgrid structure on master
unalmis Aug 22, 2024
2724c5b
Making progress on tests. All the bounce pointand splines
unalmis Aug 22, 2024
7007e9e
Merge branch 'master' into ku/fourier_bounce
unalmis Aug 22, 2024
c683bb8
Fix comment
unalmis Aug 22, 2024
32d64e9
Commit before I start modifying bounce_integral.py
unalmis Aug 25, 2024
819bff2
Major refactoring of bounce integrals
unalmis Aug 25, 2024
baa907d
Fix some stuff from previous commit
unalmis Aug 25, 2024
4b9ff2d
Merge branch 'master' into ku/fourier_bounce
unalmis Aug 25, 2024
cee3da7
Change interp2argmin to expect already reshaped data to be consistent…
unalmis Aug 25, 2024
a1f249c
Containerize and refactor basis used in bounce integrals
unalmis Aug 25, 2024
b7cc0e3
Merge branch 'master' into ku/fourier_bounce
unalmis Aug 25, 2024
540d062
Review algorithm. Fix documentation of integrals and use better names…
unalmis Aug 26, 2024
04f87a3
Debugging fourier bounce stuff
unalmis Aug 26, 2024
1a24a43
Fix bug in Fourier bounce with interpolation of b_sup_z
unalmis Aug 26, 2024
b6ade4c
Preparing merge into bounce branch
unalmis Aug 27, 2024
5fa3758
fix docstrings
unalmis Aug 27, 2024
1180cf9
Merge branch 'master' into bounce
unalmis Aug 27, 2024
f7350e7
Merge branch 'ku/fourier_bounce' into bounce
unalmis Aug 27, 2024
594cbd8
Remove stuff that should be in ku/fourier_bounce that came here after…
unalmis Aug 27, 2024
74e229f
Remove code that should be in fourier_bounce branch
unalmis Aug 27, 2024
dedc01b
Improve tests and fix failing test
unalmis Aug 27, 2024
bc03ab4
Merge branch 'master' into bounce
unalmis Aug 27, 2024
3e57eca
Clean up tests and API
unalmis Aug 27, 2024
3f479bb
Merge branch 'master' into bounce
unalmis Aug 27, 2024
6297d14
Merge branch 'master' into bounce
unalmis Aug 28, 2024
62d553b
Fix plotting bug from recent commits and address review comments part 1
unalmis Aug 28, 2024
afdd826
Merge branch 'master' into bounce
unalmis Aug 28, 2024
90e3596
Merge branch 'master' into bounce
unalmis Aug 28, 2024
5653376
Resolves #1228 on pull request #854
unalmis Aug 28, 2024
add9aaf
Fix claim for number of wells in an unoptimized stellarator
unalmis Aug 28, 2024
b17e513
Add description to test requested by Rory
unalmis Aug 28, 2024
e4dcd2e
Clarify test as requested by @f0uriest
unalmis Aug 29, 2024
1c1fa96
Make Bounce1D pytree and ioable and ensure eigh_tridiagonal is revers…
unalmis Aug 29, 2024
45b7c7a
Use more efficient vectorization
unalmis Aug 29, 2024
03ff0b1
Make super useful plotting function public and super user-friendly
unalmis Aug 29, 2024
6896586
Merge branch 'master' into bounce
unalmis Aug 29, 2024
deff086
Increase tolerance for plotting test
unalmis Aug 29, 2024
8edc317
Finishing touch clean up some docstrings
unalmis Aug 30, 2024
75c13fd
Make pitch optional argument for plot function
unalmis Aug 30, 2024
446c0b7
Address review comments and fix regression in batch argument from re…
unalmis Aug 30, 2024
239e441
Increase coverage
unalmis Aug 30, 2024
8376d03
Make broadcasting simpler for end user
unalmis Sep 1, 2024
2b6e9b6
Merge branch 'master' into bounce
f0uriest Sep 2, 2024
c531a82
Merge branch 'master' into bounce
rahulgaur104 Sep 2, 2024
e39dc14
Pull down changes from ripple branch
unalmis Sep 2, 2024
eb04894
Merge commit 'c531a82' into bounce
unalmis Sep 2, 2024
08e4257
Merge branch 'bounce' into bounce_pitch_shape
unalmis Sep 2, 2024
1436035
Swap vectorization order in bounce integrals (#1242)
unalmis Sep 3, 2024
138f90c
Merge branch 'master' into bounce
unalmis Sep 3, 2024
917ad1c
Tweak documentation as requested in pull request review
unalmis Sep 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 77 additions & 42 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,23 @@
)

if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign?
jit = jax.jit
fori_loop = jax.lax.fori_loop
cond = jax.lax.cond
switch = jax.lax.switch
while_loop = jax.lax.while_loop
vmap = jax.vmap
bincount = jnp.bincount
repeat = jnp.repeat
take = jnp.take
scan = jax.lax.scan
from jax import custom_jvp
from jax import custom_jvp, jit, vmap

imap = jax.lax.map
from jax.experimental.ode import odeint
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
from jax.lax import cond, fori_loop, scan, switch, while_loop
from jax.nn import softmax as softargmax
from jax.numpy import bincount, flatnonzero, repeat, take
from jax.numpy.fft import irfft, rfft, rfft2
from jax.scipy.fft import dct, idct
from jax.scipy.linalg import (
block_diag,
cho_factor,
cho_solve,
eigh_tridiagonal,
f0uriest marked this conversation as resolved.
Show resolved Hide resolved
qr,
solve_triangular,
)
from jax.scipy.special import gammaln, logsumexp
from jax.tree_util import (
register_pytree_node,
Expand All @@ -90,6 +94,10 @@
treedef_is_leaf,
)

trapezoid = (
jnp.trapezoid if hasattr(jnp, "trapezoid") else jax.scipy.integrate.trapezoid
)

def put(arr, inds, vals):
"""Functional interface for array "fancy indexing".

Expand Down Expand Up @@ -328,6 +336,8 @@ def root(
This routine may be used on over or under-determined systems, in which case it
will solve it in a least squares / least norm sense.
"""
from desc.compute.utils import safenorm

if fixup is None:
fixup = lambda x, *args: x
if jac is None:
Expand Down Expand Up @@ -392,7 +402,7 @@ def tangent_solve(g, y):
x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (jnp.linalg.norm(res), niter)
return x, (safenorm(res), niter)


# we can't really test the numpy backend stuff in automated testing, so we ignore it
Expand All @@ -401,15 +411,54 @@ def tangent_solve(g, y):
jit = lambda func, *args, **kwargs: func
execute_on_cpu = lambda func: func
import scipy.optimize
from numpy.fft import irfft, rfft, rfft2 # noqa: F401
from scipy.fft import dct, idct # noqa: F401
unalmis marked this conversation as resolved.
Show resolved Hide resolved
from scipy.integrate import odeint # noqa: F401
from scipy.linalg import ( # noqa: F401
block_diag,
cho_factor,
cho_solve,
eigh_tridiagonal,
qr,
solve_triangular,
)
from scipy.special import gammaln, logsumexp # noqa: F401
from scipy.special import softmax as softargmax # noqa: F401

trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz

def imap(f, xs, batch_size=None, in_axes=0, out_axes=0):
"""Generalizes jax.lax.map; uses numpy."""
if not isinstance(xs, np.ndarray):
raise NotImplementedError(
"Require numpy array input, or install jax to support pytrees."
)
xs = np.moveaxis(xs, source=in_axes, destination=0)
return np.stack([f(x) for x in xs], axis=out_axes)

def vmap(fun, in_axes=0, out_axes=0):
"""A numpy implementation of jax.lax.map whose API is a subset of jax.vmap.

Like Python's builtin map,
except inputs and outputs are in the form of stacked arrays,
and the returned object is a vectorized version of the input function.

Parameters
----------
fun: callable
Function (A -> B)
in_axes: int
Axis to map over.
out_axes: int
An integer indicating where the mapped axis should appear in the output.

Returns
-------
fun_vmap: callable
Vectorized version of fun.

"""
return lambda xs: imap(fun, xs, in_axes=in_axes, out_axes=out_axes)

def tree_stack(*args, **kwargs):
"""Stack pytree for numpy backend."""
Expand Down Expand Up @@ -592,32 +641,6 @@ def while_loop(cond_fun, body_fun, init_val):
val = body_fun(val)
return val

def vmap(fun, out_axes=0):
"""A numpy implementation of jax.lax.map whose API is a subset of jax.vmap.

Like Python's builtin map,
except inputs and outputs are in the form of stacked arrays,
and the returned object is a vectorized version of the input function.

Parameters
----------
fun: callable
Function (A -> B)
out_axes: int
An integer indicating where the mapped axis should appear in the output.

Returns
-------
fun_vmap: callable
Vectorized version of fun.

"""

def fun_vmap(fun_inputs):
return np.stack([fun(fun_input) for fun_input in fun_inputs], axis=out_axes)

return fun_vmap

def scan(f, init, xs, length=None, reverse=False, unroll=1):
"""Scan a function over leading array axes while carrying along state.

Expand Down Expand Up @@ -657,9 +680,14 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1):
ys.append(y)
return carry, np.stack(ys)

def bincount(x, weights=None, minlength=None, length=None):
"""Same as np.bincount but with a dummy parameter to match jnp.bincount API."""
return np.bincount(x, weights, minlength)
def bincount(x, weights=None, minlength=0, length=None):
rahulgaur104 marked this conversation as resolved.
Show resolved Hide resolved
"""A numpy implementation of jnp.bincount."""
x = np.clip(x, 0, None)
if length is None:
length = max(minlength, x.max() + 1)
else:
minlength = max(minlength, length)
return np.bincount(x, weights, minlength)[:length]

def repeat(a, repeats, axis=None, total_repeat_length=None):
"""A numpy implementation of jnp.repeat."""
Expand Down Expand Up @@ -778,6 +806,13 @@ def root(
out = scipy.optimize.root(fun, x0, args, jac=jac, tol=tol)
return out.x, out

def flatnonzero(a, size=None, fill_value=0):
"""A numpy implementation of jnp.flatnonzero."""
nz = np.flatnonzero(a)
if size is not None:
nz = np.pad(nz, (0, max(size - nz.size, 0)), constant_values=fill_value)
return nz

def take(
a,
indices,
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.special import roots_legendre

from ..backend import fori_loop, jnp
from ..integrals import surface_averages_map
from ..integrals.surface_integral import surface_averages_map
from .data_index import register_compute_fun


Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_equil.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from desc.backend import jnp

from ..integrals import surface_averages
from ..integrals.surface_integral import surface_averages
from .data_index import register_compute_fun
from .utils import cross, dot, safediv, safenorm

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import jnp

from ..integrals import (
from ..integrals.surface_integral import (
surface_averages,
surface_integrals_map,
surface_max,
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import jnp

from ..integrals import surface_averages
from ..integrals.surface_integral import surface_averages
from .data_index import register_compute_fun
from .utils import cross, dot, safediv, safenorm

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import cond, jnp

from ..integrals import surface_averages, surface_integrals
from ..integrals.surface_integral import surface_averages, surface_integrals
from .data_index import register_compute_fun
from .utils import cumtrapz, dot, safediv

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import jnp

from ..integrals import surface_integrals_map
from ..integrals.surface_integral import surface_integrals_map
from .data_index import register_compute_fun
from .utils import dot

Expand Down
7 changes: 5 additions & 2 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,11 +685,14 @@ def get_rtz_grid(
rvp : rho, theta_PEST, phi
rtz : rho, theta, zeta
period : tuple of float
Assumed periodicity for each quantity in inbasis.
Assumed periodicity for functions of the given coordinates.
Use ``np.inf`` to denote no periodicity.
jitable : bool, optional
If false the returned grid has additional attributes.
Required to be false to retain nodes at magnetic axis.
kwargs
Additional parameters to supply to the coordinate mapping function.
See ``desc.equilibrium.coords.map_coordinates``.

Returns
-------
Expand All @@ -701,7 +704,7 @@ def get_rtz_grid(
[radial, poloidal, toroidal], coordinates=coordinates, period=period
)
if "iota" in kwargs:
kwargs["iota"] = grid.expand(kwargs["iota"])
kwargs["iota"] = grid.expand(jnp.atleast_1d(kwargs["iota"]))
inbasis = {
"r": "rho",
"t": "theta",
Expand Down
6 changes: 5 additions & 1 deletion desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,11 @@
point. Only returned if ``full_output`` is True.

"""
warnif(True, DeprecationWarning, msg="Use map_coordinates instead.")
warnif(

Check warning on line 1258 in desc/equilibrium/equilibrium.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/equilibrium.py#L1258

Added line #L1258 was not covered by tests
True,
DeprecationWarning,
"Use map_coordinates instead of compute_theta_coords.",
)
return map_coordinates(
self,
flux_coords,
Expand Down
18 changes: 13 additions & 5 deletions desc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def meshgrid_reshape(self, x, order):
-------
x : ndarray
Data reshaped to align with grid nodes.

"""
errorif(
not self.is_meshgrid,
Expand All @@ -637,7 +638,8 @@ def meshgrid_reshape(self, x, order):
vec = True
shape += (-1,)
x = x.reshape(shape, order="F")
x = jnp.moveaxis(x, 1, 0) # now shape rtz/raz etc
# swap to change shape from trz/arz to rtz/raz etc.
x = jnp.swapaxes(x, 1, 0)
newax = tuple(self.coordinates.index(c) for c in order)
if vec:
newax += (3,)
Expand Down Expand Up @@ -788,10 +790,11 @@ def create_meshgrid(
rtz : rho, theta, zeta
period : tuple of float
Assumed periodicity for each coordinate.
Use np.inf to denote no periodicity.
Use ``np.inf`` to denote no periodicity.
NFP : int
Number of field periods (Default = 1).
Only makes sense to change from 1 if ``period[2]==2π``.
Only makes sense to change from 1 if last coordinate is periodic
with some constant divided by ``NFP``.

Returns
-------
Expand Down Expand Up @@ -1885,8 +1888,13 @@ def _periodic_spacing(x, period=2 * jnp.pi, sort=False, jnp=jnp):
x = jnp.sort(x, axis=0)
# choose dx to be half the distance between its neighbors
if x.size > 1:
dx_0 = x[1] + (period - x[-1]) % period
dx_1 = x[0] + (period - x[-2]) % period
if np.isfinite(period):
dx_0 = x[1] + (period - x[-1]) % period
dx_1 = x[0] + (period - x[-2]) % period
else:
# just set to 0 to stop nan gradient, even though above gives expected value
unalmis marked this conversation as resolved.
Show resolved Hide resolved
dx_0 = 0
dx_1 = 0
if x.size == 2:
# then dx[0] == period and dx[-1] == 0, so fix this
dx_1 = dx_0
Expand Down
1 change: 1 addition & 0 deletions desc/integrals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Classes for function integration."""

from .bounce_integral import Bounce1D
from .singularities import (
DFTInterpolator,
FFTInterpolator,
Expand Down
Loading
Loading