Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Add export for MultiscaleImage to SpatialData #3355

Merged
merged 4 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions apis/python/requirements_spatial.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
geopandas
tifffile
pillow
spatialdata
xarray>=2024.05.0
spatialdata>=0.2.5
xarray
dask
20 changes: 20 additions & 0 deletions apis/python/src/tiledbsoma/experimental/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,26 @@
return int(value)


def _version_less_than(version: str, upper_bound: Tuple[int, int, int]) -> bool:
split_version = version.split(".")
try:
major = _str_to_int(split_version[0])
minor = _str_to_int(split_version[1])
patch = _str_to_int(split_version[2])
except ValueError as err:
raise ValueError(f"Unable to parse version {version}.") from err

Check warning on line 26 in apis/python/src/tiledbsoma/experimental/_util.py

View check run for this annotation

Codecov / codecov/patch

apis/python/src/tiledbsoma/experimental/_util.py#L25-L26

Added lines #L25 - L26 were not covered by tests
print(f"Actual: {(major, minor, patch)} and Compare: {upper_bound}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(f"Actual: {(major, minor, patch)} and Compare: {upper_bound}")

return (
major < upper_bound[0]
or (major == upper_bound[0] and minor < upper_bound[1])
or (
major == upper_bound[0]
and minor == upper_bound[1]
and patch < upper_bound[2]
)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as a heads up, tiledbsoma has packaging as a transient dependency. It comes with a robust version parser + even a version type: https://packaging.pypa.io/en/latest/version.html

It's mostly what we use for this kind of thing in scverse packages.



def _read_visium_software_version(
gene_expression_path: Union[str, Path]
) -> Tuple[int, int, int]:
Expand Down
37 changes: 33 additions & 4 deletions apis/python/src/tiledbsoma/experimental/_xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,25 @@
#
# Licensed under the MIT License.
import json
from typing import Any, Mapping, Optional, Tuple, Union
import warnings
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

import dask.array as da
import numpy as np
from xarray import DataArray

from ._util import _version_less_than

try:
import spatialdata as sd
from spatialdata.models.models import DataTree
except ImportError as err:
warnings.warn("Experimental spatial exporter requires the spatialdatda package.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
warnings.warn("Experimental spatial exporter requires the spatialdatda package.")
warnings.warn("Experimental spatial exporter requires the spatialdata package.")

raise err

Check warning on line 19 in apis/python/src/tiledbsoma/experimental/_xarray_backend.py

View check run for this annotation

Codecov / codecov/patch

apis/python/src/tiledbsoma/experimental/_xarray_backend.py#L17-L19

Added lines #L17 - L19 were not covered by tests
try:
import xarray as xr
except ImportError as err:
warnings.warn("Experimental spatial exporter requires the spatialdatda package.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
warnings.warn("Experimental spatial exporter requires the spatialdatda package.")
warnings.warn("Experimental spatial exporter requires the spatialdata package.")

raise err

Check warning on line 24 in apis/python/src/tiledbsoma/experimental/_xarray_backend.py

View check run for this annotation

Codecov / codecov/patch

apis/python/src/tiledbsoma/experimental/_xarray_backend.py#L22-L24

Added lines #L22 - L24 were not covered by tests

from .. import DenseNDArray
from ..options._soma_tiledb_context import SOMATileDBContext
Expand Down Expand Up @@ -112,7 +126,7 @@
chunks: Optional[Tuple[int, ...]] = None,
attrs: Optional[Mapping[str, Any]] = None,
context: Optional[SOMATileDBContext] = None,
) -> DataArray:
) -> xr.DataArray:
"""Create a :class:`xarray.DataArray` that accesses a SOMA :class:`DenseNDarray`
through dask.

Expand All @@ -139,4 +153,19 @@
fancy=False,
)

return DataArray(data, dims=dim_names, attrs=attrs)
return xr.DataArray(data, dims=dim_names, attrs=attrs)


def images_to_datatree(image_data_arrays: Sequence[xr.DataArray]) -> DataTree:
# If SpatialData version < 0.2.6 use the legacy xarray_datatree implementation
# of the DataTree.
if _version_less_than(sd.__version__, (0, 2, 5)):
return DataTree.from_dict(

Check warning on line 163 in apis/python/src/tiledbsoma/experimental/_xarray_backend.py

View check run for this annotation

Codecov / codecov/patch

apis/python/src/tiledbsoma/experimental/_xarray_backend.py#L163

Added line #L163 was not covered by tests
{f"scale{index}": image for index, image in enumerate(image_data_arrays)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add the "scale" prefix?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
return DataTree.from_dict(
{
f"scale{index}": xr.Dataset({"image": image})
for index, image in enumerate(image_data_arrays)
}
)
99 changes: 93 additions & 6 deletions apis/python/src/tiledbsoma/experimental/outgest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,29 @@
# Copyright (c) 2024 TileDB, Inc
#
# Licensed under the MIT License.
from typing import Dict, Optional, Tuple, Union
import warnings
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union

import geopandas as gpd
try:
import geopandas as gpd
except ImportError as err:
warnings.warn("Experimental spatial outgestor requires the geopandas package.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might make sense to import spatial data first, since geopandas is a hard dep there.

Then there isn't the UX of import, hit missing dep error, install geopandas, import, hit missing spatialdata

raise err

Check warning on line 12 in apis/python/src/tiledbsoma/experimental/outgest.py

View check run for this annotation

Codecov / codecov/patch

apis/python/src/tiledbsoma/experimental/outgest.py#L10-L12

Added lines #L10 - L12 were not covered by tests
import pandas as pd
import somacore
import spatialdata as sd
import xarray as xr

try:
import spatialdata as sd
except ImportError as err:
warnings.warn("Experimental spatial outgestor requires the spatialdata package.")
raise err

Check warning on line 20 in apis/python/src/tiledbsoma/experimental/outgest.py

View check run for this annotation

Codecov / codecov/patch

apis/python/src/tiledbsoma/experimental/outgest.py#L18-L20

Added lines #L18 - L20 were not covered by tests

from .. import MultiscaleImage, PointCloudDataFrame
from .._constants import SOMA_JOINID
from ._xarray_backend import dense_nd_array_to_data_array
from ._xarray_backend import dense_nd_array_to_data_array, images_to_datatree

if TYPE_CHECKING:
from spatialdata.models.models import DataArray, DataTree

Check warning on line 27 in apis/python/src/tiledbsoma/experimental/outgest.py

View check run for this annotation

Codecov / codecov/patch

apis/python/src/tiledbsoma/experimental/outgest.py#L27

Added line #L27 was not covered by tests


def _convert_axis_names(
Expand Down Expand Up @@ -195,7 +207,7 @@
scene_id: str,
scene_dim_map: Dict[str, str],
transform: somacore.CoordinateTransform,
) -> xr.DataArray:
) -> "DataArray":
"""Export a level of a :class:`MultiscaleImage` to a
:class:`spatialdata.Image2DModel` or :class:`spatialdata.Image3DModel`.
"""
Expand Down Expand Up @@ -256,3 +268,78 @@
attrs={"transform": transformations},
context=image.context,
)


def to_spatial_data_multiscale_image(
image: MultiscaleImage,
*,
scene_id: str,
scene_dim_map: Dict[str, str],
transform: somacore.CoordinateTransform,
) -> "DataTree":
"""Export a MultiscaleImage to a DataTree."""

# Check for channel axis.
if not image.has_channel_axis:
raise NotImplementedError(

Check warning on line 284 in apis/python/src/tiledbsoma/experimental/outgest.py

View check run for this annotation

Codecov / codecov/patch

apis/python/src/tiledbsoma/experimental/outgest.py#L284

Added line #L284 was not covered by tests
"Support for exporting a MultiscaleImage to without a channel axis to "
"SpatialData is not yet implemented."
)

# Convert from SOMA axis names to SpatialData axis names.
orig_axis_names = image.coordinate_space.axis_names
if len(orig_axis_names) not in {2, 3}:
raise NotImplementedError(

Check warning on line 292 in apis/python/src/tiledbsoma/experimental/outgest.py

View check run for this annotation

Codecov / codecov/patch

apis/python/src/tiledbsoma/experimental/outgest.py#L292

Added line #L292 was not covered by tests
f"Support for converting a '{len(orig_axis_names)}'D is not yet implemented."
)
new_axis_names, image_dim_map = _convert_axis_names(
orig_axis_names, image.data_axis_order
)

# Get the transformtion from the image level to the scene:
# If the result is a single scale transform (or identity transform), output a
# single transformation. Otherwise, convert to a SpatialData sequence of
# transformations.
inv_transform = transform.inverse_transform()
if isinstance(transform, somacore.ScaleTransform):
# inv_transform @ scale_transform -> applies scale_transform first
spatial_data_transformations = tuple(
_transform_to_spatial_data(
inv_transform @ image.get_transform_from_level(level),
image_dim_map,
scene_dim_map,
)
for level in range(image.level_count)
)

else:
sd_scale_transforms = tuple(
_transform_to_spatial_data(
image.get_transform_from_level(level), image_dim_map, image_dim_map
)
for level in range(1, image.level_count)
)
sd_inv_transform = _transform_to_spatial_data(
inv_transform, image_dim_map, scene_dim_map
)

# First level transform is always the identity, so just directly use
# inv_transform. For remaining transformations,
# Sequence([sd_transform1, sd_transform2]) -> applies sd_transform1 first
spatial_data_transformations = (sd_inv_transform,) + tuple(
sd.transformations.Sequence([scale_transform, sd_inv_transform])
for scale_transform in sd_scale_transforms
)

# Create a sequence of resolution level.
image_data_arrays = tuple(
dense_nd_array_to_data_array(
uri=image.level_uri(index),
dim_names=new_axis_names,
attrs={"transform": {scene_id: spatial_data_transformations[index]}},
context=image.context,
)
for index, (soma_name, val) in enumerate(image.levels().items())
)

return images_to_datatree(image_data_arrays)
89 changes: 87 additions & 2 deletions apis/python/tests/test_export_multiscale_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

soma_outgest = pytest.importorskip("tiledbsoma.experimental.outgest")
sd = pytest.importorskip("spatialdata")
xr = pytest.importorskip("xarray")


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -119,7 +118,7 @@ def test_export_image_level_to_spatial_data(
transform=transform,
)

assert isinstance(image2d, xr.DataArray)
assert isinstance(image2d, sd.models.models.DataArray)

# Validate the model.
schema = sd.models.get_model(image2d)
Expand All @@ -133,3 +132,89 @@ def test_export_image_level_to_spatial_data(
metadata = dict(image2d.attrs)
assert len(metadata) == 1
assert metadata["transform"] == {"scene0": expected_transformation}


@pytest.mark.parametrize(
"transform,expected_transformation",
[
(
somacore.IdentityTransform(("x_scene", "y_scene"), ("x_image", "y_image")),
[
sd.transformations.Identity(),
sd.transformations.Scale([2, 2], ("x", "y")),
sd.transformations.Scale([4, 4], ("x", "y")),
],
),
(
somacore.ScaleTransform(
("x_scene", "y_scene"), ("x_image", "y_image"), [0.25, 0.5]
),
[
sd.transformations.Scale([4, 2], ("x", "y")),
sd.transformations.Scale([8, 4], ("x", "y")),
sd.transformations.Scale([16, 8], ("x", "y")),
],
),
(
somacore.AffineTransform(
("x_scene", "y_scene"), ("x_image", "y_image"), [[1, 0, 1], [0, 1, 2]]
),
[
sd.transformations.Affine(
np.array([[1, 0, -1], [0, 1, -2], [0, 0, 1]]),
("x", "y"),
("x", "y"),
),
sd.transformations.Sequence(
[
sd.transformations.Scale([2, 2], ("x", "y")),
sd.transformations.Affine(
np.array([[1, 0, -1], [0, 1, -2], [0, 0, 1]]),
("x", "y"),
("x", "y"),
),
]
),
sd.transformations.Sequence(
[
sd.transformations.Scale([4, 4], ("x", "y")),
sd.transformations.Affine(
np.array([[1, 0, -1], [0, 1, -2], [0, 0, 1]]),
("x", "y"),
("x", "y"),
),
]
),
],
),
],
)
def test_export_full_image_to_spatial_data(
sample_multiscale_image_2d, sample_2d_data, transform, expected_transformation
):
image2d = soma_outgest.to_spatial_data_multiscale_image(
sample_multiscale_image_2d,
scene_id="scene0",
scene_dim_map={"x_scene": "x", "y_scene": "y"},
transform=transform,
)

assert isinstance(image2d, sd.models.models.DataTree)

# Validate the model.
schema = sd.models.get_model(image2d)
assert schema == sd.models.Image2DModel

# Check the correct data exists.
for index in range(3):
data_array = image2d[f"scale{index}"]["image"]
print(f"{index}: {data_array}")

# Check data.
result = data_array.data.compute()
np.testing.assert_equal(result, sample_2d_data[index])

# Check the metadata.
metadata = dict(data_array.attrs)
assert len(metadata) == 1
assert metadata["transform"] == {"scene0": expected_transformation[index]}