Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save velocity parameters to .varm instead .var #579

Merged
merged 22 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 72 additions & 67 deletions dynamo/plot/dynamics.py

Large diffs are not rendered by default.

18 changes: 11 additions & 7 deletions dynamo/plot/scatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..dynamo_logger import main_debug, main_info, main_warning
from ..preprocessing.utils import affine_transform, gen_rotation_2d
from ..tools.moments import calc_1nd_moment
from ..tools.utils import flatten, get_mapper, update_dict
from ..tools.utils import flatten, get_mapper, get_vel_params, update_dict, update_vel_params
from .utils import (
_datashade_points,
_get_adata_color_vec,
Expand Down Expand Up @@ -808,9 +808,11 @@ def _plot_basis_layer(cur_b, cur_l):
points.iloc[:, 0].max() * 0.80,
)
k_name = "gamma_k" if _adata.uns["dynamics"]["experiment_type"] == "one-shot" else "gamma"
if k_name in _adata.var.columns:
if not ("gamma_b" in _adata.var.columns) or all(_adata.var.gamma_b.isna()):
_adata.var.loc[:, "gamma_b"] = 0
vel_params_df = get_vel_params(_adata)
if k_name in vel_params_df.columns:
if not ("gamma_b" in vel_params_df.columns) or all(vel_params_df.gamma_b.isna()):
vel_params_df.loc[:, "gamma_b"] = 0
update_vel_params(_adata, params_df=vel_params_df)
ax.plot(
xnew,
xnew * _adata[:, cur_b].var.loc[:, k_name].unique()
Expand Down Expand Up @@ -841,9 +843,11 @@ def _plot_basis_layer(cur_b, cur_l):
+ group_adata[:, cur_b].var.loc[:, group_b_key].unique()
)
ax.annotate(group + "_" + cur_group, xy=(group_xnew[-1], group_ynew[-1]))
if group_k_name in group_adata.var.columns:
if not (group_b_key in group_adata.var.columns) or all(group_adata.var[group_b_key].isna()):
group_adata.var.loc[:, group_b_key] = 0
vel_params_df = get_vel_params(group_adata)
if group_k_name in vel_params_df.columns:
if not (group_b_key in vel_params_df.columns) or all(vel_params_df[group_b_key].isna()):
vel_params_df.loc[:, group_b_key] = 0
update_vel_params(group_adata, params_df=vel_params_df)
main_info("No %s found, setting all bias terms to zero" % group_b_key)
ax.plot(
group_xnew,
Expand Down
7 changes: 4 additions & 3 deletions dynamo/plot/utils_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
prepare_data_mix_no_splicing,
prepare_data_no_splicing,
)
from ..tools.utils import get_mapper
from ..tools.utils import get_mapper, get_vel_params
from .utils import _to_hex


Expand Down Expand Up @@ -1685,8 +1685,9 @@ def plot_kin_twostep(

colors = pd.Series(T).map(new_color_key).values

r2 = adata[:, genes].var["gamma_r2"]
mean_R2 = adata[:, genes].var["mean_R2"]
vel_params_df = get_vel_params(adata)
r2 = vel_params_df.loc[genes, "gamma_r2"]
mean_R2 = vel_params_df.loc[genes, "mean_R2"]

for i, gene_name in enumerate(genes):
cur_X_data, cur_X_fit_data, cur_logLL = (
Expand Down
3 changes: 2 additions & 1 deletion dynamo/prediction/tscRNA_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from scipy.sparse import csr_matrix

from ..dynamo_logger import LoggerManager, main_exception, main_warning
from ..tools.utils import get_vel_params
from ..utils import copy_adata
from .utils import init_r0_pulse

Expand Down Expand Up @@ -67,7 +68,7 @@ def get_pulse_r0(
% (tkey, nkey, gamma_k_key)
)
R, L = adata[:, gene_names].layers[tkey], adata[:, gene_names].layers[nkey]
K = adata[:, gene_names].var[gamma_k_key].values.astype(float)
K = get_vel_params(adata, params=gamma_k_key, genes=gene_names).astype(float)

logger.info("Calculate initial total RNA via r0 = (r - l) / (1 - k)")
res = init_r0_pulse(R, L, K[None, :])
Expand Down
1 change: 1 addition & 0 deletions dynamo/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
AnnDataPredicate,
cell_norm,
compute_smallest_distance,
get_vel_params,
index_gene,
select,
select_cell,
Expand Down
9 changes: 7 additions & 2 deletions dynamo/tools/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
get_data_for_kin_params_estimation,
get_U_S_for_velocity_estimation,
get_valid_bools,
get_vel_params,
one_shot_alpha_matrix,
remove_2nd_moments,
set_param_kinetic,
Expand Down Expand Up @@ -479,8 +480,12 @@ def dynamics(
est_method = "gmm" if model.lower() == "stochastic" else "ols"

if experiment_type.lower() == "one-shot":
beta = subset_adata.var.beta if "beta" in subset_adata.var.keys() else None
gamma = subset_adata.var.gamma if "gamma" in subset_adata.var.keys() else None
try:
vel_params_df = get_vel_params(subset_adata)
beta = vel_params_df.beta if "beta" in vel_params_df.columns else None
gamma = vel_params_df.gamma if "gamma" in vel_params_df.columns else None
except KeyError:
beta, gamma = None, None
ss_estimation_kwargs = {"beta": beta, "gamma": gamma}
else:
ss_estimation_kwargs = {}
Expand Down
4 changes: 2 additions & 2 deletions dynamo/tools/pseudotime_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,5 @@ def pseudotime_velocity(

logger.info("set gamma to be 0 in .var. so that velocity_S = unspliced RNA.")
logger.info_insert_adata("gamma", "var", indent_level=2)
adata.var["gamma"] = 0
adata.var["gamma_b"] = 0
adata.varm["pseudotime_vel_params"] = np.zeros((adata.n_vars, 2))
adata.uns["pseudotime_vel_params_names"] = ["gamma", "gamma_b"]
15 changes: 9 additions & 6 deletions dynamo/tools/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .dimension_reduction import reduceDimension
from .dynamics import dynamics
from .moments import moments
from .utils import set_transition_genes
from .utils import get_vel_params, set_transition_genes, update_vel_params

# add recipe_csc_data()

Expand Down Expand Up @@ -325,9 +325,10 @@ def recipe_deg_data(
set_transition_genes(adata)
cell_velocities(adata, enforce=True, vkey=vkey, ekey=ekey, basis=basis)
except BaseException:
vel_params_df = get_vel_params(adata)
cell_velocities(
adata,
min_r2=adata.var.gamma_r2.min(),
min_r2=vel_params_df.gamma_r2.min(),
enforce=True,
vkey=vkey,
ekey=ekey,
Expand Down Expand Up @@ -708,6 +709,7 @@ def velocity_N(

var_columns = adata.var.columns
layer_keys = adata.layers.keys()
vel_params_df = get_vel_params(adata)

# check velocity_N, velocity_T, X_new, X_total
if not np.all([i in layer_keys for i in ["X_new", "X_total"]]):
Expand Down Expand Up @@ -743,8 +745,8 @@ def velocity_N(
"beta_k",
"gamma_k",
]:
if i in var_columns:
del adata.var[i]
if i in vel_params_df.columns:
del vel_params_df[i]

Comment on lines -746 to 750
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what adata.uns["vel_params_names"]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The adata.uns["vel_params_names"] will be updated later with update_vel_params(adata, params_df=vel_params_df).

if group is not None:
group_prefixes = [group + "_" + str(i) + "_" for i in adata.obs[group].unique()]
Expand Down Expand Up @@ -773,8 +775,9 @@ def velocity_N(
"beta_k",
"gamma_k",
]:
if i + j in var_columns:
del adata.var[i + j]
if i + j in vel_params_df.columns:
del vel_params_df[i + j]
update_vel_params(adata, params_df=vel_params_df)

# now let us first run pca with new RNA
if recalculate_pca:
Expand Down
Loading
Loading