Skip to content

Commit

Permalink
Fix subset operation on primary samples (#21)
Browse files Browse the repository at this point in the history
This fixes a few inconsistencies when subsetting MAE across the three dimensions: rows, columns and experiment names. In addition also simplified code for `complete_cases` and `replicated`.

Adds a new method `get_with_col_data` for consistency of the same function name from R's MAE implementation.

Update tests and docstrings.
  • Loading branch information
jkanche authored Jan 16, 2024
1 parent fba0775 commit 461c635
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 44 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ install_requires =
importlib-metadata; python_version<"3.8"
biocframe>=0.5.6,<0.6.0
biocutils>=0.1.4,<0.2.0
summarizedexperiment>=0.4.0,<0.5.0
summarizedexperiment>=0.4.1,<0.5.0

[options.packages.find]
where = src
Expand Down
103 changes: 69 additions & 34 deletions src/multiassayexperiment/MultiAssayExperiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def _validate_sample_map_with_expts(sample_map, experiments):
if (len(unique_expt_names) != len(smap_unique_assays)) or (
unique_expt_names != smap_unique_assays
):
raise ValueError("'assays' mismatch between `sample_map` and `experiments`.")
warn(
"'experiments' contains names not represented in 'sample_map' or vice-versa.",
UserWarning,
)

# check if colnames exist
agroups = sample_map.split("assay")
Expand Down Expand Up @@ -488,6 +491,13 @@ def experiment(self, name: str, with_sample_data: bool = False) -> Any:

return expt

def get_with_column_data(self, name: str) -> Any:
"""Alias to :py:meth:`~experiment`.
Consistency with Bioconductor's naming of the same function.
"""
return self.experiment(name, with_sample_data=True)

############################
######>> sample map <<######
############################
Expand Down Expand Up @@ -570,8 +580,8 @@ def set_column_data(
"""
column_data = _sanitize_frame(column_data)

self._validate_column_data(column_data)
self._validate_sample_map_with_column_data(self._sample_map, column_data)
_validate_column_data(column_data)
_validate_sample_map_with_column_data(self._sample_map, column_data)

output = self._define_output(in_place)
output._column_data = column_data
Expand Down Expand Up @@ -650,6 +660,28 @@ def metadata(self, metadata: dict):
######>> subset <<#######
#########################

def _normalize_column_slice(self, columns: Union[str, int, bool, Sequence, slice]):
_scalar = None
if columns != slice(None):
columns, _scalar = ut.normalize_subscript(
columns, len(self._column_data), self._column_data.row_names
)

return columns, _scalar

def _filter_sample_map(self, columns: Union[str, int, bool, Sequence, slice]):
_samples_to_filter = self._column_data[columns,].row_names

column_names_to_keep = {}
for i in self.experiment_names:
column_names_to_keep[i] = []

for _, row in self._sample_map:
if row["primary"] in _samples_to_filter:
column_names_to_keep[row["assay"]].append(row["colname"])

return column_names_to_keep

def subset_experiments(
self,
rows: Optional[Union[str, int, bool, Sequence]],
Expand All @@ -667,14 +699,14 @@ def subset_experiments(
:py:meth:`~biocutils.normalize_subscript.normalize_subscript`.
columns:
Column indices to subset.
Column indices (from :py:attr:`~column_data`) to subset.
Integer indices, a boolean filter, or (if the current object is
named) names specifying the ranges to be extracted, see
:py:meth:`~biocutils.normalize_subscript.normalize_subsc
:py:meth:`~biocutils.normalize_subscript.normalize_subscript`.
experiment_names:
Experiment name to keep.
Experiment names to keep.
Integer indices, a boolean filter, or (if the current object is
named) names specifying the ranges to be extracted, see
Expand Down Expand Up @@ -703,9 +735,21 @@ def subset_experiments(

_expts_copy = new_expt

if rows != slice(None) and columns != slice(None):
if rows != slice(None):
for k, v in _expts_copy.items():
_expts_copy[k] = v[rows, columns]
_expts_copy[k] = v[rows,]

columns, _ = self._normalize_column_slice(columns)
if columns != slice(None):
_col_dict = self._filter_sample_map(columns)

for k, v in _expts_copy.items():
if k in _col_dict:
if len(_col_dict[k]) != 0:
_matched_indices = ut.match(_col_dict[k], v.column_names)
else:
_matched_indices = []
_expts_copy[k] = v[:, list(_matched_indices)]

return _expts_copy

Expand Down Expand Up @@ -751,6 +795,10 @@ def _generic_slice(

if columns is None:
columns = slice(None)
columns, _ = self._normalize_column_slice(columns)

# filter column_data
_new_column_data = self._column_data[columns,]

if experiments is None:
experiments = slice(None)
Expand All @@ -764,20 +812,15 @@ def _generic_slice(
for expname, expt in _new_experiments.items():
counter = 0
for _, row in self._sample_map:
if row["assay"] == expname and row["colname"] in expt.column_names:
if (
row["assay"] == expname
and row["primary"] in _new_column_data.row_names
and row["colname"] in expt.column_names
):
smap_indices_to_keep.append(counter)
counter += 1
_new_sample_map = self._sample_map[list(set(smap_indices_to_keep)),]

# filter column_data
subset_primary = list(set(_new_sample_map.get_column("primary")))
coldata_indices_to_keep = []
for idx, row in enumerate(self._column_data._row_names):
if row in subset_primary:
coldata_indices_to_keep.append(idx)

_new_column_data = self._column_data[list(set(coldata_indices_to_keep)),]

return SlicerResult(_new_experiments, _new_sample_map, _new_column_data)

def subset_by_experiments(
Expand Down Expand Up @@ -908,7 +951,7 @@ def complete_cases(self) -> Sequence[bool]:
"""Identify samples that have data across all experiments.
Returns:
A boolean vector same as the number of samples in column_data,
A boolean vector same as the number of samples in 'column_data',
where each element is True if sample is present in all experiments or False.
"""
vec = []
Expand All @@ -921,7 +964,6 @@ def complete_cases(self) -> Sequence[bool]:
smap_indices_to_keep.append(rdx)

subset = self.sample_map[list(set(smap_indices_to_keep)),]

vec.append(set(subset.get_column("assay")) == set(self.experiment_names))

return vec
Expand All @@ -934,15 +976,14 @@ def replicated(self) -> Dict[str, Dict[str, Sequence[bool]]]:
are keys and values specify if the sample is replicated within each experiment.
"""
replicates = {}
all_samples = self.column_data.row_names
all_samples = self._column_data.row_names
for expname, expt in self._experiments.items():
if expname not in replicates:
replicates[expname] = {}

for s in all_samples:
replicates[expname][s] = []
replicates[expname][s] = [False] * expt.shape[1]

colnames = expt.column_names
smap_indices_to_keep = []

_assay = self._sample_map.get_column("assay")
Expand All @@ -952,17 +993,11 @@ def replicated(self) -> Dict[str, Dict[str, Sequence[bool]]]:

subset_smap = self.sample_map[list(set(smap_indices_to_keep)),]

for x in colnames:
_subset_smap_colnames = subset_smap.get_column("colname")
_indices = []
for cdx in range(len(_subset_smap_colnames)):
if _subset_smap_colnames[cdx] == x:
_indices.append(cdx)

__subset_smap = subset_smap[_indices,]

for s in all_samples:
replicates[expname][s].append(__subset_smap.get_column("primary"))
counter = 0
for _, row in subset_smap:
if row["assay"] == expname:
replicates[expname][row["primary"]][counter] = True
counter += 1

return replicates

Expand Down
20 changes: 11 additions & 9 deletions tests/test_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ def test_MAE_slice():
assert len(set(muMAE.sample_map["assay"])) == 3
assert len(set(muMAE.sample_map["primary"])) == 3

sliced_MAE = muMAE[1:3, 1:3]
sliced_MAE = muMAE[1:3, 1:2]
assert sliced_MAE is not None
assert isinstance(sliced_MAE, mae.MultiAssayExperiment)

assert sliced_MAE.experiments is not None
assert sliced_MAE.sample_map is not None
assert sliced_MAE.column_data is not None

assert len(set(sliced_MAE.sample_map["assay"])) == 3
assert len(set(sliced_MAE.sample_map["primary"])) == 3
assert sliced_MAE.sample_map.shape[0] == 6
assert len(set(sliced_MAE.sample_map["assay"])) == 1
assert len(set(sliced_MAE.sample_map["primary"])) == 1
assert sliced_MAE.sample_map.shape[0] != muMAE.sample_map.shape[0]
assert sliced_MAE.sample_map.shape[0] == 1000

sliced_MAE_assay = muMAE[None, None, ["rna", "spatial"]]
assert sliced_MAE_assay is not None
Expand All @@ -77,7 +78,7 @@ def test_MAE_slice():

assert len(set(sliced_MAE_assay.sample_map["assay"])) == 1
assert len(set(sliced_MAE_assay.sample_map["primary"])) == 1
assert sliced_MAE_assay.sample_map.shape[0] == 5
assert sliced_MAE_assay.sample_map.shape[0] == 1000


# def test_MAE_slice_dict():
Expand Down Expand Up @@ -137,17 +138,18 @@ def test_MAE_subset_by_column():
assert len(set(muMAE.sample_map["assay"])) == 3
assert len(set(muMAE.sample_map["primary"])) == 3

sliced_MAE = muMAE.subset_by_column(columns=[10, 2, 5])
sliced_MAE = muMAE.subset_by_column(columns=[1, 2])
assert sliced_MAE is not None
assert isinstance(sliced_MAE, mae.MultiAssayExperiment)

assert sliced_MAE.experiments is not None
assert sliced_MAE.sample_map is not None
assert sliced_MAE.column_data is not None

assert len(set(sliced_MAE.sample_map["assay"])) == 3
assert len(set(sliced_MAE.sample_map["primary"])) == 3
assert sliced_MAE.sample_map.shape == (2030, 3)
assert len(set(sliced_MAE.sample_map["assay"])) == 2
assert len(set(sliced_MAE.sample_map["primary"])) == 2
assert sliced_MAE.sample_map.shape == (1030, 3)
assert len(sliced_MAE.experiment_names) == len(muMAE.experiment_names)


def test_MAE_subsetByExpt():
Expand Down

0 comments on commit 461c635

Please sign in to comment.