Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make map_coordinates differentiable for JAX 0.4.34 #1293

Merged
merged 60 commits into from
Oct 30, 2024
Merged

Conversation

YigitElma
Copy link
Collaborator

@YigitElma YigitElma commented Oct 4, 2024

  • Adds full_output flags to root and root_scalar to make them differentiable.

  • Adds tests for differentiability of root and root_scalar in addition to map_coordinates_derivative

  • While working on jax problems, I used this PR to update our test_jax workflow with new jax versions and better dependency installation routine (i.e. previously since jax was uploaded later, rest of the packages were latest and only jax was old, this was causing incompatibilities and false-errors)

Resolves #1291

@YigitElma YigitElma added test_jax Run tests against different versions of JAX easy Short and simple to code or review bug fix Something was fixed labels Oct 4, 2024
@YigitElma YigitElma self-assigned this Oct 4, 2024
Copy link
Contributor

github-actions bot commented Oct 4, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_midres         |     +1.10 +/- 3.45     | +6.73e-03 +/- 2.11e-02 |  6.19e-01 +/- 1.8e-02  |  6.12e-01 +/- 1.0e-02  |
 test_build_transform_fft_highres        |     +0.04 +/- 2.56     | +3.64e-04 +/- 2.59e-02 |  1.01e+00 +/- 2.1e-02  |  1.01e+00 +/- 1.5e-02  |
 test_equilibrium_init_lowres            |     +1.28 +/- 1.35     | +4.96e-02 +/- 5.23e-02 |  3.92e+00 +/- 3.3e-02  |  3.87e+00 +/- 4.0e-02  |
 test_objective_compile_atf              |     +0.43 +/- 4.17     | +3.33e-02 +/- 3.26e-01 |  7.86e+00 +/- 2.6e-01  |  7.82e+00 +/- 2.0e-01  |
 test_objective_compute_atf              |     -0.22 +/- 2.94     | -2.32e-05 +/- 3.09e-04 |  1.05e-02 +/- 2.2e-04  |  1.05e-02 +/- 2.2e-04  |
 test_objective_jac_atf                  |     -1.93 +/- 2.33     | -3.75e-02 +/- 4.53e-02 |  1.91e+00 +/- 4.0e-02  |  1.95e+00 +/- 2.1e-02  |
 test_perturb_1                          |     +0.31 +/- 2.35     | +3.92e-02 +/- 2.99e-01 |  1.27e+01 +/- 1.2e-01  |  1.27e+01 +/- 2.7e-01  |
 test_proximal_jac_atf                   |     +0.41 +/- 0.98     | +3.30e-02 +/- 7.94e-02 |  8.10e+00 +/- 3.5e-02  |  8.07e+00 +/- 7.1e-02  |
 test_proximal_freeb_compute             |     +0.03 +/- 0.88     | +5.43e-05 +/- 1.61e-03 |  1.83e-01 +/- 1.3e-03  |  1.83e-01 +/- 9.1e-04  |
 test_build_transform_fft_lowres         |     -7.38 +/- 7.18     | -4.11e-02 +/- 4.00e-02 |  5.16e-01 +/- 3.6e-02  |  5.57e-01 +/- 1.8e-02  |
 test_equilibrium_init_medres            |     -3.06 +/- 8.60     | -1.28e-01 +/- 3.60e-01 |  4.05e+00 +/- 3.3e-02  |  4.18e+00 +/- 3.6e-01  |
 test_equilibrium_init_highres           |     +0.00 +/- 1.40     | +2.00e-04 +/- 7.54e-02 |  5.38e+00 +/- 5.5e-02  |  5.38e+00 +/- 5.2e-02  |
 test_objective_compile_dshape_current   |     +0.18 +/- 5.51     | +6.76e-03 +/- 2.10e-01 |  3.82e+00 +/- 2.1e-01  |  3.81e+00 +/- 4.1e-02  |
 test_objective_compute_dshape_current   |     +0.27 +/- 2.30     | +9.61e-06 +/- 8.30e-05 |  3.61e-03 +/- 6.7e-05  |  3.60e-03 +/- 4.9e-05  |
 test_objective_jac_dshape_current       |     -1.00 +/- 11.38    | -3.92e-04 +/- 4.47e-03 |  3.89e-02 +/- 3.2e-03  |  3.93e-02 +/- 3.1e-03  |
 test_perturb_2                          |     +2.36 +/- 2.95     | +4.09e-01 +/- 5.12e-01 |  1.77e+01 +/- 1.4e-01  |  1.73e+01 +/- 4.9e-01  |
 test_proximal_freeb_jac                 |     -2.36 +/- 1.26     | -1.77e-01 +/- 9.50e-02 |  7.35e+00 +/- 6.7e-02  |  7.53e+00 +/- 6.7e-02  |
 test_solve_fixed_iter                   |     -1.16 +/- 57.77    | -5.85e-02 +/- 2.92e+00 |  4.99e+00 +/- 2.1e+00  |  5.05e+00 +/- 2.1e+00  |

Copy link

codecov bot commented Oct 4, 2024

Codecov Report

Attention: Patch coverage is 85.36585% with 6 lines in your changes missing coverage. Please review.

Project coverage is 95.51%. Comparing base (831e7bd) to head (c508da8).
Report is 61 commits behind head on master.

Files with missing lines Patch % Lines
desc/backend.py 86.36% 3 Missing ⚠️
desc/equilibrium/coords.py 77.77% 2 Missing ⚠️
desc/geometry/surface.py 75.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1293      +/-   ##
==========================================
- Coverage   95.52%   95.51%   -0.02%     
==========================================
  Files          96       96              
  Lines       24010    24032      +22     
==========================================
+ Hits        22936    22954      +18     
- Misses       1074     1078       +4     
Files with missing lines Coverage Δ
desc/basis.py 98.18% <100.00%> (ø)
desc/geometry/surface.py 96.60% <75.00%> (-0.25%) ⬇️
desc/equilibrium/coords.py 88.25% <77.77%> (-0.13%) ⬇️
desc/backend.py 90.44% <86.36%> (+0.20%) ⬆️

... and 1 file with indirect coverage changes

@dpanici dpanici added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 5, 2024
@dpanici dpanici added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 6, 2024
@YigitElma YigitElma marked this pull request as draft October 6, 2024 18:56
@dpanici dpanici added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 6, 2024
@dpanici dpanici added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 7, 2024
@YigitElma YigitElma requested review from unalmis and IssraAli October 28, 2024 18:22
unalmis
unalmis previously approved these changes Oct 28, 2024
@YigitElma YigitElma dismissed stale reviews from f0uriest and dpanici October 28, 2024 21:00

requested changes are made

@YigitElma YigitElma requested review from f0uriest, dpanici, rahulgaur104 and ddudt and removed request for f0uriest, rahulgaur104, ddudt and dpanici October 28, 2024 21:02
f0uriest
f0uriest previously approved these changes Oct 28, 2024
devtools/dev-requirements_conda.yml Show resolved Hide resolved
requirements.txt Show resolved Hide resolved
requirements_conda.yml Show resolved Hide resolved
tests/test_equilibrium.py Outdated Show resolved Hide resolved
@YigitElma YigitElma dismissed stale reviews from f0uriest and unalmis via 1968e23 October 29, 2024 18:02
@YigitElma
Copy link
Collaborator Author

I added this to the change log, because after @dpanici's changes in #1327, we need to say this bug is fixed.

Copy link
Collaborator

@unalmis unalmis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dpanici dpanici merged commit 7887a52 into master Oct 30, 2024
23 of 24 checks passed
@dpanici dpanici deleted the yge/customjvp_fix branch October 30, 2024 18:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug fix Something was fixed test_jax Run tests against different versions of JAX
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error in map_coordinates with jax==0.4.34
6 participants