Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with recomputing quantities on incorrect grid #1006

Merged
merged 9 commits into from
May 2, 2024
31 changes: 27 additions & 4 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,13 @@ def compute(
params=params,
transforms=get_transforms(dep0d, obj=self, grid=grid0d, **kwargs),
profiles=get_profiles(dep0d, obj=self, grid=grid0d),
data=None,
# If a dependency of something is already computed, use it
# instead of recomputing it on a potentially bad grid.
data={
key: data[key]
for key in data
if data_index[p][key]["coordinates"] == ""
},
**kwargs,
)
# these should all be 0d quantities so don't need to compress/expand
Expand All @@ -899,14 +905,22 @@ def compute(
sym=self.sym,
)
# TODO: Pass in data0d as a seed once there are 1d quantities that
# depend on 0d quantities in data_index.
# depend on 0d quantities in data_index.
data1dr = compute_fun(
self,
dep1dr,
params=params,
transforms=get_transforms(dep1dr, obj=self, grid=grid1dr, **kwargs),
profiles=get_profiles(dep1dr, obj=self, grid=grid1dr),
data=None,
# If a dependency of something is already computed, use it
# instead of recomputing it on a potentially bad grid.
data={
key: grid1dr.copy_data_from_other(
data[key], grid, surface_label="rho"
)
for key in data
if data_index[p][key]["coordinates"] == "r"
},
**kwargs,
)
# need to make this data broadcast with the data on the original grid
Expand All @@ -915,6 +929,7 @@ def compute(
for key, val in data1dr.items()
if key in dep1dr
}

data.update(data1dr)

if calc1dz and override_grid:
Expand All @@ -933,7 +948,15 @@ def compute(
params=params,
transforms=get_transforms(dep1dz, obj=self, grid=grid1dz, **kwargs),
profiles=get_profiles(dep1dz, obj=self, grid=grid1dz),
data=None,
# If a dependency of something is already computed, use it
# instead of recomputing it on a potentially bad grid.
data={
key: grid1dz.copy_data_from_other(
data[key], grid, surface_label="zeta"
)
for key in data
if data_index[p][key]["coordinates"] == "z"
},
**kwargs,
)
# need to make this data broadcast with the data on the original grid
Expand Down
Loading