diff --git a/src/singler/_utils.py b/src/singler/_utils.py index 7b397b1..603edd1 100644 --- a/src/singler/_utils.py +++ b/src/singler/_utils.py @@ -8,12 +8,6 @@ import warnings -def _factorize(x: Sequence) -> Tuple[list, numpy.ndarray]: - f = biocutils.Factor.from_sequence(x, sort_levels=False) - print(f) - return f.levels, numpy.array(f.codes, dtype=numpy.uint32) - - def _create_map(x: Sequence) -> dict: mapping = {} for i, val in enumerate(x): diff --git a/src/singler/annotate_integrated.py b/src/singler/annotate_integrated.py index e1de142..a574c00 100644 --- a/src/singler/annotate_integrated.py +++ b/src/singler/annotate_integrated.py @@ -1,6 +1,7 @@ from typing import Any, Optional, Sequence, Tuple, Union import biocframe +import warnings from ._utils import _clean_matrix, _restrict_features from .train_single import train_single @@ -132,6 +133,15 @@ def annotate_integrated( all_ref_labels = [] all_ref_features = [] for r in range(nrefs): + curref_labels = ref_labels[r] + if isinstance(curref_labels, str): + warnings.warn( + "setting 'ref_labels' to a column name of the column data is deprecated", + category=DeprecationWarning + ) + curref_labels = ref_data[r].get_column_data().column(curref_labels) + all_ref_labels.append(curref_labels) + curref_data, curref_features = _clean_matrix( ref_data[r], ref_features[r], @@ -161,7 +171,7 @@ def annotate_integrated( for r in range(nrefs): curbuilt = train_single( ref_data=all_ref_data[r], - ref_labels=ref_labels[r], + ref_labels=all_ref_labels[r], ref_features=all_ref_features[r], test_features=test_features, **train_single_args, diff --git a/src/singler/annotate_single.py b/src/singler/annotate_single.py index 11914c6..ae25c70 100644 --- a/src/singler/annotate_single.py +++ b/src/singler/annotate_single.py @@ -91,6 +91,13 @@ def annotate_single( A :py:class:`~biocframe.BiocFrame.BiocFrame` of labelling results, see :py:func:`~singler.classify_single.classify_single` for details. """ + if isinstance(ref_labels, str): + warnings.warn( + "setting 'ref_labels' to a column name of the column data is deprecated", + category=DeprecationWarning + ) + ref_labels = ref_data.get_column_data().column(ref_labels) + test_data, test_features = _clean_matrix( test_data, test_features, diff --git a/src/singler/train_single.py b/src/singler/train_single.py index 79dc958..1eaaaab 100644 --- a/src/singler/train_single.py +++ b/src/singler/train_single.py @@ -8,7 +8,7 @@ import warnings from . import lib_singler as lib -from ._utils import _clean_matrix, _factorize, _restrict_features, _stable_intersect +from ._utils import _clean_matrix, _restrict_features, _stable_intersect from .get_classic_markers import get_classic_markers @@ -176,6 +176,12 @@ def train_single( The pre-built reference, ready for use in downstream methods like :py:meth:`~singler.classify_single_reference.classify_single`. """ + if isinstance(ref_labels, str): + warnings.warn( + "setting 'ref_labels' to a column name of the column data is deprecated", + category=DeprecationWarning + ) + ref_labels = ref_data.get_column_data().column(ref_labels) ref_data, ref_features = _clean_matrix( ref_data, @@ -196,13 +202,13 @@ def train_single( ref_features = biocutils.subset_sequence(ref_features, keep) ref_data = delayedarray.DelayedArray(ref_data)[keep,:] - if isinstance(ref_labels, str): - warnings.warn( - "setting 'labels' to a column name of the column data is deprecated", - category=DeprecationWarning - ) - ref_labels = ref_data.get_column_data().column(ref_labels) - unique_labels, label_idx = _factorize(ref_labels) + for f in ref_labels: + if f is None: + raise ValueError("entries of 'ref_labels' cannot be missing") + if not isinstance(ref_labels, biocutils.Factor): # TODO: move over to biocutils so coercion can no-op. + ref_labels = biocutils.Factor.from_sequence(ref_labels, sort_levels=False) # TODO: add a dtype= option. + unique_labels = ref_labels.levels + label_idx = ref_labels.codes.astype(dtype=numpy.uint32, copy=False) markers = _identify_genes( ref_data=ref_data, diff --git a/tests/test_integrated_with_celldex.py b/tests/test_integrated_with_celldex.py index 8181d66..dc53c0c 100644 --- a/tests/test_integrated_with_celldex.py +++ b/tests/test_integrated_with_celldex.py @@ -16,19 +16,17 @@ def test_with_minimal_args(): ) immune_cell_ref = celldex.fetch_reference("dice", "2024-02-26", realize_assays=True) - with pytest.raises(Exception): - singler.annotate_integrated( - test_data=sce.assays["counts"], - ref_data_list=(blueprint_ref, immune_cell_ref), - ref_labels_list="label.main", - num_threads=6, - ) - single, integrated = singler.annotate_integrated( test_data=sce, - ref_data_list=(blueprint_ref, immune_cell_ref), - ref_labels_list="label.main", - num_threads=6, + ref_data=( + blueprint_ref, + immune_cell_ref + ), + ref_labels=[ + blueprint_ref.get_column_data().column("label.main"), + immune_cell_ref.get_column_data().column("label.main") + ], + num_threads=2 ) assert len(single) == 2 assert isinstance(integrated, BiocFrame) @@ -43,14 +41,21 @@ def test_with_all_supplied(): immune_cell_ref = celldex.fetch_reference("dice", "2024-02-26", realize_assays=True) single, integrated = singler.annotate_integrated( - test_data=sce, + test_data=sce.assays["counts"], test_features=sce.get_row_names(), - ref_data_list=(blueprint_ref, immune_cell_ref), - ref_labels_list=[ - x.get_column_data().column("label.main") - for x in (blueprint_ref, immune_cell_ref) + ref_data=( + blueprint_ref, + immune_cell_ref + ), + ref_labels=[ + blueprint_ref.get_column_data().column("label.main"), + immune_cell_ref.get_column_data().column("label.main") + ], + ref_features=[ + blueprint_ref.get_row_names(), + immune_cell_ref.get_row_names() ], - ref_features_list=[x.get_row_names() for x in (blueprint_ref, immune_cell_ref)], + num_threads=2 ) assert len(single) == 2 @@ -67,8 +72,9 @@ def test_with_colname(): single, integrated = singler.annotate_integrated( test_data=sce, - ref_data_list=(blueprint_ref, immune_cell_ref), - ref_labels_list="label.main", + ref_data=(blueprint_ref, immune_cell_ref), + ref_labels=["label.main"] * 2, + num_threads=2 ) assert len(single) == 2 diff --git a/tests/test_single_with_celldex.py b/tests/test_single_with_celldex.py index b5cd6db..eeef181 100644 --- a/tests/test_single_with_celldex.py +++ b/tests/test_single_with_celldex.py @@ -4,21 +4,12 @@ import scrnaseq import pandas as pd import scipy -import pytest from biocframe import BiocFrame def test_with_minimal_args(): sce = scrnaseq.fetch_dataset("zeisel-brain-2015", "2023-12-14", realize_assays=True) - immgen_ref = celldex.fetch_reference("immgen", "2024-02-26", realize_assays=True) - with pytest.raises(Exception): - matches = singler.annotate_single( - test_data=sce.assays["counts"], - ref_data=immgen_ref, - ref_labels=immgen_ref.get_column_data().column("label.main"), - ) - matches = singler.annotate_single( test_data=sce, ref_data=immgen_ref, @@ -32,11 +23,10 @@ def test_with_minimal_args(): def test_with_all_supplied(): sce = scrnaseq.fetch_dataset("zeisel-brain-2015", "2023-12-14", realize_assays=True) - immgen_ref = celldex.fetch_reference("immgen", "2024-02-26", realize_assays=True) matches = singler.annotate_single( - test_data=sce, + test_data=sce.assays["counts"], test_features=sce.get_row_names(), ref_data=immgen_ref, ref_labels=immgen_ref.get_column_data().column("label.main"), @@ -50,7 +40,6 @@ def test_with_all_supplied(): def test_with_colname(): sce = scrnaseq.fetch_dataset("zeisel-brain-2015", "2023-12-14", realize_assays=True) - immgen_ref = celldex.fetch_reference("immgen", "2024-02-26", realize_assays=True) matches = singler.annotate_single( diff --git a/tests/test_train_integrated.py b/tests/test_train_integrated.py index 064d8b4..7e2a581 100644 --- a/tests/test_train_integrated.py +++ b/tests/test_train_integrated.py @@ -30,7 +30,7 @@ def test_train_integrated(): test_features, ref_prebuilt=[built1, built2], ref_names=["FOO", "BAR"], - num_threads=3, + num_threads=2, ) assert pintegrated.reference_names == ["FOO", "BAR"] diff --git a/tests/test_utils.py b/tests/test_utils.py index e112230..4d525f0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,4 @@ from singler._utils import ( - _factorize, _stable_intersect, _stable_union, _clean_matrix, @@ -8,17 +7,6 @@ import summarizedexperiment -def test_factorize(): - lev, ind = _factorize([1, 3, 5, 5, 3, 1]) - assert list(lev) == ["1", "3", "5"] - assert (ind == [0, 1, 2, 2, 1, 0]).all() - - # Preserves the order. - lev, ind = _factorize(["C", "D", "A", "B", "C", "A"]) - assert list(lev) == ["C", "D", "A", "B"] - assert (ind == [0, 1, 2, 3, 0, 2]).all() - - def test_intersect(): # Preserves the order in the first argument. out = _stable_intersect(["B", "C", "A", "D", "E"], ["A", "C", "E"])