Skip to content

Commit

Permalink
Merge pull request #909 from PlasmaControl/rc/hotfix
Browse files Browse the repository at this point in the history
Fix jvp signature for Objective
  • Loading branch information
f0uriest authored Feb 29, 2024
2 parents 569d35e + 99ff2e2 commit 664a562
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions desc/objectives/objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,19 +1021,21 @@ def jvp_scaled(self, v, x, constants=None):
----------
v : tuple of ndarray
Vectors to right-multiply the Jacobian by.
x : ndarray
x : tuple of ndarray
Optimization variables.
constants : list
Constant parameters passed to sub-objectives.
"""
v = v if isinstance(v, (tuple, list)) else (v,)
x = x if isinstance(x, (tuple, list)) else (x,)
assert len(x) == len(v)

compute_scaled = lambda *x: self.compute_scaled(*x, constants=constants)
jvpfun = lambda *dx: Derivative.compute_jvp(
compute_scaled, tuple(range(len(x))), dx, *x
)
sig = "(n)" * len(x) + "->(k)"
sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)"
return jnp.vectorize(jvpfun, signature=sig)(*v)

def jvp_unscaled(self, v, x, constants=None):
Expand All @@ -1045,19 +1047,21 @@ def jvp_unscaled(self, v, x, constants=None):
----------
v : tuple of ndarray
Vectors to right-multiply the Jacobian by.
x : ndarray
x : tuple of ndarray
Optimization variables.
constants : list
Constant parameters passed to sub-objectives.
"""
v = v if isinstance(v, (tuple, list)) else (v,)
x = x if isinstance(x, (tuple, list)) else (x,)
assert len(x) == len(v)

compute_unscaled = lambda *x: self.compute_unscaled(*x, constants=constants)
jvpfun = lambda *dx: Derivative.compute_jvp(
compute_unscaled, tuple(range(len(x))), dx, *x
)
sig = "(n)" * len(x) + "->(k)"
sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)"
return jnp.vectorize(jvpfun, signature=sig)(*v)

def print_value(self, *args, **kwargs):
Expand Down

0 comments on commit 664a562

Please sign in to comment.