Skip to content

Commit

Permalink
Don't generate duplicate cluster_id columns if `generate_and_save_c…
Browse files Browse the repository at this point in the history
…luster_masks` run twice (#1110)

* Ensure cluster_id column isn't created twice if re-run

* Add debugging statements to try and pinpoint issue with column ordering

* Also add print statements to see how the CI cache_dir is being affected

* Add more print statements

* More print statements and debugging debugging (if that makes sense)

* Even more print statements

* added ez_seg_data in scripts/get_example_dataset.py

* Track down that norm_data!

* Add a few more assert statements to check

* Did the columns get reordered?

* Adjust test

* Update print statements and remove assertion so we get further

* Fix tests for FOV removed

* PYCODESTYLE import fix

* Remove unnecessary print statements

* adjusted testing adata fixture

---------

Co-authored-by: Sricharan Reddy Varra <[email protected]>
  • Loading branch information
alex-l-kong and srivarra authored Feb 27, 2024
1 parent 63ee8ff commit d3e9e4a
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 61 deletions.
9 changes: 2 additions & 7 deletions src/ark/phenotyping/pixie_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,13 +427,8 @@ def create_pixel_matrix(fovs, channels, base_dir, tiff_dir, seg_dir,
quant_dat_fov.index.name = "channel"

# update the file with the newly processed fov quantile values
quant_dat_all = quant_dat_all.merge(
quant_dat_fov,
how="outer",
left_index=True,
right_index=True
)

quant_dat_all = quant_dat_all.merge(quant_dat_fov, how="outer",
left_index=True, right_index=True)
quant_dat_all.to_csv(quantile_path)

# update number of fovs processed
Expand Down
3 changes: 3 additions & 0 deletions src/ark/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,9 @@ def generate_and_save_cell_cluster_masks(
cluster_map = cmd.mapping.filter([cmd.cluster_column, cmd.cluster_id_column])
cluster_map = cluster_map.drop_duplicates()

# drop the cluster_id column from updated_cluster_map if it already exists, otherwise do nothing
gui_map = gui_map.drop(columns="cluster_id", errors="ignore")

# add a cluster_id column corresponding to the new mask integers
updated_cluster_map = gui_map.merge(cluster_map, on=[cmd.cluster_column], how="left")
updated_cluster_map.to_csv(cluster_id_to_name_path, index=False)
Expand Down
1 change: 1 addition & 0 deletions src/ark/utils/example_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def get_example_dataset(dataset: str, save_dir: Union[str, pathlib.Path],
* `"LDA_training_inference"`
* `"neighborhood_analysis"`
* `"pairwise_spatial_enrichment"`
* `"ez_seg_data"`
save_dir (Union[str, pathlib.Path]): The path to save the dataset files in.
overwrite_existing (bool): The option to overwrite existing configs of the `dataset`
downloaded. Defaults to True.
Expand Down
109 changes: 57 additions & 52 deletions tests/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,33 +359,35 @@ def test_generate_and_save_cell_cluster_masks(tmp_path: pathlib.Path, sub_dir, n
cluster_mapping.to_csv(os.path.join(tmp_path, 'cluster_mapping.csv'), index=False)

# test various batch_sizes, no sub_dir, name_suffix = ''.
data_utils.generate_and_save_cell_cluster_masks(
fovs=fovs,
save_dir=os.path.join(tmp_path, 'cell_masks'),
seg_dir=tmp_path,
cell_data=consensus_data_som,
cluster_id_to_name_path=mapping_file_path,
fov_col=settings.FOV_ID,
label_col=settings.CELL_LABEL,
cell_cluster_col='cell_som_cluster',
seg_suffix='_whole_cell.tiff',
sub_dir=sub_dir,
name_suffix=name_suffix
)
# NOTE: test is run twice to ensure that results are same even if existing cluster_id found
for i in np.arange(2):
data_utils.generate_and_save_cell_cluster_masks(
fovs=fovs,
save_dir=os.path.join(tmp_path, 'cell_masks'),
seg_dir=tmp_path,
cell_data=consensus_data_som,
cluster_id_to_name_path=mapping_file_path,
fov_col=settings.FOV_ID,
label_col=settings.CELL_LABEL,
cell_cluster_col='cell_som_cluster',
seg_suffix='_whole_cell.tiff',
sub_dir=sub_dir,
name_suffix=name_suffix
)

# open each cell mask and make sure the shape and values are valid
if sub_dir is None:
sub_dir = ''
# open each cell mask and make sure the shape and values are valid
if sub_dir is None:
sub_dir = ''

for i, fov in enumerate(fovs):
fov_name = fov + name_suffix + ".tiff"
cell_mask = io.imread(os.path.join(tmp_path, 'cell_masks', sub_dir, fov_name))
actual_img_dims = (40, 40) if i < fov_size_split else (20, 20)
assert cell_mask.shape == actual_img_dims
assert np.all(cell_mask <= 5)
for i, fov in enumerate(fovs):
fov_name = fov + name_suffix + ".tiff"
cell_mask = io.imread(os.path.join(tmp_path, 'cell_masks', sub_dir, fov_name))
actual_img_dims = (40, 40) if i < fov_size_split else (20, 20)
assert cell_mask.shape == actual_img_dims
assert np.all(cell_mask <= 5)

new_cluster_mapping = pd.read_csv(mapping_file_path)
assert "cluster_id" in new_cluster_mapping.columns
new_cluster_mapping = pd.read_csv(mapping_file_path)
assert "cluster_id" in new_cluster_mapping.columns


def test_generate_pixel_cluster_mask():
Expand Down Expand Up @@ -851,35 +853,35 @@ def test_convert_to_adata(self):
assert pathlib.Path(fov_adata_path).exists()


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def testing_anndatas(
tmp_path_factory: pytest.TempPathFactory
) -> Callable[[int, pathlib.Path], Tuple[List[str], AnnCollection]]:
def create_adatas(n_fovs, save_dir: pathlib.Path):
fov_names, ann_collection = ark_test_utils.generate_anncollection(
fovs=n_fovs,
n_vars=10,
n_obs=100,
obs_properties=4,
obs_categorical_properties=2,
random_n_obs=True,
join_obs="inner",
join_obsm="inner"
)
) -> Iterator[Tuple[List[str], AnnCollection, pathlib.Path]]:

fov_names, ann_collection = ark_test_utils.generate_anncollection(
fovs=5,
n_vars=10,
n_obs=100,
obs_properties=4,
obs_categorical_properties=2,
random_n_obs=True,
join_obs="inner",
join_obsm="inner",
)

for fov_name, fov_adata in zip(fov_names, ann_collection.adatas):
fov_adata.write_zarr(os.path.join(save_dir, f"{fov_name}.zarr"))
return fov_names, ann_collection
save_dir = tmp_path_factory.mktemp("anndatas")

yield create_adatas
for fov_name, fov_adata in zip(fov_names, ann_collection.adatas):
fov_adata.write_zarr(os.path.join(save_dir, f"{fov_name}.zarr"))

yield fov_names, ann_collection, save_dir


def test_load_anndatas(testing_anndatas, tmp_path_factory):
ann_collection_path = tmp_path_factory.mktemp("anndatas")
def test_load_anndatas(testing_anndatas):

fov_names, ann_collection = testing_anndatas(n_fovs=5, save_dir=ann_collection_path)
fov_names, ann_collection, anndata_dir = testing_anndatas

ac = data_utils.load_anndatas(ann_collection_path, join_obs="inner", join_obsm="inner")
ac = data_utils.load_anndatas(anndata_dir, join_obs="inner", join_obsm="inner")

assert isinstance(ac, AnnCollection)
assert len(ac.adatas) == len(fov_names)
Expand All @@ -888,21 +890,24 @@ def test_load_anndatas(testing_anndatas, tmp_path_factory):
# Assert that each AnnData component of an AnnCollection is the same as the one on disk.
for fov_name, fov_adata in zip(fov_names, ann_collection.adatas):
anndata.tests.helpers.assert_adata_equal(
a=read_zarr(ann_collection_path / f"{fov_name}.zarr"),
a=read_zarr(anndata_dir / f"{fov_name}.zarr"),
b=fov_adata
)


def test_AnnDataIterDataPipe(testing_anndatas, tmp_path_factory):
ann_collection_path = tmp_path_factory.mktemp("anndatas")
def test_AnnDataIterDataPipe(testing_anndatas):

_ = testing_anndatas(n_fovs=5, save_dir=ann_collection_path)
ac = data_utils.load_anndatas(ann_collection_path, join_obs="inner", join_obsm="inner")
fov_names, _, anndata_dir = testing_anndatas
ac = data_utils.load_anndatas(anndata_dir, join_obs="inner", join_obsm="inner")

a_idp = data_utils.AnnDataIterDataPipe(fovs=ac)

from torchdata.datapipes.iter import IterDataPipe
assert isinstance(a_idp, IterDataPipe)

for fov in a_idp:
assert isinstance(fov, AnnData)
for fov_name, fov_adata in zip(fov_names, a_idp):
assert isinstance(fov_adata, AnnData)
anndata.tests.helpers.assert_adata_equal(
a=read_zarr(anndata_dir / f"{fov_name}.zarr"),
b=fov_adata
)
7 changes: 5 additions & 2 deletions tests/utils/example_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,14 @@ def test_download_example_dataset(self, dataset_download: ExampleDataset):
Args:
dataset_download (ExampleDataset): Fixture for the dataset, respective to each
"""
import os
dataset_names = list(
dataset_download.dataset_paths[dataset_download.dataset].keys())
dataset_download.dataset_paths[dataset_download.dataset].keys()
)
for ds_n in dataset_names:
dataset_cache_path = pathlib.Path(
dataset_download.dataset_paths[dataset_download.dataset][ds_n])
dataset_download.dataset_paths[dataset_download.dataset][ds_n]
)
self.dataset_test_fns[ds_n](dir_p=dataset_cache_path)

@pytest.mark.parametrize("_overwrite_existing", [True, False])
Expand Down

0 comments on commit d3e9e4a

Please sign in to comment.