From ee3a68af1cf2cc77bcb0cd116fb0d054bb8e27d9 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sun, 2 Jun 2024 17:29:48 +0200 Subject: [PATCH] update tests --- tests/ann2data/test_ann2data_by_category.py | 29 ++++++++++++++------- tests/transforms/test_add_edge_index.py | 2 +- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/tests/ann2data/test_ann2data_by_category.py b/tests/ann2data/test_ann2data_by_category.py index badc549..b085397 100644 --- a/tests/ann2data/test_ann2data_by_category.py +++ b/tests/ann2data/test_ann2data_by_category.py @@ -11,10 +11,17 @@ def test_sample_case_ann2data_basic(): # so that the resulting splits number of edges will be the same # as the sum of the number of edges in each cluster func_args = {"radius": 4.0, "coord_type": "generic"} - coordinates[:25, 0] += 100 + cell_type = ["a"] * 20 + ["b"] * 20 + ["c"] * 5 + ["d"] * 5 + image_id = list("xy" * 20) + ["z"] * 10 + # make clusters for each cell type + for i, ct in enumerate(set(cell_type)): + idx = np.where(np.array(cell_type) == ct)[0] + coordinates[idx, 0] += 100 * i + coordinates[idx, 1] += 100 * i + adata_gt = ad.AnnData( np.random.rand(50, 2), - obs={"cell_type": ["a"] * 20 + ["b"] * 20 + ["c"] * 5 + ["d"] * 5, "image_id": list("xy" * 20) + ["z"] * 10}, + obs={"cell_type": cell_type, "image_id": image_id}, obsm={"spatial_init": coordinates}, ) a2d = ann2data.Ann2DataByCategory( @@ -55,11 +62,13 @@ def test_sample_case_ann2data_basic(): assert torch.allclose(torch.cat([d.x for d in datas]), torch.from_numpy(big_adata.X).to(torch.float)) assert sum([d.edge_index.shape[1] for d in datas]) == big_adata.uns["edge_index"].shape[1] adatas = list(iterables.ToCategoryIterator(category="cell_type")(big_adata)) - assert np.allclose( - np.array(adatas[0].obsp["graph_distances"].todense()), - np.array(big_adata.obsp["graph_distances"][0:25, 0:25].todense()), - ) - assert np.allclose( - np.array(adatas[1].obsp["graph_distances"].todense()), - np.array(big_adata.obsp["graph_distances"][25:, 25:].todense()), - ) + assert len(adatas) == 4 + # this line is the for loop version of the last two assertions + + for a in adatas: + ct = a.obs["cell_type"].values[0] + ct_idx = np.where(np.array(cell_type) == ct)[0] + np.allclose( + a.obsp["graph_distances"].todense(), + big_adata.obsp["graph_distances"][ct_idx, :][:, ct_idx].todense(), + ) diff --git a/tests/transforms/test_add_edge_index.py b/tests/transforms/test_add_edge_index.py index c33b9da..0993dc5 100644 --- a/tests/transforms/test_add_edge_index.py +++ b/tests/transforms/test_add_edge_index.py @@ -62,7 +62,7 @@ def test_add_edge_index(): tf = transforms.AddEdgeIndex( spatial_key="spatial_init", key_added="pred", - func_args={"radius": median_dist, "n_neighs": 4}, + func_args={"radius": median_dist, "n_neighs": 4, "coord_type": "generic"}, edge_index_key="edge_index", edge_weight_key="edge_weight", gets_connectivities=False, # gets distances