diff --git a/desc/grid.py b/desc/grid.py index 06ce1329ae..ee471e5d1b 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -216,7 +216,7 @@ def is_meshgrid(self): Let the tuple (r, p, t) ∈ R³ denote a radial, poloidal, and toroidal coordinate value. The is_meshgrid flag denotes whether any coordinate can be iterated over along the relevant axis of the reshaped grid: - nodes.reshape(num_radial, num_poloidal, num_toroidal, 3). + nodes.reshape((num_poloidal, num_radial, num_toroidal, 3), order="F"). """ return self.__dict__.setdefault("_is_meshgrid", False) @@ -598,6 +598,52 @@ def replace_at_axis(self, x, y, copy=False, **kwargs): ) return x + def meshgrid_reshape(self, x, order): + """Reshape data to match grid coordinates. + + Given flattened data on a tensor product grid, reshape the data such that + the axes of the array correspond to coordinate values on the grid. + + Parameters + ---------- + x : ndarray, shape(N,) or shape(N,3) + Data to reshape. + order : str + Desired order of axes for returned data. Should be a permutation of + ``grid.coordinates``, eg ``order="rtz"`` has the first axis of the returned + data correspond to different rho coordinates, the second axis to different + theta, etc. ``order="trz"`` would have the first axis correspond to theta, + and so on. + + Returns + ------- + x : ndarray + Data reshaped to align with grid nodes. + """ + errorif( + not self.is_meshgrid, + ValueError, + "grid is not a tensor product grid, so meshgrid_reshape doesn't " + "make any sense", + ) + errorif( + sorted(order) != sorted(self.coordinates), + ValueError, + f"order should be a permutation of {self.coordinates}, got {order}", + ) + shape = (self.num_poloidal, self.num_rho, self.num_zeta) + vec = False + if x.ndim > 1: + vec = True + shape += (-1,) + x = x.reshape(shape, order="F") + x = jnp.moveaxis(x, 1, 0) # now shape rtz/raz etc + newax = tuple(self.coordinates.index(c) for c in order) + if vec: + newax += (3,) + x = jnp.transpose(x, newax) + return x + class Grid(_Grid): """Collocation grid with custom node placement. @@ -632,7 +678,7 @@ class Grid(_Grid): Let the tuple (r, p, t) ∈ R³ denote a radial, poloidal, and toroidal coordinate value. The is_meshgrid flag denotes whether any coordinate can be iterated over along the relevant axis of the reshaped grid: - nodes.reshape(num_radial, num_poloidal, num_toroidal, 3). + nodes.reshape((num_poloidal, num_radial, num_toroidal, 3), order="F"). jitable : bool Whether to skip certain checks and conditionals that don't work under jit. Allows grid to be created on the fly with custom nodes, but weights, symmetry @@ -762,11 +808,16 @@ def create_meshgrid( dc = _periodic_spacing(c, period[2])[1] * NFP else: da, db, dc = spacing + + bb, aa, cc = jnp.meshgrid(b, a, c, indexing="ij") + nodes = jnp.column_stack( - list(map(jnp.ravel, jnp.meshgrid(a, b, c, indexing="ij"))) + [aa.flatten(order="F"), bb.flatten(order="F"), cc.flatten(order="F")] ) + bb, aa, cc = jnp.meshgrid(db, da, dc, indexing="ij") + spacing = jnp.column_stack( - list(map(jnp.ravel, jnp.meshgrid(da, db, dc, indexing="ij"))) + [aa.flatten(order="F"), bb.flatten(order="F"), cc.flatten(order="F")] ) weights = ( spacing.prod(axis=1) @@ -776,19 +827,18 @@ def create_meshgrid( else None ) - unique_a_idx = jnp.arange(a.size) * b.size * c.size - unique_b_idx = jnp.arange(b.size) * c.size - unique_c_idx = jnp.arange(c.size) - inverse_a_idx = repeat( - unique_a_idx // (b.size * c.size), - b.size * c.size, - total_repeat_length=a.size * b.size * c.size, + unique_a_idx = jnp.arange(a.size) * b.size + unique_b_idx = jnp.arange(b.size) + unique_c_idx = jnp.arange(c.size) * a.size * b.size + inverse_a_idx = jnp.tile( + repeat(unique_a_idx // b.size, b.size, total_repeat_length=a.size * b.size), + c.size, ) inverse_b_idx = jnp.tile( - repeat(unique_b_idx // c.size, c.size, total_repeat_length=b.size * c.size), - a.size, + unique_b_idx, + a.size * c.size, ) - inverse_c_idx = jnp.tile(unique_c_idx, a.size * b.size) + inverse_c_idx = repeat(unique_c_idx // (a.size * b.size), (a.size * b.size)) return Grid( nodes=nodes, spacing=spacing, @@ -908,6 +958,7 @@ def __init__( self._toroidal_endpoint = False self._node_pattern = "linear" self._coordinates = "rtz" + self._is_meshgrid = True self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP) self._nodes, self._spacing = self._create_nodes( L=L, @@ -1200,6 +1251,7 @@ def __init__(self, L, M, N, NFP=1): self._sym = False self._node_pattern = "quad" self._coordinates = "rtz" + self._is_meshgrid = True self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP) self._nodes, self._spacing = self._create_nodes(L=L, M=M, N=N, NFP=NFP) # symmetry is never enforced for Quadrature Grid @@ -1341,6 +1393,7 @@ def __init__(self, L, M, N, NFP=1, sym=False, axis=False, node_pattern="jacobi") self._sym = sym self._node_pattern = node_pattern self._coordinates = "rtz" + self._is_meshgrid = False self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP) self._nodes, self._spacing = self._create_nodes( L=L, M=M, N=N, NFP=NFP, axis=axis, node_pattern=node_pattern diff --git a/tests/test_grid.py b/tests/test_grid.py index 9d298b8c85..160c6aac9c 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -738,20 +738,81 @@ def test_meshgrid(self): """Test meshgrid constructor.""" R = np.linspace(0, 1, 4) A = np.linspace(0, 2 * np.pi, 2) - Z = np.linspace(0, 10 * np.pi, 3) + Z = np.linspace(0, 2 * np.pi, 3) grid = Grid.create_meshgrid( - [R, A, Z], coordinates="raz", period=(np.inf, 2 * np.pi, np.inf) + [R, A, Z], coordinates="raz", period=(np.inf, 2 * np.pi, 2 * np.pi) ) + # treating theta == alpha just for grid construction + grid1 = LinearGrid(rho=R, theta=A, zeta=Z) + # atol=1e-12 bc Grid by default shifts points away from the axis a tiny bit + np.testing.assert_allclose(grid1.nodes, grid.nodes, atol=1e-12) + # want radial/poloidal/toroidal nodes sorted in the same order for both + np.testing.assert_allclose(grid1.unique_rho_idx, grid.unique_rho_idx) + np.testing.assert_allclose(grid1.unique_theta_idx, grid.unique_alpha_idx) + np.testing.assert_allclose(grid1.unique_zeta_idx, grid.unique_zeta_idx) + np.testing.assert_allclose(grid1.inverse_rho_idx, grid.inverse_rho_idx) + np.testing.assert_allclose(grid1.inverse_theta_idx, grid.inverse_alpha_idx) + np.testing.assert_allclose(grid1.inverse_zeta_idx, grid.inverse_zeta_idx) + + @pytest.mark.unit + def test_meshgrid_reshape(self): + """Test that reshaping meshgrids works correctly.""" + grid = LinearGrid(2, 3, 4) + + r = grid.nodes[grid.unique_rho_idx, 0] + t = grid.nodes[grid.unique_theta_idx, 1] + z = grid.nodes[grid.unique_zeta_idx, 2] + + # user regular allclose for broadcasting to work correctly + # reshaping rtz should have rho along first axis + assert np.allclose( + grid.meshgrid_reshape(grid.nodes[:, 0], "rtz"), r[:, None, None] + ) + # reshaping rzt should have theta along last axis + assert np.allclose( + grid.meshgrid_reshape(grid.nodes[:, 1], "rzt"), t[None, None, :] + ) + # reshaping tzr should have zeta along 2nd axis + assert np.allclose( + grid.meshgrid_reshape(grid.nodes, "tzr")[:, :, :, 2], z[None, :, None] + ) + + # coordinates are rtz, not raz + with pytest.raises(ValueError): + grid.meshgrid_reshape(grid.nodes[:, 0], "raz") + + # not a meshgrid + grid = ConcentricGrid(2, 3, 4) + with pytest.raises(ValueError): + grid.meshgrid_reshape(grid.nodes[:, 0], "rtz") + + rho = np.linspace(0, 1, 3) + alpha = np.linspace(0, 2 * np.pi, 4) + zeta = np.linspace(0, 6 * np.pi, 5) + grid = Grid.create_meshgrid([rho, alpha, zeta], coordinates="raz") r, a, z = grid.nodes.T - _, unique, inverse = np.unique(r, return_index=True, return_inverse=True) - np.testing.assert_allclose(grid.unique_rho_idx, unique) - np.testing.assert_allclose(grid.inverse_rho_idx, inverse) - _, unique, inverse = np.unique(a, return_index=True, return_inverse=True) - np.testing.assert_allclose(grid.unique_alpha_idx, unique) - np.testing.assert_allclose(grid.inverse_alpha_idx, inverse) - _, unique, inverse = np.unique(z, return_index=True, return_inverse=True) - np.testing.assert_allclose(grid.unique_zeta_idx, unique) - np.testing.assert_allclose(grid.inverse_zeta_idx, inverse) + r = grid.meshgrid_reshape(r, "raz") + a = grid.meshgrid_reshape(a, "raz") + z = grid.meshgrid_reshape(z, "raz") + # functions of zeta should separate along first two axes + # since those are contiguous, this should work + f = z.reshape(-1, zeta.size) + for i in range(1, f.shape[0]): + np.testing.assert_allclose(f[i - 1], f[i]) + # likewise for rho + f = r.reshape(rho.size, -1) + for i in range(1, f.shape[-1]): + np.testing.assert_allclose(f[:, i - 1], f[:, i]) + # test reshaping result won't mix data + f = (a**2 + z).reshape(rho.size, alpha.size, zeta.size) + for i in range(1, f.shape[0]): + np.testing.assert_allclose(f[i - 1], f[i]) + f = (r**2 + z).reshape(rho.size, alpha.size, zeta.size) + for i in range(1, f.shape[1]): + np.testing.assert_allclose(f[:, i - 1], f[:, i]) + f = (r**2 + a).reshape(rho.size, alpha.size, zeta.size) + for i in range(1, f.shape[-1]): + np.testing.assert_allclose(f[..., i - 1], f[..., i]) @pytest.mark.unit