diff --git a/bin/composition_baseline.py b/bin/composition_baseline.py index 46c3847..fe23eb5 100644 --- a/bin/composition_baseline.py +++ b/bin/composition_baseline.py @@ -100,7 +100,7 @@ def get_local_sample_representation(self): freqs = ( comps.loc[:, [self.sample_key, self.subcluster_key, "freqs"]] .set_index([self.sample_key, self.subcluster_key]) - .squeeze() + .squeeze(axis=1) .unstack() ) freqs_all[unique_cluster] = freqs diff --git a/bin/fit_scviv2.py b/bin/fit_scviv2.py index 855da47..c37f89e 100644 --- a/bin/fit_scviv2.py +++ b/bin/fit_scviv2.py @@ -19,6 +19,7 @@ def fit_scviv2( use_attention_no_prior_mog: str = "false", use_attention_mog: str = "false", use_attention_no_prior_mog_large: str = "false", + use_ibd_config: str = "false", ) -> scvi_v2.MrVI: """ Train a MrVI model. @@ -38,7 +39,10 @@ def fit_scviv2( use_attention_smallu = use_attention_smallu.lower() == "true" use_attention_noprior = use_attention_noprior.lower() == "true" use_attention_no_prior_mog = use_attention_no_prior_mog.lower() == "true" - use_attention_no_prior_mog_large = use_attention_no_prior_mog_large.lower() == "true" + use_attention_no_prior_mog_large = ( + use_attention_no_prior_mog_large.lower() == "true" + ) + use_ibd_config = use_ibd_config.lower() == "true" config = load_config(config_in) batch_key = config.get("batch_key", None) @@ -121,19 +125,19 @@ def fit_scviv2( "use_map": True, "stop_gradients": False, "stop_gradients_mlp": True, - "dropout_rate": 0.03 + "dropout_rate": 0.03, }, "px_kwargs": { "stop_gradients": False, "stop_gradients_mlp": True, "h_activation": nn.softmax, "dropout_rate": 0.03, - "low_dim_batch": True + "low_dim_batch": True, }, "learn_z_u_prior_scale": False, "z_u_prior": True, "u_prior_mixture": False, - } + } ) if use_attention_no_prior_mog: model_kwargs.update( @@ -209,6 +213,31 @@ def fit_scviv2( "u_prior_mixture_k": 20, } ) + if use_ibd_config: + model_kwargs = { + "n_latent": 200, + "n_latent_u": 10, + "qz_nn_flavor": "attention", + "px_nn_flavor": "attention", + "qz_kwargs": { + "use_map": False, + "stop_gradients": False, + "stop_gradients_mlp": True, + "dropout_rate": 0.03, + }, + "px_kwargs": { + "stop_gradients": False, + "stop_gradients_mlp": True, + "h_activation": nn.softmax, + "dropout_rate": 0.03, + "low_dim_batch": True, + }, + "learn_z_u_prior_scale": False, + "z_u_prior": False, + "u_prior_mixture": True, + "u_prior_mixture_k": 100, + } + model = scvi_v2.MrVI(adata, **model_kwargs) model.train(**train_kwargs) diff --git a/bin/get_latent_scviv2.py b/bin/get_latent_scviv2.py index 9b9bc18..c82fc0f 100644 --- a/bin/get_latent_scviv2.py +++ b/bin/get_latent_scviv2.py @@ -62,7 +62,7 @@ def get_latent_scviv2( Path(cell_representations_out).touch() cell_dists = model.get_local_sample_distances( - adata, keep_cell=False, groupby=labels_key + adata, keep_cell=False, groupby=labels_key, batch_size=32 ) make_parents(distance_matrices_out) cell_dists.to_netcdf(distance_matrices_out) @@ -76,6 +76,7 @@ def get_latent_scviv2( normalize_distances=True, keep_cell=False, groupby=labels_key, + batch_size=32, ) cell_normalized_dists.to_netcdf(normalized_distance_matrices_out) del cell_normalized_dists diff --git a/bin/preprocess.py b/bin/preprocess.py index 7e96ac7..0f85008 100644 --- a/bin/preprocess.py +++ b/bin/preprocess.py @@ -402,6 +402,7 @@ def _process_semisynth2( n_genes_for_subclustering = semisynth_config["n_genes_for_subclustering"] if subsample: selected_subsample_cluster = semisynth_config["selected_subsample_cluster"] + selected_oversample_cluster = semisynth_config["selected_oversample_cluster"] subsample_rates = semisynth_config["subsample_rates"] # use SCVI to obtain latent space @@ -517,7 +518,14 @@ def _process_semisynth2( .astype(int) ) - subsampled_adatas = [adata[adata.obs.leiden != str(selected_subsample_cluster)]] + # subsampled_adatas = [adata[adata.obs.leiden != str(selected_subsample_cluster)]] + subsampled_adatas = [ + adata[ + ~adata.obs.leiden.isin( + [str(selected_subsample_cluster), str(selected_oversample_cluster)] + ) + ] + ] subsample_info_df = pd.DataFrame() for rank, subsample_rate in enumerate(subsample_rates, 1): samples_to_subsample = sample_assignment_mapping[ @@ -538,16 +546,31 @@ def _process_semisynth2( (adata.obs.sample_assignment == str(sample)) & (adata.obs["leiden"] == str(selected_subsample_cluster)) ] - subsample_adata = subsample_adata[ - np.random.choice( - subsample_adata.shape[0], - int(subsample_adata.shape[0] * subsample_rate), - replace=False, - ) - ] + n_subsampled = int(subsample_adata.shape[0] * subsample_rate) + n_removed = subsample_adata.shape[0] - n_subsampled + subsampled_idx = np.random.choice( + subsample_adata.shape[0], + n_subsampled, + replace=False, + ) + subsample_adata = subsample_adata[subsampled_idx] subsampled_adatas.append(subsample_adata) + # Add cells to ensure total number of cells is the same + oversample_adata = adata[ + (adata.obs.sample_assignment == str(sample)) + & (adata.obs["leiden"] == str(selected_oversample_cluster)) + ] + oversample_idx = np.random.choice( + oversample_adata.shape[0], + oversample_adata.shape[0] + n_removed, + replace=True, + ) + oversample_adata = oversample_adata[oversample_idx] + subsampled_adatas.append(oversample_adata) + res = sc.concat(subsampled_adatas) + res.obs_names_make_unique() subsample_info_df = subsample_info_df.astype({"sample": str}).set_index( "sample" ) @@ -561,6 +584,14 @@ def _process_semisynth2( res.obs.loc[:, "sample_metadata2"] = ( res.obs[f"subsample_rate_in_leiden{selected_subsample_cluster}"] <= 0.8 ) + + one_hot_groupnames = [] + for unique_group in res.obs["sample_group"].unique(): + new_key = f"group_{unique_group}" + res.obs.loc[:, new_key] = (res.obs["sample_group"] == unique_group).astype( + int + ) + one_hot_groupnames.append(new_key) res = sc.AnnData( X=res.X, obs=res.obs, @@ -568,6 +599,7 @@ def _process_semisynth2( var=adata.var, uns=adata.uns, ) + res.uns["one_hot_groupnames"] = one_hot_groupnames return res return adata diff --git a/bin/produce_figures_haniffa.py b/bin/produce_figures_haniffa.py index 8c7f1f0..321f46a 100644 --- a/bin/produce_figures_haniffa.py +++ b/bin/produce_figures_haniffa.py @@ -1,4 +1,5 @@ # %% +### Imports and basic loading import glob import os @@ -6,35 +7,220 @@ import numpy as np import pandas as pd import plotnine as p9 +import plotly.graph_objects as go +from scipy.spatial.distance import squareform import scanpy as sc import seaborn as sns import xarray as xr from biothings_client import get_client from matplotlib.colors import hex2color, rgb2hex from scib_metrics.benchmark import Benchmarker -from scipy.cluster.hierarchy import fcluster +from scipy.cluster.hierarchy import fcluster, leaves_list from scipy.special import logsumexp from sklearn.cluster import KMeans import scipy.stats as st -from sklearn.decomposition import PCA +from sklearn.decomposition import PCA, KernelPCA from sklearn.manifold import MDS, TSNE from sklearn.metrics import pairwise_distances from tree_utils import hierarchical_clustering from utils import perform_gsea +from matplotlib.patches import Patch +from mrvi import MrVI +import pynndescent +from scipy.cluster.hierarchy import linkage, fcluster, optimal_leaf_ordering +import jax.numpy as jnp +from scvi import REGISTRY_KEYS +from mrvi._constants import MRVI_REGISTRY_KEYS +from scvi.distributions import JaxNegativeBinomialMeanDisp as NegativeBinomial +from tqdm import tqdm + + +def compute_px_from_x( + self, + x, + sample_index, + batch_index, + cf_sample=None, + continuous_covs=None, + label_index=None, + mc_samples=10, +): + """Compute normalized gene expression from observations""" + log_library = 7.0 * jnp.ones_like( + sample_index + ) # placeholder, will be replaced by observed library sizes. + inference_outputs = self.inference( + x, sample_index, mc_samples=mc_samples, cf_sample=cf_sample, use_mean=False + ) + generative_inputs = { + "z": inference_outputs["z"], + "library": log_library, + "batch_index": batch_index, + "continuous_covs": continuous_covs, + "label_index": label_index, + } + generative_outputs = self.generative(**generative_inputs) + return generative_outputs["px"], inference_outputs["u"], log_library + + +def compute_sample_cf_reconstruction_scores( + self, + sample_idx, + adata=None, + indices=None, + batch_size=256, + mc_samples=10, + n_top_neighbors=5, +): + self._check_if_trained(warn=False) + adata = self._validate_anndata(adata) + sample_name = self.sample_order[sample_idx] + sample_adata = adata[adata.obs[self.sample_key] == sample_name] + if sample_adata.shape[0] == 0: + raise ValueError(f"Sample {sample_name} missing from AnnData.") + sample_u = self.get_latent_representation(sample_adata, give_z=False) + sample_index = pynndescent.NNDescent(sample_u) + + scdl = self._make_data_loader( + adata=adata, batch_size=batch_size, indices=indices, iter_ndarray=True + ) + + def _get_all_inputs( + inputs, + ): + x = jnp.array(inputs[REGISTRY_KEYS.X_KEY]) + sample_index = jnp.array(inputs[MRVI_REGISTRY_KEYS.SAMPLE_KEY]) + batch_index = jnp.array(inputs[REGISTRY_KEYS.BATCH_KEY]) + continuous_covs = inputs.get(REGISTRY_KEYS.CONT_COVS_KEY, None) + label_index = inputs.get(REGISTRY_KEYS.LABELS_KEY, None) + if continuous_covs is not None: + continuous_covs = jnp.array(continuous_covs) + return { + "x": x, + "sample_index": sample_index, + "batch_index": batch_index, + "continuous_covs": continuous_covs, + "label_index": label_index, + } + + scores = [] + top_idxs = [] + for array_dict in tqdm(scdl): + vars_in = {"params": self.module.params, **self.module.state} + rngs = self.module.rngs + + inputs = _get_all_inputs(array_dict) + px, u, log_library_placeholder = self.module.apply( + vars_in, + rngs=rngs, + method=compute_px_from_x, + x=inputs["x"], + sample_index=inputs["sample_index"], + batch_index=inputs["batch_index"], + cf_sample=np.ones(inputs["x"].shape[0]) * sample_idx, + continuous_covs=inputs["continuous_covs"], + label_index=inputs["label_index"], + mc_samples=mc_samples, + ) + px_m, px_d = px.mean, px.inverse_dispersion + if px_m.ndim == 2: + px_m, px_d = np.expand_dims(px_m, axis=0), np.expand_dims(px_d, axis=0) + px_m, px_d = np.expand_dims(px_m, axis=2), np.expand_dims(px_d, axis=2) + + mc_log_probs = [] + batch_top_idxs = [] + for mc_sample_i in range(u.shape[0]): + nearest_sample_idxs = sample_index.query(u[mc_sample_i], k=n_top_neighbors)[ + 0 + ] + top_neighbor_counts = ( + sample_adata.X[nearest_sample_idxs.reshape(-1), :] + .toarray() + .reshape( + (nearest_sample_idxs.shape[0], nearest_sample_idxs.shape[1], -1) + ) + ) + new_lib_size = top_neighbor_counts.sum( + axis=-1 + ) # batch_size x n_top_neighbors + corrected_px_m = ( + px_m[mc_sample_i] + / np.exp(log_library_placeholder[:, :, None]) + * new_lib_size[:, :, None] + ) + corrected_px = NegativeBinomial( + mean=corrected_px_m, inverse_dispersion=px_d + ) + log_probs = ( + corrected_px.log_prob(top_neighbor_counts).sum(-1).mean(-1) + ) # 1 x batch_size + mc_log_probs.append(log_probs) + batch_top_idxs.append(nearest_sample_idxs) + full_batch_log_probs = np.concatenate(mc_log_probs, axis=0).mean(0) + top_idxs.append(np.concatenate(batch_top_idxs, axis=1)) + + scores.append(full_batch_log_probs) + + all_scores = np.hstack(scores) + all_top_idxs = np.vstack(top_idxs) + adata_index = adata[indices] if indices is not None else adata + return ( + pd.Series( + all_scores, + index=adata_index.obs_names.to_numpy(), + name=f"{sample_name}_score", + ), + all_top_idxs, + ) + + +def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions=None): + """ + Computes distance matrices for MrVI and clusters cells based on them. + + Parameters + ---------- + model: + MrVI model. + adata: + AnnData object to compute distance matrices for. By default, uses the model's AnnData object. + dists: + Optional precomputed distance matrices. Useful to avoid recomputing them and considering different leiiden resolutions. + leiden_resolutions: + List of leiden resolutions to use for clustering cells based on distance matrices. + """ + if adata is None: + adata = model.adata + adata.obs.loc[:, "_indices"] = np.arange(adata.shape[0]) + + if leiden_resolutions is None: + # leiden_resolutions = [0.005, 0.01, 0.05, 0.1, 0.5, 1.0] + leiden_resolutions = [0.001, 0.005, 0.01, 0.05] + elif isinstance(leiden_resolutions, float): + leiden_resolutions = [leiden_resolutions] + + if dists is None: + dists = model.get_local_sample_distances(adata, keep_cell=True) + axis = 0 + dmats = dists["cell"].values + dmats = np.array([dmat[np.triu_indices(dmat.shape[0], k=1)] for dmat in dmats]) + dmats = (dmats - dmats.mean(axis=axis, keepdims=True)) / dmats.std( + axis=axis, keepdims=True + ) + adata.obsm["dmat_pca"] = PCA(n_components=50).fit_transform(dmats) + sc.pp.neighbors(adata, use_rep="dmat_pca", n_neighbors=15) + + for leiden_resol in leiden_resolutions: + sc.tl.leiden( + adata, key_added=f"leiden_dmats_{leiden_resol}", resolution=leiden_resol + ) + return adata, dists + INCH_TO_CM = 1 / 2.54 +CT_ANNOTATION_KEY = "initial_clustering" -# %% -gene_sets = [ - "MSigDB_Hallmark_2020", - "WikiPathway_2021_Human", - "Reactome_2022", - "GO_Biological_Process_2023", - "GO_Cellular_Component_2023", - "GO_Molecular_Function_2023", -] -# %% metad = pd.read_excel( "/data1/datasets/41591_2021_1329_MOESM3_ESM.xlsx", sheet_name=1, @@ -60,7 +246,6 @@ axis_title=p9.element_text(family="sans-serif", size=8), ) -# %% sc.set_figure_params(dpi_save=500) plt.rcParams["axes.grid"] = False plt.rcParams["svg.fonttype"] = "none" @@ -78,7 +263,6 @@ ).astype("category") adata_files = glob.glob("../results/aws_pipeline/data/haniffa2.*.final.h5ad") -# %% mg = get_client("gene") gene_conversion = mg.querymany( adata.var_names, @@ -102,25 +286,14 @@ ].values # %% -from scvi_v2 import MrVI - +### Load base model model = MrVI.load( "/data1/scvi-v2-reproducibility/results/aws_pipeline/models/haniffa2.scviv2_attention_mog", adata=adata, ) -# model = MrVI.load( -# "/data1/scvi-v2-reproducibility/results/aws_pipeline/models/haniffa.scviv2_attention_no_prior_mog", adata=adata -# ) - -# %% model.history["elbo_validation"].iloc[50:].plot() - -# %% ax = model.history["reconstruction_loss_validation"].plot() -# modelb.history["reconstruction_loss_validation"].plot(ax=ax) -# modelb.history["reconstruction_loss_validation"].iloc[10:].plot() -# %% donor_info = ( model.adata.obs.drop_duplicates("_scvi_sample") .set_index("_scvi_sample") @@ -131,29 +304,109 @@ ) # %% -# adata_file = '../results/aws_pipeline/data/haniffa.scviv2_attention_noprior.final.h5ad' +### UMAPs +replacer = { + "B_cell": "B cell", + "CD14": "CD14 Monocyte", + "CD16": "CD16 Monocyte", + "CD4": "CD4 T cell", + "CD8": "CD8 T cell", + "DCs": "DC", + "gdT": "gd T cell", + "NK_16hi": "NK", + "NK_56hi": "NK", + "pDC": "pDC", + "Plasmablast": "Other", + "Platelets": "Platelet", + "Treg": "Other", + "HSC": "Other", + "MAIT": "Other", + "Lymph_prolif": "Other", + "RBC": "Other", + "Mono_prolif": "Other", + "Lymph_prolif": "Other", +} +simplified_cts_cmap = {} + adata_file = "../results/aws_pipeline/data/haniffa2.scviv2_attention_mog.final.h5ad" adata_ = sc.read_h5ad(adata_file) -print(adata_.shape) -for obsm_key in adata_.obsm.keys(): - if obsm_key.endswith("mde") & ("scviv2" in obsm_key): - print(obsm_key) - fig = sc.pl.embedding( - adata_, - basis=obsm_key, - color=["initial_clustering", "Status", "Site", "patient_id"], - ncols=1, - show=False, - return_fig=True, - ) - fig.savefig(os.path.join(FIGURE_DIR, f"haniffa.{obsm_key}.svg")) - plt.clf() +adata_.obs.loc[:, "initial_clustering_simplified"] = adata_.obs.initial_clustering.map( + replacer +) +cat_order = [ + "B cell", + "CD14 Monocyte", + "CD16 Monocyte", + "CD4 T cell", + "CD8 T cell", + "gd T cell", + "DC", + "NK", + "Platelet", + "pDC", + # "Other", +] +adata_.obs.loc[:, "initial_clustering_simplified"] = pd.Categorical( + adata_.obs.initial_clustering_simplified, categories=cat_order +) +print("# of CTs in original key:", adata_.obs.initial_clustering.nunique()) +print("# of CTs in simplified key:", adata_.obs.initial_clustering_simplified.nunique()) # %% -donor_info_ = donor_info.set_index("sample_id") +# keys_to_plot = ["initial_clustering", "initial_clustering_simplified", "Status", "Site", "patient_id"] +# for obsm_key in adata_.obsm.keys(): +# if obsm_key.endswith("mde") & ("scviv2" in obsm_key): +# print(obsm_key) +# fig = sc.pl.embedding( +# adata_, +# basis=obsm_key, +# color=keys_to_plot, +# ncols=1, +# show=False, +# return_fig=True, +# ) +# fig.savefig(os.path.join(FIGURE_DIR, f"haniffa.{obsm_key}.svg")) +# plt.show() + +# %% +for obsm_key in ["X_scviv2_attention_mog_u_mde", "X_scviv2_attention_mog_z_mde"]: + fig = sc.pl.embedding( + adata_, + basis=obsm_key, + color="initial_clustering_simplified", + ncols=1, + show=False, + return_fig=True, + ) + fig.savefig( + os.path.join(FIGURE_DIR, f"haniffa.{obsm_key}.simple_celltypes.svg") + ) + + idx_cov = adata_.obs.query("Status=='Covid'").index + idx_heal = adata_.obs.query("Status=='Healthy'").index + n_heal = len(idx_heal) + idx_cov = np.random.choice(idx_cov, size=n_heal, replace=False) + idx_total = np.concatenate([idx_heal, idx_cov]) + idx_total = np.random.permutation(idx_total) + fig2 = sc.pl.embedding( + adata_[idx_total], + basis=obsm_key, + color="Status", + palette=["#9C0101", "#017401"], + ncols=1, + show=False, + return_fig=True, + # alpha=0.5, + ) + fig2.savefig( + os.path.join(FIGURE_DIR, f"haniffa.{obsm_key}.covid_status.svg") + ) +# %% +### Colors +donor_info_ = donor_info.set_index("sample_id") covid_legend = {"Covid": "#9E1800", "Healthy": "#019E5D"} sex_legend = {"Male": "#4791FF", "Female": "#EBA315"} outcome_legend = {"Home": "#466EB8", "Death": "#B80033", "unknown": "#718085"} @@ -204,12 +457,9 @@ dfo_colors = [rgb2hex(rgb) for rgb in dfo_colors] donor_info_["DFO_color"] = dfo_colors dfo_colors = donor_info_["DFO_color"] - colors = pd.concat([color_age, dfo_colors, color_covid, color_worst_status], axis=1) -# %% -from matplotlib.patches import Patch - +## Plot legend for legend_name, my_legend in all_legends.items(): handles = [Patch(facecolor=hex2color(my_legend[name])) for name in my_legend] plt.legend( @@ -231,61 +481,20 @@ # %% -def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions=None): - """ - Computes distance matrices for MrVI and clusters cells based on them. - - Parameters - ---------- - model: - MrVI model. - adata: - AnnData object to compute distance matrices for. By default, uses the model's AnnData object. - dists: - Optional precomputed distance matrices. Useful to avoid recomputing them and considering different leiiden resolutions. - leiden_resolutions: - List of leiden resolutions to use for clustering cells based on distance matrices. - """ - if adata is None: - adata = model.adata - adata.obs.loc[:, "_indices"] = np.arange(adata.shape[0]) - - if leiden_resolutions is None: - leiden_resolutions = [0.005, 0.01, 0.05, 0.1, 0.5, 1.0] - elif isinstance(leiden_resolutions, float): - leiden_resolutions = [leiden_resolutions] - - if dists is None: - dists = model.get_local_sample_distances(adata, keep_cell=True) - axis = 0 - dmats = dists["cell"].values - dmats = np.array([dmat[np.triu_indices(dmat.shape[0], k=1)] for dmat in dmats]) - dmats = (dmats - dmats.mean(axis=axis, keepdims=True)) / dmats.std( - axis=axis, keepdims=True - ) - adata.obsm["dmat_pca"] = PCA(n_components=50).fit_transform(dmats) - sc.pp.neighbors(adata, use_rep="dmat_pca", n_neighbors=15) - - for leiden_resol in leiden_resolutions: - sc.tl.leiden( - adata, key_added=f"leiden_dmats_{leiden_resol}", resolution=leiden_resol - ) - return adata, dists - - -# %% -CT_ANNOTATION_KEY = "initial_clustering" - -# %% -# Compute cell specific distance matrices, and cluster cells based on them. +### Compute cell specific distance matrices, and cluster cells based on them. adata_mat = model.adata.copy() adata_mat, dmats = compute_distance_matrices(model, adata_mat) - -# %% adata_mat.obsm = adata_embs.obsm +adata_mat.obs.loc[:, "initial_clustering_simplified"] = adata_mat.obs.initial_clustering.map( + replacer +) +adata_mat.obs.loc[:, "initial_clustering_simplified"] = pd.Categorical( + adata_mat.obs.initial_clustering_simplified, categories=cat_order +) # %% -DMAT_CLUSTERING_KEY = "leiden_dmats_0.005" +DMAT_CLUSTERING_KEY = "leiden_dmats_0.001" +# DMAT_CLUSTERING_KEY = "leiden_dmats_0.005" fig = sc.pl.embedding( adata_mat, basis="X_scviv2_attention_mog_u_mde", @@ -294,6 +503,7 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= DMAT_CLUSTERING_KEY, ], return_fig=True, + palette="Set2", ) fig.savefig( os.path.join( @@ -301,8 +511,67 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= "dmat_clusterings.svg", ) ) +# %% +props_per_cluster = ( + adata_mat.obs.groupby(DMAT_CLUSTERING_KEY)["initial_clustering_simplified"] + .value_counts(normalize=True) + .to_frame("prop") + .reset_index() +) +props_per_cluster + + +cmap_clustering = pd.Series( + adata_.uns["initial_clustering_simplified_colors"], + index=adata_.obs["initial_clustering_simplified"].cat.categories.values, +) + +props_per_cluster_ = props_per_cluster.loc[lambda x: x.prop > 0.01] +all_cats = list(props_per_cluster_["initial_clustering_simplified"].astype('category').cat.categories) + list(props_per_cluster_[DMAT_CLUSTERING_KEY].astype('category').cat.categories) +source = pd.Categorical(props_per_cluster_["initial_clustering_simplified"], categories=all_cats).codes +target = pd.Categorical(props_per_cluster_[DMAT_CLUSTERING_KEY], categories=all_cats).codes +values = props_per_cluster_["prop"].values + +colors_for_sankey = cmap_clustering.reindex(all_cats).fillna("#000000") +colors_source = colors_for_sankey.loc[props_per_cluster_["initial_clustering_simplified"].values] +fig = go.Figure( + data=[ + go.Sankey( + node=dict( + pad=15, + thickness=20, + line=dict( + color="black", width=0.5), + label=all_cats, + color=colors_for_sankey.values, + ), + link=dict( + source=source, + target=target, + value=values, + color=colors_source.values, + ) + ) + ] +) +fig.update_layout( + width=400, # Width in pixels + height=500, # Height in pixels + font_size=10 +) +fig.write_image( + os.path.join( + FIGURE_DIR, + "sankey_haniffa.svg", + ), + format='svg' +) +fig.show() # %% +### For each cluster, plot average distance matrix + +## Plot composition of clusters props_per_cluster = ( adata_mat.obs.groupby(DMAT_CLUSTERING_KEY)[CT_ANNOTATION_KEY] .value_counts(normalize=True) @@ -318,11 +587,10 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= + p9.geom_col(position="fill") ) -# %% + VMIN = 0 VMAX = 1 -# cluster_dmats = [] for cluster in adata_mat.obs[DMAT_CLUSTERING_KEY].unique(): print(cluster) cell_indices = adata_mat.obs[adata_mat.obs[DMAT_CLUSTERING_KEY] == cluster].index @@ -344,6 +612,7 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= # red, blue, white ).values donor_cluster_key = f"donor_clusters_{cluster}" + print(donor_cluster_key) adata_mat.obs.loc[:, donor_cluster_key] = adata_mat.obs.patient_id.map( donor_info_.loc[:, "donor_group"] ).values @@ -358,8 +627,9 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= row_colors=colors_, vmin=VMIN, vmax=VMAX, + # cmap="rocket_r", yticklabels=True, - figsize=(20, 20), + # figsize=(20, 20), ) sns_plot.savefig( os.path.join( @@ -367,40 +637,80 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= f"cluster_{cluster}_dmat.svg", ) ) - # cluster_dmats.append(d1.values) # %% -cluster = 1 -CLUSTER_NAME = f"donor_clusters_{cluster}" -de_n_clusters = 500 +donors_to_clusters = adata_mat.obs.loc[:, ["patient_id", "donor_clusters_1"]].drop_duplicates().reset_index(drop=True) +donor_info_.loc[:, "donor_group"] = donors_to_clusters.set_index("patient_id").loc[donor_info_.index, "donor_clusters_1"].values # %% -donor_keys = [ - "Sex", - "Status", - "age_group", -] -adata_mat.obs.loc[:, "is_covid1"] = (adata_mat.obs[CLUSTER_NAME] == "cluster 0").astype( - int +group1 = donor_info_.loc[lambda x: x.donor_group == "cluster 1"] +group2 = donor_info_.loc[lambda x: x.donor_group == "cluster 2"] +print(group1.shape, group2.shape) +print( + group1.Status.value_counts(normalize=True), + group2.Status.value_counts(normalize=True), ) -adata_mat.obs.loc[:, "is_covid2"] = (adata_mat.obs[CLUSTER_NAME] == "cluster 1").astype( - int +plot_df_donor_groups = pd.concat([group1, group2]) + +# %% +from scipy.stats import mannwhitneyu + +mannwhitneyu( + group1.DFO.dropna(), group2.DFO.dropna(), alternative="less" ) -donor_keys_bis = ["is_covid1", "is_covid2"] -obs_df = adata_mat.obs.copy() -obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")] -model.donor_info = obs_df.set_index("_scvi_sample").sort_index() + + # %% -_adata = adata_mat[adata_mat.obs[DMAT_CLUSTERING_KEY] == "1"].copy() +fig = ( + p9.ggplot(plot_df_donor_groups, p9.aes(x="donor_group", y="DFO")) + + p9.geom_boxplot(width=0.2, outlier_alpha=0.0) + + p9.geom_jitter(width=0.2, size=0.5) + + p9.theme_classic() + + p9.theme( + strip_background=p9.element_blank(), + axis_text_x=p9.element_text(rotation=0, hjust=1), + axis_text=p9.element_text(family="sans-serif", size=5), + axis_title=p9.element_text(family="sans-serif", size=6), + figure_size=(2, 2), + ) + + p9.labs( + x="Donor group", + y="Days from onset", + ) +) +fig.save( + os.path.join( + FIGURE_DIR, + "DFO_diffs.svg", + ) +) # %% -ap_res = model.get_outlier_cell_sample_pairs( - adata=_adata, - flavor="ap", - minibatch_size=1000 +### Select one cluster and perform DE/DA + +cluster = 1 +CLUSTER_NAME = f"donor_clusters_{cluster}" +de_n_clusters = 500 + +donor_keys = [ + "Sex", + "Status", + "age_group", +] +adata_mat.obs.loc[:, "is_covid1"] = (adata_mat.obs[CLUSTER_NAME] == "cluster 1").astype( + int +) +adata_mat.obs.loc[:, "is_covid2"] = (adata_mat.obs[CLUSTER_NAME] == "cluster 2").astype( + int ) +donor_keys_bis = ["is_covid1", "is_covid2"] +obs_df = adata_mat.obs.copy() +obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")] +model.donor_info = obs_df.set_index("_scvi_sample").sort_index() +_adata = adata_mat[adata_mat.obs[DMAT_CLUSTERING_KEY] == "1"].copy() # %% +# ## Perform DE analysis multivariate_analysis_kwargs = { "batch_size": 128, @@ -410,29 +720,85 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= "eps_lfc": 1e-4, } -res = model.perform_multivariate_analysis( - donor_keys=donor_keys_bis, - adata=_adata, - **multivariate_analysis_kwargs, -) +# res = model.perform_multivariate_analysis( +# donor_keys=donor_keys_bis, +# adata=_adata, +# **multivariate_analysis_kwargs, +# ) +# # %% +# betas_ = res.lfc.transpose("cell_name", "covariate", "gene") +# betas_ = ( +# betas_.loc[{"covariate": "is_covid2"}].values +# - betas_.loc[{"covariate": "is_covid1"}].values +# ) +# plt.hist(betas_.mean(0), bins=100) +# plt.xlabel("LFC") +# plt.show() + +# # %% +# betas_ = res.lfc.transpose("cell_name", "covariate", "gene") +# betas_ = ( +# betas_.loc[{"covariate": "is_covid2"}].values +# - betas_.loc[{"covariate": "is_covid1"}].values +# ) +# plt.hist(betas_.mean(0), bins=100) +# plt.xlabel("LFC") +# plt.show() +## Perform DE analysis + +donor_subset = obs_df.loc[ + lambda x: (x.is_covid1 == 1) | (x.is_covid2 == 1) +].patient_id.values +donor_keys_ = ["is_covid1"] +multivariate_analysis_kwargs = { + "batch_size": 128, + "normalize_design_matrix": True, + "offset_design_matrix": False, + "store_lfc": True, + "eps_lfc": 1e-4, +} + +# res = model.perform_multivariate_analysis( +# donor_keys=donor_keys_, +# adata=_adata, +# donor_subset=donor_subset, +# **multivariate_analysis_kwargs, +# ) +__adata = _adata.copy() +# sc.pp.subsample(__adata, n_obs=2000) # %% -betas_ = res.lfc.transpose("cell_name", "covariate", "gene") -betas_ = ( - betas_.loc[{"covariate": "is_covid2"}].values - - betas_.loc[{"covariate": "is_covid1"}].values +# res = model.perform_multivariate_analysis( +# donor_keys=donor_keys_, +# adata=__adata, +# donor_subset=donor_subset, +# **multivariate_analysis_kwargs, +# ) +# res.to_netcdf( +# os.path.join( +# FIGURE_DIR, +# f"multivariate_analysis.nc", +# ) +# ) +res = xr.open_dataset( + os.path.join( + FIGURE_DIR, + f"multivariate_analysis.nc", + ) ) +__adata = adata_mat[res.cell_name.values] +# %% +betas_ = res.lfc.values.squeeze(0) plt.hist(betas_.mean(0), bins=100) plt.xlabel("LFC") plt.show() - # %% lfc_df = pd.DataFrame( { "LFC": betas_.mean(0), - "LFC_q0_05": np.quantile(betas_, 0.05, axis=0), - "LFC_q0_95": np.quantile(betas_, 0.95, axis=0), + "LFC_q0_05": np.quantile(betas_, 0.20, axis=0), + "LFC_q0_95": np.quantile(betas_, 0.80, axis=0), "LFC_std": betas_.std(0), "gene": model.adata.var_names, "gene_index": np.arange(model.adata.shape[1]), @@ -442,38 +808,73 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= gene_score=lambda x: np.maximum(x.LFC_q0_95, -x.LFC_q0_05), ) # %% -bins = np.linspace(0, 1, 100) -lfc_df.gene_score.plot.hist(bins=bins) -v500 = lfc_df.gene_score.sort_values().iloc[-500] -plt.vlines(v500, 0, 1000) +adata_t_all = sc.AnnData( + X=betas_.T, + obs=lfc_df, +) + + +lfc_pca = KernelPCA(n_components=5, kernel="cosine") +lfc_pcs = lfc_pca.fit_transform(adata_t_all.X) +adata_t_all.obsm["lfc_mds"] = TSNE( + n_components=2, metric="precomputed", init="random" +).fit_transform(pairwise_distances(lfc_pcs)) + +# %% +dists = pairwise_distances(lfc_pcs) +dists = squareform(dists) +# Z = hierarchical_clustering(dists, method="ward", return_ete=False) +Z = linkage(dists, method="ward") + +# %% +order = optimal_leaf_ordering(Z, dists) # %% -# Cluster and visualize DE genes -cond = lfc_df.sort_values("gene_score", ascending=False).iloc[:500].gene_index.values -betas_de = betas_[:, cond] -obs_de = lfc_df.loc[cond, :].reset_index(drop=True) -obs_de.plot.scatter("LFC", "LFC_std") +# vmax = np.quantile(lfc_df.absLFC.values, 0.95) +# sc.pl.embedding( +# adata_t_all, +# basis="lfc_mds", +# color=["LFC"], +# vmin=-vmax, +# vmax=vmax, +# cmap="coolwarm", +# ) +# # %% +# bins = np.linspace(0, 1, 100) +# lfc_df.gene_score.plot.hist(bins=bins) +# v500 = lfc_df.gene_score.sort_values().iloc[-1000] +# plt.vlines(v500, 0, 1000) + +# # Cluster and visualize DE genes +# cond = lfc_df.sort_values("gene_score", ascending=False).iloc[:500].gene_index.values +# betas_de = betas_[:, cond] +# obs_de = lfc_df.loc[cond, :].reset_index(drop=True) +# obs_de.index = obs_de.gene +# obs_de.plot.scatter("LFC", "LFC_std") # %% +## Cluster and visualize LFCs adata_t = sc.AnnData( X=betas_de.T, obs=obs_de, ) -lfc_pca = PCA(n_components=10) +# lfc_pca = PCA(n_components=4) +lfc_pca = KernelPCA(n_components=5, kernel="cosine") lfc_pcs = lfc_pca.fit_transform(adata_t.X) adata_t.obsm["lfc_pca"] = lfc_pcs adata_t.obsm["lfc_mds"] = TSNE( n_components=2, metric="precomputed", init="random" ).fit_transform(pairwise_distances(lfc_pcs)) +sc.pp.neighbors(adata_t, use_rep="lfc_pca", n_neighbors=10, random_state=0) # %% -sc.pp.neighbors(adata_t, use_rep="lfc_pca", n_neighbors=10) -sc.tl.leiden(adata_t, key_added="lfc_leiden", resolution=0.25) - -# %% -adata_t.obs["lfc_clusters"] = KMeans(n_clusters=5).fit_predict(lfc_pcs) -adata_t.obs["lfc_clusters"] = adata_t.obs["lfc_clusters"].astype(str) +# RESOLUTION = 0.25 +RESOLUTION = 0.15 +sc.tl.leiden(adata_t, key_added="lfc_leiden", resolution=RESOLUTION, random_state=0) +adata_t.obs["lfc_clusters"] = KMeans(n_clusters=4, n_init=1000).fit_predict(lfc_pcs) +adata_t.obs["lfc_clusters"] = adata_t.obs["lfc_leiden"].astype(str) +print("N clusters:", adata_t.obs["lfc_clusters"].nunique()) vmax = np.quantile(obs_de.absLFC.values, 0.95) sc.pl.embedding( @@ -485,6 +886,7 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= cmap="coolwarm", ) plt.tight_layout() +plt.show() sc.pl.embedding( adata_t, @@ -492,25 +894,26 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= color=["lfc_clusters", "gene_score", "LFC_std"], ) plt.tight_layout() - +plt.show() # %% -gene_info_ = adata_t.obs +## Enrichment analysis per cluster -# %% gene_sets = [ "MSigDB_Hallmark_2020", "WikiPathway_2021_Human", "KEGG_2021_Human", - "Reactome_2022", + # "Reactome_2022", "GO_Biological_Process_2023", "GO_Cellular_Component_2023", "GO_Molecular_Function_2023", + # 'COVID-19_Related_Gene_Sets_2021', ] -# %% LFC_CLUSTERING_KEY = "lfc_leiden" +# LFC_CLUSTERING_KEY = "lfc_clusters" +gene_info_ = adata_t.obs beta_module_keys = [] all_enrichr_results = [] gene_info_modules = [] @@ -527,19 +930,18 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= .str.upper() .tolist() ) - gene_indices = gene_info_module.loc[:, "gene_index"].tolist() + # gene_indices = gene_info_module.loc[:, "gene_index"].tolist() + # beta_module = np.mean(betas_[:, gene_indices], 1) gene_info_modules.append(gene_info_module) - beta_module = np.mean(betas_[:, gene_indices], 1) - _adata.obs.loc[:, beta_module_name] = beta_module + beta_module = adata_t[gene_info_module.loc[:, "gene"].values].X.toarray().mean(0) + __adata.obs.loc[:, beta_module_name] = beta_module beta_module_keys.append(beta_module_name) - - enr = perform_gsea(genes, gene_sets=gene_sets).assign(cluster=cluster) - all_enrichr_results.append(enr) -all_enrichr_results = pd.concat(all_enrichr_results).astype({"Gene_set": "category"}) + + # enr = perform_gsea(genes, gene_sets=gene_sets, n_trials_max=3, use_server=False).assign(cluster=cluster) + # all_enrichr_results.append(enr) +# all_enrichr_results = pd.concat(all_enrichr_results).astype({"Gene_set": "category"}) gene_info_modules = pd.concat(gene_info_modules).astype({"gene": "category"}) - -# %% gene_info_modules.to_csv( os.path.join( FIGURE_DIR, @@ -547,26 +949,91 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= ) ) + # %% -fig = sc.pl.embedding( - _adata, - basis="X_scviv2_attention_mog_u_mde", - color=["initial_clustering"], - return_fig=True, -) -# plt.tight_layout() -fig.savefig( - os.path.join( - FIGURE_DIR, - f"initial_clustering_{cluster}.svg", - ) -) +genes_of_interest = [ + "IL1B", + # "LGALS1", + "LGALS2", + "CSF3R", + "HLA-DRA", + "HLA-DRB1", + "IFITM3", + "IFI6", + "IFI27", + "GBP1", + "IRF7", + # "ССL13", + "NFKBIZ", + "SLC25A5", + "TNF", + "RHOB", + "TKT", +] +for gene in genes_of_interest: + if gene in gene_info_modules.index: + print( + f"{gene} found, belongs, to cluster {gene_info_modules.loc[gene, LFC_CLUSTERING_KEY]}", + gene_info_modules.loc[gene, "LFC"] + ) + else: + print(f"{gene} not found") + + +# %% +adata_t2 = sc.AnnData( + X=betas_de, + obs=__adata.obs, + var=obs_de, +) +adata_t2 = adata_t2[adata_t2.obs["initial_clustering"].isin(["CD14", "CD16", "DCs"])].copy() +vmax = 1.0 +sc.pl.heatmap( + adata_t2, + genes_of_interest, + # gene_info_modules.index, + groupby='initial_clustering', + cmap='coolwarm', + vmin=-vmax, + vmax=vmax, + swap_axes=True, + save=f"haniffa_DEGs_heatmap.svg", + # return_fig=True, +) +# fig.savefig( +# os.path.join( +# FIGURE_DIR, +# f"haniffa_DEGs_heatmap.svg", +# ) +# ) +# %% +# _adata2 = __adata.copy() +# sc.pp.neighbors(_adata2, use_rep="X_scviv2_attention_mog_u_mde", n_neighbors=10) +# sc.tl.umap(_adata2, min_dist=0.3, maxiter=1000, init_pos="X_scviv2_attention_mog_u_mde") # %% +## Plot beta modules activity scores & GSEA for beta_module_key in beta_module_keys: cluster = int(beta_module_key.split("_")[-1]) - vmin, vmax = np.quantile(_adata.obs[beta_module_key], [0.05, 0.95]) - if _adata.obs[beta_module_key].mean() > 0: + gene_info_module = ( + gene_info_.loc[gene_info_[LFC_CLUSTERING_KEY] == str(cluster)] + .sort_values("absLFC", ascending=False) + .loc[:, ["gene", "LFC", "absLFC"]] + ) + gene_info_module.to_csv( + os.path.join( + FIGURE_DIR, + f"gene_info_module_{cluster}.csv", + ), + sep="\t", + ) + + vmin, vmax = np.quantile(__adata.obs[beta_module_key], [0.10, 0.90]) + # Option when sorted by abs + # cmap = "viridis" + + # Option when sorted by sign + if __adata.obs[beta_module_key].mean() > 0: cmap = "Reds" vmin = 0 else: @@ -574,7 +1041,7 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= vmax = 0 fig = sc.pl.embedding( - _adata, + __adata, basis="X_scviv2_attention_mog_u_mde", color=beta_module_key, vmin=vmin, @@ -582,997 +1049,425 @@ def compute_distance_matrices(model, adata=None, dists=None, leiden_resolutions= cmap=cmap, return_fig=True, ) + # fig = sc.pl.umap( + # _adata, + # color=beta_module_key, + # vmin=vmin, + # vmax=vmax, + # cmap=cmap, + # return_fig=True, + # ) fig.savefig( os.path.join( FIGURE_DIR, f"{beta_module_key}_{cluster}.svg", ) ) - plt.tight_layout() - - plot_df = ( - all_enrichr_results.loc[lambda x: x.cluster == cluster, :] - .loc[lambda x: x["Adjusted P-value"] < 0.1, :] - .sort_values("Adjusted P-value") - .head(5) - .sort_values("Gene_set") - .assign( - Term=lambda x: x.Term.str.split(r" \(GO", expand=True).loc[:, 0], - ) + # fig = sc.pl.umap( + # _adata2, + # # basis="X_scviv2_attention_mog_u_mde", + # color=beta_module_key, + # vmin=vmin, + # vmax=vmax, + # cmap=cmap, + # return_fig=True, + # ) + # fig.savefig( + # os.path.join( + # FIGURE_DIR, + # f"{beta_module_key}_{cluster}2.svg", + # ) + # ) + # plt.tight_layout() + + # plot_df = ( + # all_enrichr_results.loc[lambda x: x.cluster == cluster, :] + # .loc[lambda x: ~x.Term.isin(["Inflammatory Response"])] + # .loc[lambda x: x["Adjusted P-value"] < 0.1, :] + # .sort_values("Significance score", ascending=False) + # .head(5) + # # .sort_values("Gene_set") + # .assign( + # Term=lambda x: x.Term.str.split(r" \(GO", expand=True).loc[:, 0], + # ) + # ) + # scaler = len(plot_df) + # fig = ( + # p9.ggplot( + # plot_df, + # p9.aes( + # x="Term", + # y="Significance score", + # # fill='Gene_set', + # ), + # ) + # + p9.geom_col(color="grey") + # # + p9.geom_col() + # + p9.labs( + # x="", + # ) + # + p9.theme_classic() + # + p9.scale_y_continuous(expand=(0, 0)) + # + p9.scale_x_discrete(limits=plot_df.Term.tolist()) + # + p9.theme( + # strip_background=p9.element_blank(), + # axis_text_x=p9.element_text(rotation=0, hjust=1), + # axis_text=p9.element_text(family="sans-serif", size=5), + # axis_title=p9.element_text(family="sans-serif", size=6), + # ) + # + p9.coord_flip() + # # + p9.scale_x_discrete(limits=plot_df.Term.tolist()) + # ) + # fig.save( + # os.path.join( + # FIGURE_DIR, + # f"haniffa.{cluster}.beta_modules_cts.{beta_module_key}.gsea.svg", + # ) + # ) + # plt.tight_layout() + # fig.draw(show=True) + # print(beta_module_key) + # for term in plot_df.Term.tolist(): + # print(term) + # print() + +# %% +plot_df_de = ( + __adata.obs.loc[lambda x: x["initial_clustering"].isin(["CD14", "CD16", "DCs"])] + .assign( + initial_clustering=lambda x: x["initial_clustering"].astype(str) ) - scaler = len(plot_df) - fig = ( - p9.ggplot(plot_df, p9.aes(x="Term", y="Significance score")) - + p9.geom_col(color="grey") - + p9.scale_x_discrete(limits=plot_df.Term.tolist()) - + p9.labs( - x="", - ) - + p9.theme_classic() - + p9.scale_y_continuous(expand=(0, 0)) - + p9.theme( - strip_background=p9.element_blank(), - axis_text_x=p9.element_text(rotation=45, hjust=1), - axis_text=p9.element_text(family="sans-serif", size=5), - axis_title=p9.element_text(family="sans-serif", size=6), - # figure_size=(4 * INCH_TO_CM, 4 * INCH_TO_CM), - ) +) +plot_df_de +# %% +sns.violinplot( + plot_df_de, x="initial_clustering", y="beta_module_0", +) +plt.xlabel("") +plt.ylabel("Log-LR") +plt.savefig( + os.path.join( + FIGURE_DIR, + f"DE_violinplot_beta_module_0.svg", ) - # if idx != 0: - # fig = fig + p9.theme(legend_position="none") - fig.save( - os.path.join( - FIGURE_DIR, - f"haniffa.{cluster}.beta_modules_cts.{beta_module_key}.gsea.svg", - ) +) +plt.show() + + +sns.violinplot( + plot_df_de, x="initial_clustering", y="beta_module_1", +) +plt.xlabel("") +plt.ylabel("Log-LR") +plt.savefig( + os.path.join( + FIGURE_DIR, + f"DE_violinplot_beta_module_1.svg", ) - plt.tight_layout() - fig.draw(show=True) +) +plt.show() -# %% -all_enrichr_results +sns.violinplot( + plot_df_de, x="initial_clustering", y="beta_module_2", +) +plt.xlabel("") +plt.ylabel("Log-LR") +plt.savefig( + os.path.join( + FIGURE_DIR, + f"DE_violinplot_beta_module_2.svg", + ) +) +plt.show() -# %% -zs = model.get_local_sample_representation() -u = model.get_latent_representation(give_z=False, use_mean=True) +sns.violinplot( + plot_df_de, x="initial_clustering", y="beta_module_3", +) +plt.xlabel("") +plt.ylabel("Log-LR") +plt.savefig( + os.path.join( + FIGURE_DIR, + f"DE_violinplot_beta_module_3.svg", + ) +) +plt.show() # %% -# eps_ = zs.values - u[:, None] -eps_ = zs.values -eps_ = (eps_ - eps_.mean(axis=1, keepdims=True)) / eps_.std(axis=1, keepdims=True) -eps_ = eps_.reshape(eps_.shape[0], -1) -pca_eps = PCA(n_components=1) -eps_pca = pca_eps.fit_transform(eps_) -adata_embs.obs["eps_pca"] = eps_pca -adata.obs.loc[:, "eps_pca"] = eps_pca +da_res = model.get_outlier_cell_sample_pairs( + flavor="ap", minibatch_size=1000, adata=_adata +) +gp1 = model.donor_info.query('Status == "Covid"').patient_id.values +gp2 = model.donor_info.query('Status == "Healthy"').patient_id.values +log_p1 = da_res.log_probs.loc[{"sample": gp1}] +log_p1 = logsumexp(log_p1, axis=1) - np.log(log_p1.shape[1]) +log_p2 = da_res.log_probs.loc[{"sample": gp2}] +log_p2 = logsumexp(log_p2, axis=1) - np.log(log_p2.shape[1]) +log_ratios_casecontrol = log_p1 - log_p2 + + +gp1 = model.donor_info.query("is_covid2 == 1").patient_id.values +gp2 = model.donor_info.query("is_covid1 == 1").patient_id.values +# Late vs early +log_p1 = da_res.log_probs.loc[{"sample": gp1}] +log_p1 = logsumexp(log_p1, axis=1) - np.log(log_p1.shape[1]) +log_p2 = da_res.log_probs.loc[{"sample": gp2}] +log_p2 = logsumexp(log_p2, axis=1) - np.log(log_p2.shape[1]) +log_ratios_earlylate = log_p1 - log_p2 # %% -print(adata_embs.shape) -for obsm_key in adata_embs.obsm.keys(): - if obsm_key.endswith("mde") & ("scviv2" in obsm_key): - print(obsm_key) - # rdm_perm = np.random.permutation(adata.shape[0]) - fig = sc.pl.embedding( - # adata_embs[rdm_perm], - adata_embs, - basis=obsm_key, - color=["initial_clustering", "Status", "eps_pca", "patient_id"], - ncols=1, - show=True, - return_fig=True, - ) - fig.savefig( - os.path.join(FIGURE_DIR, f"embedding_{obsm_key}.svg"), bbox_inches="tight" - ) # %% -pca_eps = PCA(n_components=50) -eps_pca = pca_eps.fit_transform(eps_) -adata_embs.obsm["eps_PCs"] = eps_pca +_adata.obs.loc[:, "log_ratios_casecontrol"] = log_ratios_casecontrol +_adata.obs.loc[:, "log_ratios_earlylate"] = log_ratios_earlylate -import pymde -mde_kwargs = dict( - embedding_dim=2, - constraint=pymde.Standardized(), - repulsive_fraction=0.7, - device="cuda", - n_neighbors=15, -) -latent_mde = pymde.preserve_neighbors(eps_, **mde_kwargs).embed().cpu().numpy() +# fig = sc.pl.embedding( +# _adata, +# basis="X_scviv2_attention_mog_u_mde", +# color=["initial_clustering", "log_ratios_casecontrol", "log_ratios_earlylate"], +# return_fig=True, +# # vmin="p5", +# # vmax="p95", +# vmin=-0.6, +# vmax=0.6, +# cmap="coolwarm", +# ) +# fig.savefig( +# os.path.join( +# FIGURE_DIR, +# f"initial_clustering_{cluster}_all.svg", +# ) +# ) +# sc.pl.embedding( +# _adata, +# basis="X_scviv2_attention_mog_u_mde", +# color="initial_clustering", +# legend_loc="right", +# ) -adata_embs.obsm["eps_mde"] = latent_mde -sc.pl.embedding( - adata_embs, - basis="eps_mde", - color=["initial_clustering", "Status", "eps_pca", "patient_id"], +plot_df = ( + _adata.obs + .assign( + mde1=_adata.obsm["X_scviv2_attention_mog_u_mde"][:, 0], + mde2=_adata.obsm["X_scviv2_attention_mog_u_mde"][:, 1] + ) + .loc[_adata.obs.initial_clustering.isin(['CD14', 'CD16', 'DCs'])] + .assign( + initial_clustering=lambda x: x.initial_clustering.astype(str) + ) ) -# %% -# my_adata = model.adata.copy() -# sc.pp.subsample(my_adata, n_obs=50000) -adata.obs.loc[:, "_indices"] = np.arange(adata.shape[0]) -dists = model.get_local_sample_distances( - adata, - keep_cell=True, +( + p9.ggplot(plot_df, p9.aes(x="mde1", y="mde2", color="initial_clustering")) + + p9.geom_point() ) # %% -axis = 0 -dmats = dists["cell"].values -dmats = np.array([dmat[np.triu_indices(dmat.shape[0], k=1)] for dmat in dmats]) -dmats = (dmats - dmats.mean(axis=axis, keepdims=True)) / dmats.std( - axis=axis, keepdims=True -) -# dmats = np.argsort(dmats, axis=1) -dmats_ = PCA(n_components=50).fit_transform(dmats) -latent_ = pymde.preserve_neighbors(dmats_, **mde_kwargs).embed().cpu().numpy() -adata.obsm["dmat_mde"] = latent_ -sc.pp.neighbors(adata, use_rep="dmat_mde", n_neighbors=15) -# %% -sc.tl.leiden(adata, key_added="leiden_dmats", resolution=0.005) -# adata.obs.loc[:, "leiden_dmats"] = KMeans(n_clusters=6).fit_predict(adata.obsm["dmat_mde"]) -adata.obs.loc[:, "leiden_dmats"] = adata.obs.loc[:, "leiden_dmats"].astype(str) - -sc.pl.embedding(adata, basis="dmat_mde", color=["initial_clustering", "leiden_dmats"]) +sc.pp.neighbors(_adata, use_rep="X_scviv2_attention_mog_u_mde", n_neighbors=30) +# fibroblast.obsp['normalized_connectivities'] = fibroblast.obsp['connectivities']/fibroblast.obsp['connectivities'].sum(1) # %% -props_per_cluster = ( - adata.obs.groupby("leiden_dmats") - .initial_clustering.value_counts(normalize=True) - .to_frame("prop") - .reset_index() +# fig = plt.figure(figsize=(2.5, 2.5)) +sns.violinplot( + plot_df, x="initial_clustering", y="log_ratios_casecontrol", ) -props_per_cluster - -# %% -( - p9.ggplot( - props_per_cluster, p9.aes(x="leiden_dmats", y="prop", fill="initial_clustering") +plt.xlabel("") +plt.ylabel("Log-LR") +plt.axhline(0, color='red', linestyle='--', linewidth=2) +plt.ylim(-5.0, 5.0) +plt.savefig( + os.path.join( + FIGURE_DIR, + f"DA_violinplot_{cluster}_all_casecontrol.svg", ) - + p9.geom_col(position="fill") ) # %% -( - p9.ggplot( - props_per_cluster, p9.aes(x="initial_clustering", y="prop", fill="leiden_dmats") +sns.violinplot( + plot_df, x="initial_clustering", y="log_ratios_earlylate", +) +plt.axhline(0, color='red', linestyle='--', linewidth=2) +plt.xlabel("") +plt.ylabel("Log-LR") +plt.ylim(-2.0, 2.0) +plt.savefig( + os.path.join( + FIGURE_DIR, + f"DA_violinplot_{cluster}_all_earlylate.svg", ) - + p9.geom_col(position="dodge") - + p9.coord_flip() ) - # %% -mapper = { - "0": "CD14", - "1": "NK", - "2": "CD4", - "3": "CD8", - "4": "CD4/CD8", - "5": "B cell", - "6": "CD16", - # "7": "Platelets", - "7": "Plasmablasts", -} -adata.obs.loc[:, "leiden_names"] = adata.obs.leiden_dmats.map(mapper) +def get_smooth_logratios(adata, log_ratio_key): + adata.obsp["normalized_connectivities"] = adata.obsp["connectivities"] / adata.obsp[ + "connectivities" + ].sum(1) + adata.obs[log_ratio_key + "_smoothed"] = np.asarray( + adata.obsp["normalized_connectivities"] + * np.expand_dims(adata.obs[log_ratio_key], 1) + ).squeeze() -# %% -dmat_files = glob.glob("../results/aws_pipeline/distance_matrices/haniffa2.*.nc") -dmat_files -# %% -# dmat_file = "../results/aws_pipeline/distance_matrices/haniffa2.scviv2_attention.distance_matrices.nc" -# dmat_file = "../results/aws_pipeline/distance_matrices/haniffa2.scviv2_attention_no_prior_mog_large.distance_matrices.nc" -dmat_file = "../results/aws_pipeline/distance_matrices/haniffa2.scviv2_attention_mog.distance_matrices.nc" -d = xr.open_dataset(dmat_file) +for key in ["log_ratios_casecontrol", "log_ratios_earlylate"]: + get_smooth_logratios(_adata, key) + + fig = sc.pl.embedding( + _adata, + basis="X_scviv2_attention_mog_u_mde", + color=[key, key+"_smoothed"], + # vmin=-0.6, + # vmax=0.6, + vmin=-2.5, + vmax=2.5, + cmap="coolwarm", + return_fig=True, + ) + fig.savefig( + os.path.join( + FIGURE_DIR, + f"DA_embs_{key}.svg", + ) + ) # %% -plt.rcParams["axes.grid"] = False +from scipy.stats import ttest_ind +from statsmodels.stats.weightstats import ttest_ind as sm_ttest -VMIN = 0.0 -VMAX = 1.0 -selected_cts = [ - "CD14", - # "B_cell", - # "CD4", -] -n_clusters = [ - 3, - # 3, - # 2, -] -for idx, (selected_ct, n_cluster) in enumerate(zip(selected_cts, n_clusters)): - mask_samples = donor_info_.index - d1 = d.loc[dict(initial_clustering_name=selected_ct)]["initial_clustering"] - d1 = d1.loc[dict(sample_x=mask_samples)].loc[dict(sample_y=mask_samples)] - Z = hierarchical_clustering(d1.values, method="ward", return_ete=False) +for ct in ["CD14", "CD16", "DCs"]: + pop1 = _adata.obs.loc[lambda x: x.initial_clustering == ct] + pop2 = _adata.obs.loc[lambda x: x.initial_clustering != ct] - colors_ = colors.loc[d1.sample_x.values] - donor_info_ = donor_info_.loc[d1.sample_x.values] + test1 = sm_ttest( + pop1.log_ratios_casecontrol, + pop2.log_ratios_casecontrol, + alternative="smaller", + value=-0.1 + ) + test2 = sm_ttest( + pop1.log_ratios_casecontrol, + pop2.log_ratios_casecontrol, + alternative="larger", + value=0.1 + ) + pval = np.minimum(test1[1], test2[1]) + print(f"Cluster {ct}: {pval}") - # Get clusters - clusters = fcluster(Z, t=n_cluster, criterion="maxclust") - donor_info_.loc[:, "donor_group"] = clusters - colors_.loc[:, "cluster"] = clusters - colors_.loc[:, "cluster"] = colors_.cluster.map( - {1: "#eb4034", 2: "#3452eb", 3: "#f7fcf5", 4: "#FF8000"} - # red, blue, white - ).values +# %% +from scipy.stats import ttest_ind +from statsmodels.stats.weightstats import ttest_ind as sm_ttest - donor_cluster_key = f"donor_clusters_{selected_ct}" - adata.obs.loc[:, donor_cluster_key] = adata.obs.patient_id.map( - donor_info_.loc[:, "donor_group"] - ).values - adata.obs.loc[:, donor_cluster_key] = "cluster " + adata.obs.loc[ - :, donor_cluster_key - ].astype(str) - sns.clustermap( - d1.to_pandas(), - row_linkage=Z, - col_linkage=Z, - row_colors=colors_, - vmin=VMIN, - vmax=VMAX, - # cmap="YlGnBu", - yticklabels=True, - figsize=(20, 20), - ) - plt.savefig(os.path.join(FIGURE_DIR, f"clustermap_{selected_ct}.svg")) - - adata_log = adata.copy() - sc.pp.normalize_total(adata_log) - sc.pp.log1p(adata_log) - pop = adata_log[(adata_log.obs.initial_clustering == selected_ct)].copy() - - sc.tl.rank_genes_groups( - pop, - donor_cluster_key, - method="t-test", - n_genes=1000, +for ct in ["CD14", "CD16", "DCs"]: + pop1 = _adata.obs.loc[lambda x: x.initial_clustering == ct] + pop2 = _adata.obs.loc[lambda x: x.initial_clustering != ct] + + test1 = sm_ttest( + pop1.log_ratios_earlylate, + pop2.log_ratios_earlylate, + alternative="smaller", + value=-0.1 ) - fig = sc.pl.rank_genes_groups_dotplot( - pop, - n_genes=5, - min_logfoldchange=0.5, - swap_axes=True, - return_fig=True, + test2 = sm_ttest( + pop1.log_ratios_earlylate, + pop2.log_ratios_earlylate, + alternative="larger", + value=0.1 ) - fig.savefig(os.path.join(FIGURE_DIR, f"DOThaniffa.{selected_ct}.clustered.svg")) + pval = np.minimum(test1[1], test2[1]) + print(f"Cluster {ct}: {pval}") + +# %% +# plot_df = _adata.obs.loc[ +# lambda x: x.initial_clustering.isin(["CD14", "CD16", "DCs"]) +# ].melt( +# id_vars=["initial_clustering"], +# value_vars=["log_ratios_casecontrol", "log_ratios_earlylate"], +# ) + +# ( +# p9.ggplot(plot_df, p9.aes(x="initial_clustering", y="value", fill="variable")) +# + p9.geom_boxplot() +# + p9.ylim(-0.6, 0.6) +# ) + # %% +### OOD experiment ood_res = model.get_outlier_cell_sample_pairs( subsample_size=5000, minibatch_size=256, quantile_threshold=0.05 ) - +sample_name = "newcastle74" +sample_idx = model.sample_order.tolist().index(sample_name) +np.random.seed(42) +random_indices = np.random.choice(adata.shape[0], size=10000, replace=False) +sample_scores, top_idxs = compute_sample_cf_reconstruction_scores( + model, sample_idx, indices=random_indices +) # %% -n_admissible = ( - ood_res.to_dataframe()["is_admissible"].unstack().sum(1).to_frame("n_admissible") +adata_subset = adata[sample_scores.index] +sample_ball_res = ood_res.sel(cell_name=adata_subset.obs_names).sel( + sample=model.sample_order[sample_idx] ) -obs_ = adata.obs.join(n_admissible) - -atleast = 25 -n_donors_with_atleast = ( - obs_.groupby(["initial_clustering"]) - .apply(lambda x: x.loc[x.n_admissible >= atleast].patient_id.nunique()) - .to_frame("n_donors_with_atleast") +sample_adm_log_probs = sample_ball_res.log_probs.to_series() +sample_adm_bool = sample_ball_res.is_admissible.to_series() +is_sample = pd.Series( + adata_subset.obs["sample_id"] == model.sample_order[sample_idx], + name="is_sample", + dtype=bool, +) +sample_log_lib_size = pd.Series( + np.log(adata_subset.X.toarray().sum(axis=1)), + index=adata_subset.obs_names, + name="log_lib_size", +) +cell_category = pd.Series( + ["Not Admissible"] * adata_subset.shape[0], + dtype=str, + name="cell_category", + index=adata_subset.obs_names, ) -n_donors_with_atleast +cell_category[sample_adm_bool.to_numpy()] = "Admissible" +cell_category[is_sample.to_numpy()] = "In Sample" +cell_category = cell_category.astype("category") -n_pred_donors = ( - obs_.groupby(["initial_clustering"]).n_admissible.mean().to_frame("n_pred_donors") +rec_score_plot_df = pd.concat( + ( + sample_adm_log_probs, + sample_adm_bool, + is_sample, + cell_category, + sample_scores, + sample_log_lib_size, + ), + axis=1, +).sample(frac=1, replace=False) + +ax = sns.scatterplot( + rec_score_plot_df, x="log_probs", y=f"{sample_name}_score", hue="cell_category", s=5 ) - -joined = n_donors_with_atleast.join(n_pred_donors) - -# %% -( - p9.ggplot( - joined.reset_index(), p9.aes(x="n_donors_with_atleast", y="n_pred_donors") +plt.xlabel("Admissibility Score") +plt.ylabel("Reconstruction Log Prob of In-Sample NN") +handles, labels = plt.gca().get_legend_handles_labels() +order = [1, 2, 0] +plt.legend([handles[idx] for idx in order], [labels[idx] for idx in order]) +plt.xlim(-100, 30) +plt.savefig( + os.path.join( + FIGURE_DIR, + f"haniffa_{sample_name}_admissibility_vs_reconstruction_w_category.svg", ) - + p9.geom_point() - + p9.geom_abline(intercept=0, slope=1) ) +plt.clf() # %% -# Admissibility vs Counterfactual Reconstruction -import pynndescent -import jax.numpy as jnp -from scvi import REGISTRY_KEYS -from scvi_v2._constants import MRVI_REGISTRY_KEYS -from scvi.distributions import JaxNegativeBinomialMeanDisp as NegativeBinomial -from tqdm import tqdm - -# module level function -def compute_px_from_x(self, x, sample_index, batch_index, cf_sample=None, continuous_covs=None, label_index=None, mc_samples=10): - """Compute normalized gene expression from observations""" - log_library = 7.0 * jnp.ones_like(sample_index) # placeholder, will be replaced by observed library sizes. - inference_outputs = self.inference(x, sample_index, mc_samples=mc_samples, cf_sample=cf_sample, use_mean=False) - generative_inputs = { - "z": inference_outputs["z"], - "library": log_library, - "batch_index": batch_index, - "continuous_covs": continuous_covs, - "label_index": label_index, - } - generative_outputs = self.generative(**generative_inputs) - return generative_outputs["px"], inference_outputs["u"], log_library - - -def compute_sample_cf_reconstruction_scores(self, sample_idx, adata=None, indices=None, batch_size=256, inner_batch_size=8, mc_samples=10, n_top_neighbors=5): - self._check_if_trained(warn=False) - adata = self._validate_anndata(adata) - sample_name = self.sample_order[sample_idx] - sample_adata = adata[adata.obs[self.sample_key] == sample_name] - if sample_adata.shape[0] == 0: - raise ValueError(f"Sample {sample_name} missing from AnnData.") - sample_u = self.get_latent_representation(sample_adata, give_z=False) - sample_index = pynndescent.NNDescent(sample_u) - - scdl = self._make_data_loader(adata=adata, batch_size=batch_size, indices=indices, iter_ndarray=True) - - def _get_all_inputs( - inputs, - ): - x = jnp.array(inputs[REGISTRY_KEYS.X_KEY]) - sample_index = jnp.array(inputs[MRVI_REGISTRY_KEYS.SAMPLE_KEY]) - batch_index = jnp.array(inputs[REGISTRY_KEYS.BATCH_KEY]) - continuous_covs = inputs.get(REGISTRY_KEYS.CONT_COVS_KEY, None) - label_index = inputs.get(REGISTRY_KEYS.LABELS_KEY, None) - if continuous_covs is not None: - continuous_covs = jnp.array(continuous_covs) - return { - "x": x, - "sample_index": sample_index, - "batch_index": batch_index, - "continuous_covs": continuous_covs, - "label_index": label_index, - } - - scores = [] - top_idxs = [] - for array_dict in tqdm(scdl): - vars_in = {"params": self.module.params, **self.module.state} - rngs = self.module.rngs - - inputs = _get_all_inputs(array_dict) - px, u, log_library_placeholder = self.module.apply( - vars_in, - rngs=rngs, - method=compute_px_from_x, - x=inputs["x"], - sample_index=inputs["sample_index"], - batch_index=inputs["batch_index"], - cf_sample=np.ones(inputs["x"].shape[0]) * sample_idx, - continuous_covs=inputs["continuous_covs"], - label_index=inputs["label_index"], - mc_samples=mc_samples, - ) - px_m, px_d = px.mean, px.inverse_dispersion - if px_m.ndim == 2: - px_m, px_d = np.expand_dims(px_m, axis=0), np.expand_dims(px_d, axis=0) - px_m, px_d = np.expand_dims(px_m, axis=2), np.expand_dims(px_d, axis=2) # for inner_batch_size dim - - mc_log_probs = [] - batch_top_idxs = [] - for mc_sample_i in range(u.shape[0]): - nearest_sample_idxs = sample_index.query(u[mc_sample_i], k=n_top_neighbors)[0] - top_neighbor_counts = sample_adata.X[nearest_sample_idxs.reshape(-1), :].toarray().reshape((nearest_sample_idxs.shape[0],nearest_sample_idxs.shape[1], -1)) - new_lib_size = top_neighbor_counts.sum(axis=-1) # batch_size x n_top_neighbors - corrected_px_m = px_m[mc_sample_i] / np.exp(log_library_placeholder[:, :, None]) * new_lib_size[:, :, None] - corrected_px = NegativeBinomial(mean=corrected_px_m, inverse_dispersion=px_d) # mc_samples x batch_size x inner_batch_size x genes - log_probs = corrected_px.log_prob(top_neighbor_counts).sum(-1).mean(-1) # 1 x batch_size - mc_log_probs.append(log_probs) - batch_top_idxs.append(nearest_sample_idxs) - full_batch_log_probs = np.concatenate(mc_log_probs, axis=0).mean(0) - top_idxs.append(np.concatenate(batch_top_idxs, axis=1)) - - scores.append(full_batch_log_probs) - - all_scores = np.hstack(scores) - all_top_idxs = np.vstack(top_idxs) - adata_index = adata[indices] if indices is not None else adata - return pd.Series(all_scores, index=adata_index.obs_names.to_numpy(), name=f"{sample_name}_score"), all_top_idxs - -# %% -sample_name = "newcastle74" -sample_idx = model.sample_order.tolist().index(sample_name) -np.random.seed(42) -random_indices = np.random.choice(adata.shape[0], size=10000, replace=False) -sample_scores, top_idxs = compute_sample_cf_reconstruction_scores(model, sample_idx, indices=random_indices) - -# %% -adata_subset = adata[sample_scores.index] -sample_ball_res = ood_res.sel(cell_name=adata_subset.obs_names).sel(sample=model.sample_order[sample_idx]) -sample_adm_log_probs = sample_ball_res.log_probs.to_series() -sample_adm_bool = sample_ball_res.is_admissible.to_series() -is_sample = pd.Series(adata_subset.obs["sample_id"] == model.sample_order[sample_idx], name="is_sample", dtype=bool) -sample_log_lib_size = pd.Series(np.log(adata_subset.X.toarray().sum(axis=1)), index=adata_subset.obs_names, name="log_lib_size") -cell_category = pd.Series(["Not Admissible"] * adata_subset.shape[0], dtype=str, name="cell_category", index=adata_subset.obs_names) -cell_category[sample_adm_bool.to_numpy()] = "Admissible" -cell_category[is_sample.to_numpy()] = "In Sample" -cell_category = cell_category.astype("category") - -rec_score_plot_df = ( - pd.concat((sample_adm_log_probs, sample_adm_bool, is_sample, cell_category, sample_scores, sample_log_lib_size), axis=1) - .sample(frac=1, replace=False) -) -# %% -sns.scatterplot(rec_score_plot_df, x="log_probs", y=f"{sample_name}_score", hue="cell_category", s=5) -plt.xlabel("Admissibility Score") -plt.ylabel("Reconstruction Log Prob of In-Sample NN") -handles, labels = plt.gca().get_legend_handles_labels() -order = [1, 2, 0] -plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order]) -plt.xlim(-100, 30) -# fig.save(os.path.join(FIGURE_DIR, f"haniffa_{sample_name}_admissibility_vs_reconstruction_w_category.svg")) - -# %% -# adata_embs.obs.loc[:, "n_valid_donors"] = res["is_admissible"].values.sum(axis=1) -# for obsm_key in adata_embs.obsm.keys(): -# if obsm_key.endswith("mde") & ("scviv2" in obsm_key): -# print(obsm_key) -# sc.pl.embedding( -# # adata_embs[rdm_perm], -# adata_embs, -# basis=obsm_key, -# color=["initial_clustering","n_valid_donors"], -# save=f"haniffa.{obsm_key}.svg", -# ncols=1, -# ) - -# %% -donor_keys = [ - "Sex", - "Status", - "age_group", -] -# %% -res = model.perform_multivariate_analysis( - donor_keys=donor_keys, - adata=None, - batch_size=256, - normalize_design_matrix=True, - offset_design_matrix=False, - filter_donors=True, - subsample_size=500, - quantile_threshold=0.05, -) - - -# %% -es_keys = [f"es_{cov}" for cov in res.covariate.values] -is_sig_keys_ = [f"is_sig_{cov}_" for cov in res.covariate.values] -is_sig_keys = [f"is_sig_{cov}" for cov in res.covariate.values] - -adata.obs.loc[:, es_keys] = res["effect_size"].values -adata.obs.loc[:, is_sig_keys_] = res["padj"].values <= 0.1 -adata.obs.loc[:, is_sig_keys] = adata.obs.loc[:, is_sig_keys_].astype(str).values - -adata_embs.obs.loc[:, es_keys] = res["effect_size"].values -adata_embs.obs.loc[:, is_sig_keys_] = res["padj"].values <= 0.1 -adata_embs.obs.loc[:, is_sig_keys] = ( - adata_embs.obs.loc[:, is_sig_keys_].astype(str).values -) - -# %% -for obsm_key in adata_embs.obsm.keys(): - if obsm_key.endswith("mde") & ("scviv2" in obsm_key): - print(obsm_key) - sc.pl.embedding( - # adata_embs[rdm_perm], - adata_embs, - basis=obsm_key, - color=["initial_clustering"] + es_keys, - save=f"haniffa.{obsm_key}.svg", - ncols=1, - ) - -# %% -plot_df = adata.obs.reset_index().melt( - id_vars=["index", "initial_clustering"], - value_vars=es_keys, -) - -n_points = adata.obs.initial_clustering.value_counts().to_frame("n_points") -plot_df = plot_df.merge( - n_points, left_on="initial_clustering", right_index=True, how="left" -).assign( - variable_name=lambda x: x.variable.map( - { - "es_SexMale": "Sex", - "es_StatusHealthy": "Status", - "es_age_group>=60": "Age", - } - ) -) - - -# %% -INCH_TO_CM = 1 / 2.54 - -plt.rcParams["svg.fonttype"] = "none" - -fig = ( - p9.ggplot( - # plot_df.loc[lambda x: x.n_points > 7000, :], - plot_df, - p9.aes(x="initial_clustering", y="value"), - ) - + p9.geom_boxplot(outlier_shape="", fill="#3492eb") - # + p9.facet_wrap("~variable", scales="free_x") - + p9.facet_wrap("~variable_name") - + p9.coord_flip() - + p9.labs(y="Effect size", x="") - + p9.theme( - figure_size=(10 * INCH_TO_CM, 5 * INCH_TO_CM), - axis_text=p9.element_text(size=7), - ) -) -fig.save(os.path.join(FIGURE_DIR, "haniffa_multivariate.svg")) -fig - - -donor_keys = [ - "Sex", - "Status", - "age_group", -] - -# %% -de_res = model.perform_multivariate_analysis( - donor_keys=donor_keys, - adata=None, - batch_size=256, - normalize_design_matrix=True, - offset_design_matrix=False, - filter_donors=True, - subsample_size=500, - quantile_threshold=0.05, -) -da_res = model.get_outlier_cell_sample_pairs(flavor="ap", minibatch_size=1000) - -# %% -gp1 = model.donor_info.query('Status == "Covid"').patient_id.values -gp2 = model.donor_info.query('Status == "Healthy"').patient_id.values -log_p1 = da_res.log_probs.loc[{"sample": gp1}] -log_p1 = logsumexp(log_p1, axis=1) - np.log(log_p1.shape[1]) -log_p2 = da_res.log_probs.loc[{"sample": gp2}] -log_p2 = logsumexp(log_p2, axis=1) - np.log(log_p2.shape[1]) - -log_ratios = log_p1 - log_p2 - -# %% -log_p_general = logsumexp(da_res.log_probs, axis=1) - np.log(da_res.log_probs.shape[1]) -adata.obs.loc[:, "log_p"] = log_p_general -admissibility_threshold = 0.05 -adata.obs.loc[:, "is_admissible"] = log_p_general > np.quantile( - log_p_general, admissibility_threshold -) -adata.obs.loc[:, "is_admissible_"] = adata.obs.loc[:, "is_admissible"].astype(str) - -de_es = de_res["effect_size"].loc[{"covariate": "StatusHealthy"}].values -adata.obs.loc[:, "da_es"] = np.clip( - log_ratios, - a_min=np.quantile(log_ratios, 0.01), - a_max=np.quantile(log_ratios, 0.99), -) -adata.obs.loc[~adata.obs.is_admissible, "da_es"] = 0.0 -adata.obs.loc[:, "de_es"] = de_es - -# %% -sc.pl.embedding( - adata, - basis="X_scviv2_attention_mog_u_mde", - color=["initial_clustering", "da_es", "log_p"], - # vmin=-2, - # vmax=2, - # cmap="coolwarm", -) - -# %% -( - p9.ggplot( - adata.obs, - p9.aes("log_p", "da_es"), - ) - + p9.geom_point() - + p9.xlim(-15, 10) - # + p9.ylim(-2, 2) -) -# %% -( - p9.ggplot( - adata.obs.query("is_admissible"), - p9.aes("da_es", "de_es"), - ) - + p9.geom_point() - # + p9.xlim(-15, 10) - # + p9.ylim(-2, 2) -) - -# %% -my_adata = adata[adata.obs.initial_clustering == "Plasmablast"].copy() -my_adata.obs.da_es.hist(bins=100) -sc.pl.embedding( - my_adata, - basis="X_scviv2_attention_mog_u_mde", - color=["initial_clustering", "da_es", "log_p", "is_admissible_"], - ncols=1, -) +### SCIB metrics -# %% -( - (adata.obs.Status == "Healthy") & (adata.obs.initial_clustering == "Plasmablast") -).mean() - -# %% -((adata.obs.Status == "Healthy") & (adata.obs.initial_clustering == "RBC")).mean() - -# %% -fig = sc.pl.embedding( - adata, - basis="X_scviv2_attention_mog_u_mde", - color=["initial_clustering", "da_es", "is_admissible_"], - vmin=-2, - vmax=2, - cmap="coolwarm", - return_fig=True, - ncols=1, -) -fig.savefig(os.path.join(FIGURE_DIR, f"haniffa.DA_mde.svg")) - -# %% -fig = sc.pl.embedding( - adata, - basis="X_scviv2_attention_mog_u_mde", - color="de_es", - return_fig=True, -) -fig.savefig(os.path.join(FIGURE_DIR, f"haniffa.DE_mde.svg")) - -# %% - -# %% -cross_df = adata.obs.assign( - da_es=log_ratios, - de_es=de_es, -) - -# cross_df_key = "initial_clustering" -# cross_df_key = "leiden_dmats" -cross_df_key = "leiden_names" -cross_df_avg = ( - cross_df.groupby(cross_df_key)[["da_es", "de_es"]] - .median() - .reset_index() - .merge( - cross_df.groupby(cross_df_key).size().to_frame("n_points").reset_index(), - on=cross_df_key, - ) -) - -# %% -fig = ( - p9.ggplot(cross_df_avg, p9.aes(x="da_es", y="de_es")) - + p9.geom_point(size=0.5) - + p9.geom_text(p9.aes(label=cross_df_key), nudge_y=0.1, size=5) - # + p9.xlim(-1, 2) - # + p9.xlim(-1, 8) - + p9.labs( - x="DA score", - y="DE score", - ) - + p9.theme_classic() - + p9.theme( - axis_text=p9.element_text(family="sans-serif", size=5), - axis_title=p9.element_text(family="sans-serif", size=6), - figure_size=(4 * INCH_TO_CM, 4 * INCH_TO_CM), - ) -) -# fig.save(os.path.join(FIGURE_DIR, f"haniffa.DE_DA_cross.svg")) -fig - - -# %% -donor_keys = [ - "Sex", - "Status", - "age_group", -] -adata.obs.loc[:, "is_covid1"] = ( - adata.obs["donor_clusters_CD14"] == "cluster 2" -).astype(int) -adata.obs.loc[:, "is_covid2"] = ( - adata.obs["donor_clusters_CD14"] == "cluster 3" -).astype(int) -donor_keys_bis = ["is_covid1", "is_covid2"] - -obs_df = adata.obs.copy() -obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")] -model.donor_info = obs_df.set_index("_scvi_sample").sort_index() - - -# %% -selected_cluster = "Monocytes" -# adata.obsm = adata_embs.obsm -# adata_ = adata[adata.obs.initial_clustering.isin(["CD14", "CD16"])].copy() - -adata.obsm = adata_embs.obsm -adata_ = adata[adata.obs.leiden_names.isin(["CD14", "CD16"])].copy() - -sc.pp.subsample(adata_, n_obs=50000, random_state=0) -adata_.obs.loc[:, "_indices"] = np.arange(adata_.shape[0]) -adata_log_ = adata_.copy() -sc.pp.log1p(adata_log_) - -res = model.perform_multivariate_analysis( - donor_keys=donor_keys_bis, - adata=adata_, - batch_size=128, - normalize_design_matrix=True, - offset_design_matrix=False, - store_lfc=True, - eps_lfc=1e-4, -) -gene_properties = (adata_.X != 0).mean(axis=0).A1 -gene_properties = pd.DataFrame( - gene_properties, index=adata_.var_names, columns=["sparsity"] -) - - -# %% -betas_ = res.lfc.transpose("cell_name", "covariate", "gene") -betas_ = ( - betas_.loc[{"covariate": "is_covid2"}].values - - betas_.loc[{"covariate": "is_covid1"}].values -) -plt.hist(betas_.mean(0), bins=100) -plt.xlabel("LFC") -plt.show() - -lfc_df = pd.DataFrame( - { - "LFC": betas_.mean(0), - "LFC_std": betas_.std(0), - "gene": adata_.var_names, - "gene_index": np.arange(adata_.shape[1]), - "ensembl_gene": adata_.var["ensembl_gene"], - } -).assign(absLFC=lambda x: np.abs(x.LFC)) - -thresh = np.quantile(lfc_df.absLFC, 0.95) -lfc_df.absLFC.hist(bins=100) -plt.axvline(thresh, color="red") -plt.xlabel("AbsLFC") -plt.show() -print((lfc_df.absLFC > thresh).sum()) - -# %% - -# %% -VMAX = 1.0 -cond = lfc_df.absLFC > thresh -betas_de = betas_[:, cond] -obs_de = lfc_df.loc[cond, :].reset_index(drop=True) -obs_de.LFC.hist(bins=100) - -# %% - -adata_t = sc.AnnData( - X=betas_de.T, - obs=obs_de, -) -adata_t.X = (adata_t.X - adata_t.X.mean(0)) / adata_t.X.std(0) -# adata_t.X = (adata_t.X - adata_t.X.mean(1, keepdims=True)) / adata_t.X.std(1, keepdims=True) -sc.pp.neighbors(adata_t, n_neighbors=50, metric="cosine", use_rep="X") -sc.tl.umap(adata_t, min_dist=0.5) -sc.tl.leiden(adata_t, resolution=0.5) -fig = sc.pl.umap( - adata_t, - color=["leiden", "LFC"], - vmin=-VMAX, - vmax=VMAX, - cmap="coolwarm", - return_fig=True, -) -plt.tight_layout() -fig.savefig( - os.path.join(FIGURE_DIR, f"haniffa.{selected_cluster}.gene_umap.svg"), -) - -fig = sc.pl.umap( - adata_t, - color="LFC_std", - return_fig=True, -) -fig.savefig( - os.path.join(FIGURE_DIR, f"haniffa.{selected_cluster}.gene_umap_std.svg"), -) - -# %% -cov_mat = np.corrcoef(adata_t.X) - -X_pca = PCA(n_components=50).fit_transform(adata_t.X) -dissimilarity = pairwise_distances(X_pca) -clusters = KMeans(n_clusters=10).fit_predict(X_pca) -# dissimilarity = 1 - cov_mat -# dissimilarity = pairwise_distances(adata_t.X, metric="cosine") - -gene_reps = MDS(n_components=2, dissimilarity="precomputed").fit_transform( - dissimilarity -) -# gene_reps = TSNE(n_components=2, metric="precomputed", init="random", perplexity=50).fit_transform(dissimilarity) -adata_t.obsm["gene_reps"] = gene_reps -adata_t.obs.loc[:, "KMeans_clusters"] = clusters -adata_t.obs.loc[:, "KMeans_clusters"] = adata_t.obs.loc[:, "KMeans_clusters"].astype( - str -) -sc.pl.embedding( - adata_t, - basis="gene_reps", - color=["KMeans_clusters", "LFC"], - vmin=-VMAX, - vmax=VMAX, - cmap="coolwarm", -) -fig = sc.pl.embedding( - adata_t, - basis="gene_reps", - color="LFC_std", -) - -# %% -clustering_key = "KMeans_clusters" -# %% -gene_info_ = adata_t.obs - -beta_module_keys = [] -all_enrichr_results = [] -for cluster in np.arange(gene_info_[clustering_key].nunique()): - beta_module_name = f"beta_module_{cluster}" - gene_info_module = gene_info_.loc[ - gene_info_[clustering_key] == str(cluster) - ].sort_values("absLFC", ascending=False) - genes = ( - # gene_info_module.loc[lambda x: ~x["ensembl_gene"].isna(), "gene"] - gene_info_module.loc[:, "gene"] - .str.strip() - .str.split(".", expand=True) - .loc[:, 0] - .str.upper() - .tolist() - ) - gene_indices = gene_info_module.loc[:, "gene_index"].tolist() - - # beta_module = betas_[:, gene_indices].mean(1) - beta_module = np.median(betas_[:, gene_indices], 1) - adata_.obs.loc[:, beta_module_name] = beta_module - beta_module_keys.append(beta_module_name) - - enr = perform_gsea(genes).assign(cluster=cluster) - all_enrichr_results.append(enr) -all_enrichr_results = pd.concat(all_enrichr_results).astype({"Gene_set": "category"}) - - -# %% -fig = sc.pl.embedding( - adata_, - basis="X_scviv2_attention_mog_u_mde", - color=["initial_clustering"], - vmax="p95", - cmap="coolwarm", - return_fig=True, -) -fig.savefig( - os.path.join(FIGURE_DIR, f"haniffa.{selected_cluster}.beta_modules_cts.svg") -) - -for beta_module_key in beta_module_keys: - cluster = int(beta_module_key.split("_")[-1]) - vmin, vmax = np.quantile(adata_.obs[beta_module_key], [0.05, 0.95]) - if adata_.obs[beta_module_key].mean() > 0: - cmap = "Reds" - vmin = 0 - else: - cmap = "Blues_r" - vmax = 0 - - fig = sc.pl.embedding( - adata_, - basis="X_scviv2_attention_mog_u_mde", - color=beta_module_key, - vmin=vmin, - vmax=vmax, - cmap=cmap, - return_fig=True, - ) - plt.tight_layout() - fig.savefig( - os.path.join( - FIGURE_DIR, - f"haniffa.{selected_cluster}.beta_modules_cts.{beta_module_key}.svg", - ) - ) - - genes = gene_info_.query(f"{clustering_key} == '{cluster}'").gene.tolist() - cond = adata_log_.obs.is_covid1.astype(bool) | adata_log_.obs.is_covid2.astype(bool) - adata_log_1 = adata_log_[cond, :] - adata_log_1 = adata_log_1[:, genes].copy() - adata_log_1.X = adata_log_1.X.toarray() - adata_log_1.X = (adata_log_1.X - adata_log_1.X.mean(0)) / ( - 1e-6 + adata_log_1.X.std(0) - ) - adata_log_1.X = adata_log_1.X.clip(-5, 5) - adata_log_1.obs["is_covid2"] = adata_log_1.obs["is_covid2"].astype(str) - fig = sc.pl.heatmap( - adata_log_1, - genes, - groupby="is_covid2", - show=False, - vmin=-2, - vmax=2, - cmap="coolwarm", - show_gene_labels=True, - ) - plt.tight_layout() - ax = fig["heatmap_ax"] - ax.figure.savefig( - os.path.join( - FIGURE_DIR, - f"haniffa.{selected_cluster}.beta_modules_cts.{beta_module_key}.heatmap.svg", - ) - ) - plt.show() - - plot_df = ( - all_enrichr_results.loc[lambda x: x.cluster == cluster, :] - .loc[lambda x: x["Adjusted P-value"] < 0.1, :] - .sort_values("Adjusted P-value") - .head(5) - .sort_values("Gene_set") - .assign( - Term=lambda x: x.Term.str.split(r" \(GO", expand=True).loc[:, 0], - ) - ) - scaler = len(plot_df) - fig = ( - p9.ggplot(plot_df, p9.aes(x="Term", y="Significance score")) - + p9.geom_col(color="grey") - + p9.scale_x_discrete(limits=plot_df.Term.tolist()) - + p9.labs( - x="", - ) - + p9.theme_classic() - + p9.scale_y_continuous(expand=(0, 0)) - + p9.theme( - strip_background=p9.element_blank(), - axis_text_x=p9.element_text(rotation=45, hjust=1), - axis_text=p9.element_text(family="sans-serif", size=5), - axis_title=p9.element_text(family="sans-serif", size=6), - # figure_size=(4 * INCH_TO_CM, 4 * INCH_TO_CM), - ) - ) - if idx != 0: - fig = fig + p9.theme(legend_position="none") - fig.save( - os.path.join( - FIGURE_DIR, - f"haniffa.{selected_cluster}.beta_modules_cts.{beta_module_key}.gsea.svg", - ) - ) - plt.tight_layout() - fig.draw(show=True) - - -# %% -# %% keys_of_interest = { "X_SCVI_clusterkey_subleiden1": "SCVI", "X_PCA_clusterkey_subleiden1": "PCA", @@ -1596,7 +1491,7 @@ def _get_all_inputs( adata_sub = adata.copy() sc.pp.subsample(adata_sub, n_obs=25000) -# %% + bm = Benchmarker( adata_sub, batch_key="patient_id", @@ -1613,3 +1508,5 @@ def _get_all_inputs( min_max_scale=False, save_dir=FIGURE_DIR, ) + +# %% diff --git a/bin/produce_figures_pbmc68k_for_subsample.py b/bin/produce_figures_pbmc68k_for_subsample.py index 246670f..5713e9a 100644 --- a/bin/produce_figures_pbmc68k_for_subsample.py +++ b/bin/produce_figures_pbmc68k_for_subsample.py @@ -8,6 +8,12 @@ from scipy.special import logsumexp import matplotlib.pyplot as plt import seaborn as sns +import scipy.sparse as sp +from scipy.stats import pearsonr +from scipy.stats import ttest_ind +from statsmodels.stats.weightstats import ttest_ind as sm_ttest_ind +from pydeseq2.dds import DeseqDataSet +from tqdm import tqdm from ete3 import Tree from tree_utils import hierarchical_clustering @@ -50,6 +56,13 @@ def compute_ratio(dist, sample_to_mask): adata = sc.read_h5ad("../results/milo/data/pbmcs68k_for_subsample.preprocessed.h5ad") + +# %% +( + adata.obs.loc[lambda x: x.leiden == "0"].groupby(["sample_group", "sample_assignment"]).size() + .loc[lambda x: x > 1] +) + # %% import scipy.cluster.hierarchy as sch @@ -114,10 +127,11 @@ def compute_ratio(dist, sample_to_mask): print(obsm_key) rdm_perm = np.random.permutation(adata.shape[0]) adata_.obs["cell_type"] = adata.obs.leiden.copy().astype(str) - adata_.obs.loc[~adata_.obs.leiden.isin(("0", "1")), "cell_type"] = "other" + adata_.obs.loc[~adata_.obs.leiden.isin(("0", "1", "3")), "cell_type"] = "other" adata_.obs["cell_type"] = adata_.obs["cell_type"].astype("category") adata_.uns["cell_type_colors"] = { - "0": "#7AB5FF", + "0": "#7AA98F", + "3": "#7AB5FF", "1": "#FF7A7A", "other": "lightgray", } @@ -151,6 +165,7 @@ def compute_ratio(dist, sample_to_mask): return_fig=True, show=False, ) + print(os.path.join(FIGURE_DIR, f"{obsm_key}_subcluster.svg")) plt.savefig( os.path.join(FIGURE_DIR, f"{obsm_key}_subcluster.svg"), bbox_inches="tight", @@ -212,6 +227,22 @@ def compute_ratio(dist, sample_to_mask): # %% sample_order = adata.obs["sample_assignment"].cat.categories +dmat_files = [ + '../results/milo/distance_matrices/pbmcs68k_for_subsample.distance_matrices_gt.nc', +# '../results/milo/distance_matrices/pbmcs68k_for_subsample.scviv2_attention_no_prior_mog_large.normalized_distance_matrices.nc', +# '../results/milo/distance_matrices/pbmcs68k_for_subsample.scviv2_attention_mog.distance_matrices.nc', +# '../results/milo/distance_matrices/pbmcs68k_for_subsample.scviv2_attention_mog.normalized_distance_matrices.nc', +# '../results/milo/distance_matrices/pbmcs68k_for_subsample.scviv2_attention_noprior.normalized_distance_matrices.nc', + '../results/milo/distance_matrices/pbmcs68k_for_subsample.scviv2_attention_no_prior_mog.distance_matrices.nc', + '../results/milo/distance_matrices/pbmcs68k_for_subsample.composition_SCVI_clusterkey_subleiden1.distance_matrices.nc', + '../results/milo/distance_matrices/pbmcs68k_for_subsample.composition_PCA_clusterkey_subleiden1.distance_matrices.nc', +# '../results/milo/distance_matrices/pbmcs68k_for_subsample.scviv2_attention_no_prior_mog_large.distance_matrices.nc', +# '../results/milo/distance_matrices/pbmcs68k_for_subsample.composition_SCVI_leiden1_subleiden1.distance_matrices.nc', +# '../results/milo/distance_matrices/pbmcs68k_for_subsample.scviv2_attention_noprior.distance_matrices.nc', +# '../results/milo/distance_matrices/pbmcs68k_for_subsample.scviv2_attention_no_prior_mog.distance_matrices.nc', +# '../results/milo/distance_matrices/pbmcs68k_for_subsample.composition_PCA_leiden1_subleiden1.distance_matrices.nc' +] + all_res = [] for dmat_file in dmat_files: print(dmat_file) @@ -311,7 +342,7 @@ def compute_ratio(dist, sample_to_mask): + SHARED_THEME + p9.labs(x="", y="Intra-cluster distance ratio") ) -fig.save(os.path.join(FIGURE_DIR, "intra_distance_ratios.svg")) +# fig.save(os.path.join(FIGURE_DIR, "intra_distance_ratios.svg")) fig # %% @@ -347,10 +378,23 @@ def compute_ratio(dist, sample_to_mask): + p9.ylim(0, 1.2) + p9.coord_flip() + p9.labs(y="Inter cluster distance ratio", x="") + + p9.geom_point() ) fig.save(os.path.join(FIGURE_DIR, "inter_distance_ratios.svg")) fig +# %% +from scipy.stats import mannwhitneyu + +for model in relative_d.Model.unique(): + if model.startswith("MrVI"): + continue + print("MrVI", model, mannwhitneyu( + relative_d.query("Model == 'MrVI (MoG)'").relative_d, + relative_d.query(f"Model == '{model}'").relative_d, + )) + + # %% # Plot variance of dist to sample 8 per rank sample_to_group_and_rank = pd.DataFrame(sample_to_group).reset_index() @@ -495,7 +539,8 @@ def compute_ratio(dist, sample_to_mask): # %% # DEG Analysis -import scvi_v2 +import mrvi +# import scvi_v2 modelname = "scviv2_attention_mog" @@ -504,8 +549,8 @@ def compute_ratio(dist, sample_to_mask): ) model_path = os.path.join(f"../results/milo/models/pbmcs68k_for_subsample.{modelname}") adata = sc.read(adata_path) -model = scvi_v2.MrVI.load(model_path, adata=adata) -model +# model = mrvi.MrVI.load(model_path, adata=adata) +# model # %% model_out_adata_path = os.path.join( @@ -531,13 +576,207 @@ def compute_ratio(dist, sample_to_mask): adata.obs # %% +WORKDIR = "/data1/scvi-v2-reproducibility/results/milo/models" +filedir = os.path.join(WORKDIR, "pbmcs68k_for_subsample.MILODE.de_analysis.tsv") +mtx_file = os.path.join(WORKDIR, "pbmcs68k_for_subsample.MILODE.assignments.mtx") + +milode_preds = pd.read_csv( + filedir, + sep="\t", +) +milode_neighborhoods = pd.read_csv( + mtx_file, + sep="\s", + skiprows=2, + header=None, +) +cell_idx = milode_neighborhoods.iloc[:, 0].to_numpy() - 1 +neigh_idx = milode_neighborhoods.iloc[:, 1].to_numpy() - 1 +data_idx = np.ones_like(cell_idx) +mtx = sp.csr_matrix((data_idx, (cell_idx, neigh_idx))) + +# %% +milode_nhoods_lfcs = ( + milode_preds + .pivot_table(columns="gene", index="Nhood", values="logFC") + .fillna(0.0) + .sort_index() +) +milode_nhoods_lfcs.columns.name = None +milode_nhoods_pvals = ( + milode_preds + .pivot_table(columns="gene", index="Nhood", values="pval") + .fillna(1.0) + .sort_index() +) +milode_nhoods_pvals.columns.name = None + +milode_cell_pvals_max = [] +milode_cell_pvals_mean = [] +milode_cell_lfcs = [] +for row in tqdm(mtx): + row_ = row.toarray().flatten() + row_ = row_ > 0 + res_lfc = milode_nhoods_lfcs.iloc[row_] + res_pval = milode_nhoods_pvals.iloc[row_] + + milode_cell_lfcs.append(res_lfc.mean(0)) + milode_cell_pvals_mean.append(res_pval.mean(0)) + milode_cell_pvals_max.append(res_pval.max(0)) +milode_cell_lfcs = pd.concat(milode_cell_lfcs, axis=1).T +milode_cell_pvals_max = pd.concat(milode_cell_pvals_max, axis=1).T +milode_cell_pvals_mean = pd.concat(milode_cell_pvals_mean, axis=1).T + + +# %% + +# %% +multivariate_analysis_kwargs = { + "batch_size": 128, + "normalize_design_matrix": True, + "offset_design_matrix": False, + "store_lfc": True, + "eps_lfc": 1e-4, +} + +d_keys = ["group_1"] + mv_deg_res = model.perform_multivariate_analysis( - adata, - donor_keys=[f"group_{group}" for group in sample_to_group.unique()], - store_lfc=True, + adata, donor_keys=d_keys, **multivariate_analysis_kwargs, ) mv_deg_res +# %% +mrvi_lfcs = mv_deg_res.lfc.squeeze("covariate").to_pandas() +milode_cell_lfcs.index = adata.obs_names + + +# adata_pos = adata[adata.obs.leiden == "0"].copy() +# sc.pp.normalize_total(adata_pos, target_sum=1e4) +# sc.pp.log1p(adata_pos) + +# adata_pos_1 = adata_pos[adata_pos.obs.sample_group == 1].copy() +# x_1 = adata_pos_1.X.toarray() +# adata_pos_2 = adata_pos[adata_pos.obs.sample_group != 1].copy() +# x_2 = adata_pos_2.X.toarray() +# lfc_gt = x_1.mean(0) - x_2.mean(0) +# stat_res = ranksums(x_1, x_2) + +# gt_de_analysis = pd.DataFrame( +# { +# "gene": adata_pos_1.var_names, +# "pvalue": stat_res.pvalue, +# "statistic": stat_res.statistic, +# "lfc_gt": lfc_gt, +# "padj": multipletests(stat_res.pvalue, method="fdr_bh")[1], +# } +# ) +# gt_de_analysis.loc[:, "is_gene_for_subclustering"] = gt_de_analysis.padj < 0.05 + +# %% +adata_pos = adata[adata.obs.leiden == "0"].copy() +sc.pp.subsample(adata_pos, n_obs=8000) +counts_df = pd.DataFrame( + adata_pos.X.toarray(), + columns=adata_pos.var_names +) +metadata = (adata_pos.obs.sample_group == 1).astype(int).to_frame("condition") +dds = DeseqDataSet( + counts=counts_df, + metadata=metadata, + design_factors="condition", + refit_cooks=True, + n_cpus=8, +) +dds.deseq2() +# %% +gt_de_analysis = pd.DataFrame( + { + "gene": dds.var_names, + "lfc_gt": dds.varm["LFC"]["condition_1_vs_0"].fillna(0.0).values, + } +) + + +# %% + +# %% + +mrvi_res = gt_de_analysis.merge( + mrvi_lfcs[(adata.obs.leiden == "0").values].mean(0).to_frame("LFC"), + on="gene", + how="left", +) + +milode_res = gt_de_analysis.merge( + milode_cell_lfcs[(adata.obs.leiden == "0").values].mean(0).to_frame("LFC"), + left_on="gene", + right_index=True, +) +r_mrvi = pearsonr( + mrvi_res.LFC.values, + mrvi_res.lfc_gt.values, +) +r_milode = pearsonr( + milode_res.LFC.values, + milode_res.lfc_gt.values, +) +de_comp = ( + pd.concat( + [ + mrvi_res.assign(method=f"MrVI; Pearson r: {r_mrvi[0]:.2f}"), + milode_res.assign(method=f"MILODE; Pearson r: {r_milode[0]:.2f}"), + ] + ) + .merge( + adata.var.loc[:, ["is_gene_for_subclustering"]], + left_on="gene", + right_index=True, + how="left", + ) +) + +# %% +fig = ( + p9.ggplot( + de_comp, + p9.aes(x="LFC", fill="method"), + ) + # + p9.geom_histogram(bins=100, alpha=0.5) + + p9.theme_classic() + + p9.geom_density() + + p9.facet_wrap("~method+is_gene_for_subclustering") + + p9.xlim(-0.5, 0.5) +) +fig.save(os.path.join(FIGURE_DIR, "deg_comparison.svg")) + +# %% +de_comp_ = de_comp.sample(frac=0.3) +de_comp_ +# %% +vmax = 0.5 +fig = ( + p9.ggplot(de_comp_, p9.aes(x="LFC", y="lfc_gt", fill="method")) + + p9.geom_abline(slope=1, intercept=0, color="black", linetype="dashed", size=1) + + p9.geom_point(stroke=0.1, alpha=0.5) + + p9.theme_classic() + + p9.xlim(-vmax, vmax) + + p9.ylim(-vmax, vmax) + + p9.theme( + figure_size=(4 * INCH_TO_CM, 4 * INCH_TO_CM), + axis_text=p9.element_text(size=6), + axis_title=p9.element_text(size=7), + legend_position="none", + ) + + p9.labs( + y="Inferred LFC", + ) +) +fig.save(os.path.join(FIGURE_DIR, "deg_comparison_legend.svg")) +fig + + + # %% group_no = 1 model_out_adata.obs[f"group_{group_no}_eff_size"] = mv_deg_res.effect_size.sel( @@ -546,7 +785,7 @@ def compute_ratio(dist, sample_to_mask): fig = sc.pl.embedding( model_out_adata, basis="X_scviv2_attention_mog_u_mde", - color=f"group_{group_no}_eff_size", + color=[f"group_{group_no}_eff_size", "leiden"], vmax="p95", vmin="p5", return_fig=True, @@ -559,6 +798,8 @@ def compute_ratio(dist, sample_to_mask): plt.show() plt.clf() + + # %% lfcs = mv_deg_res.lfc.sel(covariate="group_1").values lfcs_in_one = lfcs[adata.obs.leiden == "0"].mean(0) @@ -571,54 +812,6 @@ def compute_ratio(dist, sample_to_mask): plt.hist(lfcs_in_others, bins=bins, alpha=0.5, label="other clusters") plt.legend() -# %% -# abs_lfcs = np.abs(lfcs) -# abs_lfcs_in_one = abs_lfcs[adata.obs.leiden == "0"].mean(0) -# abs_lfcs_in_others = abs_lfcs[adata.obs.leiden != "0"].mean(0) -# abs_lfcs_in_one = np.abs(lfcs_in_one) -# abs_lfcs_in_others = np.abs(lfcs_in_others) -# is_gene_used = adata.var["is_gene_for_subclustering"] -# plot_df = pd.concat( -# [ -# pd.DataFrame( -# { -# "abs_lfc": abs_lfcs_in_one, -# "is_gene_for_subclustering": is_gene_used, -# "cluster": "cluster 0", -# } -# ), -# pd.DataFrame( -# { -# "abs_lfc": abs_lfcs_in_others, -# "is_gene_for_subclustering": is_gene_used, -# "cluster": "other_clusters", -# } -# ), -# ] -# ) - -# colors = ["#7AB5FF", "lightgray"] -# fig = ( -# p9.ggplot( -# plot_df, -# p9.aes(x="factor(is_gene_for_subclustering)", y="abs_lfc", fill="cluster"), -# ) -# + p9.geom_boxplot(outlier_alpha=0.0) -# + p9.ylim(0, 0.02) -# + p9.scale_fill_manual(values=colors) -# + p9.coord_flip() -# + p9.labs( -# x="", -# y="Absolute LFC", -# ) -# + p9.theme_classic() -# + p9.theme( -# figure_size=(8 * INCH_TO_CM, 4 * INCH_TO_CM), -# ) -# + SHARED_THEME -# ) -# fig.save(os.path.join(FIGURE_DIR, "DEGs.svg")) -# fig # %% from scipy.stats import wilcoxon, ranksums @@ -627,34 +820,6 @@ def compute_ratio(dist, sample_to_mask): adata_pos = adata[adata.obs.leiden == "0"].copy() sc.pp.normalize_total(adata_pos, target_sum=1e4) sc.pp.log1p(adata_pos) -# gt_de_analysis = [] -# for leiden in adata_pos.obs.sample_group.unique(): -# adata_pos_1 = adata_pos[adata_pos.obs.sample_group == leiden].copy() -# x_1 = adata_pos_1.X.toarray() -# adata_pos_2 = adata_pos[adata_pos.obs.sample_group != leiden].copy() -# x_2 = adata_pos_2.X.toarray() - -# # stat_res = np.array([wilcoxon(x_, y_).pvalue for (x_, y_) in zip(x_1.T, x_2.T)]) -# stat_res = ranksums(x_1, x_2) - -# gt_de_analysis.append( -# pd.DataFrame( -# { -# "gene": adata_pos_1.var_names, -# "pvalue": stat_res.pvalue, -# "statistic": stat_res.statistic, -# "leiden": leiden, -# "padj": multipletests(stat_res.pvalue, method="fdr_bh")[1], -# } -# ) -# ) -# gene_scores = ( -# gt_de_analysis -# .groupby("gene") -# .mean() -# ) -# gt_de_analysis = pd.concat(gt_de_analysis) - adata_pos_1 = adata_pos[adata_pos.obs.sample_group == 1].copy() x_1 = adata_pos_1.X.toarray() @@ -706,7 +871,9 @@ def compute_ratio(dist, sample_to_mask): ) lfcs_ -(p9.ggplot(lfcs_, p9.aes(x="LFC", y="statistic", fill="Group")) + p9.geom_point()) + +# %% + # %% @@ -728,10 +895,11 @@ def compute_ratio(dist, sample_to_mask): ) ) fig.save(os.path.join(FIGURE_DIR, "semisynth_DEGs_legend.svg")) +fig -fig = fig + p9.theme(legend_position="none") -fig.save(os.path.join(FIGURE_DIR, "semisynth_DEGs.svg")) - +fig2 = fig + p9.theme(legend_position="none") +fig2.save(os.path.join(FIGURE_DIR, "semisynth_DEGs.svg")) +fig # %% high_eff_size_cells = model_out_adata[ model_out_adata.obs["group_1_eff_size"] > 400 @@ -795,15 +963,116 @@ def compute_ratio(dist, sample_to_mask): bbox_inches="tight", ) +# Comparison to MILO +milo_analysis = pd.read_csv( + "../results/milo/models/pbmcs68k_for_subsample.MILO.da_analysis.tsv", sep="\t" +) +milo_neighborhoods = pd.read_csv( + "../results/milo/models/pbmcs68k_for_subsample.MILO.assignments.mtx", + sep="\s", + skiprows=2, + header=None, +) +cell_idx = milo_neighborhoods.iloc[:, 0].to_numpy() - 1 +neigh_idx = milo_neighborhoods.iloc[:, 1].to_numpy() - 1 +data_idx = np.ones_like(cell_idx) +mtx = sp.csr_matrix((data_idx, (cell_idx, neigh_idx))) + +print(mtx.shape, milo_analysis.shape) +milo_cell_res = [] +for row in tqdm(mtx): + row_ = row.toarray().flatten() + row_ = row_ > 0 + res_ = milo_analysis.iloc[row_] + lfc = res_["logFC"].mean(0) + pval = res_["PValue"].min(0) + pval_conservative = res_["PValue"].max(0) + abs_lfc = np.abs(lfc) + milo_cell_res.append(dict(lfc=lfc, pval=pval, pval_conservative=pval_conservative, abs_lfc=abs_lfc)) +milo_cell_res = pd.DataFrame(milo_cell_res) + +# %% +from mrvi import MrVI +import flax.linen as nn + + +MrVI.setup_anndata( + adata, batch_key="Site", sample_key="sample_assignment" +) +model = MrVI( + adata, + **{ + "n_latent": 30, + "n_latent_u": 5, + "qz_nn_flavor": "attention", + "px_nn_flavor": "attention", + "qz_kwargs": { + "use_map": True, + "stop_gradients": False, + "stop_gradients_mlp": True, + # "dropout_rate": 0.03, + }, + "px_kwargs": { + "stop_gradients": False, + "stop_gradients_mlp": True, + "h_activation": nn.softmax, + "low_dim_batch": True, + # "dropout_rate": 0.03, + }, + "learn_z_u_prior_scale": False, + "z_u_prior": True, + "u_prior_mixture": True, + "u_prior_mixture_k": 5, + } +) + +import jax + +model.train( + # accelerator="cuda", + devices=jax.devices(), + **{ + "max_epochs": 400, + "batch_size": 1024, + "check_val_every_n_epoch": 1, + "early_stopping": True, + "early_stopping_patience": 100, + "early_stopping_monitor": "elbo_validation", + "plan_kwargs": { + "n_epochs_kl_warmup": 50, + "lr": 1e-2 + } + } +) +# model = MRVI( +# adata, +# model_name="pbmcs68k_for_subsample", +# model_dir="results/mrvi/models", +# use_cuda=False, +# device=0, +# n_epochs=1000, +# lr=0.001, +# weight_decay=0.01, +# alpha=1.0, +# beta=1.0, +# gamma=1.0, +# n_neighbors=30, +# n_layers=1, +# n_hidden=256, +# dropout=0.1, +# batch_size=1000, +# seed=0, +# ) + +# %% +model.history["elbo_train"].plot() +model.history["elbo_validation"].plot() + # %% # Admissibility for rank 1 subsampled (can subsample further to see effect) -model_out_adata[model_out_adata.obs["sample_assignment"] == "1"] ball_res = model.get_outlier_cell_sample_pairs( flavor="ball", quantile_threshold=0.05, minibatch_size=1000 ) -ball_res - -# %% model_out_adata.obs["sample_1_admissibility"] = ball_res.is_admissible.sel( sample="1" ).astype(str) @@ -847,30 +1116,10 @@ def compute_ratio(dist, sample_to_mask): # %% -# Differential abundance rank 1 vs rank 4 -ap_res = model.get_outlier_cell_sample_pairs(flavor="ap", minibatch_size=1000) -ap_res + # %% -model_out_adata.obs["sample_1_4_da"] = ap_res.log_probs.sel( - sample="1" -) - ap_res.log_probs.sel(sample="4") -fig = sc.pl.embedding( - model_out_adata, - basis="X_scviv2_attention_mog_u_mde", - color=["sample_1_4_da", "leiden"], - vmax="p95", - vmin="p5", - vcenter=0, - cmap="RdBu", - show=False, - return_fig=True, -) -plt.savefig( - os.path.join(FIGURE_DIR, f"pres_{modelname}_rank_1_4_da.svg"), - bbox_inches="tight", -) -# %% +ap_res = model.get_outlier_cell_sample_pairs(flavor="ap", minibatch_size=1000) sample_to_metadata = ( adata.obs.loc[:, ["sample_assignment", "sample_metadata2"]] .drop_duplicates() @@ -892,6 +1141,7 @@ def compute_ratio(dist, sample_to_mask): # %% q5, q95 = np.quantile(log_ratios, [0.05, 0.95]) qval = np.maximum(np.abs(q5), np.abs(q95)) +qval = np.abs(q5) fig = sc.pl.embedding( model_out_adata, basis="X_scviv2_attention_mog_u_mde", @@ -908,83 +1158,99 @@ def compute_ratio(dist, sample_to_mask): bbox_inches="tight", ) + # %% -import scipy.sparse as sp -from tqdm import tqdm +# def get_smooth_logratios(adata, log_ratio_key): +# adata.obsp["normalized_connectivities"] = adata.obsp["connectivities"] / adata.obsp[ +# "connectivities" +# ].sum(1) +# adata.obs[log_ratio_key + "_smoothed"] = np.asarray( +# adata.obsp["normalized_connectivities"] +# * np.expand_dims(adata.obs[log_ratio_key], 1) +# ).squeeze() + +# sc.pp.neighbors(model_out_adata, n_neighbors=30, use_rep="X_scviv2_attention_mog_u") +model_out_adata.obs.loc[:, "mrvi_lfc"] = log_ratios +# get_smooth_logratios(model_out_adata, "mrvi_lfc") -# Comparison to MILO -milo_analysis = pd.read_csv( - "../results/milo/models/pbmcs68k_for_subsample.MILO.da_analysis.tsv", sep="\t" -) -milo_analysis -milo_neighborhoods = pd.read_csv( - "../results/milo/models/pbmcs68k_for_subsample.MILO.assignments.mtx", - sep="\s", - skiprows=2, - header=None, -) -cell_idx = milo_neighborhoods.iloc[:, 0].to_numpy() - 1 -neigh_idx = milo_neighborhoods.iloc[:, 1].to_numpy() - 1 -data_idx = np.ones_like(cell_idx) -mtx = sp.csr_matrix((data_idx, (cell_idx, neigh_idx))) # %% -milo_cell_res = [] -for row in tqdm(mtx): - row_ = row.toarray().flatten() - row_ = row_ > 0 - res_ = milo_analysis.iloc[row_] - lfc = res_["logFC"].mean(0) - pval = res_["PValue"].min(0) - milo_cell_res.append(dict(lfc=lfc, pval=pval)) -milo_cell_res = pd.DataFrame(milo_cell_res) - +# model_out_adata.obs.loc[:, "milo_lfc"] = milo_cell_res["lfc"].values +# qval = 2.5 +# fig = sc.pl.embedding( +# model_out_adata, +# basis="X_scviv2_attention_mog_u_mde", +# # color=["milo_lfc", "mrvi_lfc", "mrvi_lfc_smoothed", "leiden"], +# color=["mrvi_lfc", "leiden"], +# vmax=qval, +# vmin=-qval, +# vcenter=0, +# cmap="RdBu", +# show=False, +# return_fig=True, +# ) +# fig.savefig( +# os.path.join( +# FIGURE_DIR, f"DA_comparison_lfcs.svg" +# ) # %% -model_out_adata.obs.loc[:, "milo_lfc"] = milo_cell_res["lfc"].values -model_out_adata.obs.loc[:, "mrvi_lfc"] = log_ratios -fig = sc.pl.embedding( - model_out_adata, - basis="X_scviv2_attention_mog_u_mde", - color=["milo_lfc", "mrvi_lfc", "leiden"], - vmax=qval, - vmin=-qval, - vcenter=0, - cmap="RdBu", - show=False, - return_fig=True, +model_out_adata.obs.loc[:,"cluster_clean"] = model_out_adata.obs["leiden"].map( + { + "0": "subset A", + "1": "subset B", + "2": "Other", + "3": "subset C", + "4": "Other", + "5": "Other", + } +) +model_out_adata.obs["cluster_clean"] = pd.Categorical( + model_out_adata.obs["cluster_clean"], + categories=["subset A", "subset B", "subset C", "Other"], +) + +sns.violinplot( + model_out_adata.obs, x="cluster_clean", y="mrvi_lfc" +) +plt.axhline(0, color="black", linestyle="--") +plt.xticks(rotation=-45) +plt.ylim(-1.5, 1.5) +plt.savefig( + os.path.join(FIGURE_DIR, f"DA_comparison_lfcs_violin.svg") ) -fig.show() # %% -from sklearn.metrics import precision_recall_curve +for key in model_out_adata.obs["cluster_clean"].unique(): + pop1 = model_out_adata[model_out_adata.obs["cluster_clean"] == key].obs["mrvi_lfc"] + pop2 = model_out_adata[model_out_adata.obs["cluster_clean"] != key].obs["mrvi_lfc"] + stat_res = ttest_ind(pop1, pop2) + stat_res1 = sm_ttest_ind(pop1, pop2, alternative="larger", value=0.1) + stat_res2 = sm_ttest_ind(pop1, pop2, alternative="smaller", value=-0.1) + pval = np.minimum(stat_res1[1], stat_res2[1]) + print(key, pval) + +# %% +from sklearn.metrics import precision_recall_curve, auc def plot_pr(y_pred, label=None): - y_true = (adata.obs["leiden"] == "1").values + # y_true = (adata.obs["leiden"] == "1").values + y_true = (adata.obs["leiden"].isin(["1", "3"])).values y_pred_ = y_pred.copy() y_pred_[np.isnan(y_pred_)] = 0.0 pre, rec, _ = precision_recall_curve(y_true, -y_pred_) + prauc = auc(rec, pre) # plt.plot(rec, pre, label=label) - return pd.DataFrame(dict(precision=pre, recall=rec)) + return pd.DataFrame(dict(precision=pre, recall=rec)), prauc -# plot_pr(log_ratios, label="MrVI") -# plot_pr(milo_cell_res["lfc"].values, label="MILO") -# plt.legend() -# plt.xlabel("Recall") -# plt.ylabel("Precision") -# plt.xlim(0.01, 1.0) -# plt.savefig( -# os.path.join(FIGURE_DIR, f"semisynth_pr.svg"), -# ) -# plt.show() # %% -df1 = plot_pr(log_ratios, label="MrVI") -df2 = plot_pr(milo_cell_res["lfc"].values, label="MrVI") +df1, auc_mrvi = plot_pr(-np.abs(log_ratios)) +# df2 = plot_pr(milo_cell_res["pval"].values) +df2, auc_milo = plot_pr(-np.abs(milo_cell_res["lfc"].values)) df = pd.concat([df1.assign(model="MrVI"), df2.assign(model="MILO")]) - fig = ( p9.ggplot(df, p9.aes(x="recall", y="precision", color="model")) + p9.geom_line() @@ -999,121 +1265,175 @@ def plot_pr(y_pred, label=None): ) ) fig.save( - os.path.join(FIGURE_DIR, f"semisynth_pr.svg"), + os.path.join(FIGURE_DIR, f"semisynth_pr2.svg"), ) fig -# %% -plot_df = pd.concat( - [ - pd.DataFrame( - dict( - es=log_ratios, - method="MrVI", - is_cell_affected=(adata.obs["leiden"] == "1").values, - ) - ), - pd.DataFrame( - dict( - es=milo_cell_res["lfc"].values, - method="MILO", - is_cell_affected=(adata.obs["leiden"] == "1").values, - ) - ), - ] -) -( - p9.ggplot(plot_df, p9.aes(x="factor(is_cell_affected)", y="es", fill="method")) - + p9.geom_boxplot() -) # %% -lib_size = adata.X.sum(1).A1 -log_ratios - -df_ = pd.DataFrame( +plot_df = pd.DataFrame( { - "lib_size": lib_size, - "log_ratios": log_ratios, - "population": adata.obs["leiden"].values, + "model": ["MrVI", "MILO"], + "auc": [auc_mrvi, auc_milo], } ) -df_ = ( - df_.assign( - is_cell_affected=lambda x: x["population"] == "1", +fig = ( + p9.ggplot(plot_df, p9.aes(x="model", y="auc")) + + p9.geom_bar(stat="identity", width=0.5) + + p9.theme_classic() + + SHARED_THEME + + p9.labs( + x="", + y="PRAUC", ) - # .query("population != '1'") - # .sort_values("log_ratios") -) - -( - p9.ggplot( - df_.sample(frac=1.0), - p9.aes(x="lib_size", y="log_ratios", fill="is_cell_affected"), + + p9.scale_y_continuous(expand=(0.0, 0.0)) + + p9.theme( + figure_size=(4 * INCH_TO_CM, 4 * INCH_TO_CM), ) - + p9.geom_point() ) +fig.save( + os.path.join(FIGURE_DIR, f"semisynth_auc.svg"), +) +plot_df # %% -df_.query("~is_cell_affected").sort_values("log_ratios").assign( - rank_=lambda x: np.arange(len(x)) -).plot.scatter("rank_", "lib_size") +# df1 = plot_pr(log_ratios, label="MrVI") +# # df2 = plot_pr(milo_cell_res["lfc"].values, label="MrVI") +# df2 = plot_pr(-milo_cell_res["pvalue"].values, label="MrVI") +# df = pd.concat([df1.assign(model="MrVI"), df2.assign(model="MILO")]) +# fig = ( +# p9.ggplot(df, p9.aes(x="recall", y="precision", color="model")) +# + p9.geom_line() +# + p9.theme_classic() +# + SHARED_THEME +# + p9.theme( +# figure_size=(5.8 * INCH_TO_CM, 4 * INCH_TO_CM), +# ) +# + p9.xlim(0.01, 1.0) +# + p9.labs( +# color="", +# ) +# ) +# fig.save( +# os.path.join(FIGURE_DIR, f"semisynth_pr.svg"), +# ) +# fig -# %% -adata_files = glob.glob("../results/milo/data/pbmcs68k_for_subsample*.final.h5ad") -for adata_file in adata_files: - print(adata_file) +# # %% +# plot_df = pd.concat( +# [ +# pd.DataFrame( +# dict( +# es=log_ratios, +# method="MrVI", +# is_cell_affected=(adata.obs["leiden"] == "1").values, +# ) +# ), +# pd.DataFrame( +# dict( +# es=milo_cell_res["lfc"].values, +# method="MILO", +# is_cell_affected=(adata.obs["leiden"] == "1").values, +# ) +# ), +# ] +# ) +# ( +# p9.ggplot(plot_df, p9.aes(x="factor(is_cell_affected)", y="es", fill="method")) +# + p9.geom_boxplot() +# ) -# %% -selected_files = [ - dict( - filename="../results/milo/data/pbmcs68k_for_subsample.scviv2_attention_mog.final.h5ad", - modelname="MrVI", - keyname="X_scviv2_attention_mog_u", - ), - dict( - filename="../results/milo/data/pbmcs68k_for_subsample.composition_SCVI_leiden1_subleiden1.final.h5ad", - modelname="SCVI", - keyname="X_SCVI_leiden1_subleiden1", - ), - dict( - filename="../results/milo/data/pbmcs68k_for_subsample.composition_PCA_leiden1_subleiden1.final.h5ad", - modelname="PCA", - keyname="X_PCA_leiden1_subleiden1", - ), -] -embedding_obsm_keys = [] -for mydic in selected_files: - file = mydic["filename"] - keyname = mydic["keyname"] - modelname = mydic["modelname"] - _adata = sc.read_h5ad(file) - adata.obsm[modelname] = _adata.obsm[keyname] - embedding_obsm_keys.append(modelname) +# # %% +# lib_size = adata.X.sum(1).A1 +# log_ratios -# %% -from scib_metrics.benchmark import Benchmarker -bm = Benchmarker( - adata, - batch_key="sample_assignment", - label_key="leiden", - embedding_obsm_keys=embedding_obsm_keys, - # pre_integrated_embedding_obsm_key="X_pca", - n_jobs=-1, -) +# df_ = pd.DataFrame( +# { +# "lib_size": lib_size, +# "log_ratios": log_ratios, +# "population": adata.obs["leiden"].values, +# } +# ) +# df_ = ( +# df_.assign( +# is_cell_affected=lambda x: x["population"] == "1", +# ) +# # .query("population != '1'") +# # .sort_values("log_ratios") +# ) -# bm.prepare(neighbor_computer=faiss_brute_force_nn) -bm.prepare() -bm.benchmark() -bm.plot_results_table( - min_max_scale=False, - save_dir=FIGURE_DIR, -) +# ( +# p9.ggplot( +# df_.sample(frac=1.0), +# p9.aes(x="lib_size", y="log_ratios", fill="is_cell_affected"), +# ) +# + p9.geom_point() +# ) + +# # %% +# df_.query("~is_cell_affected").sort_values("log_ratios").assign( +# rank_=lambda x: np.arange(len(x)) +# ).plot.scatter("rank_", "lib_size") + + +# # %% +# adata_files = glob.glob("../results/milo/data/pbmcs68k_for_subsample*.final.h5ad") +# for adata_file in adata_files: +# print(adata_file) + +# # %% +# selected_files = [ +# dict( +# filename="../results/milo/data/pbmcs68k_for_subsample.scviv2_attention_mog.final.h5ad", +# modelname="MrVI", +# keyname="X_scviv2_attention_mog_u", +# ), +# dict( +# filename="../results/milo/data/pbmcs68k_for_subsample.composition_SCVI_leiden1_subleiden1.final.h5ad", +# modelname="SCVI", +# keyname="X_SCVI_leiden1_subleiden1", +# ), +# dict( +# filename="../results/milo/data/pbmcs68k_for_subsample.composition_PCA_leiden1_subleiden1.final.h5ad", +# modelname="PCA", +# keyname="X_PCA_leiden1_subleiden1", +# ), +# ] +# embedding_obsm_keys = [] +# for mydic in selected_files: +# file = mydic["filename"] +# keyname = mydic["keyname"] +# modelname = mydic["modelname"] +# _adata = sc.read_h5ad(file) +# adata.obsm[modelname] = _adata.obsm[keyname] +# embedding_obsm_keys.append(modelname) + +# # %% +# from scib_metrics.benchmark import Benchmarker + +# bm = Benchmarker( +# adata, +# batch_key="sample_assignment", +# label_key="leiden", +# embedding_obsm_keys=embedding_obsm_keys, +# # pre_integrated_embedding_obsm_key="X_pca", +# n_jobs=-1, +# ) + +# # bm.prepare(neighbor_computer=faiss_brute_force_nn) +# bm.prepare() +# bm.benchmark() +# bm.plot_results_table( +# min_max_scale=False, +# save_dir=FIGURE_DIR, +# ) + +# # %% +# model.history["reconstruction_loss_validation"].plot() +# # %% +# model.history["elbo_validation"].iloc[20:].plot() # %% -model.history["reconstruction_loss_validation"].plot() -# %% -model.history["elbo_validation"].iloc[20:].plot() diff --git a/bin/run_milode.R b/bin/run_milode.R index 5f67698..18a41cd 100644 --- a/bin/run_milode.R +++ b/bin/run_milode.R @@ -26,7 +26,7 @@ rep_key <- "pca.corrected" file_str <- paste(readLines(config_file), collapse="\n") config <- fromJSON(file_str) sample_key <- config$sample_key -covariate_key <- config$covariate_key +covariate_key <- config$covariate_key_de nuisance_key <- config$batch_key data <- readH5AD(file_name) diff --git a/bin/utils.py b/bin/utils.py index c30e80e..5886092 100644 --- a/bin/utils.py +++ b/bin/utils.py @@ -13,6 +13,7 @@ import numpy as np import pandas as pd import scanpy as sc +from tqdm import tqdm from remote_pdb import RemotePdb @@ -270,7 +271,7 @@ def perform_gsea( if use_server: is_done = False - for _ in range(n_trials_max): + for _ in tqdm(range(n_trials_max)): if is_done: break @@ -284,7 +285,8 @@ def perform_gsea( ) is_done = True except: - time.sleep(3) + print("GSEA failed; retrying...") + time.sleep(1) continue if not is_done: raise ValueError( diff --git a/conf/datasets/crohns_data.json b/conf/datasets/crohns_data.json new file mode 100644 index 0000000..92a8044 --- /dev/null +++ b/conf/datasets/crohns_data.json @@ -0,0 +1,20 @@ +{ + "batch_key": "Layer_Chem", + "labels_key": "Celltype", + "sample_key": "biosample_id", + "mrvi_train_kwargs": { + "max_epochs": 400 + }, + "scviv2_train_kwargs": { + "early_stopping": true, + "plan_kwargs": { + "lr": 1e-3, + "n_epochs_kl_warmup": 25 + } + }, + "composition_scvi_train_kwargs": { + "max_epochs": 400 + }, + "clustering_method": "ward", + "compute_local_representations": false +} \ No newline at end of file diff --git a/conf/datasets/pbmcs68k_for_subsample.json b/conf/datasets/pbmcs68k_for_subsample.json index 7f1597f..efebd52 100644 --- a/conf/datasets/pbmcs68k_for_subsample.json +++ b/conf/datasets/pbmcs68k_for_subsample.json @@ -3,6 +3,7 @@ "labels_key": "leiden", "sample_key": "sample_assignment", "covariate_key": "sample_metadata2", + "covariate_key_de": "group_1", "mrvi_train_kwargs": { "max_epochs": 400 }, @@ -32,6 +33,7 @@ "n_replicates_per_subcluster": 4, "selected_cluster": 0, "selected_subsample_cluster": 1, + "selected_oversample_cluster": 3, "subsample_rates": [ 0.7, 0.8, diff --git a/env/gpu/compute_2dreps.yaml b/env/gpu/compute_2dreps.yaml index 4b831fc..83ad7a7 100644 --- a/env/gpu/compute_2dreps.yaml +++ b/env/gpu/compute_2dreps.yaml @@ -12,6 +12,7 @@ dependencies: - dask=2022.12.1 - netcdf4=1.6.2 - pymde=0.1.18 + - pytorch>=2.0 - pip - pip: - click @@ -19,3 +20,4 @@ dependencies: - scikit-learn==1.0.2 - scikit-misc==0.1.4 - remote-pdb==2.1.0 + - scvi-tools diff --git a/modules/fit_scviv2.nf b/modules/fit_scviv2.nf index 739e85b..cb6d761 100644 --- a/modules/fit_scviv2.nf +++ b/modules/fit_scviv2.nf @@ -9,6 +9,7 @@ process fit_scviv2 { val use_attention_no_prior_mog val use_attention_mog val use_attention_no_prior_mog_large + val use_ibd_config script: adata_name = adata_in.getSimpleName() @@ -38,6 +39,9 @@ process fit_scviv2 { else if (use_attention_no_prior_mog_large) { method_name = "scviv2_attention_no_prior_mog_large" } + else if (use_ibd_config) { + method_name = "scviv2_ibd" + } else { method_name = "scviv2" } @@ -55,6 +59,7 @@ process fit_scviv2 { --use_attention_noprior ${use_attention_noprior} \\ --use_attention_no_prior_mog ${use_attention_no_prior_mog} \\ --use_attention_mog ${use_attention_mog} \\ + --use_ibd_config ${use_ibd_config} \\ --use_attention_no_prior_mog_large ${use_attention_no_prior_mog_large} """ diff --git a/nextflow.config b/nextflow.config index 845e71c..1e7b2d0 100644 --- a/nextflow.config +++ b/nextflow.config @@ -28,13 +28,13 @@ params { env { root = "${projectDir}/env/${params.profile}" preprocess_data = "${params.env.root}/preprocess_data.yaml" - run_models_jax = "${params.env.root}/run_models_jax.yaml" + run_models_jax = "/data1/mambaforge/envs/run-models-jax" run_milo = "${params.env.root}/run_milo.yaml" - run_models_torch = "${params.env.root}/run_models_torch.yaml" - run_models = "${params.env.root}/run_models_jax.yaml" + run_models_torch = "/data1/mambaforge/envs/run-models-torch" + run_models = "${params.env.root}/run_models_torch.yaml" compute_metrics = "${params.env.root}/compute_metrics.yaml" analyze_results = "${params.env.root}/analyze_results.yaml" - compute_2dreps = "${params.env.root}/compute_2dreps.yaml" + compute_2dreps = "/data1/mambaforge/envs/compute-2drep" } conf { root = "${projectDir}/conf" @@ -106,7 +106,7 @@ process { conda = "${params.env.preprocess_data}" } withName: "fit_scviv2" { - conda = "${params.env.run_models}" + conda = "${params.env.run_models_jax}" } withName: "run_milo" { conda = "${params.env.run_milo}" @@ -115,10 +115,10 @@ process { conda = "${params.env.run_milo}" } withName: "get_latent_scviv2" { - conda = "${params.env.run_models}" + conda = "${params.env.run_models_jax}" } withName: "fit_and_get_latent_composition_baseline" { - conda = "${params.env.run_models}" + conda = "${params.env.run_models_torch}" } withName: "scib" { conda = "${params.env.compute_metrics}" diff --git a/subworkflows/run_models/main.nf b/subworkflows/run_models/main.nf index 87f21d8..68efa60 100644 --- a/subworkflows/run_models/main.nf +++ b/subworkflows/run_models/main.nf @@ -8,6 +8,7 @@ include { fit_scviv2 as fit_scviv2_attention_no_prior_mog; fit_scviv2 as fit_scviv2_attention_mog; fit_scviv2 as fit_scviv2_attention_no_prior_mog_large; + fit_scviv2 as fit_scviv2_ibd_config; } from params.modules.fit_scviv2 include { get_latent_scviv2; @@ -19,6 +20,7 @@ include { get_latent_scviv2 as get_latent_scviv2_attention_no_prior_mog; get_latent_scviv2 as get_latent_scviv2_attention_mog; get_latent_scviv2 as get_latent_scviv2_attention_no_prior_mog_large; + get_latent_scviv2 as get_latent_scviv2_ibd_config; } from params.modules.get_latent_scviv2 include { fit_and_get_latent_composition_baseline as fit_and_get_latent_composition_scvi_clusterkey; @@ -41,18 +43,21 @@ workflow run_models { // Step 1: Run models // Run base model - scvi_attention_noprior_outs = fit_scviv2_attention_noprior(adatas_in, false, false, false, false, true, false, false, false) | get_latent_scviv2_attention_noprior + scvi_attention_noprior_outs = fit_scviv2_attention_noprior(adatas_in, false, false, false, false, true, false, false, false, false) | get_latent_scviv2_attention_noprior scvi_attention_noprior_adata = scvi_attention_noprior_outs.adata - scvi_attention_no_prior_mog_outs = fit_scviv2_attention_no_prior_mog(adatas_in, false, false, false, false, false, true, false, false) | get_latent_scviv2_attention_no_prior_mog + scvi_attention_no_prior_mog_outs = fit_scviv2_attention_no_prior_mog(adatas_in, false, false, false, false, false, true, false, false, false) | get_latent_scviv2_attention_no_prior_mog scvi_attention_no_prior_mog_adata = scvi_attention_no_prior_mog_outs.adata - scvi_attention_mog_outs = fit_scviv2_attention_mog(adatas_in, false, false, false, false, false, false, true, false) | get_latent_scviv2_attention_mog + scvi_attention_mog_outs = fit_scviv2_attention_mog(adatas_in, false, false, false, false, false, false, true, false, false) | get_latent_scviv2_attention_mog scvi_attention_mog_adata = scvi_attention_mog_outs.adata - scvi_attention_no_prior_mog_large_outs = fit_scviv2_attention_no_prior_mog_large(adatas_in, false, false, false, false, false, false, false, true) | get_latent_scviv2_attention_no_prior_mog_large + scvi_attention_no_prior_mog_large_outs = fit_scviv2_attention_no_prior_mog_large(adatas_in, false, false, false, false, false, false, false, true, false) | get_latent_scviv2_attention_no_prior_mog_large scvi_attention_no_prior_mog_large_adata = scvi_attention_no_prior_mog_large_outs.adata + scvi_attention_ibd_config_outs = fit_scviv2_ibd_config(adatas_in, false, false, false, false, false, false, false, false, true) | get_latent_scviv2_ibd_config + scvi_attention_ibd_config_adata = scvi_attention_ibd_config_outs.adata + distance_matrices = scvi_attention_no_prior_mog_large_outs.distance_matrices.concat( scvi_attention_no_prior_mog_large_outs.normalized_distance_matrices, ) @@ -71,6 +76,7 @@ workflow run_models { scvi_attention_no_prior_mog_adata, scvi_attention_mog_adata, scvi_attention_no_prior_mog_large_adata, + scvi_attention_ibd_config_adata, ) if ( params.runMILO ) { @@ -79,21 +85,21 @@ workflow run_models { } if ( params.runAllMRVIModels ) { - scvi_outs = fit_scviv2(adatas_in, false, false, false, false, false, false, false) | get_latent_scviv2 + scvi_outs = fit_scviv2(adatas_in, false, false, false, false, false, false, false, false) | get_latent_scviv2 scvi_adata = scvi_outs.adata // Run scviv2 mlp - scvi_mlp_outs = fit_scviv2_mlp(adatas_in, true, false, false, false, false, false, false) | get_latent_scviv2_mlp + scvi_mlp_outs = fit_scviv2_mlp(adatas_in, true, false, false, false, false, false, false, false) | get_latent_scviv2_mlp scvi_mlp_adata = scvi_mlp_outs.adata // Run scviv2 mlp smallu - scvi_mlp_smallu_outs = fit_scviv2_mlp_smallu(adatas_in, false, true, false, false, false, false, false) | get_latent_scviv2_mlp_smallu + scvi_mlp_smallu_outs = fit_scviv2_mlp_smallu(adatas_in, false, true, false, false, false, false, false, false) | get_latent_scviv2_mlp_smallu scvi_mlp_smallu_adata = scvi_mlp_smallu_outs.adata - scvi_attention_outs = fit_scviv2_attention(adatas_in, false, false, true, false, false, false, false) | get_latent_scviv2_attention + scvi_attention_outs = fit_scviv2_attention(adatas_in, false, false, true, false, false, false, false, false) | get_latent_scviv2_attention scvi_attention_adata = scvi_attention_outs.adata - scvi_attention_smallu_outs = fit_scviv2_attention_smallu(adatas_in, false, false, false, true, false, false, false) | get_latent_scviv2_attention_smallu + scvi_attention_smallu_outs = fit_scviv2_attention_smallu(adatas_in, false, false, false, true, false, false, false, false) | get_latent_scviv2_attention_smallu scvi_attention_smallu_adata = scvi_attention_smallu_outs.adata distance_matrices = distance_matrices.concat(