diff --git a/api/python/quilt3/formats.py b/api/python/quilt3/formats.py index 216cacf3702..7ca158b193f 100644 --- a/api/python/quilt3/formats.py +++ b/api/python/quilt3/formats.py @@ -69,12 +69,15 @@ import copy import csv import gzip +import importlib import io import json import sys +import tempfile import warnings from abc import ABC, abstractmethod from collections import defaultdict +from pathlib import Path try: from importlib import metadata as importlib_metadata @@ -323,6 +326,7 @@ def all_supported_formats(cls): Python Object Type Serialization Formats [ssv, csv, tsv, parquet] + [.h5ad] [npy, npz] [md, json, rst, txt] [json] @@ -333,19 +337,18 @@ def all_supported_formats(cls): [json] [bin] """ - try: - import numpy as np - except ImportError: - pass - else: - cls.search(np.ndarray) # Force FormatHandlers to register np.ndarray as a supported object type - - try: - import pandas as pd - except ImportError: - pass - else: - cls.search(pd.DataFrame) # Force FormatHandlers to register pd.DataFrame as a supported object type + # Force FormatHandlers to register these classes as supported object types + for mod_name, cls_name in [ + ('numpy', 'ndarray'), + ('pandas', 'DataFrame'), + ('anndata', 'AnnData'), + ]: + try: + mod = importlib.import_module(mod_name) + except ImportError: + pass + else: + cls.search(getattr(mod, cls_name)) type_map = defaultdict(set) for handler in cls.registered_handlers: @@ -1033,6 +1036,55 @@ def deserialize(self, bytes_obj, meta=None, ext=None, **format_opts): ParquetFormatHandler().register() # latest is preferred +class AnnDataFormatHandler(BaseFormatHandler): + """Format for AnnData <--> .h5ad + + Format Opts: + The following options may be used anywhere format opts are accepted, + or directly in metadata under `{'format': {'opts': {...: ...}}}`. + + compression('gzip', 'lzf', None): applies during serialization only. + """ + name = 'h5ad' + handled_extensions = ['h5ad'] + opts = ('compression',) + defaults = dict( + compression='lzf', + ) + + def handles_type(self, typ: type) -> bool: + # don't load module unless we actually have to use it. + if 'anndata' not in sys.modules: + return False + import anndata as ad + self.handled_types.add(ad.AnnData) + return super().handles_type(typ) + + def serialize(self, obj, meta=None, ext=None, **format_opts): + opts = self.get_opts(meta, format_opts) + opts_with_defaults = copy.deepcopy(self.defaults) + opts_with_defaults.update(opts) + + with tempfile.TemporaryDirectory() as td: + path = Path(td) / 'data.h5ad' + obj.write(path, **opts_with_defaults) + data = path.read_bytes() + + return data, self._update_meta(meta, additions=opts_with_defaults) + + def deserialize(self, bytes_obj, meta=None, ext=None, **format_opts): + try: + import anndata as ad + except ImportError: + raise QuiltException("Please install quilt3[anndata]") + + buf = io.BytesIO(bytes_obj) + return ad.read_h5ad(buf) + + +AnnDataFormatHandler().register() + + class CompressionRegistry: """A collection for organizing `CompressionHandler` objects.""" registered_handlers = [] diff --git a/api/python/setup.py b/api/python/setup.py index f314c6283a2..f19bad7c4e4 100644 --- a/api/python/setup.py +++ b/api/python/setup.py @@ -75,10 +75,9 @@ def run(self): 'pandas>=0.19.2', 'pyarrow>=0.14.1', # as of 7/5/19: linux/circleci bugs on 0.14.0 ], + 'anndata': ['anndata>=0.8.0'], 'tests': [ - 'numpy>=1.14.0', # required by pandas, but missing from its dependencies. - 'pandas>=0.19.2', - 'pyarrow>=0.14.1', # as of 7/5/19: linux/circleci bugs on 0.14.0 + 'quilt3[pyarrow,anndata]', 'pytest==6.*', 'pytest-cov', 'coverage==6.4', diff --git a/api/python/tests/data/test.h5ad b/api/python/tests/data/test.h5ad new file mode 100644 index 00000000000..c2acf144747 Binary files /dev/null and b/api/python/tests/data/test.h5ad differ diff --git a/api/python/tests/test_formats.py b/api/python/tests/test_formats.py index 60690f6d1db..fea47907c3a 100644 --- a/api/python/tests/test_formats.py +++ b/api/python/tests/test_formats.py @@ -3,11 +3,13 @@ import numpy as np import pandas as pd import pytest +from anndata import AnnData from quilt3.formats import FormatRegistry from quilt3.util import QuiltException # Constants +data_dir = pathlib.Path(__file__).parent / 'data' # Code @@ -146,6 +148,36 @@ def test_formats_csv_roundtrip(): assert df1.equals(df2) +def test_formats_anndata_roundtrip(): + meta = {'format': {'name': 'h5ad'}} + ad_file = data_dir / 'test.h5ad' + ad: AnnData = FormatRegistry.deserialize(ad_file.read_bytes(), meta) + assert isinstance(ad, AnnData) + + bin, format_meta = FormatRegistry.serialize(ad, meta) + meta2 = {**meta, **format_meta} + ad2: AnnData = FormatRegistry.deserialize(bin, meta2) + np.allclose(ad.X, ad2.X) + ad.obs.equals(ad2.obs) + ad.var.equals(ad2.var) + + +def test_all_supported_formats(): + assert FormatRegistry.all_supported_formats() == { + AnnData: {'h5ad'}, + pd.DataFrame: {'csv', 'parquet', 'ssv', 'tsv'}, + np.ndarray: {'npy', 'npz'}, + str: {'json', 'md', 'rst', 'txt'}, + tuple: {'json'}, + type(None): {'json'}, + dict: {'json'}, + int: {'json'}, + list: {'json'}, + float: {'json'}, + bytes: {'bin'}, + } + + def test_formats_search_fail_notfound(): # a search that finds nothing should raise with an explanation. class Foo: diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index c6f932e6cf9..35045cc0918 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -14,6 +14,10 @@ Entries inside each section should be ordered by type: ## Catalog, Lambdas !--> +# unreleased - YYYY-MM-DD +## Python API +* [Added] Support [AnnData](https://anndata.readthedocs.io/en/latest/) format ([#2974](https://github.com/quiltdata/quilt/pull/2974)) + # 5.2.1 - 2023-04-05 ## Python API * [Fixed] Fixed CSV serialization with pandas 2 ([#3395](https://github.com/quiltdata/quilt/pull/3395))