Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate axes for writer #123

Merged
merged 8 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 48 additions & 19 deletions ome_zarr/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,47 @@
LOGGER = logging.getLogger("ome_zarr.writer")


def validate_axes_names(
sbesson marked this conversation as resolved.
Show resolved Hide resolved
ndim: int, axes: Union[str, List[str]] = None, fmt: Format = CurrentFormat()
) -> Union[None, List[str]]:

if fmt.version not in ("0.1", "0.2"):
if axes is None:
if ndim == 2:
axes = ["y", "x"]
elif ndim == 5:
axes = ["t", "c", "z", "y", "x"]
else:
raise ValueError(
"axes must be provided. Can't be guessed for 3D or 4D data"
)

if isinstance(axes, str):
axes = list(axes)

if axes is not None:
if len(axes) != ndim:
raise ValueError("axes length must match number of dimensions")
# from https://github.com/constantinpape/ome-ngff-implementations/
val_axes = tuple(axes)
if ndim == 2:
assert val_axes == ("y", "x"), str(val_axes)
elif ndim == 3:
assert val_axes in [("z", "y", "x"), ("c", "y", "x"), ("t", "y", "x")], str(
val_axes
)
elif ndim == 4:
assert val_axes in [
("t", "z", "y", "x"),
("c", "z", "y", "x"),
("t", "c", "y", "x"),
], str(val_axes)
else:
assert val_axes == ("t", "c", "z", "y", "x"), str(val_axes)

return axes


def write_multiscale(
pyramid: List,
group: zarr.Group,
Expand All @@ -28,6 +69,8 @@ def write_multiscale(
----------
pyramid: List of np.ndarray
the image data to save. Largest level first
All image arrays MUST be up to 5-dimensional with dimensions
ordered (t, c, z, y, x)
group: zarr.Group
the group within the zarr store to store the data in
chunks: int or tuple of ints,
Expand All @@ -41,25 +84,7 @@ def write_multiscale(
"""

dims = len(pyramid[0].shape)
if fmt.version not in ("0.1", "0.2"):
if axes is None:
if dims == 2:
axes = ["y", "x"]
elif dims == 5:
axes = ["t", "c", "z", "y", "x"]
else:
raise ValueError(
"axes must be provided. Can't be guessed for 3D or 4D data"
)
if len(axes) != dims:
raise ValueError("axes length must match number of dimensions")

if isinstance(axes, str):
axes = list(axes)

for dim in axes:
if dim not in ("t", "c", "z", "y", "x"):
raise ValueError("axes must each be one of 'x', 'y', 'z', 'c' or 't'")
axes = validate_axes_names(dims, axes, fmt)

paths = []
for path, dataset in enumerate(pyramid):
Expand Down Expand Up @@ -90,6 +115,8 @@ def write_image(
image: np.ndarray
the image data to save. A downsampling of the data will be computed
if the scaler argument is non-None.
Image array MUST be up to 5-dimensional with dimensions
ordered (t, c, z, y, x)
group: zarr.Group
the group within the zarr store to store the data in
chunks: int or tuple of ints,
Expand All @@ -115,6 +142,8 @@ def write_image(
# v0.1 and v0.2 are strictly 5D
shape_5d: Tuple[Any, ...] = (*(1,) * (5 - image.ndim), *image.shape)
image = image.reshape(shape_5d)
# and we don't need axes
axes = None

if chunks is not None:
chunks = _retuple(chunks, image.shape)
Expand Down
33 changes: 32 additions & 1 deletion tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ome_zarr.io import parse_url
from ome_zarr.reader import Multiscales, Reader
from ome_zarr.scale import Scaler
from ome_zarr.writer import write_image
from ome_zarr.writer import validate_axes_names, write_image


class TestWriter:
Expand Down Expand Up @@ -77,3 +77,34 @@ def test_writer(self, shape, scaler, format_version):
else:
assert node.data[0].ndim == 5
assert np.allclose(data, node.data[0][...].compute())

def test_dim_names(self):
will-moore marked this conversation as resolved.
Show resolved Hide resolved

v03 = FormatV03()

# v0.3 MUST specify axes for 3D or 4D data
with pytest.raises(ValueError):
validate_axes_names(3, axes=None, fmt=v03)

# ndims must match axes length
with pytest.raises(ValueError):
validate_axes_names(3, axes="yx", fmt=v03)

# axes must be ordered tczyx
with pytest.raises(AssertionError):
validate_axes_names(3, axes="yxt", fmt=v03)
with pytest.raises(AssertionError):
validate_axes_names(2, axes=["x", "y"], fmt=v03)

validate_axes_names(2, axes=["y", "x"], fmt=v03)
sbesson marked this conversation as resolved.
Show resolved Hide resolved
validate_axes_names(5, axes="tczyx", fmt=v03)
sbesson marked this conversation as resolved.
Show resolved Hide resolved

# check that write_image is checking axes
data = self.create_data((125, 125))
with pytest.raises(ValueError):
write_image(
image=data,
group=self.group,
fmt=v03,
axes="xyz",
)