Skip to content

Commit

Permalink
Merge branch 'master' into rc/freeb
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Feb 6, 2024
2 parents ba11e41 + 3707e08 commit 969774a
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 17 deletions.
2 changes: 1 addition & 1 deletion desc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main(cl_args=sys.argv[1:]):
and (inputs[-1]["pres_ratio"] is None)
and (inputs[-1]["bdry_ratio"] is None)
):
eq = Equilibrium(**inputs[-1])
eq = Equilibrium(**inputs[-1], check_kwargs=False)
equil_fam = EquilibriaFamily.solve_continuation_automatic(
eq,
objective=inputs[-1]["objective"],
Expand Down
39 changes: 39 additions & 0 deletions desc/compute/_equil.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,27 @@ def _e_sup_helical(params, transforms, profiles, data, **kwargs):
return data


@register_compute_fun(
name="e^helical*sqrt(g)",
label=" \\sqrt{g}(B^{\\theta} \\nabla \\zeta - B^{\\zeta} \\nabla \\theta)",
units="T \\cdot m^{2}",
units_long="Tesla * square meter",
description="Helical basis vector weighted by 3-D volume Jacobian",
dim=3,
params=[],
transforms={},
profiles=[],
coordinates="rtz",
data=["B^theta", "B^zeta", "e^theta*sqrt(g)", "e^zeta", "sqrt(g)"],
)
def _e_sup_helical_times_sqrt_g(params, transforms, profiles, data, **kwargs):
data["e^helical*sqrt(g)"] = (
data["B^zeta"] * data["e^theta*sqrt(g)"].T
- (data["sqrt(g)"] * data["B^theta"]) * data["e^zeta"].T
).T
return data


@register_compute_fun(
name="|e^helical|",
label="|B^{\\theta} \\nabla \\zeta - B^{\\zeta} \\nabla \\theta|",
Expand All @@ -580,6 +601,24 @@ def _e_sup_helical_mag(params, transforms, profiles, data, **kwargs):
return data


@register_compute_fun(
name="|e^helical*sqrt(g)|",
label="|\\sqrt{g}(B^{\\theta} \\nabla \\zeta - B^{\\zeta} \\nabla \\theta)|",
units="T \\cdot m^{2}",
units_long="Tesla * square meter",
description="Magnitude of helical basis vector weighted by 3-D volume Jacobian",
dim=1,
params=[],
transforms={},
profiles=[],
coordinates="rtz",
data=["e^helical*sqrt(g)"],
)
def _e_sup_helical_times_sqrt_g_mag(params, transforms, profiles, data, **kwargs):
data["|e^helical*sqrt(g)|"] = jnp.linalg.norm(data["e^helical*sqrt(g)"], axis=-1)
return data


@register_compute_fun(
name="F_anisotropic",
label="F_{anisotropic}",
Expand Down
16 changes: 13 additions & 3 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,12 @@ def __init__(
self.set_initial_guess(ensure_nested=ensure_nested)
if check_orientation:
ensure_positive_jacobian(self)
if kwargs.get("check_kwargs", True):
errorif(
len(kwargs),
TypeError,
f"Equilibrium got unexpected kwargs: {kwargs.keys()}",
)

def _set_up(self):
"""Set unset attributes after loading.
Expand Down Expand Up @@ -1630,7 +1636,7 @@ def from_near_axis(
"M": M,
"N": N,
"sym": not na_eq.lasym,
"spectral_indexing ": spectral_indexing,
"spectral_indexing": spectral_indexing,
"pressure": np.array([[0, -na_eq.p2 * r**2], [2, na_eq.p2 * r**2]]),
"iota": None,
"current": np.array([[2, 2 * np.pi / mu_0 * na_eq.I2 * r**2]]),
Expand Down Expand Up @@ -2148,7 +2154,9 @@ def __init__(self, *args):
# ensure that first step is nested
ensure_nested_bool = True if i == 0 else False
self.equilibria.append(
Equilibrium(**inp, ensure_nested=ensure_nested_bool)
Equilibrium(
**inp, ensure_nested=ensure_nested_bool, check_kwargs=False
)
)
else:
for i, arg in enumerate(args):
Expand All @@ -2157,7 +2165,9 @@ def __init__(self, *args):
elif isinstance(arg, dict):
ensure_nested_bool = True if i == 0 else False
self.equilibria.append(
Equilibrium(**arg, ensure_nested=ensure_nested_bool)
Equilibrium(
**arg, ensure_nested=ensure_nested_bool, check_kwargs=False
)
)
else:
raise TypeError(
Expand Down
4 changes: 2 additions & 2 deletions desc/objectives/_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def build(self, use_jit=True, verbose=1):
"|grad(rho)|",
"sqrt(g)",
"F_helical",
"|e^helical|",
"|e^helical*sqrt(g)|",
]

timer = Timer()
Expand Down Expand Up @@ -180,7 +180,7 @@ def compute(self, params, constants=None):
profiles=constants["profiles"],
)
fr = data["F_rho"] * data["|grad(rho)|"] * data["sqrt(g)"]
fb = data["F_helical"] * data["|e^helical|"] * data["sqrt(g)"]
fb = data["F_helical"] * data["|e^helical*sqrt(g)|"]

return jnp.concatenate([fr, fb])

Expand Down
2 changes: 0 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ markers=
filterwarnings=
error
ignore::pytest.PytestUnraisableExceptionWarning
ignore::RuntimeWarning:desc.compute
# Ignore division by zero warnings.
ignore:numpy.ndarray size changed:RuntimeWarning
# ignore benign Cython warnings on ndarray size
ignore::DeprecationWarning:ml_dtypes.*
Expand Down
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,6 @@ def DummyStellarator(tmpdir_factory):
],
),
"axis": np.array([[-1, 0, -0.2], [0, 3.4, 0], [1, 0.2, 0]]),
"objective": "force",
"optimizer": "lsq-exact",
}
eq = Equilibrium(**inputs)
eq.save(output_path)
Expand Down
Binary file modified tests/inputs/master_compute_data.pkl
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/test_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_from_input_file(self):
path = "tests/inputs/input.QSC_r2_5.5_desc"

curve1 = FourierRZCurve.from_input_file(path)
curve2 = Equilibrium(**InputReader(path).inputs[0]).axis
curve2 = Equilibrium(**InputReader(path).inputs[0], check_kwargs=False).axis
curve1.change_resolution(curve2.N)

np.testing.assert_allclose(curve1.R_n, curve2.R_n)
Expand All @@ -245,7 +245,7 @@ def test_from_input_file(self):

with pytest.warns(UserWarning):
curve3 = FourierRZCurve.from_input_file(path)
curve4 = Equilibrium(**InputReader(path).inputs[0]).axis
curve4 = Equilibrium(**InputReader(path).inputs[0], check_kwargs=False).axis
curve3.change_resolution(curve4.N)

np.testing.assert_allclose(curve3.R_n, curve4.R_n)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from desc.grid import Grid, LinearGrid
from desc.io import InputReader
from desc.objectives import get_equilibrium_objective
from desc.profiles import PowerSeriesProfile

from .utils import area_difference, compute_coords

Expand Down Expand Up @@ -408,3 +409,13 @@ def test_error_when_ndarray_or_integer_passed():
eq.compute("R", grid=1)
with pytest.raises(TypeError):
eq.compute("R", grid=np.linspace(0, 1, 10))


@pytest.mark.unit
def test_equilibrium_unused_kwargs():
"""Test that invalid kwargs raise an error, for gh issue #850."""
pres = PowerSeriesProfile()
curr = PowerSeriesProfile()
with pytest.raises(TypeError):
_ = Equilibrium(pres=pres, curr=curr)
_ = Equilibrium(pressure=pres, current=curr)
2 changes: 0 additions & 2 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,6 @@ def test_quad_grid_volume_integration(self):
"iota": np.array([[0, 0]]),
"surface": np.array([[0, 0, 0, R, 0], [0, 1, 0, r, 0], [0, -1, 0, 0, -r]]),
"spectral_indexing": "ansi",
"bdry_mode": "lcfs",
"node_pattern": "quad",
}

eq = Equilibrium(**inputs)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def test_LambdaGauge_asym():
],
),
"axis": np.array([[-1, 0, -0.2], [0, 3.4, 0], [1, 0.2, 0]]),
"objective": "force",
"optimizer": "lsq-exact",
}
eq = Equilibrium(**inputs)
lam_con = FixLambdaGauge(eq)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,3 +2060,13 @@ def test_asymmetric_normalization():
assert np.all(np.isfinite(val))
for val in scales_eq.values():
assert np.all(np.isfinite(val))


@pytest.mark.unit
def test_force_balance_axis_error():
"""Test that ForceBalance objective is not NaN if the grid contains axis."""
eq = get("SOLOVEV")
grid = LinearGrid(L=2, M=2, N=2, axis=True)
obj = ForceBalance(eq, grid=grid)
obj.build()
assert not np.any(np.isnan(obj.compute_unscaled(*obj.xs(eq))))
2 changes: 1 addition & 1 deletion tests/test_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_same_result(self):
input_path = "./tests/inputs/SOLOVEV"
ir = InputReader(input_path)

eq1 = Equilibrium(**ir.inputs[-1])
eq1 = Equilibrium(**ir.inputs[-1], check_kwargs=False)
eq2 = eq1.copy()
eq2.pressure = eq1.pressure.to_spline()
eq2.iota = eq1.iota.to_spline()
Expand Down

0 comments on commit 969774a

Please sign in to comment.