-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow Optimization with OptimizableCollections to Work #857
Conversation
…Fixed how A matrix is made to work with opt collection, and fixed unpack params to work under jit, but still more to be done
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +1.48 +/- 1.25 | +1.82e-04 +/- 1.54e-04 | 1.25e-02 +/- 1.2e-04 | 1.23e-02 +/- 1.0e-04 |
test_build_transform_fft_midres | +0.32 +/- 0.89 | +2.92e-04 +/- 8.11e-04 | 9.13e-02 +/- 7.2e-04 | 9.10e-02 +/- 3.8e-04 |
test_build_transform_fft_highres | +0.41 +/- 0.87 | +1.89e-03 +/- 4.00e-03 | 4.61e-01 +/- 2.5e-03 | 4.59e-01 +/- 3.1e-03 |
test_equilibrium_init_lowres | -1.32 +/- 0.96 | -1.05e-02 +/- 7.70e-03 | 7.89e-01 +/- 3.8e-03 | 7.99e-01 +/- 6.7e-03 |
test_equilibrium_init_medres | +0.03 +/- 2.73 | +3.83e-04 +/- 3.85e-02 | 1.41e+00 +/- 1.1e-02 | 1.41e+00 +/- 3.7e-02 |
test_equilibrium_init_highres | +0.05 +/- 0.97 | +1.95e-03 +/- 4.03e-02 | 4.17e+00 +/- 2.7e-02 | 4.17e+00 +/- 3.0e-02 |
test_objective_compile_dshape_current | +0.34 +/- 8.05 | +1.46e-02 +/- 3.46e-01 | 4.31e+00 +/- 2.5e-01 | 4.29e+00 +/- 2.4e-01 |
test_objective_compile_atf | +2.53 +/- 6.60 | +2.29e-01 +/- 5.98e-01 | 9.28e+00 +/- 4.5e-01 | 9.05e+00 +/- 3.9e-01 |
test_objective_compute_dshape_current | -0.42 +/- 4.47 | -9.26e-06 +/- 9.79e-05 | 2.18e-03 +/- 5.5e-05 | 2.19e-03 +/- 8.1e-05 |
test_objective_compute_atf | +4.57 +/- 2.59 | +3.40e-04 +/- 1.92e-04 | 7.78e-03 +/- 1.7e-04 | 7.44e-03 +/- 9.1e-05 |
test_objective_jac_dshape_current | +5.06 +/- 15.87 | +2.31e-03 +/- 7.26e-03 | 4.81e-02 +/- 5.1e-03 | 4.57e-02 +/- 5.2e-03 |
test_objective_jac_atf | +2.20 +/- 3.22 | +4.84e-02 +/- 7.09e-02 | 2.25e+00 +/- 5.9e-02 | 2.20e+00 +/- 4.0e-02 |
test_perturb_1 | +0.85 +/- 12.88 | +7.21e-02 +/- 1.10e+00 | 8.58e+00 +/- 8.2e-01 | 8.50e+00 +/- 7.2e-01 |
test_perturb_2 | -0.75 +/- 5.06 | -1.05e-01 +/- 7.07e-01 | 1.39e+01 +/- 5.0e-01 | 1.40e+01 +/- 5.0e-01 | |
…clean up the code for general examples... but the test is an optimization with Optimizable collection, Reworked FixParameter and factorize_linear_constraints to work with OptimizableCollection
) | ||
# store split indices for unpacking | ||
# as a numpy array to avoid jax issues later | ||
self._split_idx = np.asarray(offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Way to put this in a super init for OptimizableCollection
? otherwise this would have to be done in every init of subclasses of OptimizableCollection
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this needed for? I think there already is something similar in OptimizableCollection
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was, but it used jnp.split
and jax threw an error complaining that it could not deal with non-static indices, and it seems that jnp.split
cannot be jitted, but if you use an np array as the indices it works
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm ok. We could add an init method to do this, and then call super init from all the subclasses. It might also be possible to do some metaclass stuff like we do to register things as pytrees:
DESC/desc/io/equilibrium_io.py
Line 107 in 969774a
class _AutoRegisterPytree(type): |
Another option might be something like we do in the Optimizable
class where it gets built the first time its called:
Line 18 in 969774a
def optimizable_params(self): |
) | ||
) | ||
optimizer = Optimizer("lsq-exact") | ||
constraints = (FixParameter(field, [["R0"], []]),) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently I think the API I have setup (have not changed docs yet) is that if it is an optimizableCollection, you MUST pass in params that is a list of lists, the outer list being len(OptimizableCollection)
, to avoid ambiguity in what is being input. Then if you say only want to fix one subthing's param, like in the case here, you leave the other list of params empty.
This case is fixing only the "R0"
of the Tf field but leaving its "B0"
and the vertical field's "B0"
free
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
something like this might be easier? https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths
with the keypath being into thing.params
. We could use something similar for targets etc
for subthing in thing: | ||
if subthing in con._subthings: | ||
# this returns a tuple of length 1 for some reason? | ||
A_ = con.jac_scaled(*xz)[0][con._subthings.index(subthing)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure why I needed a 0 index here, but it con.jac_scaled(*xz)
was returning a tuple of length 1 that contained a list of dicts of length(con._subthings), so I had to index into the tuple here. I am really tired so likely i just am missing something stupid here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iirc con.jac_scaled
returns a tuple of dict of array, where the tuple corresponds to how many sets of params the constraint takes, and each dict is the derivative wrt that set of params
# this returns a tuple of length 1 for some reason? | ||
A_ = con.jac_scaled(*xz)[0][con._subthings.index(subthing)] | ||
else: | ||
A_ = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is completely untested right now (way I have it setup as of commit 237da24 is that FixParameter
has in its self._subthings
list every single object contained in the OptimizableCollection
it has as its self.things[0]
, so even though I am not fixing anything for the VF field in the example, it is still counted in the first logic branch above this line and the else is not triggered
Just as a general comment for this type of stuff, the jax pytree utils are usually super helpful and a lot easier than rolling our own versions of similar stuff: |
I will give this a look tomorrow, thanks |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #857 +/- ##
==========================================
- Coverage 94.92% 94.87% -0.05%
==========================================
Files 80 80
Lines 19619 19724 +105
==========================================
+ Hits 18623 18713 +90
- Misses 996 1011 +15
|
So action items @ddudt
for OptimizableCollection stuff
|
#956 takes care of this |
Resolves #860
This is an inelegant attempt to get OptimizableCollection optimization working, it currently does work for the simple test case I have in here so it could be a possible way towards implementing this?
FixParameter
to work with OptimizableCollection, though the API needs to be hammered out and still need to decide how to handle all the different cases that things may be passed into FixParameter for an OptimizableCollcetion (i.e. a list of lists of params of len(collection), or maybe just one param? I think we should enforce that it is a list of lists of length(collection) though as parsing without that assumption is hard), currently I think it should accept a list of lists of length collection, and if nothing is being fixed for a particular subthing in the collection, then that corresponding entry (say, the ith subthing) params[i] = [ ] (an empty list).OptimizableCollection.unpack_params
to work withjit
, I think the best solution would be to have the init of anOptimizableCollection
subclass (or really, it should be in theinit
of theOptimizableColleciton
itself) create theself._split_idx
array as a numpy array (to be considered static by JAX during the jit and thus allowed to be used as indexing inunpack_params
, currently I have this attribute created in theSumMagneticField
init which feels not the neatest way, as it really only needs to know aboutOptimizableCollection
stuff.I have a dummy objective that just takes in a sum of fields in this PR as every other PR with a sensible objective is not yet in master.
Some of the issues to be addressed (feel free to add to these)
A
matrix construction assumes things.dimensions is not a listFixParameter
objects whenthing
is actually anOptimizableCollection
? Should it be passed thething
then for the params a list of the usual argument for eachsubthing
inthings
e.g.FixParameter(sumfield,params=[["B0",R0"], ["Phi_mn"]])
forsumfield
being aSumMagneticField
of aToroidalMagneticField
and aFourierCurrentPotentialField
?unpack_params
forOptimizableCollection
as implementing in master currently is not jittable, I think there is a way around it but we'd need to define thesplit_idx
in the init of theOptimizableCollection
objects as annp.array
so JAX treats as static array when it is jitted@f0uriest @ddudt I am not claiming this entirely as my own PR, just wanted to get the conversation started and see if I could make any progress