From 4605ad1afb829c7943415ca70f29427a00e5c302 Mon Sep 17 00:00:00 2001 From: Kaya Unalmis Date: Tue, 8 Aug 2023 22:06:40 -0500 Subject: [PATCH] Add new surface quantities needed after merging #583... - Switch test marked regression to unit for same reason as git commit 44f2877ac90c5b62a21a9ffafc413ca1061a6e41 - Remove unneded test for limits with fix iota --- desc/compute/_basis_vectors.py | 4 + desc/compute/_metric.py | 16 ++++ desc/compute/_surface.py | 158 ++++++++++++++++++++++++++++++++ desc/compute/data_index.py | 2 +- tests/test_axis_limits.py | 119 ++++++++++++++---------- tests/test_compute_funs.py | 4 +- tests/test_constrain_current.py | 4 +- 7 files changed, 255 insertions(+), 52 deletions(-) diff --git a/desc/compute/_basis_vectors.py b/desc/compute/_basis_vectors.py index 9d34eaa79..6bfe5b83b 100644 --- a/desc/compute/_basis_vectors.py +++ b/desc/compute/_basis_vectors.py @@ -196,6 +196,10 @@ def _e_sup_theta(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["e_rho", "e_zeta"], + parameterization=[ + "desc.equilibrium.equilibrium.Equilibrium", + "desc.geometry.core.Surface", + ], ) def _e_sup_theta_times_sqrt_g(params, transforms, profiles, data, **kwargs): data["e^theta*sqrt(g)"] = cross(data["e_zeta"], data["e_rho"]) diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index 4c8ba3690..35262506e 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -90,6 +90,10 @@ def _e_theta_x_e_zeta(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["e_theta", "e_zeta", "e_theta_r", "e_zeta_r"], + parameterization=[ + "desc.equilibrium.equilibrium.Equilibrium", + "desc.geometry.core.Surface", + ], ) def _e_theta_x_e_zeta_r(params, transforms, profiles, data, **kwargs): a = cross(data["e_theta"], data["e_zeta"]) @@ -120,6 +124,10 @@ def _e_theta_x_e_zeta_r(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["e_theta", "e_zeta", "e_theta_r", "e_zeta_r", "e_theta_rr", "e_zeta_rr"], + parameterization=[ + "desc.equilibrium.equilibrium.Equilibrium", + "desc.geometry.core.Surface", + ], ) def _e_theta_x_e_zeta_rr(params, transforms, profiles, data, **kwargs): a = cross(data["e_theta"], data["e_zeta"]) @@ -201,6 +209,10 @@ def _e_rho_x_e_theta(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["e_rho", "e_theta", "e_rho_r", "e_theta_r"], + parameterization=[ + "desc.equilibrium.equilibrium.Equilibrium", + "desc.geometry.core.Surface", + ], ) def _e_rho_x_e_theta_r(params, transforms, profiles, data, **kwargs): a = cross(data["e_rho"], data["e_theta"]) @@ -231,6 +243,10 @@ def _e_rho_x_e_theta_r(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["e_rho", "e_theta", "e_rho_r", "e_theta_r", "e_rho_rr", "e_theta_rr"], + parameterization=[ + "desc.equilibrium.equilibrium.Equilibrium", + "desc.geometry.core.Surface", + ], ) def _e_rho_x_e_theta_rr(params, transforms, profiles, data, **kwargs): a = cross(data["e_rho"], data["e_theta"]) diff --git a/desc/compute/_surface.py b/desc/compute/_surface.py index c6b3981e7..7c45d7df1 100644 --- a/desc/compute/_surface.py +++ b/desc/compute/_surface.py @@ -141,6 +141,30 @@ def _e_rho_r_FourierRZToroidalSurface(params, transforms, profiles, data, **kwar return data +@register_compute_fun( + name="e_rho_rr", + label="\\partial_{\\rho \\rho} \\mathbf{e}_{\\rho}", + units="m", + units_long="meters", + description="Covariant radial basis vector," + " second derivative wrt radial coordinate", + dim=3, + params=[], + transforms={ + "grid": [], + }, + profiles=[], + coordinates="tz", + data=[], + parameterization="desc.geometry.surface.FourierRZToroidalSurface", + basis="basis", +) +def _e_rho_rr_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs): + coords = jnp.zeros((transforms["grid"].num_nodes, 3)) + data["e_rho_rr"] = coords + return data + + @register_compute_fun( name="e_rho_t", label="\\partial_{\\theta} \\mathbf{e}_{\\rho}", @@ -210,6 +234,30 @@ def _e_theta_r_FourierRZToroidalSurface(params, transforms, profiles, data, **kw return data +@register_compute_fun( + name="e_theta_rr", + label="\\partial_{\\rho \\rho} \\mathbf{e}_{\\theta}", + units="m", + units_long="meters", + description="Covariant poloidal basis vector," + " second derivative wrt radial coordinate", + dim=3, + params=[], + transforms={ + "grid": [], + }, + profiles=[], + coordinates="tz", + data=[], + parameterization="desc.geometry.surface.FourierRZToroidalSurface", + basis="basis", +) +def _e_theta_rr_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs): + coords = jnp.zeros((transforms["grid"].num_nodes, 3)) + data["e_theta_rr"] = coords + return data + + @register_compute_fun( name="e_theta_t", label="\\partial_{\\theta} \\mathbf{e}_{\\theta}", @@ -293,6 +341,30 @@ def _e_zeta_r_FourierRZToroidalSurface(params, transforms, profiles, data, **kwa return data +@register_compute_fun( + name="e_zeta_rr", + label="\\partial_{\\rho \\rho} \\mathbf{e}_{\\zeta}", + units="m", + units_long="meters", + description="Covariant toroidal basis vector," + " second derivative wrt radial coordinate", + dim=3, + params=[], + transforms={ + "grid": [], + }, + profiles=[], + coordinates="tz", + data=[], + parameterization="desc.geometry.surface.FourierRZToroidalSurface", + basis="basis", +) +def _e_zeta_rr_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs): + coords = jnp.zeros((transforms["grid"].num_nodes, 3)) + data["e_zeta_rr"] = coords + return data + + @register_compute_fun( name="e_zeta_t", label="\\partial_{\\theta} \\mathbf{e}_{\\zeta}", @@ -498,6 +570,37 @@ def _e_rho_r_ZernikeRZToroidalSection(params, transforms, profiles, data, **kwar return data +@register_compute_fun( + name="e_rho_rr", + label="\\partial_{\\rho \\rho} \\mathbf{e}_{\\rho}", + units="m", + units_long="meters", + description="Covariant radial basis vector," + " second derivative wrt radial coordinate", + dim=3, + params=["R_lmn", "Z_lmn"], + transforms={ + "R": [[3, 0, 0]], + "Z": [[3, 0, 0]], + "grid": [], + }, + profiles=[], + coordinates="rt", + data=[], + parameterization="desc.geometry.surface.ZernikeRZToroidalSection", + basis="basis", +) +def _e_rho_rr_ZernikeRZToroidalSection(params, transforms, profiles, data, **kwargs): + R = transforms["R"].transform(params["R_lmn"], dr=3) + Z = transforms["Z"].transform(params["Z_lmn"], dr=3) + phi = jnp.zeros(transforms["grid"].num_nodes) + coords = jnp.stack([R, phi, Z], axis=1) + if kwargs.get("basis", "rpz").lower() == "xyz": + coords = rpz2xyz(coords) + data["e_rho_rr"] = coords + return data + + @register_compute_fun( name="e_rho_t", label="\\partial_{\\theta} \\mathbf{e}_{\\rho}", @@ -581,6 +684,37 @@ def _e_theta_r_ZernikeRZToroidalSection(params, transforms, profiles, data, **kw return data +@register_compute_fun( + name="e_theta_rr", + label="\\partial_{\\rho \\rho} \\mathbf{e}_{\\theta}", + units="m", + units_long="meters", + description="Covariant poloidal basis vector," + " second derivative wrt radial coordinate", + dim=3, + params=["R_lmn", "Z_lmn"], + transforms={ + "R": [[2, 1, 0]], + "Z": [[2, 1, 0]], + "grid": [], + }, + profiles=[], + coordinates="rt", + data=[], + parameterization="desc.geometry.surface.ZernikeRZToroidalSection", + basis="basis", +) +def _e_theta_rr_ZernikeRZToroidalSection(params, transforms, profiles, data, **kwargs): + R = transforms["R"].transform(params["R_lmn"], dr=2, dt=1) + Z = transforms["Z"].transform(params["Z_lmn"], dr=2, dt=1) + phi = jnp.zeros(transforms["grid"].num_nodes) + coords = jnp.stack([R, phi, Z], axis=1) + if kwargs.get("basis", "rpz").lower() == "xyz": + coords = rpz2xyz(coords) + data["e_theta_rr"] = coords + return data + + @register_compute_fun( name="e_theta_t", label="\\partial_{\\theta} \\mathbf{e}_{\\theta}", @@ -657,6 +791,30 @@ def _e_zeta_r_ZernikeRZToroidalSection(params, transforms, profiles, data, **kwa return data +@register_compute_fun( + name="e_zeta_rr", + label="\\partial_{\\rho \\rho} \\mathbf{e}_{\\zeta}", + units="m", + units_long="meters", + description="Covariant toroidal basis vector," + " second derivative wrt radial coordinate", + dim=3, + params=[], + transforms={ + "grid": [], + }, + profiles=[], + coordinates="rt", + data=[], + parameterization="desc.geometry.surface.ZernikeRZToroidalSection", + basis="basis", +) +def _e_zeta_rr_ZernikeRZToroidalSection(params, transforms, profiles, data, **kwargs): + coords = jnp.zeros((transforms["grid"].num_nodes, 3)) + data["e_zeta_rr"] = coords + return data + + @register_compute_fun( name="e_zeta_t", label="\\partial_{\\theta} \\mathbf{e}_{\\zeta}", diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index 5d595326f..59a9025e1 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -46,7 +46,7 @@ def register_compute_fun( a flux function, etc. data : list of str Names of other items in the data index needed to compute qty. - parameterization: str + parameterization: str or list of str Name of desc types the method is valid for. eg 'desc.geometry.FourierXYZCurve' or `desc.equilibrium.Equilibrium`. axis_limit_data : list of str diff --git a/tests/test_axis_limits.py b/tests/test_axis_limits.py index f4b99718a..5e1512e79 100644 --- a/tests/test_axis_limits.py +++ b/tests/test_axis_limits.py @@ -7,6 +7,7 @@ import desc.compute from desc.compute import data_index +from desc.compute.data_index import _class_inheritance from desc.compute.utils import surface_integrals_map from desc.equilibrium import Equilibrium from desc.examples import get @@ -16,7 +17,7 @@ # made to compute the magnetic axis limit can be reduced to assuming that these # functions tend toward zero as the magnetic axis is approached and that # d^2šœ“/(dšœŒ)^2 and šœ•āˆšš‘”/šœ•šœŒ are both finite nonzero at the magnetic axis. -# Also d^nšœ“/(dšœŒ)^n for n > 3 is assumed zero everywhere. +# Also, d^nšœ“/(dšœŒ)^n for n > 3 is assumed zero everywhere. zero_limits = {"rho", "psi", "psi_r", "e_theta", "sqrt(g)", "B_t"} not_finite_limits = { @@ -216,10 +217,23 @@ def get_matches(fun, pattern): # attempt to remove comments src = "\n".join(line.partition("#")[0] for line in src.splitlines()) matches = pattern.findall(src) - matches = {s.replace("'", "").replace('"', "") for s in matches} + matches = {s.strip().strip('"') for s in matches} return matches +def get_parameterization(fun, default="desc.equilibrium.equilibrium.Equilibrium"): + """Get parameterization of thing computed by function ``fun``.""" + pattern = re.compile(r'parameterization=(?:\[([^]]+)]|"([^"]+)")') + decorator = inspect.getsource(fun).partition("def ")[0] + matches = pattern.findall(decorator) + # if list was found, split strings in list, else string was found so just get that + matches = [match[0].split(",") if match[0] else [match[1]] for match in matches] + # flatten the list + matches = {s.strip().strip('"') for sublist in matches for s in sublist} + matches.discard("") + return matches if matches else {default} + + class TestAxisLimits: """Tests for compute functions evaluated at limits.""" @@ -227,41 +241,55 @@ class TestAxisLimits: def test_data_index_deps(self): """Ensure developers do not add extra (or forget needed) dependencies.""" queried_deps = {} - pattern_keys = re.compile(r"(?