-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove jit method of objective, directly compile methods #1043
Conversation
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +0.61 +/- 10.68 | +3.31e-03 +/- 5.76e-02 | 5.43e-01 +/- 4.0e-02 | 5.39e-01 +/- 4.1e-02 |
test_build_transform_fft_midres | +2.68 +/- 9.01 | +1.64e-02 +/- 5.52e-02 | 6.28e-01 +/- 3.8e-02 | 6.12e-01 +/- 4.0e-02 |
test_build_transform_fft_highres | +3.35 +/- 4.50 | +3.34e-02 +/- 4.49e-02 | 1.03e+00 +/- 3.0e-02 | 9.97e-01 +/- 3.3e-02 |
test_equilibrium_init_lowres | +9.11 +/- 8.22 | +3.51e-01 +/- 3.17e-01 | 4.21e+00 +/- 2.9e-01 | 3.86e+00 +/- 1.3e-01 |
test_equilibrium_init_medres | +2.58 +/- 6.03 | +1.19e-01 +/- 2.79e-01 | 4.75e+00 +/- 2.1e-01 | 4.63e+00 +/- 1.8e-01 |
test_equilibrium_init_highres | +3.75 +/- 4.12 | +2.13e-01 +/- 2.33e-01 | 5.88e+00 +/- 1.7e-01 | 5.67e+00 +/- 1.6e-01 |
test_objective_compile_dshape_current | +4.94 +/- 2.51 | +1.87e-01 +/- 9.50e-02 | 3.98e+00 +/- 6.3e-02 | 3.79e+00 +/- 7.1e-02 |
test_objective_compile_atf | -4.31 +/- 2.42 | -3.62e-01 +/- 2.03e-01 | 8.04e+00 +/- 1.4e-01 | 8.40e+00 +/- 1.5e-01 |
-test_objective_compute_dshape_current | +172.30 +/- 4.71 | +2.19e-03 +/- 5.99e-05 | 3.46e-03 +/- 3.7e-05 | 1.27e-03 +/- 4.7e-05 |
-test_objective_compute_atf | +134.12 +/- 4.95 | +5.98e-03 +/- 2.21e-04 | 1.04e-02 +/- 1.7e-04 | 4.46e-03 +/- 1.4e-04 |
test_objective_jac_dshape_current | +4.88 +/- 7.31 | +1.90e-03 +/- 2.84e-03 | 4.08e-02 +/- 2.4e-03 | 3.89e-02 +/- 1.5e-03 |
test_objective_jac_atf | -1.03 +/- 3.34 | -1.98e-02 +/- 6.42e-02 | 1.90e+00 +/- 5.5e-02 | 1.92e+00 +/- 3.3e-02 |
test_perturb_1 | -7.18 +/- 7.10 | -9.60e-01 +/- 9.49e-01 | 1.24e+01 +/- 1.6e-01 | 1.34e+01 +/- 9.4e-01 |
test_perturb_2 | -6.97 +/- 2.94 | -1.30e+00 +/- 5.47e-01 | 1.73e+01 +/- 3.6e-01 | 1.86e+01 +/- 4.1e-01 |
test_proximal_jac_atf | +0.07 +/- 1.40 | +5.68e-03 +/- 1.14e-01 | 8.11e+00 +/- 8.8e-02 | 8.11e+00 +/- 7.2e-02 |
-test_proximal_freeb_compute | +5.08 +/- 1.63 | +8.93e-03 +/- 2.86e-03 | 1.85e-01 +/- 2.1e-03 | 1.76e-01 +/- 2.0e-03 |
test_proximal_freeb_jac | +2.16 +/- 1.33 | +1.57e-01 +/- 9.65e-02 | 7.44e+00 +/- 8.9e-02 | 7.28e+00 +/- 3.8e-02 |
+test_solve_fixed_iter | -71.70 +/- 17.72 | -1.24e+01 +/- 3.07e+00 | 4.91e+00 +/- 2.1e+00 | 1.73e+01 +/- 2.2e+00 | |
@YigitElma @dpanici @kianorr @rahulgaur104 we should profile this change memory-wise |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1043 +/- ##
==========================================
- Coverage 95.42% 95.42% -0.01%
==========================================
Files 87 87
Lines 22341 22423 +82
==========================================
+ Hits 21320 21398 +78
- Misses 1021 1025 +4
|
I think the slowdown in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I woud prefer to get rid of
set_derivatives
function completely since it basically does nothing now. - For blocked jacobian, extra explanation would be good
@@ -816,10 +848,6 @@ def __init__( | |||
|
|||
def _set_derivatives(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above, I would prefer having this in init function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see comment above
# in ProximalProjection we have an explicit state that we keep track of (and add | ||
# to as we go) meaning if we jit anything with self static it doesn't update | ||
# correctly, while if we leave self unstatic then it recompiles every time because | ||
# the pytree structure of ProximalProjection is changing. To get around that we |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, for example ForceBalance constraint changes in terms of value but not the structure so, it doesn't recompile everytime. but do to appended state to the self, structure of ProximalProjection changes? Do I understand correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes exactly. In theory we might be able to get around with with #1034 but that's future work.
for method in methods: | ||
try: | ||
delattr(self, method) | ||
setattr( | ||
self, method, functools.partial(getattr(self, method)._fun, self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going to ask why grad jac etc had the Derivative
call inside of them now, but this functools.partial call basically partially evaluates the method withself
, which also creates the Derivative and sets up the correct AD method, so that subsequent calls are not re-creating the grad/jac function again, but rather just using it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is only for unjitting something (ie if use_jit=False
) in which case it will effectively create a new Derivative
object on each call (partial
doesn't actually evaluate anything when its created).
The default (when jitted) is that the creation of the Derivative
object basically gets compiled away by jax so its not costing anything. And even if it's created each time instantiating the object is negligible compared to actually computing the derivative.
@@ -1534,7 +1534,7 @@ def test_boundary_error_print(capsys): | |||
obj = VacuumBoundaryError(eq, coilset, field_grid=coil_grid) | |||
obj.build() | |||
|
|||
f = np.abs(obj.compute_unscaled(*obj.xs(eq))) | |||
f = np.abs(obj.compute_unscaled(*obj.xs(eq, coilset))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was the fact that the old way worked before somewhat unintentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah seems like it was an uncaught bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just had some questions but nothing on the content
After the recent refactoring to the `Bounce1D` class that resulted from #1214, the API is a little too strict for computations like effective ripple etc. where we vectorize the computation over over some dimensions and loop over others to save memory. This PR changes the expected shape of the pitch angle input to `Bounce1D` in #854 from `(P, M, L)` to `(M, L, P)`. With this change, the two leading axes of all inputs to the methods in that class is `(M, L)`. These changes are tested and already included in downstream branches. I am making new PR instead of directly committing to the `bounce` branch for people who have already reviewed the `bounce` PR. This is better because 1. Easier usage for end users. (Previously, you'd have to manually add trailing axes to pitch angle array). 2. Makes it much simpler to use with JAX's new batched map. 3. Previously we would loop over the pitch angles to save memory. However, this means some computation is repeated because interpax would interpolate multiple times. By looping over the field lines instead and doing the interpolation for all the pitch angles at once, both `_bounce_quadrature` and `interp_to_argmin` are faster. (I'm seeing 30% faster speed just from computing effective ripple (no optimization), but I don't plan to do any benchmarking to see whether that is from recent changes like #1154 or #1043 , or others).
Basically, right now we close over
self
when compiling the methods ofObjectiveFunction
. This means that JAX may bake all the attributes of the objective (ie, transforms, profiles, fields, equilibrium etc) into the compiled function which likely both slows down compilation and may lead to extra memory usage.This changes things so that instead we JIT the method directly, treating
self
as just another argument. Doing this requires refactoring how the derivatives get handled a bit (they are now only local to their respective functions rather than being created separately, shouldn't be any performance hit since creating theDerivative
objects is basically free).Resolves #957
Resolves #1191