diff --git a/bengrn/base.py b/bengrn/base.py index 44b9ebd..2eb436e 100644 --- a/bengrn/base.py +++ b/bengrn/base.py @@ -4,10 +4,13 @@ import json import logging +import os import os.path +import tarfile import urllib.request from typing import Optional, Union +import gdown import gseapy as gp import matplotlib.pyplot as plt import numpy as np @@ -29,9 +32,6 @@ from sklearn.linear_model import LogisticRegression, RidgeClassifier from sklearn.metrics import PrecisionRecallDisplay, auc, precision_recall_curve from sklearn.model_selection import train_test_split -import gdown -import os -import tarfile from .tools import GENIE3 @@ -409,13 +409,13 @@ def get_scenicplus( def get_sroy_gt( - get: str = "main", join: str = "outer", species: str = "human", gt: str = "full" + get: str = "mine", join: str = "outer", species: str = "human", gt: str = "full" ) -> GRNAnnData: """ This function retrieves the ground truth data from the McCall et al.'s paper. Args: - get (str): The specific dataset to retrieve. Options include "main", "liu", and "chen". + get (str): The specific dataset to retrieve. Options include "mine", "liu", and "chen". join (str, optional): The type of join to be performed when concatenating the data. Default is "outer". species (str, optional): The species of the dataset. Default is "human". gt (str, optional): The type of ground truth data to retrieve. Options include "full", "chip", and "ko". Default is "full". diff --git a/bengrn/tools/genie3.py b/bengrn/tools/genie3.py index 6a6b641..e753b80 100644 --- a/bengrn/tools/genie3.py +++ b/bengrn/tools/genie3.py @@ -17,7 +17,7 @@ def compute_feature_importances(estimator): for e in estimator.estimators_ ] importances = array(importances) - return sum(importances, axis=0) / len(estimator) + return importances.sum(0) / len(estimator.estimators_) def get_link_list( diff --git a/tests/test_base.py b/tests/test_base.py index 503df37..528cfa0 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,12 +1,22 @@ import os import numpy as np +import pandas as pd import pytest import scanpy as sc from grnndata import GRNAnnData from scipy.sparse import csr_matrix -from bengrn.base import NAME, BenGRN +from bengrn.base import ( + NAME, + BenGRN, + compute_epr, + compute_genie3, + get_GT_db, + get_perturb_gt, + get_sroy_gt, + train_classifier, +) def test_base(): @@ -20,5 +30,30 @@ def test_base(): grn = GRNAnnData(adata.copy(), grn=sparse_random_matrix) grn.var.index = grn.var.symbol.astype(str) _ = BenGRN(grn, doplot=False).scprint_benchmark() + + # Test get_sroy_gt function + sroy_gt = get_sroy_gt(get="liu") + assert isinstance( + sroy_gt, GRNAnnData + ), "get_sroy_gt should return a GRNAnnData object" + + # Test get_perturb_gt function + perturb_gt = get_perturb_gt() + assert isinstance( + perturb_gt, GRNAnnData + ), "get_perturb_gt should return a GRNAnnData object" + + # Test compute_genie3 function + genie3_result = compute_genie3(adata[:, :100], ntrees=10, nthreads=1) + assert isinstance( + genie3_result, GRNAnnData + ), "compute_genie3 should return a GRNAnnData object" + + # Test train_classifier function + random_matrix = np.random.rand(4, 10000).reshape(100, 100, 4) + subgrn = grn[:, :100] + subgrn.varp["GRN"] = random_matrix + classifier, metrics, clf = train_classifier(subgrn) + except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") diff --git a/uv.lock b/uv.lock index 160ed27..6796e97 100644 --- a/uv.lock +++ b/uv.lock @@ -257,7 +257,7 @@ wheels = [ [[package]] name = "bengrn" -version = "1.1.0" +version = "1.2.2" source = { editable = "." } dependencies = [ { name = "anndata" }, @@ -266,12 +266,14 @@ dependencies = [ { name = "ctxcore" }, { name = "dask-expr" }, { name = "decoupler" }, + { name = "gdown" }, { name = "grnndata" }, { name = "gseapy" }, { name = "numpy" }, { name = "omnipath" }, { name = "pandas" }, { name = "pyscenic" }, + { name = "rich" }, { name = "scikit-learn" }, { name = "scipy" }, { name = "seaborn" }, @@ -301,6 +303,7 @@ requires-dist = [ { name = "ctxcore", specifier = ">=0.1.1" }, { name = "dask-expr", specifier = ">=1.0.0" }, { name = "decoupler", specifier = ">=1.2.0" }, + { name = "gdown", specifier = ">=4.7.1" }, { name = "gitchangelog", marker = "extra == 'dev'", specifier = ">=3.0.4" }, { name = "grnndata", specifier = ">=0.1.0" }, { name = "gseapy", specifier = ">=0.10.0" }, @@ -316,6 +319,7 @@ requires-dist = [ { name = "pyscenic", specifier = ">=0.12.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.3" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1.0" }, + { name = "rich", specifier = ">=13.5.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.6.4" }, { name = "scikit-learn", specifier = ">=1.0.0" }, { name = "scipy", specifier = ">=1.7.0" }, @@ -1191,6 +1195,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1d/a0/6aaea0c2fbea2f89bfd5db25fb1e3481896a423002ebe4e55288907a97a3/fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b", size = 179253 }, ] +[[package]] +name = "gdown" +version = "5.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "beautifulsoup4" }, + { name = "filelock" }, + { name = "requests", extra = ["socks"] }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/09/6a/37e6b70c5bda3161e40265861e63b64a86bfc6ca6a8f1c35328a675c84fd/gdown-5.2.0.tar.gz", hash = "sha256:2145165062d85520a3cd98b356c9ed522c5e7984d408535409fd46f94defc787", size = 284647 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/70/e07c381e6488a77094f04c85c9caf1c8008cdc30778f7019bc52e5285ef0/gdown-5.2.0-py3-none-any.whl", hash = "sha256:33083832d82b1101bdd0e9df3edd0fbc0e1c5f14c9d8c38d2a35bf1683b526d6", size = 18235 }, +] + [[package]] name = "ghp-import" version = "2.1.0" @@ -2957,8 +2976,6 @@ version = "6.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/18/c7/8c6872f7372eb6a6b2e4708b88419fb46b857f7a2e1892966b851cc79fc9/psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2", size = 508067 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c5/66/78c9c3020f573c58101dc43a44f6855d01bbbd747e24da2f0c4491200ea3/psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35", size = 249766 }, - { url = "https://files.pythonhosted.org/packages/e1/3f/2403aa9558bea4d3854b0e5e567bc3dd8e9fbc1fc4453c0aa9aafeb75467/psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1", size = 253024 }, { url = "https://files.pythonhosted.org/packages/0b/37/f8da2fbd29690b3557cca414c1949f92162981920699cd62095a984983bf/psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0", size = 250961 }, { url = "https://files.pythonhosted.org/packages/35/56/72f86175e81c656a01c4401cd3b1c923f891b31fbcebe98985894176d7c9/psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0", size = 287478 }, { url = "https://files.pythonhosted.org/packages/19/74/f59e7e0d392bc1070e9a70e2f9190d652487ac115bb16e2eff6b22ad1d24/psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd", size = 290455 }, @@ -3206,6 +3223,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5c/e3/6a0eaf46a897da829c896f0a034fce82133fce72f95d314bea81287c4279/pyscenic-0.12.1-py3-none-any.whl", hash = "sha256:a250d682e073e67dc80505843764d9cade68dada45a40a622e1aefbae78756e9", size = 7099773 }, ] +[[package]] +name = "pysocks" +version = "1.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bd/11/293dd436aea955d45fc4e8a35b6ae7270f5b8e00b53cf6c024c83b657a11/PySocks-1.7.1.tar.gz", hash = "sha256:3f8804571ebe159c380ac6de37643bb4685970655d3bba243530d6558b799aa0", size = 284429 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/59/b4572118e098ac8e46e399a1dd0f2d85403ce8bbaad9ec79373ed6badaf9/PySocks-1.7.1-py3-none-any.whl", hash = "sha256:2725bd0a9925919b9b51739eea5f9e2bae91e83288108a9ad338b2e3a4435ee5", size = 16725 }, +] + [[package]] name = "pytest" version = "8.3.3" @@ -3549,6 +3575,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, ] +[package.optional-dependencies] +socks = [ + { name = "pysocks" }, +] + +[[package]] +name = "rich" +version = "13.9.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/aa/9e/1784d15b057b0075e5136445aaea92d23955aad2c93eaede673718a40d95/rich-13.9.2.tar.gz", hash = "sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c", size = 222843 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/91/5474b84e505a6ccc295b2d322d90ff6aa0746745717839ee0c5fb4fdcceb/rich-13.9.2-py3-none-any.whl", hash = "sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1", size = 242117 }, +] + [[package]] name = "rpds-py" version = "0.20.0"