From cc1ca001cd8498cf328b64454b7e510ad47e5eed Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 28 Feb 2024 14:52:12 -0500 Subject: [PATCH 1/3] Fix jvp signature for Objective --- desc/objectives/objective_funs.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index a76058142d..7967ef4ee7 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -1021,19 +1021,21 @@ def jvp_scaled(self, v, x, constants=None): ---------- v : tuple of ndarray Vectors to right-multiply the Jacobian by. - x : ndarray + x : tuple of ndarray Optimization variables. constants : list Constant parameters passed to sub-objectives. """ v = v if isinstance(v, (tuple, list)) else (v,) + x = x if isinstance(x, (tuple, list)) else (x,) + assert len(x) == len(v) compute_scaled = lambda *x: self.compute_scaled(*x, constants=constants) jvpfun = lambda *dx: Derivative.compute_jvp( compute_scaled, tuple(range(len(x))), dx, *x ) - sig = "(n)" * len(x) + "->(k)" + sig = "(n)," * len(x) + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) def jvp_unscaled(self, v, x, constants=None): @@ -1045,19 +1047,21 @@ def jvp_unscaled(self, v, x, constants=None): ---------- v : tuple of ndarray Vectors to right-multiply the Jacobian by. - x : ndarray + x : tuple of ndarray Optimization variables. constants : list Constant parameters passed to sub-objectives. """ v = v if isinstance(v, (tuple, list)) else (v,) + x = x if isinstance(x, (tuple, list)) else (x,) + assert len(x) == len(v) compute_unscaled = lambda *x: self.compute_unscaled(*x, constants=constants) jvpfun = lambda *dx: Derivative.compute_jvp( compute_unscaled, tuple(range(len(x))), dx, *x ) - sig = "(n)" * len(x) + "->(k)" + sig = "(n)," * len(x) + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) def print_value(self, *args, **kwargs): From 5663560e7eac4ba5e5eb752012039ce365c2dd4e Mon Sep 17 00:00:00 2001 From: YigitElma Date: Wed, 28 Feb 2024 16:24:56 -0500 Subject: [PATCH 2/3] fix comma issue --- desc/objectives/objective_funs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 7967ef4ee7..12d3a16c41 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -1035,7 +1035,7 @@ def jvp_scaled(self, v, x, constants=None): jvpfun = lambda *dx: Derivative.compute_jvp( compute_scaled, tuple(range(len(x))), dx, *x ) - sig = "(n)," * len(x) + "->(k)" + sig = "(n)," * (len(x) - 1) + "(n)" + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) def jvp_unscaled(self, v, x, constants=None): @@ -1061,7 +1061,7 @@ def jvp_unscaled(self, v, x, constants=None): jvpfun = lambda *dx: Derivative.compute_jvp( compute_unscaled, tuple(range(len(x))), dx, *x ) - sig = "(n)," * len(x) + "->(k)" + sig = "(n)," * (len(x) - 1) + "(n)" + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) def print_value(self, *args, **kwargs): From 99ff2e2d807702850b98e697da5482f65b3d4e88 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 28 Feb 2024 20:05:22 -0500 Subject: [PATCH 3/3] Fix signature --- desc/objectives/objective_funs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 12d3a16c41..ffff8335ce 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -1035,7 +1035,7 @@ def jvp_scaled(self, v, x, constants=None): jvpfun = lambda *dx: Derivative.compute_jvp( compute_scaled, tuple(range(len(x))), dx, *x ) - sig = "(n)," * (len(x) - 1) + "(n)" + "->(k)" + sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) def jvp_unscaled(self, v, x, constants=None): @@ -1061,7 +1061,7 @@ def jvp_unscaled(self, v, x, constants=None): jvpfun = lambda *dx: Derivative.compute_jvp( compute_unscaled, tuple(range(len(x))), dx, *x ) - sig = "(n)," * (len(x) - 1) + "(n)" + "->(k)" + sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) def print_value(self, *args, **kwargs):