Skip to content

Commit

Permalink
Fix signature
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Feb 29, 2024
1 parent 5663560 commit 99ff2e2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions desc/objectives/objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def jvp_scaled(self, v, x, constants=None):
jvpfun = lambda *dx: Derivative.compute_jvp(
compute_scaled, tuple(range(len(x))), dx, *x
)
sig = "(n)," * (len(x) - 1) + "(n)" + "->(k)"
sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)"

Check warning on line 1038 in desc/objectives/objective_funs.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/objective_funs.py#L1038

Added line #L1038 was not covered by tests
return jnp.vectorize(jvpfun, signature=sig)(*v)

def jvp_unscaled(self, v, x, constants=None):
Expand All @@ -1061,7 +1061,7 @@ def jvp_unscaled(self, v, x, constants=None):
jvpfun = lambda *dx: Derivative.compute_jvp(
compute_unscaled, tuple(range(len(x))), dx, *x
)
sig = "(n)," * (len(x) - 1) + "(n)" + "->(k)"
sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)"

Check warning on line 1064 in desc/objectives/objective_funs.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/objective_funs.py#L1064

Added line #L1064 was not covered by tests
return jnp.vectorize(jvpfun, signature=sig)(*v)

def print_value(self, *args, **kwargs):
Expand Down

0 comments on commit 99ff2e2

Please sign in to comment.