Skip to content

Commit

Permalink
Got more tests to pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
LTLA committed Dec 13, 2024
1 parent bdbb9a8 commit dec4971
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 59 deletions.
6 changes: 0 additions & 6 deletions src/singler/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@
import warnings


def _factorize(x: Sequence) -> Tuple[list, numpy.ndarray]:
f = biocutils.Factor.from_sequence(x, sort_levels=False)
print(f)
return f.levels, numpy.array(f.codes, dtype=numpy.uint32)


def _create_map(x: Sequence) -> dict:
mapping = {}
for i, val in enumerate(x):
Expand Down
12 changes: 11 additions & 1 deletion src/singler/annotate_integrated.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Optional, Sequence, Tuple, Union

import biocframe
import warnings

from ._utils import _clean_matrix, _restrict_features
from .train_single import train_single
Expand Down Expand Up @@ -132,6 +133,15 @@ def annotate_integrated(
all_ref_labels = []
all_ref_features = []
for r in range(nrefs):
curref_labels = ref_labels[r]
if isinstance(curref_labels, str):
warnings.warn(
"setting 'ref_labels' to a column name of the column data is deprecated",
category=DeprecationWarning
)
curref_labels = ref_data[r].get_column_data().column(curref_labels)
all_ref_labels.append(curref_labels)

curref_data, curref_features = _clean_matrix(
ref_data[r],
ref_features[r],
Expand Down Expand Up @@ -161,7 +171,7 @@ def annotate_integrated(
for r in range(nrefs):
curbuilt = train_single(
ref_data=all_ref_data[r],
ref_labels=ref_labels[r],
ref_labels=all_ref_labels[r],
ref_features=all_ref_features[r],
test_features=test_features,
**train_single_args,
Expand Down
7 changes: 7 additions & 0 deletions src/singler/annotate_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def annotate_single(
A :py:class:`~biocframe.BiocFrame.BiocFrame` of labelling results, see
:py:func:`~singler.classify_single.classify_single` for details.
"""
if isinstance(ref_labels, str):
warnings.warn(
"setting 'ref_labels' to a column name of the column data is deprecated",
category=DeprecationWarning
)
ref_labels = ref_data.get_column_data().column(ref_labels)

test_data, test_features = _clean_matrix(
test_data,
test_features,
Expand Down
22 changes: 14 additions & 8 deletions src/singler/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings

from . import lib_singler as lib
from ._utils import _clean_matrix, _factorize, _restrict_features, _stable_intersect
from ._utils import _clean_matrix, _restrict_features, _stable_intersect
from .get_classic_markers import get_classic_markers


Expand Down Expand Up @@ -176,6 +176,12 @@ def train_single(
The pre-built reference, ready for use in downstream methods like
:py:meth:`~singler.classify_single_reference.classify_single`.
"""
if isinstance(ref_labels, str):
warnings.warn(
"setting 'ref_labels' to a column name of the column data is deprecated",
category=DeprecationWarning
)
ref_labels = ref_data.get_column_data().column(ref_labels)

ref_data, ref_features = _clean_matrix(
ref_data,
Expand All @@ -196,13 +202,13 @@ def train_single(
ref_features = biocutils.subset_sequence(ref_features, keep)
ref_data = delayedarray.DelayedArray(ref_data)[keep,:]

if isinstance(ref_labels, str):
warnings.warn(
"setting 'labels' to a column name of the column data is deprecated",
category=DeprecationWarning
)
ref_labels = ref_data.get_column_data().column(ref_labels)
unique_labels, label_idx = _factorize(ref_labels)
for f in ref_labels:
if f is None:
raise ValueError("entries of 'ref_labels' cannot be missing")
if not isinstance(ref_labels, biocutils.Factor): # TODO: move over to biocutils so coercion can no-op.
ref_labels = biocutils.Factor.from_sequence(ref_labels, sort_levels=False) # TODO: add a dtype= option.
unique_labels = ref_labels.levels
label_idx = ref_labels.codes.astype(dtype=numpy.uint32, copy=False)

markers = _identify_genes(
ref_data=ref_data,
Expand Down
44 changes: 25 additions & 19 deletions tests/test_integrated_with_celldex.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,17 @@ def test_with_minimal_args():
)
immune_cell_ref = celldex.fetch_reference("dice", "2024-02-26", realize_assays=True)

with pytest.raises(Exception):
singler.annotate_integrated(
test_data=sce.assays["counts"],
ref_data_list=(blueprint_ref, immune_cell_ref),
ref_labels_list="label.main",
num_threads=6,
)

single, integrated = singler.annotate_integrated(
test_data=sce,
ref_data_list=(blueprint_ref, immune_cell_ref),
ref_labels_list="label.main",
num_threads=6,
ref_data=(
blueprint_ref,
immune_cell_ref
),
ref_labels=[
blueprint_ref.get_column_data().column("label.main"),
immune_cell_ref.get_column_data().column("label.main")
],
num_threads=2
)
assert len(single) == 2
assert isinstance(integrated, BiocFrame)
Expand All @@ -43,14 +41,21 @@ def test_with_all_supplied():
immune_cell_ref = celldex.fetch_reference("dice", "2024-02-26", realize_assays=True)

single, integrated = singler.annotate_integrated(
test_data=sce,
test_data=sce.assays["counts"],
test_features=sce.get_row_names(),
ref_data_list=(blueprint_ref, immune_cell_ref),
ref_labels_list=[
x.get_column_data().column("label.main")
for x in (blueprint_ref, immune_cell_ref)
ref_data=(
blueprint_ref,
immune_cell_ref
),
ref_labels=[
blueprint_ref.get_column_data().column("label.main"),
immune_cell_ref.get_column_data().column("label.main")
],
ref_features=[
blueprint_ref.get_row_names(),
immune_cell_ref.get_row_names()
],
ref_features_list=[x.get_row_names() for x in (blueprint_ref, immune_cell_ref)],
num_threads=2
)

assert len(single) == 2
Expand All @@ -67,8 +72,9 @@ def test_with_colname():

single, integrated = singler.annotate_integrated(
test_data=sce,
ref_data_list=(blueprint_ref, immune_cell_ref),
ref_labels_list="label.main",
ref_data=(blueprint_ref, immune_cell_ref),
ref_labels=["label.main"] * 2,
num_threads=2
)

assert len(single) == 2
Expand Down
13 changes: 1 addition & 12 deletions tests/test_single_with_celldex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,12 @@
import scrnaseq
import pandas as pd
import scipy
import pytest
from biocframe import BiocFrame

def test_with_minimal_args():
sce = scrnaseq.fetch_dataset("zeisel-brain-2015", "2023-12-14", realize_assays=True)

immgen_ref = celldex.fetch_reference("immgen", "2024-02-26", realize_assays=True)

with pytest.raises(Exception):
matches = singler.annotate_single(
test_data=sce.assays["counts"],
ref_data=immgen_ref,
ref_labels=immgen_ref.get_column_data().column("label.main"),
)

matches = singler.annotate_single(
test_data=sce,
ref_data=immgen_ref,
Expand All @@ -32,11 +23,10 @@ def test_with_minimal_args():

def test_with_all_supplied():
sce = scrnaseq.fetch_dataset("zeisel-brain-2015", "2023-12-14", realize_assays=True)

immgen_ref = celldex.fetch_reference("immgen", "2024-02-26", realize_assays=True)

matches = singler.annotate_single(
test_data=sce,
test_data=sce.assays["counts"],
test_features=sce.get_row_names(),
ref_data=immgen_ref,
ref_labels=immgen_ref.get_column_data().column("label.main"),
Expand All @@ -50,7 +40,6 @@ def test_with_all_supplied():

def test_with_colname():
sce = scrnaseq.fetch_dataset("zeisel-brain-2015", "2023-12-14", realize_assays=True)

immgen_ref = celldex.fetch_reference("immgen", "2024-02-26", realize_assays=True)

matches = singler.annotate_single(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_train_integrated.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_train_integrated():
test_features,
ref_prebuilt=[built1, built2],
ref_names=["FOO", "BAR"],
num_threads=3,
num_threads=2,
)

assert pintegrated.reference_names == ["FOO", "BAR"]
Expand Down
12 changes: 0 additions & 12 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from singler._utils import (
_factorize,
_stable_intersect,
_stable_union,
_clean_matrix,
Expand All @@ -8,17 +7,6 @@
import summarizedexperiment


def test_factorize():
lev, ind = _factorize([1, 3, 5, 5, 3, 1])
assert list(lev) == ["1", "3", "5"]
assert (ind == [0, 1, 2, 2, 1, 0]).all()

# Preserves the order.
lev, ind = _factorize(["C", "D", "A", "B", "C", "A"])
assert list(lev) == ["C", "D", "A", "B"]
assert (ind == [0, 1, 2, 3, 0, 2]).all()


def test_intersect():
# Preserves the order in the first argument.
out = _stable_intersect(["B", "C", "A", "D", "E"], ["A", "C", "E"])
Expand Down

0 comments on commit dec4971

Please sign in to comment.