Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add linear operators computing gradients in non-Cartesian coordinates #536

Merged
merged 111 commits into from
Jul 19, 2024
Merged
Changes from 1 commit
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
1903072
Bump version number
bwohlberg Sep 30, 2021
5d3a221
Trivial edit
bwohlberg Sep 30, 2021
86a7c67
Add non-Cartesian gradient linops
bwohlberg Sep 30, 2021
a8b1e77
Bug fix
bwohlberg Sep 30, 2021
9a0d6e9
Merge branch 'main' into brendt/polartv
bwohlberg Oct 7, 2021
c1d1ebe
Merge branch 'main' into brendt/polartv
bwohlberg Oct 9, 2021
8b334a1
Merge branch 'main' into brendt/polartv
bwohlberg Oct 13, 2021
e06dbbd
Merge branch 'main' into brendt/polartv
bwohlberg Oct 18, 2021
51093f5
Merge branch 'main' into brendt/polartv
bwohlberg Oct 18, 2021
6475211
Merge branch 'main' into brendt/polartv
bwohlberg Oct 23, 2021
ca3e345
Merge branch 'main' into brendt/polartv
bwohlberg Oct 27, 2021
1408170
Merge branch 'main' into brendt/polartv
bwohlberg Oct 29, 2021
10daa71
Merge branch 'main' into brendt/polartv
bwohlberg Nov 5, 2021
2c25777
Merge branch 'main' into brendt/polartv
bwohlberg Nov 13, 2021
87ef78d
Merge branch 'main' into brendt/polartv
bwohlberg Nov 17, 2021
f0360f1
Merge branch 'main' into brendt/polartv
bwohlberg Dec 7, 2021
16b6d9e
Style compliance
bwohlberg Dec 7, 2021
ee921ac
Merge branch 'main' into brendt/polartv
bwohlberg Dec 16, 2021
fb31ec2
Merge branch 'main' into brendt/polartv
bwohlberg Dec 21, 2021
158a722
Merge branch 'main' into brendt/polartv
bwohlberg Dec 24, 2021
14f1bd4
Merge branch 'main' into brendt/polartv
bwohlberg Jan 6, 2022
02c9f7d
Merge branch 'main' into brendt/polartv
bwohlberg Feb 8, 2022
587de87
Merge branch 'main' into brendt/polartv
bwohlberg Feb 9, 2022
57b7252
Merge branch 'main' into brendt/polartv
bwohlberg Feb 15, 2022
56e8039
Merge branch 'main' into brendt/polartv
bwohlberg Feb 15, 2022
f5c71bc
Merge branch 'main' into brendt/polartv
bwohlberg Feb 23, 2022
384fd88
Merge branch 'main' into brendt/polartv
bwohlberg Feb 25, 2022
f46326b
Merge branch 'main' into brendt/polartv
bwohlberg Mar 4, 2022
b5e8d14
Clean up
bwohlberg Mar 5, 2022
2e31db2
Minor docstring cleanup
bwohlberg Mar 7, 2022
869274c
Trivial edit
bwohlberg Mar 8, 2022
ea0629e
Merge branch 'main' into brendt/polartv
bwohlberg Mar 8, 2022
e8c1134
Typing fixes
bwohlberg Mar 8, 2022
a953ef2
Clean up logic
bwohlberg Mar 8, 2022
f7fac52
Trivial edits
bwohlberg Mar 8, 2022
bb97851
Add support and corresponding tests for non-increasing axes indices
bwohlberg Mar 8, 2022
5ae5120
Trivial edit
bwohlberg Mar 8, 2022
b360abf
Docstring edits
bwohlberg Mar 8, 2022
5a65c9a
Docs improvements
bwohlberg Mar 8, 2022
4b4167b
Merge branch 'main' into brendt/polartv
bwohlberg Mar 8, 2022
271205d
Update submodule
bwohlberg Mar 22, 2022
41ac860
Move local docs figures directory
bwohlberg Mar 22, 2022
8d6ab32
Merge branch 'main' into brendt/polartv
bwohlberg Mar 22, 2022
3aa24c3
Remove multiple occurrences
bwohlberg Mar 22, 2022
f4134bd
Merge remote-tracking branch 'origin/main' into brendt/polartv
bwohlberg Mar 22, 2022
ac17509
Merge branch 'main' into brendt/polartv
bwohlberg Mar 29, 2022
f1452b7
Merge branch 'main' into brendt/polartv
Apr 5, 2022
3f843bd
Merge branch 'main' into brendt/polartv
bwohlberg Apr 22, 2022
b223a84
Merge branch 'main' into brendt/polartv
bwohlberg May 1, 2022
dfc8646
Merge branch 'main' into brendt/polartv
bwohlberg May 6, 2022
d8cadc4
Merge branch 'main' into brendt/polartv
bwohlberg May 11, 2022
bae466e
Merge branch 'main' into brendt/polartv
bwohlberg May 16, 2022
c327b30
Merge branch 'main' into brendt/polartv
bwohlberg May 24, 2022
977d10b
Merge branch 'main' into brendt/polartv
bwohlberg Jun 1, 2022
38612c9
Merge branch 'main' into brendt/polartv
bwohlberg Aug 4, 2022
93ca97a
Merge branch 'main' into brendt/polartv
bwohlberg Sep 2, 2022
694309b
Merge branch 'main' into brendt/polartv
bwohlberg Sep 14, 2022
b128967
Changes required by new blockarray implementation
bwohlberg Sep 14, 2022
d83106c
Fix merge error
bwohlberg Sep 14, 2022
8e2ad43
Merge branch 'main' into brendt/polartv
bwohlberg Sep 15, 2022
247e927
Merge branch 'main' into brendt/polartv
bwohlberg Sep 23, 2022
47cecbe
Merge branch 'main' into brendt/polartv
bwohlberg Nov 23, 2022
8e71b01
Merge branch 'main' into brendt/polartv
bwohlberg Feb 8, 2023
e16cf3b
Merge branch 'main' into brendt/polartv
bwohlberg Feb 8, 2023
652394c
Update submodule
bwohlberg Feb 8, 2023
796f2d2
Merge branch 'main' into brendt/polartv
bwohlberg Apr 12, 2023
7f2d49e
Merge branch 'main' into brendt/polartv
bwohlberg Apr 21, 2023
56f9e67
Merge branch 'main' into brendt/polartv
bwohlberg May 3, 2023
05ac637
Replace JaxArray with jax.Array and other cleanup
bwohlberg May 3, 2023
cf36c0d
Fix tests
bwohlberg May 3, 2023
78a33f9
Use snp instead of np functions
bwohlberg May 3, 2023
2d496c0
Fix dtype handling when double float enabled
bwohlberg May 3, 2023
734a25e
Merge branch 'main' into brendt/polartv
bwohlberg May 5, 2023
93f842d
Merge branch 'main' into brendt/polartv
bwohlberg May 19, 2023
8ce6281
Add gradient type option
bwohlberg May 19, 2023
095feb3
Polar TV example script
bwohlberg May 19, 2023
b10c0f3
Merge branch 'main' into brendt/polartv
bwohlberg May 26, 2023
4f8fd3d
Merge branch 'main' into brendt/polartv
bwohlberg Jun 19, 2023
2124db7
Restore extension removed in merge
bwohlberg Jun 19, 2023
527d999
Merge branch 'main' into brendt/polartv
bwohlberg Jul 19, 2023
f87f8fc
Merge branch 'main' into brendt/polartv
bwohlberg Jul 20, 2023
daea17b
Merge branch 'main' into brendt/polartv
bwohlberg Aug 4, 2023
66699f6
Merge branch 'main' into brendt/polartv
bwohlberg Nov 4, 2023
584070b
Merge branch 'main' into brendt/polartv
bwohlberg Nov 10, 2023
33bf73e
Merge branch 'main' into brendt/polartv
bwohlberg Nov 16, 2023
ad0dec4
Merge branch 'main' into brendt/polartv
bwohlberg Dec 1, 2023
7f5fa75
Merge branch 'main' into brendt/polartv
bwohlberg Feb 20, 2024
82b249e
Merge branch 'main' into brendt/polartv
bwohlberg Mar 1, 2024
63c6c42
Merge branch 'main' into brendt/polartv
bwohlberg May 7, 2024
9cd5e65
Merge branch 'main' into brendt/polartv
bwohlberg May 28, 2024
4bc07c2
Merge branch 'main' into brendt/polartv
bwohlberg Jun 11, 2024
ae4e91c
Don't include source link in figures
bwohlberg Jun 12, 2024
7d41360
Minor improvement
bwohlberg Jun 12, 2024
9d3679b
Merge branch 'main' into brendt/polartv
bwohlberg Jun 25, 2024
b0fe8f2
Move files to scico-data submodule
bwohlberg Jun 26, 2024
78a364e
Update figure path
bwohlberg Jun 26, 2024
18149d9
Add stub bibtex entry
bwohlberg Jun 26, 2024
d1b68e4
Update submodule
bwohlberg Jun 26, 2024
0e627c7
Clean up example scripts
bwohlberg Jun 26, 2024
ef80011
Add polar tv example to index
bwohlberg Jun 26, 2024
a28c44a
Update generated index files
bwohlberg Jun 26, 2024
2b4a98a
Update submodule
bwohlberg Jun 26, 2024
11cd312
Update change summary
bwohlberg Jun 26, 2024
7095558
Add arxiv reference
bwohlberg Jun 27, 2024
08f2a63
Fix typing errors
bwohlberg Jun 27, 2024
f30f647
Fix pygraphviz version in docs build requirements
bwohlberg Jun 27, 2024
10642d7
Address furo config warning
bwohlberg Jun 27, 2024
0bea44e
Update submodule
bwohlberg Jun 27, 2024
19c0bb8
Merge branch 'main' into brendt/polartv
bwohlberg Jul 9, 2024
1c5e3be
Update submodule
bwohlberg Jul 19, 2024
3da95c4
Merge branch 'main' into brendt/polartv
bwohlberg Jul 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add non-Cartesian gradient linops
bwohlberg committed Sep 30, 2021
commit 86a7c674de5d03ddb1c94ebec6301d918ab24e3a
2 changes: 1 addition & 1 deletion docs/docs_requirements.txt
Original file line number Diff line number Diff line change
@@ -6,6 +6,6 @@ sphinx-autodoc-typehints
faculty-sphinx-theme
nbsphinx
py2jn
pygraphviz>=1.7
pygraphviz>=2
pandoc
docutils==0.16
7 changes: 7 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -79,6 +79,7 @@ def patched_parse(self):
"sphinx.ext.viewcode",
"sphinxcontrib.bibtex",
"sphinx.ext.inheritance_diagram",
"matplotlib.sphinxext.plot_directive",
"sphinx.ext.mathjax",
"sphinx.ext.todo",
"nbsphinx",
@@ -124,6 +125,7 @@ def patched_parse(self):
"mb": [r"\mathbf{#1}", 1],
"mbs": [r"\boldsymbol{#1}", 1],
"mbb": [r"\mathbb{#1}", 1],
"mrm": [r"\mathrm{#1}", 1],
"norm": [r"\lVert #1 \rVert", 1],
"abs": [r"\left| #1 \right|", 1],
"argmin": [r"\mathop{\mathrm{argmin}}"],
@@ -276,6 +278,11 @@ def patched_parse(self):
fillcolor='"#f4f4ffff"',
)

plot_include_source = False
plot_html_show_source_link = False
plot_formats = ["svg"]
plot_html_show_formats = False


# -- Options for manual page output ---------------------------------------

48 changes: 48 additions & 0 deletions docs/source/figures/cylindgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np

import scico.linop as scl
from scico import plot

input_shape = (7, 7, 7)
centre = (np.array(input_shape) - 1) / 2
end = np.array(input_shape) - centre
g0, g1, g2 = np.mgrid[-centre[0] : end[0], -centre[1] : end[1], -centre[2] : end[2]]

cg = scl.CylindricalGradient(input_shape=input_shape)

ang = cg.coord[0]
rad = cg.coord[1]
axi = cg.coord[2]

theta = np.arctan2(g0, g1)
clr = theta
# See https://stackoverflow.com/a/49888126
clr = (clr.ravel() - clr.min()) / clr.ptp()
clr = np.concatenate((clr, np.repeat(clr, 2)))
clr = plot.plt.cm.plasma(clr)

plot.plt.rcParams["savefig.transparent"] = True

fig = plot.plt.figure(figsize=(20, 6))
ax = fig.add_subplot(1, 3, 1, projection="3d")
ax.quiver(g0, g1, g2, ang[0], ang[1], ang[2], colors=clr, length=0.9)
ax.set_title("Angular local coordinate axis")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
ax = fig.add_subplot(1, 3, 2, projection="3d")
ax.quiver(g0, g1, g2, rad[0], rad[1], rad[2], colors=clr, length=0.9)
ax.set_title("Radial local coordinate axis")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
ax = fig.add_subplot(1, 3, 3, projection="3d")
ax.quiver(g0, g1, g2, axi[0], axi[1], axi[2], colors=clr[0], length=0.9)
ax.set_title("Axial local coordinate axis")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
fig.tight_layout()
plot.plt.show()

input()
35 changes: 35 additions & 0 deletions docs/source/figures/polargrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np

import scico.linop as scl
from scico import plot

input_shape = (21, 21)
centre = (np.array(input_shape) - 1) / 2
end = np.array(input_shape) - centre
g0, g1 = np.mgrid[-centre[0] : end[0], -centre[1] : end[1]]

pg = scl.PolarGradient(input_shape=input_shape)

ang = pg.coord[0]
rad = pg.coord[1]

clr = (np.arctan2(ang[1], ang[0]) + np.pi) / (2 * np.pi)

plot.plt.rcParams["image.cmap"] = "plasma"
plot.plt.rcParams["savefig.transparent"] = True

fig, ax = plot.plt.subplots(nrows=1, ncols=2, figsize=(13, 6))
ax[0].quiver(g0, g1, ang[0], ang[1], clr)
ax[0].set_title("Angular local coordinate axis")
ax[0].set_xlabel("$x$")
ax[0].set_ylabel("$y$")
ax[0].xaxis.set_ticks((-10, -5, 0, 5, 10))
ax[0].yaxis.set_ticks((-10, -5, 0, 5, 10))
ax[1].quiver(g0, g1, rad[0], rad[1], clr)
ax[1].set_title("Radial local coordinate axis")
ax[1].set_xlabel("$x$")
ax[1].set_ylabel("$y$")
ax[1].xaxis.set_ticks((-10, -5, 0, 5, 10))
ax[1].yaxis.set_ticks((-10, -5, 0, 5, 10))
fig.tight_layout()
plot.plt.show()
3 changes: 3 additions & 0 deletions docs/source/figures/projgrad.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 47 additions & 0 deletions docs/source/figures/spheregrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np

import scico.linop as scl
from scico import plot

input_shape = (7, 7, 7)
centre = (np.array(input_shape) - 1) / 2
end = np.array(input_shape) - centre
g0, g1, g2 = np.mgrid[-centre[0] : end[0], -centre[1] : end[1], -centre[2] : end[2]]

sg = scl.SphericalGradient(input_shape=input_shape)

azi = sg.coord[0]
pol = sg.coord[1]
rad = sg.coord[2]

theta = np.arctan2(g0, g1)
phi = np.arctan2(np.sqrt(g0 ** 2 + g1 ** 2), g2)
clr = theta * phi
# See https://stackoverflow.com/a/49888126
clr = (clr.ravel() - clr.min()) / clr.ptp()
clr = np.concatenate((clr, np.repeat(clr, 2)))
clr = plot.plt.cm.plasma(clr)

plot.plt.rcParams["savefig.transparent"] = True

fig = plot.plt.figure(figsize=(20, 6))
ax = fig.add_subplot(1, 3, 1, projection="3d")
ax.quiver(g0, g1, g2, azi[0], azi[1], azi[2], colors=clr, length=0.9)
ax.set_title("Azimuthal local coordinate axis")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
ax = fig.add_subplot(1, 3, 2, projection="3d")
ax.quiver(g0, g1, g2, pol[0], pol[1], pol[2], colors=clr, length=0.9)
ax.set_title("Polar local coordinate axis")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
ax = fig.add_subplot(1, 3, 3, projection="3d")
ax.quiver(g0, g1, g2, rad[0], rad[1], rad[2], colors=clr, length=0.9)
ax.set_title("Radial local coordinate axis")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
fig.tight_layout()
plot.plt.show()
5 changes: 5 additions & 0 deletions scico/linop/__init__.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
from ._linop import Diagonal, Identity, power_iteration, Sum
from ._matrix import MatrixOperator
from ._diff import FiniteDifference
from ._grad import ProjectedGradient, PolarGradient, CylindricalGradient, SphericalGradient
from ._convolve import Convolve, ConvolveByX
from ._circconv import CircularConvolve
from ._dft import DFT
@@ -26,6 +27,10 @@
"Diagonal",
"MatrixOperator",
"FiniteDifference",
"ProjectedGradient",
"PolarGradient",
"CylindricalGradient",
"SphericalGradient",
"Convolve",
"CircularConvolve",
"DFT",
383 changes: 383 additions & 0 deletions scico/linop/_grad.py

Large diffs are not rendered by default.

220 changes: 220 additions & 0 deletions scico/test/linop/test_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from itertools import combinations

import numpy as np

import jax

import pytest
from jaxlib.xla_extension import DeviceArray

import scico.numpy as snp
from scico.blockarray import BlockArray
from scico.linop import CylindricalGradient, PolarGradient, SphericalGradient
from scico.random import randn


class TestPolarGradient:
def setup_method(self, method):
self.key = jax.random.PRNGKey(12345)

@pytest.mark.parametrize("jit", [True, False])
@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize("outflags", [(True, True), (True, False), (False, True)])
@pytest.mark.parametrize("center", [None, (-2, 3), (1.2, -3.5)])
@pytest.mark.parametrize(
"shape_axes",
[
((20, 20), None),
((20, 20), (0, 1)),
((17, 18), None),
((17, 17), None),
((16, 17, 3), (0, 1)),
((2, 17, 16), (1, 2)),
],
)
def test_eval(self, shape_axes, center, outflags, input_dtype, jit):

input_shape, axes = shape_axes
if axes is None:
testaxes = (0, 1)
else:
testaxes = axes
if center is not None:
axes_shape = [input_shape[ax] for ax in testaxes]
center = (snp.array(axes_shape) - 1) / 2 + snp.array(center)
angular, radial = outflags
x, key = randn(input_shape, dtype=input_dtype, key=self.key)
A = PolarGradient(
input_shape,
axes=axes,
center=center,
angular=angular,
radial=radial,
input_dtype=input_dtype,
jit=jit,
)
Ax = A @ x
if angular and radial:
assert type(Ax) is BlockArray
assert len(Ax.shape) == 2
assert Ax[0].shape == input_shape
assert Ax[1].shape == input_shape
else:
assert type(Ax) is DeviceArray
assert Ax.shape == input_shape
assert Ax.dtype == input_dtype

# Test orthogonality of coordinate axes
coord = A.coord
for n0, n1 in combinations(range(len(coord)), 2):
c0 = coord[n0]
c1 = coord[n1]
assert snp.abs(c0 @ c1) < 1e-5


class TestCylindricalGradient:
def setup_method(self, method):
self.key = jax.random.PRNGKey(12345)

@pytest.mark.parametrize("jit", [True, False])
@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize(
"outflags",
[
(True, True, True),
(True, True, False),
(True, False, True),
(True, False, False),
(False, True, True),
(False, True, False),
(False, False, True),
],
)
@pytest.mark.parametrize("center", [None, (-2, 3, 0), (1.2, -3.5, 1.5)])
@pytest.mark.parametrize(
"shape_axes",
[
((20, 20, 20), None),
((20, 20, 21), (0, 1, 2)),
((17, 18, 19), None),
((17, 17, 18), None),
((16, 17, 18, 3), (0, 1, 2)),
((2, 17, 16, 15), (1, 2, 3)),
((17, 2, 16, 15), (0, 2, 3)),
],
)
def test_eval(self, shape_axes, center, outflags, input_dtype, jit):

input_shape, axes = shape_axes
if axes is None:
testaxes = (0, 1, 2)
else:
testaxes = axes
if center is not None:
axes_shape = [input_shape[ax] for ax in testaxes]
center = (snp.array(axes_shape) - 1) / 2 + snp.array(center)
angular, radial, axial = outflags
x, key = randn(input_shape, dtype=input_dtype, key=self.key)
A = CylindricalGradient(
input_shape,
axes=axes,
center=center,
angular=angular,
radial=radial,
axial=axial,
input_dtype=input_dtype,
jit=jit,
)
Ax = A @ x
Nc = sum([angular, radial, axial])
if Nc > 1:
assert type(Ax) is BlockArray
assert Ax.num_blocks == Nc
for n in range(Nc):
assert Ax[n].shape == input_shape
else:
assert type(Ax) is DeviceArray
assert Ax.shape == input_shape
assert Ax.dtype == input_dtype

# Test orthogonality of coordinate axes
coord = A.coord
for n0, n1 in combinations(range(len(coord)), 2):
c0 = coord[n0]
c1 = coord[n1]
s = sum([c0[m] * c1[m] for m in range(c0.num_blocks)]).sum()
assert snp.abs(s) < 1e-5


class TestSphericalGradient:
def setup_method(self, method):
self.key = jax.random.PRNGKey(12345)

@pytest.mark.parametrize("jit", [True, False])
@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize(
"outflags",
[
(True, True, True),
(True, True, False),
(True, False, True),
(True, False, False),
(False, True, True),
(False, True, False),
(False, False, True),
],
)
@pytest.mark.parametrize("center", [None, (-2, 3, 0), (1.2, -3.5, 1.5)])
@pytest.mark.parametrize(
"shape_axes",
[
((20, 20, 20), None),
((20, 20, 21), (0, 1, 2)),
((17, 18, 19), None),
((17, 17, 18), None),
((16, 17, 18, 3), (0, 1, 2)),
((2, 17, 16, 15), (1, 2, 3)),
((17, 2, 16, 15), (0, 2, 3)),
],
)
def test_eval(self, shape_axes, center, outflags, input_dtype, jit):

input_shape, axes = shape_axes
if axes is None:
testaxes = (0, 1, 2)
else:
testaxes = axes
if center is not None:
axes_shape = [input_shape[ax] for ax in testaxes]
center = (snp.array(axes_shape) - 1) / 2 + snp.array(center)
azimuthal, polar, radial = outflags
x, key = randn(input_shape, dtype=input_dtype, key=self.key)
A = SphericalGradient(
input_shape,
axes=axes,
center=center,
azimuthal=azimuthal,
polar=polar,
radial=radial,
input_dtype=input_dtype,
jit=jit,
)
Ax = A @ x
Nc = sum([azimuthal, polar, radial])
if Nc > 1:
assert type(Ax) is BlockArray
assert Ax.num_blocks == Nc
for n in range(Nc):
assert Ax[n].shape == input_shape
else:
assert type(Ax) is DeviceArray
assert Ax.shape == input_shape
assert Ax.dtype == input_dtype

# Test orthogonality of coordinate axes
coord = A.coord
for n0, n1 in combinations(range(len(coord)), 2):
c0 = coord[n0]
c1 = coord[n1]
s = sum([c0[m] * c1[m] for m in range(c0.num_blocks)]).sum()
assert snp.abs(s) < 1e-5