Skip to content

Commit

Permalink
Avoid casting dict to array
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Oct 13, 2023
1 parent ec3abc7 commit f756622
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
6 changes: 3 additions & 3 deletions desc/compute/_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def _Phi_CurrentPotentialField(params, transforms, profiles, data, **kwargs):
data["Phi"] = transforms["potential"](
transforms["grid"].nodes[:, 1],
transforms["grid"].nodes[:, 2],
**params["params"][0]
**params["params"]
)
return data

Expand All @@ -600,7 +600,7 @@ def _Phi_t_CurrentPotentialField(params, transforms, profiles, data, **kwargs):
data["Phi_t"] = transforms["potential_dtheta"](
transforms["grid"].nodes[:, 1],
transforms["grid"].nodes[:, 2],
**params["params"][0]
**params["params"]
)
return data

Expand All @@ -623,7 +623,7 @@ def _Phi_z_CurrentPotentialField(params, transforms, profiles, data, **kwargs):
data["Phi_z"] = transforms["potential_dzeta"](
transforms["grid"].nodes[:, 1],
transforms["grid"].nodes[:, 2],
**params["params"][0]
**params["params"]
)
return data

Expand Down
10 changes: 8 additions & 2 deletions desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,14 @@ def get_params(keys, obj, has_axis=False, **kwargs):
params = _sort_args(list(set(params)))
if isinstance(obj, str) or inspect.isclass(obj):
return params
params = {name: np.atleast_1d(getattr(obj, name)).copy() for name in params}
return params
temp_params = {}
for name in params:
p = getattr(obj, name)
if isinstance(p, dict):
temp_params[name] = p.copy()
else:
temp_params[name] = jnp.atleast_1d(p)
return temp_params


def get_transforms(keys, obj, grid, jitable=False, **kwargs):
Expand Down

0 comments on commit f756622

Please sign in to comment.