From b36814afbdbbe25509ed238839b772b055cadbf0 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 23 Aug 2023 12:22:27 -0400 Subject: [PATCH] update params in scatters --- dynamo/plot/scatters.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 85f0d57d8..5e0cf339a 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -23,7 +23,7 @@ from ..dynamo_logger import main_debug, main_info, main_warning from ..preprocessing.utils import affine_transform, gen_rotation_2d from ..tools.moments import calc_1nd_moment -from ..tools.utils import flatten, get_mapper, update_dict +from ..tools.utils import flatten, get_mapper, get_vel_params, update_dict, update_vel_params from .utils import ( _datashade_points, _get_adata_color_vec, @@ -808,9 +808,11 @@ def _plot_basis_layer(cur_b, cur_l): points.iloc[:, 0].max() * 0.80, ) k_name = "gamma_k" if _adata.uns["dynamics"]["experiment_type"] == "one-shot" else "gamma" - if k_name in _adata.var.columns: - if not ("gamma_b" in _adata.var.columns) or all(_adata.var.gamma_b.isna()): - _adata.var.loc[:, "gamma_b"] = 0 + vel_params_df = get_vel_params(_adata) + if k_name in vel_params_df.columns: + if not ("gamma_b" in vel_params_df.columns) or all(vel_params_df.gamma_b.isna()): + vel_params_df.loc[:, "gamma_b"] = 0 + update_vel_params(_adata, params_df=vel_params_df) ax.plot( xnew, xnew * _adata[:, cur_b].var.loc[:, k_name].unique() @@ -841,9 +843,11 @@ def _plot_basis_layer(cur_b, cur_l): + group_adata[:, cur_b].var.loc[:, group_b_key].unique() ) ax.annotate(group + "_" + cur_group, xy=(group_xnew[-1], group_ynew[-1])) - if group_k_name in group_adata.var.columns: - if not (group_b_key in group_adata.var.columns) or all(group_adata.var[group_b_key].isna()): - group_adata.var.loc[:, group_b_key] = 0 + vel_params_df = get_vel_params(group_adata) + if group_k_name in vel_params_df.columns: + if not (group_b_key in vel_params_df.columns) or all(vel_params_df[group_b_key].isna()): + vel_params_df.loc[:, group_b_key] = 0 + update_vel_params(group_adata, params_df=vel_params_df) main_info("No %s found, setting all bias terms to zero" % group_b_key) ax.plot( group_xnew,