Skip to content

Commit

Permalink
fix error in benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici committed Mar 2, 2024
1 parent 5879320 commit 3d6d073
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions tests/benchmarks/benchmark_cpu_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def setup():
jax.clear_caches()
eq = desc.examples.get("DSHAPE_current")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq), get_fixed_boundary_constraints(eq)
get_equilibrium_objective(eq),
ObjectiveFunction(get_fixed_boundary_constraints(eq)),
)
objective.build(eq)
args = (
Expand All @@ -149,7 +150,8 @@ def setup():
jax.clear_caches()
eq = desc.examples.get("ATF")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq), get_fixed_boundary_constraints(eq)
get_equilibrium_objective(eq),
ObjectiveFunction(get_fixed_boundary_constraints(eq)),
)
objective.build(eq)
args = (objective, eq)
Expand All @@ -169,7 +171,8 @@ def test_objective_compute_dshape_current(benchmark):
jax.clear_caches()
eq = desc.examples.get("DSHAPE_current")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq), get_fixed_boundary_constraints(eq)
get_equilibrium_objective(eq),
ObjectiveFunction(get_fixed_boundary_constraints(eq)),
)
objective.build(eq)
objective.compile()
Expand All @@ -188,7 +191,8 @@ def test_objective_compute_atf(benchmark):
jax.clear_caches()
eq = desc.examples.get("ATF")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq), get_fixed_boundary_constraints(eq)
get_equilibrium_objective(eq),
ObjectiveFunction(get_fixed_boundary_constraints(eq)),
)
objective.build(eq)
objective.compile()
Expand All @@ -207,7 +211,8 @@ def test_objective_jac_dshape_current(benchmark):
jax.clear_caches()
eq = desc.examples.get("DSHAPE_current")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq), get_fixed_boundary_constraints(eq)
get_equilibrium_objective(eq),
ObjectiveFunction(get_fixed_boundary_constraints(eq)),
)
objective.build(eq)
objective.compile()
Expand All @@ -226,7 +231,8 @@ def test_objective_jac_atf(benchmark):
jax.clear_caches()
eq = desc.examples.get("ATF")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq), get_fixed_boundary_constraints(eq)
get_equilibrium_objective(eq),
ObjectiveFunction(get_fixed_boundary_constraints(eq)),
)
objective.build(eq)
objective.compile()
Expand All @@ -248,7 +254,7 @@ def setup():
eq = desc.examples.get("SOLOVEV")
objective = get_equilibrium_objective(eq)
objective.build(eq)
constraints = get_fixed_boundary_constraints(eq)
constraints = ObjectiveFunction(get_fixed_boundary_constraints(eq))
tr_ratio = [0.01, 0.25, 0.25]
dp = np.zeros_like(eq.p_l)
dp[np.array([0, 2])] = 8e3 * np.array([1, -1])
Expand Down Expand Up @@ -281,7 +287,7 @@ def setup():
eq = desc.examples.get("SOLOVEV")
objective = get_equilibrium_objective(eq)
objective.build(eq)
constraints = get_fixed_boundary_constraints(eq)
constraints = ObjectiveFunction(get_fixed_boundary_constraints(eq))
tr_ratio = [0.01, 0.25, 0.25]
dp = np.zeros_like(eq.p_l)
dp[np.array([0, 2])] = 8e3 * np.array([1, -1])
Expand Down

0 comments on commit 3d6d073

Please sign in to comment.