From 353219b8e4b42683e6f5cbda5fc295581dced0cc Mon Sep 17 00:00:00 2001 From: sichao Date: Mon, 9 Oct 2023 11:44:57 -0400 Subject: [PATCH] fix bug and update docstr --- dynamo/tools/dynamics.py | 9 ++++++--- dynamo/tools/pseudotime_velocity.py | 4 ++-- dynamo/tools/utils.py | 14 +++++++------- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/dynamo/tools/dynamics.py b/dynamo/tools/dynamics.py index 66d479323..0471c2d53 100755 --- a/dynamo/tools/dynamics.py +++ b/dynamo/tools/dynamics.py @@ -480,9 +480,12 @@ def dynamics( est_method = "gmm" if model.lower() == "stochastic" else "ols" if experiment_type.lower() == "one-shot": - vel_params_df = get_vel_params(subset_adata) - beta = vel_params_df.beta if "beta" in vel_params_df.columns else None - gamma = vel_params_df.gamma if "gamma" in vel_params_df.columns else None + try: + vel_params_df = get_vel_params(subset_adata) + beta = vel_params_df.beta if "beta" in vel_params_df.columns else None + gamma = vel_params_df.gamma if "gamma" in vel_params_df.columns else None + except KeyError: + beta, gamma = None, None ss_estimation_kwargs = {"beta": beta, "gamma": gamma} else: ss_estimation_kwargs = {} diff --git a/dynamo/tools/pseudotime_velocity.py b/dynamo/tools/pseudotime_velocity.py index 341b37e14..660646f90 100644 --- a/dynamo/tools/pseudotime_velocity.py +++ b/dynamo/tools/pseudotime_velocity.py @@ -235,5 +235,5 @@ def pseudotime_velocity( logger.info("set gamma to be 0 in .var. so that velocity_S = unspliced RNA.") logger.info_insert_adata("gamma", "var", indent_level=2) - adata.varm["vel_params"] = np.zeros((adata.n_vars, 2)) - adata.uns["vel_params_names"] = ["gamma", "gamma_b"] + adata.varm["pseudotime_vel_params"] = np.zeros((adata.n_vars, 2)) + adata.uns["pseudotime_vel_params_names"] = ["gamma", "gamma_b"] diff --git a/dynamo/tools/utils.py b/dynamo/tools/utils.py index 182ad6548..be42ace3c 100755 --- a/dynamo/tools/utils.py +++ b/dynamo/tools/utils.py @@ -1630,7 +1630,7 @@ def get_vel_params( adata: the anndata object which contains the parameters. params: the names of parameters to query. If set to None, the entire velocity parameters DataFrame from `.varm` will be returned. - kin_param_pre: the prefix used in dynamics when estimating the parameters. + kin_param_pre: the prefix used to kinetic parameters related to RNA dynamics. skip_cell_wise: whether to skip the detected cell wise parameters. If set to True, the mean will be returned instead of cell wise parameters. @@ -1641,7 +1641,7 @@ def get_vel_params( params = [params] if kin_param_pre + "vel_params" not in adata.varm.keys(): - raise KeyError("No velocity parameters found.") + raise KeyError("The key of velocity related parameters are not found in varm.") array_data = adata.varm[kin_param_pre + "vel_params"] df_columns = adata.uns[kin_param_pre + "vel_params_names"] @@ -1673,15 +1673,15 @@ def get_vel_params( def update_vel_params(adata: AnnData, params_df: pd.DataFrame, kin_param_pre: str = "") -> None: - """Update the velocity parameters. + """Update the kinetic parameters related to RNA velocity calculation. Args: - adata: the AnnData obejct to update. - params_df: the new velocity parameters. - kin_param_pre: the prefix to locate the corresponding parameters. + adata: the AnnData object whose kinetic parameters related to RNA velocity calculation will be updated. + params_df: the dataframe of kinetic parameters related to RNA velocity calculation. + kin_param_pre: the prefix used to kinetic parameters related to RNA dynamics. Returns: - The anndata object will be updated. + The anndata object will be updated with parameters and columns names from given dataframe. """ adata.varm[kin_param_pre + "vel_params"] = params_df.to_numpy() adata.uns[kin_param_pre + "vel_params_names"] = list(params_df.columns)