Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Optimization with OptimizableCollections to Work #857

Closed
wants to merge 6 commits into from

Conversation

dpanici
Copy link
Collaborator

@dpanici dpanici commented Feb 5, 2024

Resolves #860
This is an inelegant attempt to get OptimizableCollection optimization working, it currently does work for the simple test case I have in here so it could be a possible way towards implementing this?

  • modified FixParameter to work with OptimizableCollection, though the API needs to be hammered out and still need to decide how to handle all the different cases that things may be passed into FixParameter for an OptimizableCollcetion (i.e. a list of lists of params of len(collection), or maybe just one param? I think we should enforce that it is a list of lists of length(collection) though as parsing without that assumption is hard), currently I think it should accept a list of lists of length collection, and if nothing is being fixed for a particular subthing in the collection, then that corresponding entry (say, the ith subthing) params[i] = [ ] (an empty list).
  • Modified Factorize_linear_constraints to work with at least this one example of optimizing with OptimizableCollection
  • The jacobian calculation from the constraint seemed to return a tuple though which was weird... so I must be missing something there
  • modified OptimizableCollection.unpack_params to work with jit, I think the best solution would be to have the init of an OptimizableCollection subclass (or really, it should be in the init of the OptimizableColleciton itself) create the self._split_idx array as a numpy array (to be considered static by JAX during the jit and thus allowed to be used as indexing in unpack_params, currently I have this attribute created in the SumMagneticField init which feels not the neatest way, as it really only needs to know about OptimizableCollection stuff.

I have a dummy objective that just takes in a sum of fields in this PR as every other PR with a sensible objective is not yet in master.

Some of the issues to be addressed (feel free to add to these)

  • factorize_linear_constraints the A matrix construction assumes things.dimensions is not a list
  • How to use FixParameter objects when thing is actually an OptimizableCollection? Should it be passed the thing then for the params a list of the usual argument for each subthing in things e.g. FixParameter(sumfield,params=[["B0",R0"], ["Phi_mn"]]) for sumfield being a SumMagneticField of a ToroidalMagneticField and a FourierCurrentPotentialField?
    • And fixing anywhere that assumes FixParameter is not acting on a list of things
    • Or making a new objective for fixing collections of objects?
  • unpack_params for OptimizableCollection as implementing in master currently is not jittable, I think there is a way around it but we'd need to define the split_idx in the init of the OptimizableCollection objects as an np.array so JAX treats as static array when it is jitted

@f0uriest @ddudt I am not claiming this entirely as my own PR, just wanted to get the conversation started and see if I could make any progress

dpanici and others added 4 commits February 1, 2024 01:41
Copy link
Contributor

github-actions bot commented Feb 5, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +1.48 +/- 1.25     | +1.82e-04 +/- 1.54e-04 |  1.25e-02 +/- 1.2e-04  |  1.23e-02 +/- 1.0e-04  |
 test_build_transform_fft_midres         |     +0.32 +/- 0.89     | +2.92e-04 +/- 8.11e-04 |  9.13e-02 +/- 7.2e-04  |  9.10e-02 +/- 3.8e-04  |
 test_build_transform_fft_highres        |     +0.41 +/- 0.87     | +1.89e-03 +/- 4.00e-03 |  4.61e-01 +/- 2.5e-03  |  4.59e-01 +/- 3.1e-03  |
 test_equilibrium_init_lowres            |     -1.32 +/- 0.96     | -1.05e-02 +/- 7.70e-03 |  7.89e-01 +/- 3.8e-03  |  7.99e-01 +/- 6.7e-03  |
 test_equilibrium_init_medres            |     +0.03 +/- 2.73     | +3.83e-04 +/- 3.85e-02 |  1.41e+00 +/- 1.1e-02  |  1.41e+00 +/- 3.7e-02  |
 test_equilibrium_init_highres           |     +0.05 +/- 0.97     | +1.95e-03 +/- 4.03e-02 |  4.17e+00 +/- 2.7e-02  |  4.17e+00 +/- 3.0e-02  |
 test_objective_compile_dshape_current   |     +0.34 +/- 8.05     | +1.46e-02 +/- 3.46e-01 |  4.31e+00 +/- 2.5e-01  |  4.29e+00 +/- 2.4e-01  |
 test_objective_compile_atf              |     +2.53 +/- 6.60     | +2.29e-01 +/- 5.98e-01 |  9.28e+00 +/- 4.5e-01  |  9.05e+00 +/- 3.9e-01  |
 test_objective_compute_dshape_current   |     -0.42 +/- 4.47     | -9.26e-06 +/- 9.79e-05 |  2.18e-03 +/- 5.5e-05  |  2.19e-03 +/- 8.1e-05  |
 test_objective_compute_atf              |     +4.57 +/- 2.59     | +3.40e-04 +/- 1.92e-04 |  7.78e-03 +/- 1.7e-04  |  7.44e-03 +/- 9.1e-05  |
 test_objective_jac_dshape_current       |     +5.06 +/- 15.87    | +2.31e-03 +/- 7.26e-03 |  4.81e-02 +/- 5.1e-03  |  4.57e-02 +/- 5.2e-03  |
 test_objective_jac_atf                  |     +2.20 +/- 3.22     | +4.84e-02 +/- 7.09e-02 |  2.25e+00 +/- 5.9e-02  |  2.20e+00 +/- 4.0e-02  |
 test_perturb_1                          |     +0.85 +/- 12.88    | +7.21e-02 +/- 1.10e+00 |  8.58e+00 +/- 8.2e-01  |  8.50e+00 +/- 7.2e-01  |
 test_perturb_2                          |     -0.75 +/- 5.06     | -1.05e-01 +/- 7.07e-01 |  1.39e+01 +/- 5.0e-01  |  1.40e+01 +/- 5.0e-01  |

@dpanici dpanici marked this pull request as draft February 6, 2024 01:20
…clean up the code for general examples... but the test is an optimization with Optimizable collection, Reworked FixParameter and factorize_linear_constraints to work with OptimizableCollection
@dpanici dpanici requested review from f0uriest and ddudt February 6, 2024 04:14
)
# store split indices for unpacking
# as a numpy array to avoid jax issues later
self._split_idx = np.asarray(offset)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Way to put this in a super init for OptimizableCollection? otherwise this would have to be done in every init of subclasses of OptimizableCollection

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this needed for? I think there already is something similar in OptimizableCollection

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was, but it used jnp.split and jax threw an error complaining that it could not deal with non-static indices, and it seems that jnp.split cannot be jitted, but if you use an np array as the indices it works

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm ok. We could add an init method to do this, and then call super init from all the subclasses. It might also be possible to do some metaclass stuff like we do to register things as pytrees:

class _AutoRegisterPytree(type):

Another option might be something like we do in the Optimizable class where it gets built the first time its called:

def optimizable_params(self):

)
)
optimizer = Optimizer("lsq-exact")
constraints = (FixParameter(field, [["R0"], []]),)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently I think the API I have setup (have not changed docs yet) is that if it is an optimizableCollection, you MUST pass in params that is a list of lists, the outer list being len(OptimizableCollection), to avoid ambiguity in what is being input. Then if you say only want to fix one subthing's param, like in the case here, you leave the other list of params empty.

This case is fixing only the "R0" of the Tf field but leaving its "B0" and the vertical field's "B0" free

Copy link
Member

@f0uriest f0uriest Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like this might be easier? https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths

with the keypath being into thing.params. We could use something similar for targets etc

for subthing in thing:
if subthing in con._subthings:
# this returns a tuple of length 1 for some reason?
A_ = con.jac_scaled(*xz)[0][con._subthings.index(subthing)]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why I needed a 0 index here, but it con.jac_scaled(*xz) was returning a tuple of length 1 that contained a list of dicts of length(con._subthings), so I had to index into the tuple here. I am really tired so likely i just am missing something stupid here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc con.jac_scaled returns a tuple of dict of array, where the tuple corresponds to how many sets of params the constraint takes, and each dict is the derivative wrt that set of params

# this returns a tuple of length 1 for some reason?
A_ = con.jac_scaled(*xz)[0][con._subthings.index(subthing)]
else:
A_ = {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is completely untested right now (way I have it setup as of commit 237da24 is that FixParameter has in its self._subthings list every single object contained in the OptimizableCollection it has as its self.things[0], so even though I am not fixing anything for the VF field in the example, it is still counted in the first logic branch above this line and the else is not triggered

@f0uriest
Copy link
Member

f0uriest commented Feb 6, 2024

Just as a general comment for this type of stuff, the jax pytree utils are usually super helpful and a lot easier than rolling our own versions of similar stuff:
https://jax.readthedocs.io/en/latest/jax.tree_util.html

@dpanici
Copy link
Collaborator Author

dpanici commented Feb 6, 2024

Just as a general comment for this type of stuff, the jax pytree utils are usually super helpful and a lot easier than rolling our own versions of similar stuff: https://jax.readthedocs.io/en/latest/jax.tree_util.html

I will give this a look tomorrow, thanks

Copy link

codecov bot commented Feb 6, 2024

Codecov Report

Attention: 17 lines in your changes are missing coverage. Please review.

Comparison is base (dc7d2eb) 94.92% compared to head (366fd43) 94.87%.
Report is 7 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #857      +/-   ##
==========================================
- Coverage   94.92%   94.87%   -0.05%     
==========================================
  Files          80       80              
  Lines       19619    19724     +105     
==========================================
+ Hits        18623    18713      +90     
- Misses        996     1011      +15     
Files Coverage Δ
desc/magnetic_fields.py 96.08% <100.00%> (+0.01%) ⬆️
desc/objectives/__init__.py 100.00% <ø> (ø)
desc/objectives/utils.py 97.14% <93.33%> (-0.76%) ⬇️
desc/optimizable.py 96.80% <80.00%> (-1.00%) ⬇️
desc/objectives/_geometry.py 97.72% <90.90%> (-0.71%) ⬇️
desc/objectives/linear_objectives.py 95.73% <84.41%> (-0.92%) ⬇️

... and 1 file with indirect coverage changes

@dpanici
Copy link
Collaborator Author

dpanici commented Feb 6, 2024

  • make FixParameter only act on single objects 0> neater that way
  • make a separate class to fix all currents of a given coilset, for instance (FixParameters, FixCollectionParameters...)
  • Fix all subthings that have the parameter we are asking to constrain
  • have it basically call FixParameter for each single object, and then stack all its A matrices like Combine near axis constraints into fewer number of Objectives #528
  • ShareParameter constraint?
  • One idea is could have as its thing the optimizable collection, and therefore when it makes the A it knows of all the params of the subthings and its orderings, it could make the A matrix like we do now, and then inside of factorize linear when we check subthing in things we would also then check for thing in things and that would create the right A matrix... but no that is not for the right arg? or maybe this does work I need to think more
  • How to handle indexing?
  • How does optimizer know which subthings are actually being constrained?
  • things, is_leaf=lambda x: isinstance(x, Optimizable)
    change is_leaf to be and not isinstance(x, OptimizableCollection) in order to flatten to base Optimizable objects (maybe need to ensure that OptimizableCollection has no params to itself)
  • Rory wants to ignore factorize linear constraints, instead put all linear constraints into a single ObjectiveFunction, and get x vector from it(ensure it only has unique parts)
  • total A = jac(objective)
  • b = jac(objective)(zeros)
  • and add a combine_args call for the objective and the linear constraint ObjectiveFunction as well (move before the linear constraints get factorized in the projection)

So action items @ddudt

  • create ObjectiveFunction inside of optimizer.optimize that takes in the linear constriants
  • call combine_args on it and objective before we do the linear constraint projection
  • now factorize_linear_constraints argument self._constraints in (
    ) = factorize_linear_constraints(
    ) is now just an objective function, use its jac_scaled argument to get A_full in (
    A_full = jnp.vstack([jnp.hstack(Ai) for Ai in A])
    ) and call obj.compute(zeros) to get b_full
  • keep the indexing stuff after that line above, and everything should work fine for optimizable optimization

for OptimizableCollection stuff

  • things, is_leaf=lambda x: isinstance(x, Optimizable)
    change is_leaf to be and not isinstance(x, OptimizableCollection) in order to flatten to base Optimizable objects (maybe need to ensure that OptimizableCollection has no params to itself)
  • Make FixParameter only be for a single object (so reverting what I have done in this PR)
  • make a FixCollectionParameter for use in fixing a colletion of objects which share the same parameter (enforce they are the same subthings i.e. all same coiil type in a coilset?)
  • change the things inside of optimizable.py for everything leading up to the split_idx at heree.g. at
    key: jnp.asarray(getattr(self, key)).size for key in self.optimizable_params
  • to avoid jax complaining about non-static indices

@dpanici
Copy link
Collaborator Author

dpanici commented May 1, 2024

#956 takes care of this

@dpanici dpanici closed this May 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make optimization framework work with OptimizableCollection
2 participants