From d8fe0bb242d043ef5b0139f2904cfc3ca6e706b3 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 4 Jun 2024 16:14:37 -0400 Subject: [PATCH 01/12] Remove jit method of objective, directly compile methods --- desc/compute/utils.py | 2 + desc/objectives/linear_objectives.py | 8 +- desc/objectives/objective_funs.py | 315 +++++++++++++++------------ 3 files changed, 185 insertions(+), 140 deletions(-) diff --git a/desc/compute/utils.py b/desc/compute/utils.py index b65b9365a9..c1cc580671 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -344,6 +344,8 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs): basis = getattr(obj, c + "_basis") # first check if we already have a transform with a compatible basis for transform in transforms.values(): + if jitable: # re-using transforms doesn't work under jit, so skip + continue if basis.equiv(getattr(transform, "basis", None)): ders = np.unique( np.vstack([derivs[c], transform.derivatives]), axis=0 diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index b20317e3aa..ea91b4bcb6 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -38,8 +38,8 @@ def update_target(self, thing): assert len(new_target) == len(self.target) self.target = new_target self._target_from_user = self.target # in case the Objective is re-built - if self._use_jit: - self.jit() + # if self._use_jit: + # self.jit() def _parse_target_from_user( self, target_from_user, default_target, default_bounds, idx @@ -231,8 +231,8 @@ def update_target(self, thing): """ self.target = self.compute(thing.params_dict) - if self._use_jit: - self.jit() + # if self._use_jit: + # self.jit() class BoundaryRSelfConsistency(_Objective): diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 0a98bb0a28..6211a7cc18 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -1,7 +1,6 @@ """Base classes for objectives.""" from abc import ABC, abstractmethod -from functools import partial import numpy as np @@ -63,90 +62,47 @@ def _set_derivatives(self): self._deriv_mode = "batched" else: self._deriv_mode = "blocked" - if self._deriv_mode in {"batched", "looped", "blocked"}: - self._grad = Derivative(self.compute_scalar, mode="grad") - self._hess = Derivative(self.compute_scalar, mode="hess") - if self._deriv_mode == "batched": - self._jac_scaled = Derivative(self.compute_scaled, mode="fwd") - self._jac_scaled_error = Derivative(self.compute_scaled_error, mode="fwd") - self._jac_unscaled = Derivative(self.compute_unscaled, mode="fwd") - if self._deriv_mode == "looped": - self._jac_scaled = Derivative(self.compute_scaled, mode="looped") - self._jac_scaled_error = Derivative( - self.compute_scaled_error, mode="looped" - ) - self._jac_unscaled = Derivative(self.compute_unscaled, mode="looped") - if self._deriv_mode == "blocked": - # could also do something similar for grad and hess, but probably not - # worth it. grad is already super cheap to eval all at once, and blocked - # hess would only be block diag which may miss important interactions. - - def jac_(op, x, constants=None): - if constants is None: - constants = self.constants - xs_splits = np.cumsum([t.dim_x for t in self.things]) - xs = jnp.split(x, xs_splits) - J = [] - for obj, const in zip(self.objectives, constants): - # get the xs that go to that objective - xi = [x for x, t in zip(xs, self.things) if t in obj.things] - Ji_ = getattr(obj, op)( - *xi, constants=const - ) # jac wrt to just those things - Ji = [] # jac wrt all things - for thing in self.things: - if thing in obj.things: - i = obj.things.index(thing) - Ji += [Ji_[i]] - else: - Ji += [jnp.zeros((obj.dim_f, thing.dim_x))] - Ji = jnp.hstack(Ji) - J += [Ji] - return jnp.vstack(J) - - self._jac_scaled = partial(jac_, "jac_scaled") - self._jac_scaled_error = partial(jac_, "jac_scaled_error") - self._jac_unscaled = partial(jac_, "jac_unscaled") - - def jit(self): # noqa: C901 - """Apply JIT to compute methods, or re-apply after updating self.""" - # can't loop here because del doesn't work on getattr - # main idea is that when jitting a method, jax replaces that method - # with a CompiledFunction object, with self compiled in. To re-jit - # (ie, after updating attributes of self), we just need to delete the jax - # CompiledFunction object, which will then leave the raw method in its place, - # and then jit the raw method with the new self - - self._use_jit = True - - methods = [ - "compute_scaled", - "compute_scaled_error", - "compute_unscaled", - "compute_scalar", - "jac_scaled", - "jac_scaled_error", - "jac_unscaled", - "hess", - "grad", - "jvp_scaled", - "jvp_scaled_error", - "jvp_unscaled", - "vjp_scaled", - "vjp_scaled_error", - "vjp_unscaled", - ] - - for method in methods: - try: - delattr(self, method) - except AttributeError: - pass - setattr(self, method, jit(getattr(self, method))) - - for obj in self._objectives: - if obj._use_jit: - obj.jit() + + # TODO: figure out how to not use jit if user requests + # def jit(self): # noqa: C901 + # """Apply JIT to compute methods, or re-apply after updating self.""" + # # can't loop here because del doesn't work on getattr + # # main idea is that when jitting a method, jax replaces that method + # # with a CompiledFunction object, with self compiled in. To re-jit + # # (ie, after updating attributes of self), we just need to delete the jax + # # CompiledFunction object, which will then leave the raw method in its place, + # # and then jit the raw method with the new self + + # self._use_jit = True + + # methods = [ + # "compute_scaled", + # "compute_scaled_error", + # "compute_unscaled", + # "compute_scalar", + # "jac_scaled", + # "jac_scaled_error", + # "jac_unscaled", + # "hess", + # "grad", + # "jvp_scaled", + # "jvp_scaled_error", + # "jvp_unscaled", + # "vjp_scaled", + # "vjp_scaled_error", + # "vjp_unscaled", + # ] + + # for method in methods: + # try: + # delattr(self, method) + # except AttributeError: + # pass + # setattr(self, method, jit(getattr(self, method))) + + # for obj in self._objectives: + # if obj._use_jit: + # obj.jit() def build(self, use_jit=None, verbose=1): """Build the objective. @@ -178,11 +134,19 @@ def build(self, use_jit=None, verbose=1): self._scalar = False self._set_derivatives() - if self.use_jit: - self.jit() + # if self.use_jit: + # self.jit() self._set_things() + # this is needed to know which "thing" goes with which sub-objective, + # ie objectives[i].things == [things[k] for k in things_per_objective_idx[i]] + self._things_per_objective_idx = [] + for obj in self.objectives: + self._things_per_objective_idx.append( + [i for i, t in enumerate(self.things) if t in obj.things] + ) + self._built = True timer.stop("Objective build") if verbose > 1: @@ -235,6 +199,7 @@ def flatten(things): self._unflatten = unflatten self._flatten = flatten + @jit def compute_unscaled(self, x, constants=None): """Compute the raw value of the objective function. @@ -262,6 +227,7 @@ def compute_unscaled(self, x, constants=None): ) return f + @jit def compute_scaled(self, x, constants=None): """Compute the objective function and apply weighting and normalization. @@ -289,6 +255,7 @@ def compute_scaled(self, x, constants=None): ) return f + @jit def compute_scaled_error(self, x, constants=None): """Compute and apply the target/bounds, weighting, and normalization. @@ -316,6 +283,7 @@ def compute_scaled_error(self, x, constants=None): ) return f + @jit def compute_scalar(self, x, constants=None): """Compute the sum of squares error. @@ -395,8 +363,8 @@ def unpack_state(self, x, per_objective=True): params = self._unflatten(params) # this filters out the params of things that are unused by each objective params = [ - [par for par, thing in zip(param, self.things) if thing in obj.things] - for param, obj in zip(params, self.objectives) + [param[i] for i in idx] + for param, idx in zip(params, self._things_per_objective_idx) ] return params @@ -408,35 +376,94 @@ def x(self, *things): xs = [t.pack_params(t.params_dict) for t in things] return jnp.concatenate(xs) + @jit def grad(self, x, constants=None): """Compute gradient vector of self.compute_scalar wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_1d(self._grad(x, constants).squeeze()) + return jnp.atleast_1d( + Derivative(self.compute_scalar, mode="grad")(x, constants).squeeze() + ) + @jit def hess(self, x, constants=None): """Compute Hessian matrix of self.compute_scalar wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_2d(self._hess(x, constants).squeeze()) + return jnp.atleast_2d( + Derivative(self.compute_scalar, mode="hess")(x, constants).squeeze() + ) + + def _jac(self, op, x, constants=None): + # could also do something similar for grad and hess, but probably not + # worth it. grad is already super cheap to eval all at once, and blocked + # hess would only be block diag which may miss important interactions. + if constants is None: + constants = self.constants + xs_splits = np.cumsum([t.dim_x for t in self.things]) + xs = jnp.split(x, xs_splits) + J = [] + for k, (obj, const) in enumerate(zip(self.objectives, constants)): + # get the xs that go to that objective + xi = [xs[i] for i in self._things_per_objective_idx[k]] + Ji_ = getattr(obj, op)(*xi, constants=const) # jac wrt to just those things + Ji = [] # jac wrt all things + for i, (thing, idx) in enumerate( + zip(self.things, self._things_per_objective_idx) + ): + if i in idx: + Ji += [Ji_[idx.index(i)]] + else: + Ji += [jnp.zeros((obj.dim_f, thing.dim_x))] + Ji = jnp.hstack(Ji) + J += [Ji] + return jnp.vstack(J) + + @jit def jac_scaled(self, x, constants=None): """Compute Jacobian matrix of self.compute_scaled wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_2d(self._jac_scaled(x, constants).squeeze()) + if self._deriv_mode == "batched": + J = Derivative(self.compute_scaled, mode="fwd")(x, constants) + if self._deriv_mode == "looped": + J = Derivative(self.compute_scaled, mode="looped")(x, constants) + if self._deriv_mode == "blocked": + J = self._jac("jac_scaled", x, constants) + + return jnp.atleast_2d(J.squeeze()) + + @jit def jac_scaled_error(self, x, constants=None): """Compute Jacobian matrix of self.compute_scaled_error wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_2d(self._jac_scaled_error(x, constants).squeeze()) + if self._deriv_mode == "batched": + J = Derivative(self.compute_scaled_error, mode="fwd")(x, constants) + if self._deriv_mode == "looped": + J = Derivative(self.compute_scaled_error, mode="looped")(x, constants) + if self._deriv_mode == "blocked": + J = self._jac("jac_scaled_error", x, constants) + + return jnp.atleast_2d(J.squeeze()) + + @jit def jac_unscaled(self, x, constants=None): """Compute Jacobian matrix of self.compute_unscaled wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_2d(self._jac_unscaled(x, constants).squeeze()) + + if self._deriv_mode == "batched": + J = Derivative(self.compute_unscaled, mode="fwd")(x, constants) + if self._deriv_mode == "looped": + J = Derivative(self.compute_unscaled, mode="looped")(x, constants) + if self._deriv_mode == "blocked": + J = self._jac("jac_unscaled", x, constants) + + return jnp.atleast_2d(J.squeeze()) def _jvp(self, v, x, constants=None, op="compute_scaled"): v = v if isinstance(v, (tuple, list)) else (v,) @@ -456,6 +483,7 @@ def _jvp(self, v, x, constants=None, op="compute_scaled"): else: raise NotImplementedError("Cannot compute JVP higher than 3rd order.") + @jit def jvp_scaled(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_scaled. @@ -472,6 +500,7 @@ def jvp_scaled(self, v, x, constants=None): """ return self._jvp(v, x, constants, "compute_scaled") + @jit def jvp_scaled_error(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_scaled_error. @@ -488,6 +517,7 @@ def jvp_scaled_error(self, v, x, constants=None): """ return self._jvp(v, x, constants, "compute_scaled_error") + @jit def jvp_unscaled(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_unscaled. @@ -508,6 +538,7 @@ def _vjp(self, v, x, constants=None, op="compute_scaled"): fun = lambda x: getattr(self, op)(x, constants) return Derivative.compute_vjp(fun, 0, v, x) + @jit def vjp_scaled(self, v, x, constants=None): """Compute vector-Jacobian product of self.compute_scaled. @@ -523,6 +554,7 @@ def vjp_scaled(self, v, x, constants=None): """ return self._vjp(v, x, constants, "compute_scaled") + @jit def vjp_scaled_error(self, v, x, constants=None): """Compute vector-Jacobian product of self.compute_scaled_error. @@ -538,6 +570,7 @@ def vjp_scaled_error(self, v, x, constants=None): """ return self._vjp(v, x, constants, "compute_scaled_error") + @jit def vjp_unscaled(self, v, x, constants=None): """Compute vector-Jacobian product of self.compute_unscaled. @@ -815,10 +848,6 @@ def __init__( def _set_derivatives(self): """Set up derivatives of the objective wrt each argument.""" - argnums = tuple(range(len(self.things))) - # derivatives return tuple, one for each thing - self._grad = Derivative(self.compute_scalar, argnums, mode="grad") - self._hess = Derivative(self.compute_scalar, argnums, mode="hess") if self._deriv_mode == "auto": # choose based on shape of jacobian. fwd mode is more memory efficient # so we prefer that unless the jacobian is really wide @@ -827,38 +856,29 @@ def _set_derivatives(self): if self.dim_f >= 0.5 * sum(t.dim_x for t in self.things) else "rev" ) - self._jac_scaled = Derivative( - self.compute_scaled, argnums, mode=self._deriv_mode - ) - self._jac_scaled_error = Derivative( - self.compute_scaled_error, argnums, mode=self._deriv_mode - ) - self._jac_unscaled = Derivative( - self.compute_unscaled, argnums, mode=self._deriv_mode - ) - def jit(self): # noqa: C901 - """Apply JIT to compute methods, or re-apply after updating self.""" - self._use_jit = True - - methods = [ - "compute_scaled", - "compute_scaled_error", - "compute_unscaled", - "compute_scalar", - "jac_scaled", - "jac_scaled_error", - "jac_unscaled", - "hess", - "grad", - ] - - for method in methods: - try: - delattr(self, method) - except AttributeError: - pass - setattr(self, method, jit(getattr(self, method))) + # def jit(self): # noqa: C901 + # """Apply JIT to compute methods, or re-apply after updating self.""" + # self._use_jit = True + + # methods = [ + # "compute_scaled", + # "compute_scaled_error", + # "compute_unscaled", + # "compute_scalar", + # "jac_scaled", + # "jac_scaled_error", + # "jac_unscaled", + # "hess", + # "grad", + # ] + + # for method in methods: + # try: + # delattr(self, method) + # except AttributeError: + # pass + # setattr(self, method, jit(getattr(self, method))) def _check_dimensions(self): """Check that len(target) = len(bounds) = len(weight) = dim_f.""" @@ -912,8 +932,8 @@ def build(self, use_jit=True, verbose=1): if use_jit is not None: self._use_jit = use_jit - if self._use_jit: - self.jit() + # if self._use_jit: + # self.jit() self._built = True @@ -930,6 +950,7 @@ def _maybe_array_to_params(self, *args): argsout += (arg,) return argsout + @jit def compute_unscaled(self, *args, **kwargs): """Compute the raw value of the objective.""" args = self._maybe_array_to_params(*args) @@ -938,6 +959,7 @@ def compute_unscaled(self, *args, **kwargs): f = self._loss_function(f) return jnp.atleast_1d(f) + @jit def compute_scaled(self, *args, **kwargs): """Compute and apply weighting and normalization.""" args = self._maybe_array_to_params(*args) @@ -946,6 +968,7 @@ def compute_scaled(self, *args, **kwargs): f = self._loss_function(f) return jnp.atleast_1d(self._scale(f, **kwargs)) + @jit def compute_scaled_error(self, *args, **kwargs): """Compute and apply the target/bounds, weighting, and normalization.""" args = self._maybe_array_to_params(*args) @@ -988,6 +1011,7 @@ def _scale(self, f, *args, **kwargs): f_norm = jnp.atleast_1d(f) / self.normalization # normalization return f_norm * w * self.weight + @jit def compute_scalar(self, *args, **kwargs): """Compute the scalar form of the objective.""" if self.scalar: @@ -996,25 +1020,41 @@ def compute_scalar(self, *args, **kwargs): f = jnp.sum(self.compute_scaled_error(*args, **kwargs) ** 2) / 2 return f.squeeze() + @jit def grad(self, *args, **kwargs): """Compute gradient vector of self.compute_scalar wrt x.""" - return self._grad(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_scalar, argnums, mode="grad")(*args, **kwargs) + @jit def hess(self, *args, **kwargs): """Compute Hessian matrix of self.compute_scalar wrt x.""" - return self._hess(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_scalar, argnums, mode="hess")(*args, **kwargs) + @jit def jac_scaled(self, *args, **kwargs): """Compute Jacobian matrix of self.compute_scaled wrt x.""" - return self._jac_scaled(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_scaled, argnums, mode=self._deriv_mode)( + *args, **kwargs + ) + @jit def jac_scaled_error(self, *args, **kwargs): """Compute Jacobian matrix of self.compute_scaled_error wrt x.""" - return self._jac_scaled_error(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_scaled_error, argnums, mode=self._deriv_mode)( + *args, **kwargs + ) + @jit def jac_unscaled(self, *args, **kwargs): """Compute Jacobian matrix of self.compute_unscaled wrt x.""" - return self._jac_unscaled(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_unscaled, argnums, mode=self._deriv_mode)( + *args, **kwargs + ) def _jvp(self, v, x, constants=None, op="compute_scaled"): v = v if isinstance(v, (tuple, list)) else (v,) @@ -1026,6 +1066,7 @@ def _jvp(self, v, x, constants=None, op="compute_scaled"): sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) + @jit def jvp_scaled(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_scaled. @@ -1041,6 +1082,7 @@ def jvp_scaled(self, v, x, constants=None): """ return self._jvp(v, x, constants, "compute_scaled") + @jit def jvp_scaled_error(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_scaled_error. @@ -1056,6 +1098,7 @@ def jvp_scaled_error(self, v, x, constants=None): """ return self._jvp(v, x, constants, "compute_scaled_error") + @jit def jvp_unscaled(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_unscaled. From f99c3d182d68cdcf3bb00e5fef3adb3c4f39358a Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 4 Jun 2024 17:52:34 -0400 Subject: [PATCH 02/12] Fix incorrect indexing --- desc/objectives/objective_funs.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 6211a7cc18..dedb019687 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -406,14 +406,13 @@ def _jac(self, op, x, constants=None): J = [] for k, (obj, const) in enumerate(zip(self.objectives, constants)): # get the xs that go to that objective - xi = [xs[i] for i in self._things_per_objective_idx[k]] + thing_idx = self._things_per_objective_idx[k] + xi = [xs[i] for i in thing_idx] Ji_ = getattr(obj, op)(*xi, constants=const) # jac wrt to just those things Ji = [] # jac wrt all things - for i, (thing, idx) in enumerate( - zip(self.things, self._things_per_objective_idx) - ): - if i in idx: - Ji += [Ji_[idx.index(i)]] + for i, thing in enumerate(self.things): + if i in thing_idx: + Ji += [Ji_[thing_idx.index(i)]] else: Ji += [jnp.zeros((obj.dim_f, thing.dim_x))] Ji = jnp.hstack(Ji) From e37eb35c8a71c5de089888f6d440fa8ddf0fecaf Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 25 Jun 2024 19:36:46 -0400 Subject: [PATCH 03/12] Add unjit method --- desc/objectives/linear_objectives.py | 8 +- desc/objectives/objective_funs.py | 120 ++++++++++++--------------- 2 files changed, 58 insertions(+), 70 deletions(-) diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index ea91b4bcb6..50479542cd 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -38,8 +38,8 @@ def update_target(self, thing): assert len(new_target) == len(self.target) self.target = new_target self._target_from_user = self.target # in case the Objective is re-built - # if self._use_jit: - # self.jit() + if not self._use_jit: + self._unjit() def _parse_target_from_user( self, target_from_user, default_target, default_bounds, idx @@ -231,8 +231,8 @@ def update_target(self, thing): """ self.target = self.compute(thing.params_dict) - # if self._use_jit: - # self.jit() + if not self._use_jit: + self._unjit() class BoundaryRSelfConsistency(_Objective): diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 930b8e000d..68e71731dc 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -1,5 +1,6 @@ """Base classes for objectives.""" +import functools from abc import ABC, abstractmethod import numpy as np @@ -63,46 +64,32 @@ def _set_derivatives(self): else: self._deriv_mode = "blocked" - # TODO: figure out how to not use jit if user requests - # def jit(self): # noqa: C901 - # """Apply JIT to compute methods, or re-apply after updating self.""" - # # can't loop here because del doesn't work on getattr - # # main idea is that when jitting a method, jax replaces that method - # # with a CompiledFunction object, with self compiled in. To re-jit - # # (ie, after updating attributes of self), we just need to delete the jax - # # CompiledFunction object, which will then leave the raw method in its place, - # # and then jit the raw method with the new self - - # self._use_jit = True - - # methods = [ - # "compute_scaled", - # "compute_scaled_error", - # "compute_unscaled", - # "compute_scalar", - # "jac_scaled", - # "jac_scaled_error", - # "jac_unscaled", - # "hess", - # "grad", - # "jvp_scaled", - # "jvp_scaled_error", - # "jvp_unscaled", - # "vjp_scaled", - # "vjp_scaled_error", - # "vjp_unscaled", - # ] - - # for method in methods: - # try: - # delattr(self, method) - # except AttributeError: - # pass - # setattr(self, method, jit(getattr(self, method))) - - # for obj in self._objectives: - # if obj._use_jit: - # obj.jit() + def _unjit(self): + """Remove jit compiled methods.""" + methods = [ + "compute_scaled", + "compute_scaled_error", + "compute_unscaled", + "compute_scalar", + "jac_scaled", + "jac_scaled_error", + "jac_unscaled", + "hess", + "grad", + "jvp_scaled", + "jvp_scaled_error", + "jvp_unscaled", + "vjp_scaled", + "vjp_scaled_error", + "vjp_unscaled", + ] + for method in methods: + try: + setattr( + self, method, functools.partial(getattr(self, method)._fun, self) + ) + except AttributeError: + pass def build(self, use_jit=None, verbose=1): """Build the objective. @@ -134,8 +121,8 @@ def build(self, use_jit=None, verbose=1): self._scalar = False self._set_derivatives() - # if self.use_jit: - # self.jit() + if not self.use_jit: + self._unjit() self._set_things() @@ -856,28 +843,29 @@ def _set_derivatives(self): else "rev" ) - # def jit(self): # noqa: C901 - # """Apply JIT to compute methods, or re-apply after updating self.""" - # self._use_jit = True - - # methods = [ - # "compute_scaled", - # "compute_scaled_error", - # "compute_unscaled", - # "compute_scalar", - # "jac_scaled", - # "jac_scaled_error", - # "jac_unscaled", - # "hess", - # "grad", - # ] - - # for method in methods: - # try: - # delattr(self, method) - # except AttributeError: - # pass - # setattr(self, method, jit(getattr(self, method))) + def _unjit(self): + """Remove jit compiled methods.""" + methods = [ + "compute_scaled", + "compute_scaled_error", + "compute_unscaled", + "compute_scalar", + "jac_scaled", + "jac_scaled_error", + "jac_unscaled", + "jvp_scaled", + "jvp_scaled_error", + "jvp_unscaled", + "hess", + "grad", + ] + for method in methods: + try: + setattr( + self, method, functools.partial(getattr(self, method)._fun, self) + ) + except AttributeError: + pass def _check_dimensions(self): """Check that len(target) = len(bounds) = len(weight) = dim_f.""" @@ -931,8 +919,8 @@ def build(self, use_jit=True, verbose=1): if use_jit is not None: self._use_jit = use_jit - # if self._use_jit: - # self.jit() + if not self._use_jit: + self._unjit() self._built = True From 8980654899a7c8a5907d6aab9f90442899a75cbb Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 25 Jun 2024 19:37:54 -0400 Subject: [PATCH 04/12] Add some asserts to catch weird edge cases --- desc/objectives/objective_funs.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 68e71731dc..b8bcd00a56 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -126,14 +126,6 @@ def build(self, use_jit=None, verbose=1): self._set_things() - # this is needed to know which "thing" goes with which sub-objective, - # ie objectives[i].things == [things[k] for k in things_per_objective_idx[i]] - self._things_per_objective_idx = [] - for obj in self.objectives: - self._things_per_objective_idx.append( - [i for i, t in enumerate(self.things) if t in obj.things] - ) - self._built = True timer.stop("Objective build") if verbose > 1: @@ -169,6 +161,14 @@ def _set_things(self, things=None): ) unique_, inds_ = unique_list(flat_) + # this is needed to know which "thing" goes with which sub-objective, + # ie objectives[i].things == [things[k] for k in things_per_objective_idx[i]] + self._things_per_objective_idx = [] + for obj in self.objectives: + self._things_per_objective_idx.append( + [i for i, t in enumerate(unique_) if t in obj.things] + ) + def unflatten(unique): assert len(unique) == len(unique_) flat = [unique[i] for i in inds_] @@ -206,6 +206,7 @@ def compute_unscaled(self, x, constants=None): params = self.unpack_state(x) if constants is None: constants = self.constants + assert len(params) == len(constants) == len(self.objectives) f = jnp.concatenate( [ obj.compute_unscaled(*par, constants=const) @@ -234,6 +235,7 @@ def compute_scaled(self, x, constants=None): params = self.unpack_state(x) if constants is None: constants = self.constants + assert len(params) == len(constants) == len(self.objectives) f = jnp.concatenate( [ obj.compute_scaled(*par, constants=const) @@ -262,6 +264,7 @@ def compute_scaled_error(self, x, constants=None): params = self.unpack_state(x) if constants is None: constants = self.constants + assert len(params) == len(constants) == len(self.objectives) f = jnp.concatenate( [ obj.compute_scaled_error(*par, constants=const) @@ -309,6 +312,7 @@ def print_value(self, x, constants=None): f = jnp.sum(self.compute_scaled_error(x, constants=constants) ** 2) / 2 print("Total (sum of squares): {:10.3e}, ".format(f)) params = self.unpack_state(x) + assert len(params) == len(constants) == len(self.objectives) for par, obj, const in zip(params, self.objectives, constants): obj.print_value(*par, constants=const) return None @@ -344,11 +348,14 @@ def unpack_state(self, x, per_objective=True): xs_splits = np.cumsum([t.dim_x for t in self.things]) xs = jnp.split(x, xs_splits) + xs = xs[: len(self.things)] # jnp.split returns an empty array at the end + assert len(xs) == len(self.things) params = [t.unpack_params(xi) for t, xi in zip(self.things, xs)] if per_objective: # params is a list of lists of dicts, for each thing and for each objective params = self._unflatten(params) # this filters out the params of things that are unused by each objective + assert len(params) == len(self._things_per_objective_idx) params = [ [param[i] for i in idx] for param, idx in zip(params, self._things_per_objective_idx) @@ -359,6 +366,7 @@ def x(self, *things): """Return the full state vector from the Optimizable objects things.""" # TODO: also check resolution etc? things = things or self.things + assert len(things) == len(self.things) assert all([type(t1) is type(t2) for t1, t2 in zip(things, self.things)]) xs = [t.pack_params(t.params_dict) for t in things] return jnp.concatenate(xs) @@ -391,6 +399,7 @@ def _jac(self, op, x, constants=None): xs_splits = np.cumsum([t.dim_x for t in self.things]) xs = jnp.split(x, xs_splits) J = [] + assert len(self.objectives) == len(self.constants) for k, (obj, const) in enumerate(zip(self.objectives, constants)): # get the xs that go to that objective thing_idx = self._things_per_objective_idx[k] @@ -930,6 +939,7 @@ def compute(self, *args, **kwargs): def _maybe_array_to_params(self, *args): argsout = tuple() + assert len(args) == len(self.things) for arg, thing in zip(args, self.things): if isinstance(arg, (np.ndarray, jnp.ndarray)): argsout += (thing.unpack_params(arg),) @@ -1270,6 +1280,7 @@ def things(self, new): if not isinstance(new, (tuple, list)): new = [new] assert all(isinstance(x, Optimizable) for x in new) + assert len(new) == len(self.things) assert all(type(a) is type(b) for a, b in zip(new, self.things)) self._things = list(new) # can maybe improve this later to not rebuild if resolution is the same From cfcb3937c1dfa3be4424fa31fe48a90b8dce39e7 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Mon, 12 Aug 2024 23:23:26 -0400 Subject: [PATCH 05/12] Fix hashing of partial objects --- desc/io/optimizable_io.py | 5 ++++- desc/optimize/_constraint_wrappers.py | 22 +++++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/desc/io/optimizable_io.py b/desc/io/optimizable_io.py index 5a11d39245..554cdac070 100644 --- a/desc/io/optimizable_io.py +++ b/desc/io/optimizable_io.py @@ -1,6 +1,7 @@ """Functions and methods for saving and loading equilibria and other objects.""" import copy +import functools import os import pickle import pydoc @@ -86,7 +87,9 @@ def _unjittable(x): return any([_unjittable(y) for y in x.values()]) if hasattr(x, "dtype") and np.ndim(x) == 0: return np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.int_) - return isinstance(x, (str, types.FunctionType, bool, int, np.int_)) + return isinstance( + x, (str, types.FunctionType, functools.partial, bool, int, np.int_) + ) def _make_hashable(x): diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 93bf5385ae..31384f4987 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -1,7 +1,5 @@ """Wrappers for doing STELLOPT/SIMSOPT like optimization.""" -import functools - import numpy as np from desc.backend import jit, jnp @@ -972,7 +970,7 @@ def jvp_scaled(self, v, x, constants=None): v = v[0] if isinstance(v, (tuple, list)) else v constants = setdefault(constants, self.constants) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="scaled") + jvpfun = lambda u: self._jvp_scaled(u, xf, xg, constants) return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) def jvp_scaled_error(self, v, x, constants=None): @@ -992,7 +990,7 @@ def jvp_scaled_error(self, v, x, constants=None): v = v[0] if isinstance(v, (tuple, list)) else v constants = setdefault(constants, self.constants) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="scaled_error") + jvpfun = lambda u: self._jvp_scaled_error(u, xf, xg, constants) return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) def jvp_unscaled(self, v, x, constants=None): @@ -1012,10 +1010,9 @@ def jvp_unscaled(self, v, x, constants=None): v = v[0] if isinstance(v, (tuple, list)) else v constants = setdefault(constants, self.constants) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="unscaled") + jvpfun = lambda u: self._jvp_unscaled(u, xf, xg, constants) return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) - @functools.partial(jit, static_argnames=("self", "op")) def _jvp_f(self, xf, dc, constants, op): Fx = getattr(self._constraint, "jac_" + op)(xf, constants) Fx_reduced = Fx[:, self._unfixed_idx] @ self._Z @@ -1028,7 +1025,6 @@ def _jvp_f(self, xf, dc, constants, op): Fxh_inv = vtf.T @ (sfi[..., None] * uf.T) return Fxh_inv @ Fc - @functools.partial(jit, static_argnames=("self", "op")) def _jvp(self, v, xf, xg, constants, op): # we're replacing stuff like this with jvps # Fx_reduced = Fx[:, unfixed_idx] @ Z # noqa: E800 @@ -1086,6 +1082,18 @@ def _jvp(self, v, xf, xg, constants, op): out = jnp.concatenate(out) return -out + @jit + def _jvp_scaled(self, v, xf, xg, constants): + return self._jvp(v, xf, xg, constants, "scaled") + + @jit + def _jvp_scaled_error(self, v, xf, xg, constants): + return self._jvp(v, xf, xg, constants, "scaled_error") + + @jit + def _jvp_unscaled(self, v, xf, xg, constants): + return self._jvp(v, xf, xg, constants, "unscaled") + @property def constants(self): """list: constant parameters for each sub-objective.""" From 2b04c096d1d09e8bb376ebde1e23bc4d10b90875 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Mon, 12 Aug 2024 23:26:21 -0400 Subject: [PATCH 06/12] Better error message when using wrong things to get state vector --- desc/objectives/objective_funs.py | 37 ++++++++++++++++++++++++++++--- tests/test_objective_funs.py | 8 +++---- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index b42dc0040c..8e04112bb1 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -9,7 +9,14 @@ from desc.derivatives import Derivative from desc.io import IOAble from desc.optimizable import Optimizable -from desc.utils import Timer, flatten_list, is_broadcastable, setdefault, unique_list +from desc.utils import ( + Timer, + errorif, + flatten_list, + is_broadcastable, + setdefault, + unique_list, +) class ObjectiveFunction(IOAble): @@ -367,8 +374,19 @@ def x(self, *things): """Return the full state vector from the Optimizable objects things.""" # TODO: also check resolution etc? things = things or self.things - assert len(things) == len(self.things) - assert all([type(t1) is type(t2) for t1, t2 in zip(things, self.things)]) + errorif( + len(things) != len(self.things), + ValueError, + "Got the wrong number of things, " + f"expected {len(self.things)} got {len(things)}", + ) + for t1, t2 in zip(things, self.things): + errorif( + not isinstance(t1, type(t2)), + TypeError, + f"got incompatible types between things {type(t1)} " + f"and self.things {type(t2)}", + ) xs = [t.pack_params(t.params_dict) for t in things] return jnp.concatenate(xs) @@ -1188,6 +1206,19 @@ def print_value(self, *args, **kwargs): def xs(self, *things): """Return a tuple of args required by this objective from optimizable things.""" things = things or self.things + errorif( + len(things) != len(self.things), + ValueError, + "Got the wrong number of things, " + f"expected {len(self.things)} got {len(things)}", + ) + for t1, t2 in zip(things, self.things): + errorif( + not isinstance(t1, type(t2)), + TypeError, + f"got incompatible types between things {type(t1)} " + f"and self.things {type(t2)}", + ) return tuple([t.params_dict for t in things]) @property diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index d6f67a5ffc..d664ac356a 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -639,7 +639,7 @@ def test_plasma_vessel_distance(self): surface_fixed=True, ) obj.build() - d = obj.compute_unscaled(*obj.xs(eq, surface)) + d = obj.compute_unscaled(*obj.xs(eq)) assert d.size == obj.dim_f assert abs(d.min() - (a_s - a_p)) < 1e-14 assert abs(d.max() - (a_s - a_p)) < surf_grid.spacing[0, 1] * a_p @@ -1534,7 +1534,7 @@ def test_boundary_error_print(capsys): obj = VacuumBoundaryError(eq, coilset, field_grid=coil_grid) obj.build() - f = np.abs(obj.compute_unscaled(*obj.xs(eq))) + f = np.abs(obj.compute_unscaled(*obj.xs(eq, coilset))) n = len(f) // 2 f1 = f[:n] f2 = f[n:] @@ -1609,7 +1609,7 @@ def test_boundary_error_print(capsys): obj = BoundaryError(eq, coilset, field_grid=coil_grid) obj.build() - f = np.abs(obj.compute_unscaled(*obj.xs(eq))) + f = np.abs(obj.compute_unscaled(*obj.xs(eq, coilset))) n = len(f) // 2 f1 = f[:n] f2 = f[n:] @@ -1685,7 +1685,7 @@ def test_boundary_error_print(capsys): obj = BoundaryError(eq, coilset, field_grid=coil_grid) obj.build() - f = np.abs(obj.compute_unscaled(*obj.xs(eq))) + f = np.abs(obj.compute_unscaled(*obj.xs(eq, coilset))) n = len(f) // 3 f1 = f[:n] f2 = f[n : 2 * n] From 7b1553fe36ce226e9c426031d3d6f1003dd99a4f Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Mon, 12 Aug 2024 23:27:07 -0400 Subject: [PATCH 07/12] Don't use object identity comparison in blocked jacobian --- desc/optimize/_constraint_wrappers.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 31384f4987..f141485845 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -1060,16 +1060,20 @@ def _jvp(self, v, xf, xg, constants, op): vgs = jnp.split(tangent, np.cumsum(self._dimx_per_thing)) xgs = jnp.split(xg, np.cumsum(self._dimx_per_thing)) out = [] - for obj, const in zip( - self._objective.objectives, self._objective.constants + for k, (obj, const) in enumerate( + zip(self._objective.objectives, self._objective.constants) ): - xi = [x for x, t in zip(xgs, self._objective.things) if t in obj.things] - vi = [v for v, t in zip(vgs, self._objective.things) if t in obj.things] + thing_idx = self._objective._things_per_objective_idx[k] + xi = [xgs[i] for i in thing_idx] + vi = [vgs[i] for i in thing_idx] + assert len(xi) > 0 + assert len(vi) > 0 + assert len(xi) == len(vi) if obj._deriv_mode == "rev": - # obj might now allow fwd mode, so compute full rev mode jacobian + # obj might not allow fwd mode, so compute full rev mode jacobian # and do matmul manually. This is slightly inefficient, but usually # when rev mode is used, dim_f <<< dim_x, so its not too bad. - Ji = getattr(obj, "jac_" + op)(*xi, const) + Ji = getattr(obj, "jac_" + op)(*xi, constants=const) outi = jnp.array([Jii @ vii.T for Jii, vii in zip(Ji, vi)]).sum( axis=0 ) From b2762464a667b70584d4d05af1bd98d5a30cc788 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Mon, 12 Aug 2024 23:37:12 -0400 Subject: [PATCH 08/12] Make sure objectives use jit by default --- desc/objectives/objective_funs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 8e04112bb1..eeaccc2c25 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -849,7 +849,7 @@ def __init__( self._normalization = 1 self._deriv_mode = deriv_mode self._name = name - self._use_jit = None + self._use_jit = True self._built = False self._loss_function = { "mean": jnp.mean, From e690715c7652a5437d298688c9f9f83130acd965 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 14 Aug 2024 17:36:30 -0400 Subject: [PATCH 09/12] Mark unhashable attributes as static --- desc/objectives/linear_objectives.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index d855a025a6..5905802f89 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -3184,6 +3184,7 @@ class FixNearAxisR(_FixedObjective): """ + _static_attrs = ["_nae_eq"] _target_arg = "R_lmn" _fixed = False # not "diagonal", since its fixing a sum _units = "(m)" @@ -3320,6 +3321,7 @@ class FixNearAxisZ(_FixedObjective): """ + _static_attrs = ["_nae_eq"] _target_arg = "Z_lmn" _fixed = False # not "diagonal", since its fixing a sum _units = "(m)" @@ -3462,6 +3464,7 @@ class FixNearAxisLambda(_FixedObjective): """ + _static_attrs = ["_nae_eq"] _target_arg = "L_lmn" _fixed = False # not "diagonal", since its fixing a sum _units = "(dimensionless)" From de22dd88b5e6c116d5f8dabda2dea32fbf5b2f77 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 14 Aug 2024 17:36:59 -0400 Subject: [PATCH 10/12] Use hashable callable classes instead of local functions to avoid recompilation --- desc/objectives/objective_funs.py | 55 ++++++++++++++++++++++--------- desc/objectives/utils.py | 46 +++++++++++++++++++------- 2 files changed, 74 insertions(+), 27 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index eeaccc2c25..43ce4e7348 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -177,22 +177,8 @@ def _set_things(self, things=None): [i for i, t in enumerate(unique_) if t in obj.things] ) - def unflatten(unique): - assert len(unique) == len(unique_) - flat = [unique[i] for i in inds_] - return tree_unflatten(treedef_, flat) - - def flatten(things): - flat, treedef = tree_flatten( - things, is_leaf=lambda x: isinstance(x, Optimizable) - ) - assert treedef == treedef_ - assert len(flat) == len(flat_) - unique, _ = unique_list(flat) - return unique - - self._unflatten = unflatten - self._flatten = flatten + self._unflatten = _ThingUnflattener(len(unique_), inds_, treedef_) + self._flatten = _ThingFlattener(len(flat_), treedef_) @jit def compute_unscaled(self, x, constants=None): @@ -1317,3 +1303,40 @@ def things(self, new): self._things = list(new) # can maybe improve this later to not rebuild if resolution is the same self._built = False + + +# local functions assigned as attributes aren't hashable so they cause stuff to +# recompile, so instead we define a hashable class to do the same thing. + + +class _ThingUnflattener(IOAble): + + _static_attrs = ["length", "inds", "treedef"] + + def __init__(self, length, inds, treedef): + self.length = length + self.inds = inds + self.treedef = treedef + + def __call__(self, unique): + assert len(unique) == self.length + flat = [unique[i] for i in self.inds] + return tree_unflatten(self.treedef, flat) + + +class _ThingFlattener(IOAble): + + _static_attrs = ["length", "treedef"] + + def __init__(self, length, treedef): + self.length = length + self.treedef = treedef + + def __call__(self, things): + flat, treedef = tree_flatten( + things, is_leaf=lambda x: isinstance(x, Optimizable) + ) + assert treedef == self.treedef + assert len(flat) == self.length + unique, _ = unique_list(flat) + return unique diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index 27a8d30b37..f34d7d2d0b 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -6,6 +6,7 @@ import numpy as np from desc.backend import cond, jit, jnp, logsumexp, put +from desc.io import IOAble from desc.utils import Index, errorif, flatten_list, svd_inv_null, unique_list, warnif @@ -142,17 +143,8 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901 xp = put(xp, unfixed_idx, Ainv_full @ b) xp = jnp.asarray(xp) - @jit - def project(x): - """Project a full state vector into the reduced optimization vector.""" - x_reduced = Z.T @ ((x - xp)[unfixed_idx]) - return jnp.atleast_1d(jnp.squeeze(x_reduced)) - - @jit - def recover(x_reduced): - """Recover the full state vector from the reduced optimization vector.""" - dx = put(jnp.zeros(objective.dim_x), unfixed_idx, Z @ x_reduced) - return jnp.atleast_1d(jnp.squeeze(xp + dx)) + project = _Project(Z, xp, unfixed_idx) + recover = _Recover(Z, xp, unfixed_idx, objective.dim_x) # check that all constraints are actually satisfiable params = objective.unpack_state(xp, False) @@ -200,6 +192,38 @@ def recover(x_reduced): return xp, A, b, Z, unfixed_idx, project, recover +class _Project(IOAble): + _io_attrs_ = ["Z", "xp", "unfixed_idx"] + + def __init__(self, Z, xp, unfixed_idx): + self.Z = Z + self.xp = xp + self.unfixed_idx = unfixed_idx + + @jit + def __call__(self, x): + """Project a full state vector into the reduced optimization vector.""" + x_reduced = self.Z.T @ ((x - self.xp)[self.unfixed_idx]) + return jnp.atleast_1d(jnp.squeeze(x_reduced)) + + +class _Recover(IOAble): + _io_attrs_ = ["Z", "xp", "unfixed_idx", "dim_x"] + _static_attrs = ["dim_x"] + + def __init__(self, Z, xp, unfixed_idx, dim_x): + self.Z = Z + self.xp = xp + self.unfixed_idx = unfixed_idx + self.dim_x = dim_x + + @jit + def __call__(self, x_reduced): + """Recover the full state vector from the reduced optimization vector.""" + dx = put(jnp.zeros(self.dim_x), self.unfixed_idx, self.Z @ x_reduced) + return jnp.atleast_1d(jnp.squeeze(self.xp + dx)) + + def softmax(arr, alpha): """JAX softmax implementation. From 330a04d88ae8b74b8bc16420f78eec8d26aa4dcd Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 14 Aug 2024 17:37:39 -0400 Subject: [PATCH 11/12] Move some proximal logic to pure jax functions to avoid recompilation --- desc/optimize/_constraint_wrappers.py | 114 ++++++++++++++------------ 1 file changed, 61 insertions(+), 53 deletions(-) diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index f141485845..4249b36a00 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -1,5 +1,7 @@ """Wrappers for doing STELLOPT/SIMSOPT like optimization.""" +import functools + import numpy as np from desc.backend import jit, jnp @@ -970,7 +972,7 @@ def jvp_scaled(self, v, x, constants=None): v = v[0] if isinstance(v, (tuple, list)) else v constants = setdefault(constants, self.constants) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._jvp_scaled(u, xf, xg, constants) + jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="scaled") return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) def jvp_scaled_error(self, v, x, constants=None): @@ -990,7 +992,7 @@ def jvp_scaled_error(self, v, x, constants=None): v = v[0] if isinstance(v, (tuple, list)) else v constants = setdefault(constants, self.constants) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._jvp_scaled_error(u, xf, xg, constants) + jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="scaled_error") return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) def jvp_unscaled(self, v, x, constants=None): @@ -1010,21 +1012,9 @@ def jvp_unscaled(self, v, x, constants=None): v = v[0] if isinstance(v, (tuple, list)) else v constants = setdefault(constants, self.constants) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._jvp_unscaled(u, xf, xg, constants) + jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="unscaled") return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) - def _jvp_f(self, xf, dc, constants, op): - Fx = getattr(self._constraint, "jac_" + op)(xf, constants) - Fx_reduced = Fx[:, self._unfixed_idx] @ self._Z - Fc = Fx @ (self._dxdc @ dc) - Fxh = Fx_reduced - cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape) - uf, sf, vtf = jnp.linalg.svd(Fxh, full_matrices=False) - sf += sf[-1] # add a tiny bit of regularization - sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf) - Fxh_inv = vtf.T @ (sfi[..., None] * uf.T) - return Fxh_inv @ Fc - def _jvp(self, v, xf, xg, constants, op): # we're replacing stuff like this with jvps # Fx_reduced = Fx[:, unfixed_idx] @ Z # noqa: E800 @@ -1037,7 +1027,16 @@ def _jvp(self, v, xf, xg, constants, op): # want jvp_f to only get parts from equilibrium, not other things vs = jnp.split(v, np.cumsum(self._dimc_per_thing)) # this is Fx_reduced_inv @ Fc - dfdc = self._jvp_f(xf, vs[self._eq_idx], constants[1], op) + dfdc = _proximal_jvp_f_pure( + self._constraint, + xf, + constants[1], + vs[self._eq_idx], + self._unfixed_idx, + self._Z, + self._dxdc, + op, + ) # broadcasting against multiple things dfdcs = [jnp.zeros(dim) for dim in self._dimc_per_thing] dfdcs[self._eq_idx] = dfdc @@ -1059,45 +1058,9 @@ def _jvp(self, v, xf, xg, constants, op): else: # deriv_mode == "blocked" vgs = jnp.split(tangent, np.cumsum(self._dimx_per_thing)) xgs = jnp.split(xg, np.cumsum(self._dimx_per_thing)) - out = [] - for k, (obj, const) in enumerate( - zip(self._objective.objectives, self._objective.constants) - ): - thing_idx = self._objective._things_per_objective_idx[k] - xi = [xgs[i] for i in thing_idx] - vi = [vgs[i] for i in thing_idx] - assert len(xi) > 0 - assert len(vi) > 0 - assert len(xi) == len(vi) - if obj._deriv_mode == "rev": - # obj might not allow fwd mode, so compute full rev mode jacobian - # and do matmul manually. This is slightly inefficient, but usually - # when rev mode is used, dim_f <<< dim_x, so its not too bad. - Ji = getattr(obj, "jac_" + op)(*xi, constants=const) - outi = jnp.array([Jii @ vii.T for Jii, vii in zip(Ji, vi)]).sum( - axis=0 - ) - out.append(outi) - else: - outi = getattr(obj, "jvp_" + op)( - [_vi for _vi in vi], xi, constants=const - ).T - out.append(outi) - out = jnp.concatenate(out) + out = _proximal_jvp_blocked_pure(self._objective, vgs, xgs, op) return -out - @jit - def _jvp_scaled(self, v, xf, xg, constants): - return self._jvp(v, xf, xg, constants, "scaled") - - @jit - def _jvp_scaled_error(self, v, xf, xg, constants): - return self._jvp(v, xf, xg, constants, "scaled_error") - - @jit - def _jvp_unscaled(self, v, xf, xg, constants): - return self._jvp(v, xf, xg, constants, "unscaled") - @property def constants(self): """list: constant parameters for each sub-objective.""" @@ -1106,3 +1069,48 @@ def constants(self): def __getattr__(self, name): """For other attributes we defer to the base objective.""" return getattr(self._objective, name) + + +# in ProximalProjection we have an explicit state that we keep track of (and add +# to as we go) meaning if we jit anything with self static it doesn't update +# correctly, while if we leave self unstatic then it recompiles every time because +# the pytree structure of ProximalProjection is changing. To get around that we +# define these helper functions that are stateless so we can safely jit them + + +@functools.partial(jit, static_argnames=["op"]) +def _proximal_jvp_f_pure(constraint, xf, constants, dc, unfixed_idx, Z, dxdc, op): + Fx = getattr(constraint, "jac_" + op)(xf, constants) + Fx_reduced = Fx[:, unfixed_idx] @ Z + Fc = Fx @ (dxdc @ dc) + Fxh = Fx_reduced + cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape) + uf, sf, vtf = jnp.linalg.svd(Fxh, full_matrices=False) + sf += sf[-1] # add a tiny bit of regularization + sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf) + Fxh_inv = vtf.T @ (sfi[..., None] * uf.T) + return Fxh_inv @ Fc + + +@functools.partial(jit, static_argnames=["op"]) +def _proximal_jvp_blocked_pure(objective, vgs, xgs, op): + out = [] + for k, (obj, const) in enumerate(zip(objective.objectives, objective.constants)): + thing_idx = objective._things_per_objective_idx[k] + xi = [xgs[i] for i in thing_idx] + vi = [vgs[i] for i in thing_idx] + assert len(xi) > 0 + assert len(vi) > 0 + assert len(xi) == len(vi) + if obj._deriv_mode == "rev": + # obj might not allow fwd mode, so compute full rev mode jacobian + # and do matmul manually. This is slightly inefficient, but usually + # when rev mode is used, dim_f <<< dim_x, so its not too bad. + Ji = getattr(obj, "jac_" + op)(*xi, constants=const) + outi = jnp.array([Jii @ vii.T for Jii, vii in zip(Ji, vi)]).sum(axis=0) + out.append(outi) + else: + outi = getattr(obj, "jvp_" + op)([_vi for _vi in vi], xi, constants=const).T + out.append(outi) + out = jnp.concatenate(out) + return out From 46926fbd27ac125c722b5d1a89c2cbc36dd43ded Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Sat, 17 Aug 2024 16:23:56 -0400 Subject: [PATCH 12/12] Add comments to explain some stuff better --- desc/objectives/objective_funs.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 43ce4e7348..59734aa261 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -64,7 +64,7 @@ def __init__( self._name = name def _set_derivatives(self): - """Set up derivatives of the objective functions.""" + """Choose derivative mode based on mode of sub-objectives.""" if self._deriv_mode == "auto": if all((obj._deriv_mode == "fwd") for obj in self.objectives): self._deriv_mode = "batched" @@ -394,7 +394,7 @@ def hess(self, x, constants=None): Derivative(self.compute_scalar, mode="hess")(x, constants).squeeze() ) - def _jac(self, op, x, constants=None): + def _jac_blocked(self, op, x, constants=None): # could also do something similar for grad and hess, but probably not # worth it. grad is already super cheap to eval all at once, and blocked # hess would only be block diag which may miss important interactions. @@ -405,6 +405,9 @@ def _jac(self, op, x, constants=None): xs = jnp.split(x, xs_splits) J = [] assert len(self.objectives) == len(self.constants) + # basic idea is we compute the jacobian of each objective wrt each thing + # one by one, and assemble into big block matrix + # if objective doesn't depend on a given thing, that part is set to 0. for k, (obj, const) in enumerate(zip(self.objectives, constants)): # get the xs that go to that objective thing_idx = self._things_per_objective_idx[k] @@ -412,13 +415,16 @@ def _jac(self, op, x, constants=None): Ji_ = getattr(obj, op)(*xi, constants=const) # jac wrt to just those things Ji = [] # jac wrt all things for i, thing in enumerate(self.things): - if i in thing_idx: + if i in thing_idx: # dfi/dxj != 0 Ji += [Ji_[thing_idx.index(i)]] - else: + else: # dfi/dxj == 0 Ji += [jnp.zeros((obj.dim_f, thing.dim_x))] - Ji = jnp.hstack(Ji) + Ji = jnp.hstack(Ji) # something like [df1/dx1, df1/dx2, 0] J += [Ji] - return jnp.vstack(J) + # something like [df1/dx1, df1/dx2, 0] + # [df2/dx1, 0, df2/dx3] # noqa:E800 + J = jnp.vstack(J) + return J @jit def jac_scaled(self, x, constants=None): @@ -431,7 +437,7 @@ def jac_scaled(self, x, constants=None): if self._deriv_mode == "looped": J = Derivative(self.compute_scaled, mode="looped")(x, constants) if self._deriv_mode == "blocked": - J = self._jac("jac_scaled", x, constants) + J = self._jac_blocked("jac_scaled", x, constants) return jnp.atleast_2d(J.squeeze()) @@ -446,7 +452,7 @@ def jac_scaled_error(self, x, constants=None): if self._deriv_mode == "looped": J = Derivative(self.compute_scaled_error, mode="looped")(x, constants) if self._deriv_mode == "blocked": - J = self._jac("jac_scaled_error", x, constants) + J = self._jac_blocked("jac_scaled_error", x, constants) return jnp.atleast_2d(J.squeeze()) @@ -461,7 +467,7 @@ def jac_unscaled(self, x, constants=None): if self._deriv_mode == "looped": J = Derivative(self.compute_unscaled, mode="looped")(x, constants) if self._deriv_mode == "blocked": - J = self._jac("jac_unscaled", x, constants) + J = self._jac_blocked("jac_unscaled", x, constants) return jnp.atleast_2d(J.squeeze()) @@ -847,7 +853,7 @@ def __init__( self._things = flatten_list([things], True) def _set_derivatives(self): - """Set up derivatives of the objective wrt each argument.""" + """Choose derivative mode based on size of inputs/outputs.""" if self._deriv_mode == "auto": # choose based on shape of jacobian. fwd mode is more memory efficient # so we prefer that unless the jacobian is really wide