Skip to content

Commit

Permalink
Add AnnData as format (#2974)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored and Rob Newman committed Apr 6, 2023
1 parent 31c9193 commit 9700342
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 16 deletions.
78 changes: 65 additions & 13 deletions api/python/quilt3/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,15 @@
import copy
import csv
import gzip
import importlib
import io
import json
import sys
import tempfile
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from pathlib import Path

try:
from importlib import metadata as importlib_metadata
Expand Down Expand Up @@ -323,6 +326,7 @@ def all_supported_formats(cls):
Python Object Type Serialization Formats
<class 'pandas.core.frame.DataFrame'> [ssv, csv, tsv, parquet]
<class 'anndata.AnnData'> [.h5ad]
<class 'numpy.ndarray'> [npy, npz]
<class 'str'> [md, json, rst, txt]
<class 'dict'> [json]
Expand All @@ -333,19 +337,18 @@ def all_supported_formats(cls):
<class 'float'> [json]
<class 'bytes'> [bin]
"""
try:
import numpy as np
except ImportError:
pass
else:
cls.search(np.ndarray) # Force FormatHandlers to register np.ndarray as a supported object type

try:
import pandas as pd
except ImportError:
pass
else:
cls.search(pd.DataFrame) # Force FormatHandlers to register pd.DataFrame as a supported object type
# Force FormatHandlers to register these classes as supported object types
for mod_name, cls_name in [
('numpy', 'ndarray'),
('pandas', 'DataFrame'),
('anndata', 'AnnData'),
]:
try:
mod = importlib.import_module(mod_name)
except ImportError:
pass
else:
cls.search(getattr(mod, cls_name))

type_map = defaultdict(set)
for handler in cls.registered_handlers:
Expand Down Expand Up @@ -1033,6 +1036,55 @@ def deserialize(self, bytes_obj, meta=None, ext=None, **format_opts):
ParquetFormatHandler().register() # latest is preferred


class AnnDataFormatHandler(BaseFormatHandler):
"""Format for AnnData <--> .h5ad
Format Opts:
The following options may be used anywhere format opts are accepted,
or directly in metadata under `{'format': {'opts': {...: ...}}}`.
compression('gzip', 'lzf', None): applies during serialization only.
"""
name = 'h5ad'
handled_extensions = ['h5ad']
opts = ('compression',)
defaults = dict(
compression='lzf',
)

def handles_type(self, typ: type) -> bool:
# don't load module unless we actually have to use it.
if 'anndata' not in sys.modules:
return False
import anndata as ad
self.handled_types.add(ad.AnnData)
return super().handles_type(typ)

def serialize(self, obj, meta=None, ext=None, **format_opts):
opts = self.get_opts(meta, format_opts)
opts_with_defaults = copy.deepcopy(self.defaults)
opts_with_defaults.update(opts)

with tempfile.TemporaryDirectory() as td:
path = Path(td) / 'data.h5ad'
obj.write(path, **opts_with_defaults)
data = path.read_bytes()

return data, self._update_meta(meta, additions=opts_with_defaults)

def deserialize(self, bytes_obj, meta=None, ext=None, **format_opts):
try:
import anndata as ad
except ImportError:
raise QuiltException("Please install quilt3[anndata]")

buf = io.BytesIO(bytes_obj)
return ad.read_h5ad(buf)


AnnDataFormatHandler().register()


class CompressionRegistry:
"""A collection for organizing `CompressionHandler` objects."""
registered_handlers = []
Expand Down
5 changes: 2 additions & 3 deletions api/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ def run(self):
'pandas>=0.19.2',
'pyarrow>=0.14.1', # as of 7/5/19: linux/circleci bugs on 0.14.0
],
'anndata': ['anndata>=0.8.0'],
'tests': [
'numpy>=1.14.0', # required by pandas, but missing from its dependencies.
'pandas>=0.19.2',
'pyarrow>=0.14.1', # as of 7/5/19: linux/circleci bugs on 0.14.0
'quilt3[pyarrow,anndata]',
'pytest==6.*',
'pytest-cov',
'coverage==6.4',
Expand Down
Binary file added api/python/tests/data/test.h5ad
Binary file not shown.
32 changes: 32 additions & 0 deletions api/python/tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import numpy as np
import pandas as pd
import pytest
from anndata import AnnData

from quilt3.formats import FormatRegistry
from quilt3.util import QuiltException

# Constants
data_dir = pathlib.Path(__file__).parent / 'data'


# Code
Expand Down Expand Up @@ -146,6 +148,36 @@ def test_formats_csv_roundtrip():
assert df1.equals(df2)


def test_formats_anndata_roundtrip():
meta = {'format': {'name': 'h5ad'}}
ad_file = data_dir / 'test.h5ad'
ad: AnnData = FormatRegistry.deserialize(ad_file.read_bytes(), meta)
assert isinstance(ad, AnnData)

bin, format_meta = FormatRegistry.serialize(ad, meta)
meta2 = {**meta, **format_meta}
ad2: AnnData = FormatRegistry.deserialize(bin, meta2)
np.allclose(ad.X, ad2.X)
ad.obs.equals(ad2.obs)
ad.var.equals(ad2.var)


def test_all_supported_formats():
assert FormatRegistry.all_supported_formats() == {
AnnData: {'h5ad'},
pd.DataFrame: {'csv', 'parquet', 'ssv', 'tsv'},
np.ndarray: {'npy', 'npz'},
str: {'json', 'md', 'rst', 'txt'},
tuple: {'json'},
type(None): {'json'},
dict: {'json'},
int: {'json'},
list: {'json'},
float: {'json'},
bytes: {'bin'},
}


def test_formats_search_fail_notfound():
# a search that finds nothing should raise with an explanation.
class Foo:
Expand Down
4 changes: 4 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Entries inside each section should be ordered by type:
## Catalog, Lambdas
!-->
# unreleased - YYYY-MM-DD
## Python API
* [Added] Support [AnnData](https://anndata.readthedocs.io/en/latest/) format ([#2974](https://github.com/quiltdata/quilt/pull/2974))

# 5.2.1 - 2023-04-05
## Python API
* [Fixed] Fixed CSV serialization with pandas 2 ([#3395](https://github.com/quiltdata/quilt/pull/3395))
Expand Down

0 comments on commit 9700342

Please sign in to comment.