diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 7e6faadc0c..b32ae3b0fa 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -841,7 +841,13 @@ def compute( # we first figure out what needed qtys are flux functions or volume integrals # and compute those first on a full grid p = "desc.equilibrium.equilibrium.Equilibrium" - deps = list(set(get_data_deps(names, obj=p, has_axis=grid.axis.size) + names)) + # If the user wants to compute x which depends on y which in turn depends on z, + # and they pass in y already computed in data, then we shouldn't need to compute + # z at all. + deps = list( + set(get_data_deps(names, obj=p, has_axis=grid.axis.size) + names) + - data.keys() # subtract out y if already computed + ) # TODO: replace this logic with `grid_type` from data_index dep0d = [ dep @@ -877,19 +883,31 @@ def compute( if calc0d and override_grid: grid0d = QuadratureGrid(self.L_grid, self.M_grid, self.N_grid, self.NFP) + data0d_seed = { + key: data[key] + for key in data + if data_index[p][key]["coordinates"] == "" + } data0d = compute_fun( self, dep0d, params=params, transforms=get_transforms(dep0d, obj=self, grid=grid0d, **kwargs), profiles=get_profiles(dep0d, obj=self, grid=grid0d), - data=None, + # If a dependency of something is already computed, use it + # instead of recomputing it on a potentially bad grid. + data=data0d_seed, **kwargs, ) # these should all be 0d quantities so don't need to compress/expand data0d = {key: val for key, val in data0d.items() if key in dep0d} data.update(data0d) + data0d_seed = ( + {key: data[key] for key in data if data_index[p][key]["coordinates"] == ""} + if ((calc1dr or calc1dz) and override_grid) + else {} + ) if calc1dr and override_grid: grid1dr = LinearGrid( rho=grid.nodes[grid.unique_rho_idx, 0], @@ -898,15 +916,20 @@ def compute( NFP=self.NFP, sym=self.sym, ) - # TODO: Pass in data0d as a seed once there are 1d quantities that - # depend on 0d quantities in data_index. + data1dr_seed = { + key: grid1dr.copy_data_from_other(data[key], grid, surface_label="rho") + for key in data + if data_index[p][key]["coordinates"] == "r" + } data1dr = compute_fun( self, dep1dr, params=params, transforms=get_transforms(dep1dr, obj=self, grid=grid1dr, **kwargs), profiles=get_profiles(dep1dr, obj=self, grid=grid1dr), - data=None, + # If a dependency of something is already computed, use it + # instead of recomputing it on a potentially bad grid. + data=data1dr_seed | data0d_seed, **kwargs, ) # need to make this data broadcast with the data on the original grid @@ -925,15 +948,20 @@ def compute( NFP=grid.NFP, # ex: self.NFP>1 but grid.NFP=1 for plot_3d sym=self.sym, ) - # TODO: Pass in data0d as a seed once there are 1d quantities that - # depend on 0d quantities in data_index. + data1dz_seed = { + key: grid1dz.copy_data_from_other(data[key], grid, surface_label="zeta") + for key in data + if data_index[p][key]["coordinates"] == "z" + } data1dz = compute_fun( self, dep1dz, params=params, transforms=get_transforms(dep1dz, obj=self, grid=grid1dz, **kwargs), profiles=get_profiles(dep1dz, obj=self, grid=grid1dz), - data=None, + # If a dependency of something is already computed, use it + # instead of recomputing it on a potentially bad grid. + data=data1dz_seed | data0d_seed, **kwargs, ) # need to make this data broadcast with the data on the original grid diff --git a/tests/inputs/master_compute_data.pkl b/tests/inputs/master_compute_data.pkl index 4e72ac112d..e5ca71d53f 100644 Binary files a/tests/inputs/master_compute_data.pkl and b/tests/inputs/master_compute_data.pkl differ diff --git a/tests/test_axis_limits.py b/tests/test_axis_limits.py index 1316192c8c..9e01fcf0b5 100644 --- a/tests/test_axis_limits.py +++ b/tests/test_axis_limits.py @@ -306,8 +306,7 @@ def test_limit_continuity(self): "alpha_r": {"rtol": 1e-3}, } zero_map = dict.fromkeys(zero_limits, {"desired_at_axis": 0}) - # same as 'weaker_tolerance | zero_limit', but works on Python 3.8 (PEP 584) - kwargs = dict(weaker_tolerance, **zero_map) + kwargs = weaker_tolerance | zero_map # fixed iota eq = get("W7-X") eq.change_resolution(4, 4, 4, 8, 8, 8)