diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 8570c2389..737d68437 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -38,6 +38,9 @@ class ExternalObjective(_Objective, ABC): It does not need to be JAX transformable. dim_f : int Dimension of the output of ``fun``. + vectorized : bool + Set to False if ``fun`` takes a single Equilibrium as its positional argument. + Set to True if ``fun`` instead takes a list of Equilibria. target : {float, ndarray}, optional Target value(s) of the objective. Only used if bounds is None. Must be broadcastable to Objective.dim_f. Defaults to ``target=0``. @@ -59,9 +62,6 @@ class ExternalObjective(_Objective, ABC): Loss function to apply to the objective values once computed. This loss function is called on the raw compute value, before any shifting, scaling, or normalization. - vectorized : bool, optional - Set to False if ``fun`` takes a single Equilibrium as its positional argument. - Set to True if ``fun`` instead takes a list of Equilibria. Default = False. abs_step : float, optional Absolute finite difference step size. Default = 1e-4. Total step size is ``abs_step + rel_step * mean(abs(x))``. @@ -84,13 +84,13 @@ def __init__( eq, fun, dim_f, + vectorized, target=None, bounds=None, weight=1, normalize=False, normalize_target=False, loss_function=None, - vectorized=False, abs_step=1e-4, rel_step=0, name="external", diff --git a/tests/test_examples.py b/tests/test_examples.py index 932bea374..955573e38 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1972,7 +1972,13 @@ def data_from_vmec(eq, path="", surfs=8): path = dir.join("wout_result.nc") objective = ObjectiveFunction( ExternalObjective( - eq=eq0, fun=data_from_vmec, dim_f=4, target=target, path=path, surfs=8 + eq=eq0, + fun=data_from_vmec, + dim_f=4, + vectorized=False, + target=target, + path=path, + surfs=8, ) ) constraints = FixParameters(