-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ExternalObjective
function to wrap external codes
#1028
base: master
Are you sure you want to change the base?
Conversation
ExternalObjective
function to wrap external codes
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1028 +/- ##
==========================================
+ Coverage 95.61% 95.62% +0.01%
==========================================
Files 98 98
Lines 25420 25487 +67
==========================================
+ Hits 24306 24373 +67
Misses 1114 1114
|
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | -1.63 +/- 4.92 | -9.24e-03 +/- 2.79e-02 | 5.59e-01 +/- 2.0e-02 | 5.68e-01 +/- 1.9e-02 |
test_equilibrium_init_medres | +0.56 +/- 3.08 | +2.49e-02 +/- 1.37e-01 | 4.47e+00 +/- 1.0e-01 | 4.44e+00 +/- 8.9e-02 |
test_equilibrium_init_highres | +3.46 +/- 3.67 | +1.98e-01 +/- 2.10e-01 | 5.91e+00 +/- 1.8e-01 | 5.71e+00 +/- 1.1e-01 |
test_objective_compile_dshape_current | +0.31 +/- 2.48 | +1.25e-02 +/- 1.00e-01 | 4.05e+00 +/- 6.4e-02 | 4.04e+00 +/- 7.7e-02 |
test_objective_compute_dshape_current | -7.90 +/- 4.08 | -4.44e-04 +/- 2.30e-04 | 5.18e-03 +/- 9.6e-05 | 5.63e-03 +/- 2.1e-04 |
test_objective_jac_dshape_current | -4.27 +/- 8.36 | -1.99e-03 +/- 3.90e-03 | 4.46e-02 +/- 1.2e-03 | 4.66e-02 +/- 3.7e-03 |
test_perturb_2 | -1.57 +/- 1.49 | -3.30e-01 +/- 3.13e-01 | 2.06e+01 +/- 1.6e-01 | 2.10e+01 +/- 2.7e-01 |
test_proximal_freeb_jac | -0.31 +/- 2.27 | -2.37e-02 +/- 1.73e-01 | 7.57e+00 +/- 1.6e-01 | 7.59e+00 +/- 6.8e-02 |
test_solve_fixed_iter | +0.99 +/- 2.21 | +3.42e-01 +/- 7.65e-01 | 3.49e+01 +/- 6.7e-01 | 3.45e+01 +/- 3.7e-01 |
test_LinearConstraintProjection_build | -0.01 +/- 4.20 | -7.32e-04 +/- 4.53e-01 | 1.08e+01 +/- 1.6e-01 | 1.08e+01 +/- 4.2e-01 |
test_build_transform_fft_midres | +1.12 +/- 4.41 | +6.88e-03 +/- 2.71e-02 | 6.21e-01 +/- 2.0e-02 | 6.14e-01 +/- 1.9e-02 |
test_build_transform_fft_highres | -0.70 +/- 1.96 | -6.84e-03 +/- 1.92e-02 | 9.74e-01 +/- 9.3e-03 | 9.81e-01 +/- 1.7e-02 |
test_equilibrium_init_lowres | +1.37 +/- 1.75 | +5.32e-02 +/- 6.79e-02 | 3.93e+00 +/- 4.5e-02 | 3.88e+00 +/- 5.1e-02 |
test_objective_compile_atf | -0.12 +/- 4.73 | -9.51e-03 +/- 3.87e-01 | 8.16e+00 +/- 3.0e-01 | 8.17e+00 +/- 2.5e-01 |
test_objective_compute_atf | -1.44 +/- 3.73 | -2.31e-04 +/- 5.99e-04 | 1.58e-02 +/- 4.1e-04 | 1.61e-02 +/- 4.4e-04 |
test_objective_jac_atf | +2.47 +/- 1.49 | +4.86e-02 +/- 2.94e-02 | 2.02e+00 +/- 2.0e-02 | 1.97e+00 +/- 2.2e-02 |
test_perturb_1 | +0.63 +/- 1.71 | +9.39e-02 +/- 2.54e-01 | 1.50e+01 +/- 2.5e-01 | 1.49e+01 +/- 6.4e-02 |
test_proximal_jac_atf | +0.48 +/- 0.87 | +3.96e-02 +/- 7.26e-02 | 8.35e+00 +/- 4.1e-02 | 8.31e+00 +/- 6.0e-02 |
test_proximal_freeb_compute | +0.67 +/- 1.54 | +1.33e-03 +/- 3.07e-03 | 2.01e-01 +/- 2.5e-03 | 1.99e-01 +/- 1.8e-03 |
test_solve_fixed_iter_compiled | -0.31 +/- 1.00 | -6.77e-02 +/- 2.19e-01 | 2.19e+01 +/- 1.5e-01 | 2.20e+01 +/- 1.6e-01 | |
This reverts commit debecad.
|
||
|
||
@pytest.mark.unit | ||
@pytest.mark.slow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This takes a minute or so to run on my laptop, since VMECIO.save
is slow. If we really need to speed this up we could either reduce the equilibrium resolution or manually save only the VMEC quantities that are used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I sped this up significantly by manually saving the few quantities instead of calling VMECIO.save
Might be useful when we want to do multithreading jax-ml/jax#24756 |
I'm having an issue with the new test in this PR after updating with |
Put wrapper fxn in backend which checks JAX version and uses correct API |
|
||
""" | ||
|
||
def wrap_pure_callback(func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may want io_callback instead
func, | ||
result_shape_dtype, | ||
*args, | ||
vmap_method="expand_dims" if vectorized else "sequential", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works with JAX v0.4.35 and later. Need to add backwards compatibility.
I think any other vmap_method
besides "expand_dims"
would also work since it only changes the dimensions for internal calculations and not the final result.
Creates an abstract base class for wrapping external codes with finite differences, like GX, TERPSICHORE, etc.
TODO: