Skip to content

Commit

Permalink
Merge branch 'master' into yge/print
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma authored Aug 22, 2024
2 parents 7bd3df7 + 1c076fc commit 9ebb660
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 25 deletions.
81 changes: 67 additions & 14 deletions desc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def is_meshgrid(self):
Let the tuple (r, p, t) ∈ R³ denote a radial, poloidal, and toroidal
coordinate value. The is_meshgrid flag denotes whether any coordinate
can be iterated over along the relevant axis of the reshaped grid:
nodes.reshape(num_radial, num_poloidal, num_toroidal, 3).
nodes.reshape((num_poloidal, num_radial, num_toroidal, 3), order="F").
"""
return self.__dict__.setdefault("_is_meshgrid", False)

Expand Down Expand Up @@ -598,6 +598,52 @@ def replace_at_axis(self, x, y, copy=False, **kwargs):
)
return x

def meshgrid_reshape(self, x, order):
"""Reshape data to match grid coordinates.
Given flattened data on a tensor product grid, reshape the data such that
the axes of the array correspond to coordinate values on the grid.
Parameters
----------
x : ndarray, shape(N,) or shape(N,3)
Data to reshape.
order : str
Desired order of axes for returned data. Should be a permutation of
``grid.coordinates``, eg ``order="rtz"`` has the first axis of the returned
data correspond to different rho coordinates, the second axis to different
theta, etc. ``order="trz"`` would have the first axis correspond to theta,
and so on.
Returns
-------
x : ndarray
Data reshaped to align with grid nodes.
"""
errorif(
not self.is_meshgrid,
ValueError,
"grid is not a tensor product grid, so meshgrid_reshape doesn't "
"make any sense",
)
errorif(
sorted(order) != sorted(self.coordinates),
ValueError,
f"order should be a permutation of {self.coordinates}, got {order}",
)
shape = (self.num_poloidal, self.num_rho, self.num_zeta)
vec = False
if x.ndim > 1:
vec = True
shape += (-1,)
x = x.reshape(shape, order="F")
x = jnp.moveaxis(x, 1, 0) # now shape rtz/raz etc
newax = tuple(self.coordinates.index(c) for c in order)
if vec:
newax += (3,)
x = jnp.transpose(x, newax)
return x


class Grid(_Grid):
"""Collocation grid with custom node placement.
Expand Down Expand Up @@ -632,7 +678,7 @@ class Grid(_Grid):
Let the tuple (r, p, t) ∈ R³ denote a radial, poloidal, and toroidal
coordinate value. The is_meshgrid flag denotes whether any coordinate
can be iterated over along the relevant axis of the reshaped grid:
nodes.reshape(num_radial, num_poloidal, num_toroidal, 3).
nodes.reshape((num_poloidal, num_radial, num_toroidal, 3), order="F").
jitable : bool
Whether to skip certain checks and conditionals that don't work under jit.
Allows grid to be created on the fly with custom nodes, but weights, symmetry
Expand Down Expand Up @@ -762,11 +808,16 @@ def create_meshgrid(
dc = _periodic_spacing(c, period[2])[1] * NFP
else:
da, db, dc = spacing

bb, aa, cc = jnp.meshgrid(b, a, c, indexing="ij")

nodes = jnp.column_stack(
list(map(jnp.ravel, jnp.meshgrid(a, b, c, indexing="ij")))
[aa.flatten(order="F"), bb.flatten(order="F"), cc.flatten(order="F")]
)
bb, aa, cc = jnp.meshgrid(db, da, dc, indexing="ij")

spacing = jnp.column_stack(
list(map(jnp.ravel, jnp.meshgrid(da, db, dc, indexing="ij")))
[aa.flatten(order="F"), bb.flatten(order="F"), cc.flatten(order="F")]
)
weights = (
spacing.prod(axis=1)
Expand All @@ -776,19 +827,18 @@ def create_meshgrid(
else None
)

unique_a_idx = jnp.arange(a.size) * b.size * c.size
unique_b_idx = jnp.arange(b.size) * c.size
unique_c_idx = jnp.arange(c.size)
inverse_a_idx = repeat(
unique_a_idx // (b.size * c.size),
b.size * c.size,
total_repeat_length=a.size * b.size * c.size,
unique_a_idx = jnp.arange(a.size) * b.size
unique_b_idx = jnp.arange(b.size)
unique_c_idx = jnp.arange(c.size) * a.size * b.size
inverse_a_idx = jnp.tile(
repeat(unique_a_idx // b.size, b.size, total_repeat_length=a.size * b.size),
c.size,
)
inverse_b_idx = jnp.tile(
repeat(unique_b_idx // c.size, c.size, total_repeat_length=b.size * c.size),
a.size,
unique_b_idx,
a.size * c.size,
)
inverse_c_idx = jnp.tile(unique_c_idx, a.size * b.size)
inverse_c_idx = repeat(unique_c_idx // (a.size * b.size), (a.size * b.size))
return Grid(
nodes=nodes,
spacing=spacing,
Expand Down Expand Up @@ -908,6 +958,7 @@ def __init__(
self._toroidal_endpoint = False
self._node_pattern = "linear"
self._coordinates = "rtz"
self._is_meshgrid = True
self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP)
self._nodes, self._spacing = self._create_nodes(
L=L,
Expand Down Expand Up @@ -1200,6 +1251,7 @@ def __init__(self, L, M, N, NFP=1):
self._sym = False
self._node_pattern = "quad"
self._coordinates = "rtz"
self._is_meshgrid = True
self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP)
self._nodes, self._spacing = self._create_nodes(L=L, M=M, N=N, NFP=NFP)
# symmetry is never enforced for Quadrature Grid
Expand Down Expand Up @@ -1341,6 +1393,7 @@ def __init__(self, L, M, N, NFP=1, sym=False, axis=False, node_pattern="jacobi")
self._sym = sym
self._node_pattern = node_pattern
self._coordinates = "rtz"
self._is_meshgrid = False
self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP)
self._nodes, self._spacing = self._create_nodes(
L=L, M=M, N=N, NFP=NFP, axis=axis, node_pattern=node_pattern
Expand Down
83 changes: 72 additions & 11 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,20 +738,81 @@ def test_meshgrid(self):
"""Test meshgrid constructor."""
R = np.linspace(0, 1, 4)
A = np.linspace(0, 2 * np.pi, 2)
Z = np.linspace(0, 10 * np.pi, 3)
Z = np.linspace(0, 2 * np.pi, 3)
grid = Grid.create_meshgrid(
[R, A, Z], coordinates="raz", period=(np.inf, 2 * np.pi, np.inf)
[R, A, Z], coordinates="raz", period=(np.inf, 2 * np.pi, 2 * np.pi)
)
# treating theta == alpha just for grid construction
grid1 = LinearGrid(rho=R, theta=A, zeta=Z)
# atol=1e-12 bc Grid by default shifts points away from the axis a tiny bit
np.testing.assert_allclose(grid1.nodes, grid.nodes, atol=1e-12)
# want radial/poloidal/toroidal nodes sorted in the same order for both
np.testing.assert_allclose(grid1.unique_rho_idx, grid.unique_rho_idx)
np.testing.assert_allclose(grid1.unique_theta_idx, grid.unique_alpha_idx)
np.testing.assert_allclose(grid1.unique_zeta_idx, grid.unique_zeta_idx)
np.testing.assert_allclose(grid1.inverse_rho_idx, grid.inverse_rho_idx)
np.testing.assert_allclose(grid1.inverse_theta_idx, grid.inverse_alpha_idx)
np.testing.assert_allclose(grid1.inverse_zeta_idx, grid.inverse_zeta_idx)

@pytest.mark.unit
def test_meshgrid_reshape(self):
"""Test that reshaping meshgrids works correctly."""
grid = LinearGrid(2, 3, 4)

r = grid.nodes[grid.unique_rho_idx, 0]
t = grid.nodes[grid.unique_theta_idx, 1]
z = grid.nodes[grid.unique_zeta_idx, 2]

# user regular allclose for broadcasting to work correctly
# reshaping rtz should have rho along first axis
assert np.allclose(
grid.meshgrid_reshape(grid.nodes[:, 0], "rtz"), r[:, None, None]
)
# reshaping rzt should have theta along last axis
assert np.allclose(
grid.meshgrid_reshape(grid.nodes[:, 1], "rzt"), t[None, None, :]
)
# reshaping tzr should have zeta along 2nd axis
assert np.allclose(
grid.meshgrid_reshape(grid.nodes, "tzr")[:, :, :, 2], z[None, :, None]
)

# coordinates are rtz, not raz
with pytest.raises(ValueError):
grid.meshgrid_reshape(grid.nodes[:, 0], "raz")

# not a meshgrid
grid = ConcentricGrid(2, 3, 4)
with pytest.raises(ValueError):
grid.meshgrid_reshape(grid.nodes[:, 0], "rtz")

rho = np.linspace(0, 1, 3)
alpha = np.linspace(0, 2 * np.pi, 4)
zeta = np.linspace(0, 6 * np.pi, 5)
grid = Grid.create_meshgrid([rho, alpha, zeta], coordinates="raz")
r, a, z = grid.nodes.T
_, unique, inverse = np.unique(r, return_index=True, return_inverse=True)
np.testing.assert_allclose(grid.unique_rho_idx, unique)
np.testing.assert_allclose(grid.inverse_rho_idx, inverse)
_, unique, inverse = np.unique(a, return_index=True, return_inverse=True)
np.testing.assert_allclose(grid.unique_alpha_idx, unique)
np.testing.assert_allclose(grid.inverse_alpha_idx, inverse)
_, unique, inverse = np.unique(z, return_index=True, return_inverse=True)
np.testing.assert_allclose(grid.unique_zeta_idx, unique)
np.testing.assert_allclose(grid.inverse_zeta_idx, inverse)
r = grid.meshgrid_reshape(r, "raz")
a = grid.meshgrid_reshape(a, "raz")
z = grid.meshgrid_reshape(z, "raz")
# functions of zeta should separate along first two axes
# since those are contiguous, this should work
f = z.reshape(-1, zeta.size)
for i in range(1, f.shape[0]):
np.testing.assert_allclose(f[i - 1], f[i])
# likewise for rho
f = r.reshape(rho.size, -1)
for i in range(1, f.shape[-1]):
np.testing.assert_allclose(f[:, i - 1], f[:, i])
# test reshaping result won't mix data
f = (a**2 + z).reshape(rho.size, alpha.size, zeta.size)
for i in range(1, f.shape[0]):
np.testing.assert_allclose(f[i - 1], f[i])
f = (r**2 + z).reshape(rho.size, alpha.size, zeta.size)
for i in range(1, f.shape[1]):
np.testing.assert_allclose(f[:, i - 1], f[:, i])
f = (r**2 + a).reshape(rho.size, alpha.size, zeta.size)
for i in range(1, f.shape[-1]):
np.testing.assert_allclose(f[..., i - 1], f[..., i])


@pytest.mark.unit
Expand Down

0 comments on commit 9ebb660

Please sign in to comment.