Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove jit method of objective, directly compile methods #1043

Merged
merged 23 commits into from
Aug 22, 2024

Conversation

f0uriest
Copy link
Member

@f0uriest f0uriest commented Jun 4, 2024

Basically, right now we close over self when compiling the methods of ObjectiveFunction. This means that JAX may bake all the attributes of the objective (ie, transforms, profiles, fields, equilibrium etc) into the compiled function which likely both slows down compilation and may lead to extra memory usage.

This changes things so that instead we JIT the method directly, treating self as just another argument. Doing this requires refactoring how the derivatives get handled a bit (they are now only local to their respective functions rather than being created separately, shouldn't be any performance hit since creating the Derivative objects is basically free).

Resolves #957
Resolves #1191

@f0uriest f0uriest marked this pull request as draft June 4, 2024 20:32
desc/objectives/objective_funs.py Outdated Show resolved Hide resolved
desc/objectives/objective_funs.py Outdated Show resolved Hide resolved
desc/objectives/objective_funs.py Outdated Show resolved Hide resolved
desc/objectives/objective_funs.py Outdated Show resolved Hide resolved
desc/objectives/objective_funs.py Outdated Show resolved Hide resolved
desc/objectives/objective_funs.py Outdated Show resolved Hide resolved
desc/objectives/objective_funs.py Outdated Show resolved Hide resolved
desc/objectives/objective_funs.py Outdated Show resolved Hide resolved
desc/objectives/linear_objectives.py Outdated Show resolved Hide resolved
desc/objectives/linear_objectives.py Outdated Show resolved Hide resolved
Copy link
Contributor

github-actions bot commented Jun 4, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +0.61 +/- 10.68    | +3.31e-03 +/- 5.76e-02 |  5.43e-01 +/- 4.0e-02  |  5.39e-01 +/- 4.1e-02  |
 test_build_transform_fft_midres         |     +2.68 +/- 9.01     | +1.64e-02 +/- 5.52e-02 |  6.28e-01 +/- 3.8e-02  |  6.12e-01 +/- 4.0e-02  |
 test_build_transform_fft_highres        |     +3.35 +/- 4.50     | +3.34e-02 +/- 4.49e-02 |  1.03e+00 +/- 3.0e-02  |  9.97e-01 +/- 3.3e-02  |
 test_equilibrium_init_lowres            |     +9.11 +/- 8.22     | +3.51e-01 +/- 3.17e-01 |  4.21e+00 +/- 2.9e-01  |  3.86e+00 +/- 1.3e-01  |
 test_equilibrium_init_medres            |     +2.58 +/- 6.03     | +1.19e-01 +/- 2.79e-01 |  4.75e+00 +/- 2.1e-01  |  4.63e+00 +/- 1.8e-01  |
 test_equilibrium_init_highres           |     +3.75 +/- 4.12     | +2.13e-01 +/- 2.33e-01 |  5.88e+00 +/- 1.7e-01  |  5.67e+00 +/- 1.6e-01  |
 test_objective_compile_dshape_current   |     +4.94 +/- 2.51     | +1.87e-01 +/- 9.50e-02 |  3.98e+00 +/- 6.3e-02  |  3.79e+00 +/- 7.1e-02  |
 test_objective_compile_atf              |     -4.31 +/- 2.42     | -3.62e-01 +/- 2.03e-01 |  8.04e+00 +/- 1.4e-01  |  8.40e+00 +/- 1.5e-01  |
-test_objective_compute_dshape_current   |    +172.30 +/- 4.71    | +2.19e-03 +/- 5.99e-05 |  3.46e-03 +/- 3.7e-05  |  1.27e-03 +/- 4.7e-05  |
-test_objective_compute_atf              |    +134.12 +/- 4.95    | +5.98e-03 +/- 2.21e-04 |  1.04e-02 +/- 1.7e-04  |  4.46e-03 +/- 1.4e-04  |
 test_objective_jac_dshape_current       |     +4.88 +/- 7.31     | +1.90e-03 +/- 2.84e-03 |  4.08e-02 +/- 2.4e-03  |  3.89e-02 +/- 1.5e-03  |
 test_objective_jac_atf                  |     -1.03 +/- 3.34     | -1.98e-02 +/- 6.42e-02 |  1.90e+00 +/- 5.5e-02  |  1.92e+00 +/- 3.3e-02  |
 test_perturb_1                          |     -7.18 +/- 7.10     | -9.60e-01 +/- 9.49e-01 |  1.24e+01 +/- 1.6e-01  |  1.34e+01 +/- 9.4e-01  |
 test_perturb_2                          |     -6.97 +/- 2.94     | -1.30e+00 +/- 5.47e-01 |  1.73e+01 +/- 3.6e-01  |  1.86e+01 +/- 4.1e-01  |
 test_proximal_jac_atf                   |     +0.07 +/- 1.40     | +5.68e-03 +/- 1.14e-01 |  8.11e+00 +/- 8.8e-02  |  8.11e+00 +/- 7.2e-02  |
-test_proximal_freeb_compute             |     +5.08 +/- 1.63     | +8.93e-03 +/- 2.86e-03 |  1.85e-01 +/- 2.1e-03  |  1.76e-01 +/- 2.0e-03  |
 test_proximal_freeb_jac                 |     +2.16 +/- 1.33     | +1.57e-01 +/- 9.65e-02 |  7.44e+00 +/- 8.9e-02  |  7.28e+00 +/- 3.8e-02  |
+test_solve_fixed_iter                   |    -71.70 +/- 17.72    | -1.24e+01 +/- 3.07e+00 |  4.91e+00 +/- 2.1e+00  |  1.73e+01 +/- 2.2e+00  |

@dpanici
Copy link
Collaborator

dpanici commented Jun 25, 2024

@YigitElma @dpanici @kianorr @rahulgaur104 we should profile this change memory-wise

Copy link

codecov bot commented Aug 15, 2024

Codecov Report

Attention: Patch coverage is 93.53234% with 13 lines in your changes missing coverage. Please review.

Project coverage is 95.42%. Comparing base (1c076fc) to head (a719f8a).
Report is 1682 commits behind head on master.

Files with missing lines Patch % Lines
desc/objectives/objective_funs.py 91.66% 11 Missing ⚠️
desc/objectives/linear_objectives.py 71.42% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1043      +/-   ##
==========================================
- Coverage   95.42%   95.42%   -0.01%     
==========================================
  Files          87       87              
  Lines       22341    22423      +82     
==========================================
+ Hits        21320    21398      +78     
- Misses       1021     1025       +4     
Files with missing lines Coverage Δ
desc/io/optimizable_io.py 86.30% <100.00%> (+0.08%) ⬆️
desc/objectives/utils.py 100.00% <100.00%> (ø)
desc/optimize/_constraint_wrappers.py 95.82% <100.00%> (+0.08%) ⬆️
desc/objectives/linear_objectives.py 97.04% <71.42%> (+0.01%) ⬆️
desc/objectives/objective_funs.py 93.97% <91.66%> (+0.02%) ⬆️

... and 5 files with indirect coverage changes

---- 🚨 Try these New Features:

@f0uriest f0uriest marked this pull request as ready for review August 15, 2024 16:05
@f0uriest
Copy link
Member Author

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     -1.56 +/- 9.77     | -8.12e-03 +/- 5.09e-02 |  5.13e-01 +/- 2.9e-02  |  5.21e-01 +/- 4.2e-02  |
 test_build_transform_fft_midres         |     -0.57 +/- 2.07     | -3.38e-03 +/- 1.22e-02 |  5.89e-01 +/- 9.9e-03  |  5.92e-01 +/- 7.2e-03  |
 test_build_transform_fft_highres        |     +3.10 +/- 3.37     | +3.05e-02 +/- 3.32e-02 |  1.01e+00 +/- 2.1e-02  |  9.84e-01 +/- 2.6e-02  |
 test_equilibrium_init_lowres            |     +3.36 +/- 8.44     | +1.32e-01 +/- 3.32e-01 |  4.07e+00 +/- 1.7e-01  |  3.94e+00 +/- 2.9e-01  |
 test_equilibrium_init_medres            |     -1.05 +/- 4.88     | -4.42e-02 +/- 2.04e-01 |  4.15e+00 +/- 1.0e-01  |  4.19e+00 +/- 1.8e-01  |
 test_equilibrium_init_highres           |     +1.22 +/- 3.98     | +6.80e-02 +/- 2.22e-01 |  5.65e+00 +/- 1.4e-01  |  5.59e+00 +/- 1.7e-01  |
 test_objective_compile_dshape_current   |     +2.97 +/- 4.68     | +1.17e-01 +/- 1.85e-01 |  4.06e+00 +/- 1.7e-01  |  3.95e+00 +/- 6.1e-02  |
 test_objective_compile_atf              |     -5.39 +/- 3.01     | -4.55e-01 +/- 2.54e-01 |  7.98e+00 +/- 7.0e-02  |  8.44e+00 +/- 2.4e-01  |
-test_objective_compute_dshape_current   |    +177.24 +/- 4.92    | +2.22e-03 +/- 6.17e-05 |  3.48e-03 +/- 5.1e-05  |  1.25e-03 +/- 3.4e-05  |
-test_objective_compute_atf              |    +139.12 +/- 4.43    | +5.95e-03 +/- 1.89e-04 |  1.02e-02 +/- 1.1e-04  |  4.27e-03 +/- 1.5e-04  |
 test_objective_jac_dshape_current       |     -0.65 +/- 8.29     | -2.51e-04 +/- 3.19e-03 |  3.82e-02 +/- 1.8e-03  |  3.85e-02 +/- 2.6e-03  |
 test_objective_jac_atf                  |     +1.18 +/- 4.31     | +2.19e-02 +/- 8.03e-02 |  1.88e+00 +/- 3.7e-02  |  1.86e+00 +/- 7.1e-02  |
 test_perturb_1                          |     -9.89 +/- 5.29     | -1.38e+00 +/- 7.40e-01 |  1.26e+01 +/- 3.2e-01  |  1.40e+01 +/- 6.7e-01  |
 test_perturb_2                          |     -3.33 +/- 2.48     | -6.19e-01 +/- 4.61e-01 |  1.80e+01 +/- 4.5e-01  |  1.86e+01 +/- 1.1e-01  |
 test_proximal_jac_atf                   |     +1.05 +/- 1.09     | +7.65e-02 +/- 7.96e-02 |  7.35e+00 +/- 6.6e-02  |  7.28e+00 +/- 4.4e-02  |
-test_proximal_freeb_compute             |     +3.08 +/- 0.75     | +5.44e-03 +/- 1.33e-03 |  1.82e-01 +/- 9.9e-04  |  1.76e-01 +/- 8.9e-04  |
 test_proximal_freeb_jac                 |     -0.08 +/- 0.95     | -5.65e-03 +/- 6.95e-02 |  7.35e+00 +/- 5.6e-02  |  7.35e+00 +/- 4.2e-02  |
+test_solve_fixed_iter                   |    -60.47 +/- 16.46    | -1.08e+01 +/- 2.94e+00 |  7.06e+00 +/- 2.9e+00  |  1.78e+01 +/- 5.7e-01  |

I think the slowdown in the compute benchmark is because it has to flatten/unflatten the ObjectiveFunction pytree at each call. It looks like a big slowdown but in absolute time its only a few ms, so likely worth it to get a big speedup on actual solve/optimization (that said I think 60% faster solve is higher than what I'm seeing in practice, it's usually more like 10-20% faster, so might be unique to the benchmark case)

Copy link
Collaborator

@YigitElma YigitElma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I woud prefer to get rid of set_derivatives function completely since it basically does nothing now.
  • For blocked jacobian, extra explanation would be good

desc/objectives/objective_funs.py Show resolved Hide resolved
desc/objectives/objective_funs.py Show resolved Hide resolved
@@ -816,10 +848,6 @@ def __init__(

def _set_derivatives(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above, I would prefer having this in init function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment above

desc/objectives/utils.py Outdated Show resolved Hide resolved
# in ProximalProjection we have an explicit state that we keep track of (and add
# to as we go) meaning if we jit anything with self static it doesn't update
# correctly, while if we leave self unstatic then it recompiles every time because
# the pytree structure of ProximalProjection is changing. To get around that we
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, for example ForceBalance constraint changes in terms of value but not the structure so, it doesn't recompile everytime. but do to appended state to the self, structure of ProximalProjection changes? Do I understand correctly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly. In theory we might be able to get around with with #1034 but that's future work.

YigitElma
YigitElma previously approved these changes Aug 18, 2024
for method in methods:
try:
delattr(self, method)
setattr(
self, method, functools.partial(getattr(self, method)._fun, self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to ask why grad jac etc had the Derivative call inside of them now, but this functools.partial call basically partially evaluates the method withself, which also creates the Derivative and sets up the correct AD method, so that subsequent calls are not re-creating the grad/jac function again, but rather just using it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is only for unjitting something (ie if use_jit=False) in which case it will effectively create a new Derivative object on each call (partial doesn't actually evaluate anything when its created).

The default (when jitted) is that the creation of the Derivative object basically gets compiled away by jax so its not costing anything. And even if it's created each time instantiating the object is negligible compared to actually computing the derivative.

@@ -1534,7 +1534,7 @@ def test_boundary_error_print(capsys):
obj = VacuumBoundaryError(eq, coilset, field_grid=coil_grid)
obj.build()

f = np.abs(obj.compute_unscaled(*obj.xs(eq)))
f = np.abs(obj.compute_unscaled(*obj.xs(eq, coilset)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was the fact that the old way worked before somewhat unintentional?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah seems like it was an uncaught bug.

Copy link
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just had some questions but nothing on the content

@f0uriest f0uriest requested a review from YigitElma August 22, 2024 01:53
@f0uriest f0uriest merged commit bdc5de4 into master Aug 22, 2024
17 of 18 checks passed
@f0uriest f0uriest deleted the rc/objective_jit branch August 22, 2024 05:32
unalmis added a commit that referenced this pull request Sep 3, 2024
After the recent refactoring to the `Bounce1D` class that resulted from
#1214, the API is a little too strict for computations like effective
ripple etc. where we vectorize the computation over over some dimensions
and loop over others to save memory.

This PR changes the expected shape of the pitch angle input to
`Bounce1D` in #854 from `(P, M, L)` to `(M, L, P)`. With this change,
the two leading axes of all inputs to the methods in that class is `(M,
L)`.

These changes are tested and already included in downstream branches. I
am making new PR instead of directly committing to the `bounce` branch
for people who have already reviewed the `bounce` PR.

 This is better because
1. Easier usage for end users. (Previously, you'd have to manually add
trailing axes to pitch angle array).
2. Makes it much simpler to use with JAX's new batched map.
3. Previously we would loop over the pitch angles to save memory.
However, this means some computation is repeated because interpax would
interpolate multiple times. By looping over the field lines instead and
doing the interpolation for all the pitch angles at once, both
`_bounce_quadrature` and `interp_to_argmin` are faster. (I'm seeing 30%
faster speed just from computing effective ripple (no optimization), but
I don't plan to do any benchmarking to see whether that is from recent
changes like #1154 or #1043 , or others).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Excessive recompilation in ProximalProjection Don't close over self when jitting objective functions
4 participants