Skip to content

Commit

Permalink
Ensure that annotate_integrated uses the same matrix consistently.
Browse files Browse the repository at this point in the history
In other words, we need to make sure that it uses the same assay_type
and check_missing across both single and integrated classifications.
This is most easily done by running _clean_matrix externally on test/
ref datasets and then passing the tatami pointer into each of the
lower-level functions, which will then skip all NaN and assay checks.
  • Loading branch information
LTLA committed Sep 20, 2023
1 parent ec4258f commit f23d809
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 65 deletions.
110 changes: 70 additions & 40 deletions src/singler/annotate_integrated.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@
from .classify_single_reference import classify_single_reference
from .build_integrated_references import build_integrated_references
from .classify_integrated_references import classify_integrated_references
from .annotate_single import _build_reference
from .annotate_single import _resolve_reference, _attach_markers
from ._utils import _clean_matrix


def annotate_integrated(
test_data: Any,
test_features: Sequence,
ref_data: Sequence[Union[Any, str]],
ref_labels: Union[str, Sequence[Union[Sequence, str]]],
ref_features: Union[str, Sequence[Union[Sequence, str]]],
ref_data_list: Sequence[Union[Any, str]],
ref_labels_list: Union[str, Sequence[Union[Sequence, str]]],
ref_features_list: Union[str, Sequence[Union[Sequence, str]]],
test_assay_type: Union[str, int] = 0,
test_check_missing: bool = True,
ref_assay_type: Union[str, int] = "logcounts",
ref_check_missing: bool = True,
cache_dir: Optional[str] = None,
build_single_args: dict = {},
classify_single_args: dict = {},
Expand All @@ -40,7 +43,7 @@ def annotate_integrated(
test_features: Sequence of length equal to the number of rows in
``test_data``, containing the feature identifier for each row.
ref_data:
ref_data_list:
Sequence consisting of one or more of the following:
- A matrix-like object representing the reference dataset, where rows
Expand All @@ -54,36 +57,42 @@ def annotate_integrated(
:py:meth:`~singler.fetch_reference.fetch_github_reference`.
This will use the specified dataset as the reference.
ref_labels:
ref_labels_list:
Sequence of the same length as ``ref_data``, where the contents
depend on the type of value in the corresponding entry of ``ref_data``:
- If ``ref_data[i]`` is a matrix-like object, ``ref_labels[i]`` should be
a sequence of length equal to the number of columns of ``ref_data[i]``,
- If ``ref_data_list[i]`` is a matrix-like object, ``ref_labels_list[i]`` should be
a sequence of length equal to the number of columns of ``ref_data_list[i]``,
containing the label associated with each column.
- If ``ref_data[i]`` is a string, ``ref_labels[i]`` should be a string
- If ``ref_data_list[i]`` is a string, ``ref_labels_list[i]`` should be a string
specifying the label type to use, e.g., "main", "fine", "ont".
If a single string is supplied, it is recycled for all ``ref_data``.
ref_features:
Sequence of the same length as ``ref_data``, where the contents
ref_features_list:
Sequence of the same length as ``ref_data_list``, where the contents
depend on the type of value in the corresponding entry of ``ref_data``:
- If ``ref_data[i]`` is a matrix-like object, ``ref_features[i]`` should be
a sequence of length equal to the number of rows of ``ref_data``,
- If ``ref_data_list[i]`` is a matrix-like object, ``ref_features_list[i]`` should be
a sequence of length equal to the number of rows of ``ref_data_list[i]``,
containing the feature identifier associated with each row.
- If ``ref_data[i]`` is a string, ``ref_features[i]`` should be a string
- If ``ref_data_list[i]`` is a string, ``ref_features_list[i]`` should be a string
specifying the feature type to use, e.g., "ensembl", "symbol".
If a single string is supplied, it is recycled for all ``ref_data``.
test_assay_type:
Assay of ``test_data`` containing the expression matrix, if ``test_data`` is a
:py:class:`~summarizedexperiment.SummarizedExperiment.SummarizedExperiment`.
test_check_missing:
Whether to check for and remove missing (i.e., NaN) values from the test dataset.
ref_assay_type:
Assay containing the expression matrix for any entry of ``ref_data_list`` that is a
:py:class:`~summarizedexperiment.SummarizedExperiment.SummarizedExperiment`.
ref_check_missing:
Whether to check for and remove missing (i.e., NaN) values from the reference datasets.
cache_dir:
Path to a cache directory for downloading reference files, see
:py:meth:`~singler.fetch_reference.fetch_github_reference` for details.
Expand Down Expand Up @@ -116,15 +125,23 @@ def annotate_integrated(
(i.e., a BiocFrame from
:py:meth:`~singler.classify_integrated_references.classify_integrated_references`).
"""
nrefs = len(ref_data)
if isinstance(ref_labels, str):
ref_labels = [ref_labels] * nrefs
elif nrefs != len(ref_labels):
raise ValueError("'ref_data' and 'ref_labels' must be the same length")
if isinstance(ref_features, str):
ref_features = [ref_features] * nrefs
elif nrefs != len(ref_features):
raise ValueError("'ref_data' and 'ref_features' must be the same length")
nrefs = len(ref_data_list)
if isinstance(ref_labels_list, str):
ref_labels_list = [ref_labels_list] * nrefs
elif nrefs != len(ref_labels_list):
raise ValueError("'ref_data_list' and 'ref_labels_list' must be the same length")
if isinstance(ref_features_list, str):
ref_features_list = [ref_features_list] * nrefs
elif nrefs != len(ref_features_list):
raise ValueError("'ref_data_list' and 'ref_features_list' must be the same length")

test_ptr, test_features = _clean_matrix(
test_data,
test_features,
assay_type = test_assay_type,
check_missing = test_check_missing,
num_threads = num_threads,
)

all_ref_data = []
all_ref_labels = []
Expand All @@ -134,14 +151,30 @@ def annotate_integrated(
test_features_set = set(test_features)

for r in range(nrefs):
curref_data, curref_labels, curref_features, curbuilt = _build_reference(
ref_data=ref_data[r],
ref_labels=ref_labels[r],
ref_features=ref_features[r],
test_features_set=test_features_set,
curref_mat, curref_labels, curref_features, curref_markers = _resolve_reference(
ref_data=ref_data_list[r],
ref_labels=ref_labels_list[r],
ref_features=ref_features_list[r],
cache_dir=cache_dir,
build_args=build_single_args,
build_single_args=build_single_args,
test_features_set=test_features_set,
)

curref_ptr, curref_features = _clean_matrix(
curref_mat,
curref_features,
assay_type = ref_assay_type,
check_missing = ref_check_missing,
num_threads = num_threads,
)

bargs = _attach_markers(curref_markers, build_single_args)
curbuilt = build_single_reference(
ref_data=curref_ptr,
ref_labels=curref_labels,
ref_features=curref_features,
restrict_to=test_features_set,
**bargs,
num_threads=num_threads,
)

Expand All @@ -150,38 +183,35 @@ def annotate_integrated(
test_features=test_features,
ref_prebuilt=curbuilt,
**classify_single_args,
assay_type = test_assay_type,
num_threads=num_threads,
)

res.metadata = {
"markers": curbuilt.markers,
"unique_markers": curbuilt.marker_subset(),
}

all_ref_data.append(curref_data)
all_ref_data.append(curref_ptr)
all_ref_labels.append(curref_labels)
all_ref_features.append(curref_features)
all_built.append(curbuilt)
all_results.append(res)

res.metadata = {
"markers": curbuilt.markers,
"unique_markers": curbuilt.marker_subset(),
}

ibuilt = build_integrated_references(
test_features=test_features,
ref_data_list=all_ref_data,
ref_labels_list=all_ref_labels,
ref_features_list=all_ref_features,
ref_prebuilt_list=all_built,
assay_type = ref_assay_type,
num_threads=num_threads,
**build_integrated_args,
num_threads=num_threads,
)

ires = classify_integrated_references(
test_data=test_data,
test_data=test_ptr,
results=all_results,
integrated_prebuilt=ibuilt,
**classify_integrated_args,
assay_type = test_assay_type,
num_threads=num_threads,
)

Expand Down
50 changes: 28 additions & 22 deletions src/singler/annotate_single.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Union, Sequence, Optional, Any
from biocframe import BiocFrame
from copy import copy

from .fetch_reference import fetch_github_reference, realize_github_markers
from .build_single_reference import build_single_reference
from .classify_single_reference import classify_single_reference
from ._utils import _clean_matrix


def _build_reference(ref_data, ref_labels, ref_features, test_features_set, cache_dir, build_args, num_threads):
def _resolve_reference(ref_data, ref_labels, ref_features, cache_dir, build_args, test_features_set):
if isinstance(ref_data, str):
ref = fetch_github_reference(ref_data, cache_dir=cache_dir)
ref_features = ref.row_data.column(ref_features)
Expand All @@ -17,7 +19,7 @@ def _build_reference(ref_data, ref_labels, ref_features, test_features_set, cach
if "num_de" in marker_args:
num_de = marker_args["num_de"]

markers = realize_github_markers(
ref_markers = realize_github_markers(
ref.metadata[ref_labels],
ref_features,
num_markers=num_de,
Expand All @@ -26,26 +28,19 @@ def _build_reference(ref_data, ref_labels, ref_features, test_features_set, cach

ref_data = ref.assay("ranks")
ref_labels=ref.col_data.column(ref_labels)
built = build_single_reference(
ref_data=ref_data,
ref_labels=ref_labels,
ref_features=ref_features,
markers=markers,
num_threads=num_threads,
**build_args,
)

else:
built = build_single_reference(
ref_data=ref_data,
ref_labels=ref_labels,
ref_features=ref_features,
restrict_to=test_features_set,
num_threads=num_threads,
**build_args,
)
ref_markers = None

return ref_data, ref_labels, ref_features, ref_markers


return ref_data, ref_labels, ref_features, built
def _attach_markers(markers, build_args):
if markers is not None and "markers" not in build_args:
tmp = copy(build_args)
tmp["markers"] = markers
print(tmp)
return tmp
return build_args


def annotate_single(
Expand Down Expand Up @@ -123,13 +118,24 @@ def annotate_single(
specifying the markers that were used for each pairwise comparison
between labels; and a list of ``unique_markers`` across all labels.
"""
ref_data, ref_labels, ref_features, built = _build_reference(
test_features_set = set(test_features)

ref_data, ref_labels, ref_features, markers = _resolve_reference(
ref_data=ref_data,
ref_labels=ref_labels,
ref_features=ref_features,
test_features_set=set(test_features),
cache_dir=cache_dir,
build_args=build_args,
test_features_set=test_features_set,
)

bargs = _attach_markers(markers, build_args)
built = build_single_reference(
ref_data=ref_data,
ref_labels=ref_labels,
ref_features=ref_features,
restrict_to=test_features_set,
**bargs,
num_threads=num_threads,
)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_annotate_integrated.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def test_annotate_integrated():
single_results, integrated_results = singler.annotate_integrated(
test,
test_features=test_features,
ref_data=[ref1, ref2],
ref_labels=[labels1, labels2],
ref_features=[features1, features2],
ref_data_list=[ref1, ref2],
ref_labels_list=[labels1, labels2],
ref_features_list=[features1, features2],
)

assert len(single_results) == 2
Expand Down

0 comments on commit f23d809

Please sign in to comment.