-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix JAX error with FourierCurrentPotentialField
in Flux objectives
#1002
Conversation
…CurrentField to avoid issues when jitting the objective
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +7.06 +/- 8.20 | +3.58e-02 +/- 4.16e-02 | 5.43e-01 +/- 4.1e-02 | 5.07e-01 +/- 6.4e-03 |
test_build_transform_fft_midres | +6.67 +/- 3.29 | +3.96e-02 +/- 1.95e-02 | 6.33e-01 +/- 1.7e-02 | 5.93e-01 +/- 9.9e-03 |
test_build_transform_fft_highres | +3.78 +/- 3.51 | +3.73e-02 +/- 3.46e-02 | 1.02e+00 +/- 1.5e-02 | 9.87e-01 +/- 3.1e-02 |
test_equilibrium_init_lowres | +1.75 +/- 2.66 | +6.43e-02 +/- 9.76e-02 | 3.73e+00 +/- 9.7e-02 | 3.67e+00 +/- 9.8e-03 |
test_equilibrium_init_medres | +1.82 +/- 2.17 | +7.53e-02 +/- 8.98e-02 | 4.22e+00 +/- 8.9e-02 | 4.14e+00 +/- 1.0e-02 |
test_equilibrium_init_highres | +1.21 +/- 1.88 | +6.71e-02 +/- 1.04e-01 | 5.60e+00 +/- 8.5e-02 | 5.54e+00 +/- 6.0e-02 |
test_objective_compile_dshape_current | +0.64 +/- 5.96 | +2.43e-02 +/- 2.26e-01 | 3.81e+00 +/- 2.2e-01 | 3.79e+00 +/- 1.9e-02 |
test_objective_compile_atf | +1.09 +/- 3.08 | +8.93e-02 +/- 2.54e-01 | 8.32e+00 +/- 1.5e-01 | 8.23e+00 +/- 2.0e-01 |
test_objective_compute_dshape_current | -1.81 +/- 4.61 | -2.31e-05 +/- 5.87e-05 | 1.25e-03 +/- 2.9e-05 | 1.27e-03 +/- 5.1e-05 |
test_objective_compute_atf | +0.42 +/- 6.38 | +1.76e-05 +/- 2.70e-04 | 4.25e-03 +/- 2.3e-04 | 4.23e-03 +/- 1.5e-04 |
test_objective_jac_dshape_current | -1.67 +/- 11.51 | -6.06e-04 +/- 4.18e-03 | 3.57e-02 +/- 2.3e-03 | 3.64e-02 +/- 3.5e-03 |
test_objective_jac_atf | +3.58 +/- 2.49 | +6.67e-02 +/- 4.63e-02 | 1.93e+00 +/- 3.4e-02 | 1.86e+00 +/- 3.2e-02 |
test_perturb_1 | +0.51 +/- 0.68 | +6.71e-02 +/- 8.89e-02 | 1.31e+01 +/- 7.4e-02 | 1.30e+01 +/- 4.9e-02 |
test_perturb_2 | +0.74 +/- 1.23 | +1.33e-01 +/- 2.22e-01 | 1.81e+01 +/- 1.3e-01 | 1.80e+01 +/- 1.8e-01 |
test_proximal_jac_atf | -0.61 +/- 1.44 | -4.48e-02 +/- 1.05e-01 | 7.28e+00 +/- 6.0e-02 | 7.32e+00 +/- 8.7e-02 |
test_proximal_freeb_compute | -1.67 +/- 0.74 | -3.00e-03 +/- 1.32e-03 | 1.76e-01 +/- 1.0e-03 | 1.79e-01 +/- 8.2e-04 |
test_proximal_freeb_jac | -0.12 +/- 1.09 | -8.66e-03 +/- 7.98e-02 | 7.31e+00 +/- 5.9e-02 | 7.32e+00 +/- 5.4e-02 |
test_solve_fixed_iter | -0.38 +/- 9.08 | -5.59e-02 +/- 1.34e+00 | 1.47e+01 +/- 9.9e-01 | 1.48e+01 +/- 9.0e-01 | |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1002 +/- ##
==========================================
- Coverage 94.98% 94.97% -0.01%
==========================================
Files 87 87
Lines 21749 21762 +13
==========================================
+ Hits 20658 20669 +11
- Misses 1091 1093 +2
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reason coils worked while the current potential doesnt is that by default the coil classes create jitable transforms in their compute method, since they're only 1d its not really worth it to try to do the fft.
I'd recommend adding the new classes to the existing get_transforms
logic, I think most of it should "just work" assuming attributes are named correctly. (might need some logic somewhere for tree-like coilsets etc, but do-able.
desc/objectives/_coils.py
Outdated
@@ -753,13 +753,28 @@ def build(self, use_jit=True, verbose=1): | |||
Bplasma = compute_B_plasma( | |||
eq, eval_grid, self._source_grid, normal_only=True | |||
) | |||
field = self._field | |||
if hasattr(field, "Phi_mn"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be better to get this all to work with the existing desc.compute.utils.get_transforms
helper? That would work for coils etc as well.
The problem is less the classes and more that not every magnetic field class needs the same keys to calculate the magnetic field. Coils need "I" and "x", the current potential fields need "K", toroidal fields dont need any keys as they just have their own compute magnetic field function. |
add case for coils and 1D transforms (by calling get_transforms in build always asking for "x" (if is a curve class or CoilSet) and maybe also "K" if it is a current potential field. Also add in the current potential field compute, a jitable=True (as if there is a current potential field in a SumMagneticField it might not work correctly...) |
@f0uriest I don't exactly remember the issue with the current method/how putting the logic in We still would need to call I think I also found in this PR a fix to something causing a JAX error in a few places when current potential fields are being optimized, so if possible I'd like to get this one in. Unless there is a better case for putting more logic into get_transforms, which I think should wait until we make "B" a data index quantity for MagneticField objects and then can go through that route to get all the proper transforms etc based off of the specific MagneticField parameterization |
tests/test_optimizer.py
Outdated
|
||
|
||
@pytest.mark.unit | ||
def test_tor_flux_with_surface_current_field(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could these tests be one test with an inner test()
function or is it better to keep them separate?
@@ -644,7 +647,23 @@ def _compute_magnetic_field_from_CurrentPotentialField( | |||
# compute surface current, and store grid quantities | |||
# needed for integration in class | |||
# TODO: does this have to be xyz, or can it be computed in rpz as well? | |||
data = field.compute(["K", "x"], grid=source_grid, basis="xyz", params=params) | |||
if not params and not transforms: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this conditional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid calling field.compute
inside a compute function of an objective, as there is some logic there that jit
does not like
Equinox for runtime error checking in jitted functions |
@f0uriest check this again for placing the code in a util function |
FourierCurrentPotentialField
in Flux objectivesFourierCurrentPotentialField
in Flux objectives
Merge after increasing the coverage. |
|
||
|
||
@pytest.mark.unit | ||
def test_quad_flux_with_surface_current_field(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the plan to remove/combine this test with #1025 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, as that one will add coil tests while this one is using surface current fields. If one test on that branch involves a sum magnetic field with coils and a surface current field then we can remove this in favor of that one
@@ -1272,7 +1272,10 @@ def compute(self, field_params, constants=None): | |||
|
|||
# B_ext is not pre-computed because field is not fixed | |||
B_ext = constants["field"].compute_magnetic_field( | |||
x, source_grid=constants["field_grid"], basis="rpz", params=field_params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No real change to this file, right? You just like to make the code longer to bother me?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I blame my IDE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but yea no change here just formatting
- doesn't pass due to jax bug, maybe PR #1002 will fix it?
On master, an optimization with a
FourierCurrentPotentialField
and theQuadraticFlux
objective would fail. This happens because in QuadraticFlux.compute, field.compute_magnetic_field is called. If the field needs transforms to evaluate, then these transforms will be created on the fly if they are not provided, resulting in an error.This PR fixes that by adding
jitable=True
to the.compute
call inside.compute_magnetic_field
It also adds
transform
as an argument toMagneticField.compute_magnetic_field
methods, in preparation for when #1079 is resolved and objectives can pre-compute and pass in transform objects for all magnetic fields easily.This PR kicks the can down the road of making
get_transforms
work withMagneticField.compute_magnetic_field
objects (to allow us to pre-compute the transforms used for magnetic field computation #1079 ) and instead opts for the simple fix.