Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated tests. #42

Merged
merged 2 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion popv/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def make_agreement_plots(
prediction_keys: list,
popv_prediction_key: str | None = "popv_prediction",
save_folder: str | None = None,
show: bool = True,
):
"""
Create plot of confusion matrix for different popv methods and consensus prediction.
Expand All @@ -216,6 +217,8 @@ def make_agreement_plots(
Key in adata.obs for consensus prediction.
save_folder
Path to a folder for storing the plot. Defaults to None and plot is not stored.
show
If True, the plot will be shown in the console. If False, the plot will not be shown.

Returns
-------
Expand All @@ -237,6 +240,7 @@ def make_agreement_plots(
x_label=method,
y_label=popv_prediction_key,
res_dir=save_folder,
show=show,
)


Expand All @@ -247,6 +251,7 @@ def _prediction_eval(
x_label="",
y_label="",
res_dir="./",
show=True,
):
"""Generate confusion matrix."""
types, _ = np.unique(np.concatenate([labels, pred]), return_inverse=True)
Expand All @@ -266,4 +271,5 @@ def _prediction_eval(
for fig in range(1, plt.gcf().number + 1):
pdf.savefig(fig)
pdf.close()
plt.show()
if show:
plt.show()
19 changes: 15 additions & 4 deletions tests/core/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from popv.reproducibility import _accuracy


def _get_test_anndata(cl_obo_folder="resources/ontology/"):
def _get_test_anndata(cl_obo_folder="resources/ontology/", mode='retrain'):
print("UUU", os.getcwd())
save_folder = "tests/tmp_testing/popv_test_results/"
fn = save_folder + "annotated_query.h5ad"
Expand All @@ -34,6 +34,8 @@ def _get_test_anndata(cl_obo_folder="resources/ontology/"):
# Lesser used parameters
query_labels_key = None
unknown_celltype_label = "unknown"
hvg = 4000 if mode == "retrain" else None

adata = Process_Query(
query_adata,
ref_adata,
Expand All @@ -44,13 +46,13 @@ def _get_test_anndata(cl_obo_folder="resources/ontology/"):
unknown_celltype_label=unknown_celltype_label,
save_path_trained_models=save_folder,
cl_obo_folder=cl_obo_folder,
prediction_mode="retrain",
prediction_mode=mode,
n_samples_per_label=n_samples_per_label,
compute_embedding=True,
return_probabilities=True,
accelerator="cpu",
devices="auto",
hvg=4000,
hvg=hvg,
)

return adata
Expand Down Expand Up @@ -183,7 +185,7 @@ def test_annotation():
save_path="tests/tmp_testing/popv_test_results/")
popv.visualization.agreement_score_bar_plot(adata)
popv.visualization.prediction_score_bar_plot(adata)
popv.visualization.make_agreement_plots(adata, prediction_keys=adata.uns["prediction_keys"])
popv.visualization.make_agreement_plots(adata, prediction_keys=adata.uns["prediction_keys"], show=False)
popv.visualization.celltype_ratio_bar_plot(adata)
obo_fn = "resources/ontology/cl.obo"
_accuracy._ontology_accuracy(adata[adata.obs['_dataset']=='ref'], obofile=obo_fn, gt_key='cell_ontology_class', pred_key='popv_prediction')
Expand All @@ -192,6 +194,15 @@ def test_annotation():
assert "popv_majority_vote_prediction" in adata.obs.columns
assert not adata.obs["popv_majority_vote_prediction"].isnull().any()

adata = _get_test_anndata(mode='inference').adata
popv.annotation.annotate_data(
adata, save_path="tests/tmp_testing/popv_test_results/")

adata = _get_test_anndata(mode='fast').adata
popv.annotation.annotate_data(
adata, methods=["svm", "rf"],
save_path="tests/tmp_testing/popv_test_results/")


def test_annotation_no_ontology():
"""Test Annotation and Plotting pipeline without ontology."""
Expand Down
Loading