Skip to content

Commit

Permalink
Merge pull request #852 from PlasmaControl/rc/hotfix
Browse files Browse the repository at this point in the history
Fix for unused kwargs in Equilibrium init
  • Loading branch information
f0uriest authored Feb 6, 2024
2 parents dc7d2eb + be3714a commit 3707e08
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 13 deletions.
2 changes: 1 addition & 1 deletion desc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main(cl_args=sys.argv[1:]):
and (inputs[-1]["pres_ratio"] is None)
and (inputs[-1]["bdry_ratio"] is None)
):
eq = Equilibrium(**inputs[-1])
eq = Equilibrium(**inputs[-1], check_kwargs=False)
equil_fam = EquilibriaFamily.solve_continuation_automatic(
eq,
objective=inputs[-1]["objective"],
Expand Down
16 changes: 13 additions & 3 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,12 @@ def __init__(
self.set_initial_guess(ensure_nested=ensure_nested)
if check_orientation:
ensure_positive_jacobian(self)
if kwargs.get("check_kwargs", True):
errorif(
len(kwargs),
TypeError,
f"Equilibrium got unexpected kwargs: {kwargs.keys()}",
)

def _set_up(self):
"""Set unset attributes after loading.
Expand Down Expand Up @@ -1630,7 +1636,7 @@ def from_near_axis(
"M": M,
"N": N,
"sym": not na_eq.lasym,
"spectral_indexing ": spectral_indexing,
"spectral_indexing": spectral_indexing,
"pressure": np.array([[0, -na_eq.p2 * r**2], [2, na_eq.p2 * r**2]]),
"iota": None,
"current": np.array([[2, 2 * np.pi / mu_0 * na_eq.I2 * r**2]]),
Expand Down Expand Up @@ -2148,7 +2154,9 @@ def __init__(self, *args):
# ensure that first step is nested
ensure_nested_bool = True if i == 0 else False
self.equilibria.append(
Equilibrium(**inp, ensure_nested=ensure_nested_bool)
Equilibrium(
**inp, ensure_nested=ensure_nested_bool, check_kwargs=False
)
)
else:
for i, arg in enumerate(args):
Expand All @@ -2157,7 +2165,9 @@ def __init__(self, *args):
elif isinstance(arg, dict):
ensure_nested_bool = True if i == 0 else False
self.equilibria.append(
Equilibrium(**arg, ensure_nested=ensure_nested_bool)
Equilibrium(
**arg, ensure_nested=ensure_nested_bool, check_kwargs=False
)
)
else:
raise TypeError(
Expand Down
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,6 @@ def DummyStellarator(tmpdir_factory):
],
),
"axis": np.array([[-1, 0, -0.2], [0, 3.4, 0], [1, 0.2, 0]]),
"objective": "force",
"optimizer": "lsq-exact",
}
eq = Equilibrium(**inputs)
eq.save(output_path)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_from_input_file(self):
path = "tests/inputs/input.QSC_r2_5.5_desc"

curve1 = FourierRZCurve.from_input_file(path)
curve2 = Equilibrium(**InputReader(path).inputs[0]).axis
curve2 = Equilibrium(**InputReader(path).inputs[0], check_kwargs=False).axis
curve1.change_resolution(curve2.N)

np.testing.assert_allclose(curve1.R_n, curve2.R_n)
Expand All @@ -245,7 +245,7 @@ def test_from_input_file(self):

with pytest.warns(UserWarning):
curve3 = FourierRZCurve.from_input_file(path)
curve4 = Equilibrium(**InputReader(path).inputs[0]).axis
curve4 = Equilibrium(**InputReader(path).inputs[0], check_kwargs=False).axis
curve3.change_resolution(curve4.N)

np.testing.assert_allclose(curve3.R_n, curve4.R_n)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from desc.grid import Grid, LinearGrid
from desc.io import InputReader
from desc.objectives import get_equilibrium_objective
from desc.profiles import PowerSeriesProfile

from .utils import area_difference, compute_coords

Expand Down Expand Up @@ -408,3 +409,13 @@ def test_error_when_ndarray_or_integer_passed():
eq.compute("R", grid=1)
with pytest.raises(TypeError):
eq.compute("R", grid=np.linspace(0, 1, 10))


@pytest.mark.unit
def test_equilibrium_unused_kwargs():
"""Test that invalid kwargs raise an error, for gh issue #850."""
pres = PowerSeriesProfile()
curr = PowerSeriesProfile()
with pytest.raises(TypeError):
_ = Equilibrium(pres=pres, curr=curr)
_ = Equilibrium(pressure=pres, current=curr)
2 changes: 0 additions & 2 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,6 @@ def test_quad_grid_volume_integration(self):
"iota": np.array([[0, 0]]),
"surface": np.array([[0, 0, 0, R, 0], [0, 1, 0, r, 0], [0, -1, 0, 0, -r]]),
"spectral_indexing": "ansi",
"bdry_mode": "lcfs",
"node_pattern": "quad",
}

eq = Equilibrium(**inputs)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def test_LambdaGauge_asym():
],
),
"axis": np.array([[-1, 0, -0.2], [0, 3.4, 0], [1, 0.2, 0]]),
"objective": "force",
"optimizer": "lsq-exact",
}
eq = Equilibrium(**inputs)
lam_con = FixLambdaGauge(eq)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_same_result(self):
input_path = "./tests/inputs/SOLOVEV"
ir = InputReader(input_path)

eq1 = Equilibrium(**ir.inputs[-1])
eq1 = Equilibrium(**ir.inputs[-1], check_kwargs=False)
eq2 = eq1.copy()
eq2.pressure = eq1.pressure.to_spline()
eq2.iota = eq1.iota.to_spline()
Expand Down

0 comments on commit 3707e08

Please sign in to comment.