Skip to content

Commit

Permalink
Merge pull request #473 from aristoteleo/external_to_preprocess
Browse files Browse the repository at this point in the history
External to preprocess
  • Loading branch information
Xiaojieqiu authored May 4, 2023
2 parents ff43afe + a1effbc commit b8aecad
Show file tree
Hide file tree
Showing 10 changed files with 597 additions and 221 deletions.
Empty file added __init__.py
Empty file.
8 changes: 0 additions & 8 deletions dynamo/external/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,14 @@

from .gseapy import enrichr
from .hodge import ddhodge
from .pearson_residual_recipe import (
normalize_layers_pearson_residuals,
select_genes_by_pearson_residuals,
)
from .scifate import scifate_glmnet
from .scribe import coexp_measure, coexp_measure_mat, scribe
from .sctransform import sctransform

__all__ = [
"enrichr",
"ddhodge",
"normalize_layers_pearson_residuals",
"select_genes_by_pearson_residuals",
"scifate_glmnet",
"coexp_measure",
"coexp_measure_mat",
"scribe",
"sctransform",
]
3 changes: 3 additions & 0 deletions dynamo/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from .scatters import scatters
from .scPotential import show_landscape
from .sctransform import sctransform_plot_fit, plot_residual_var
from .scVectorField import ( # , plot_LIC_gray
cell_wise_vectors,
cell_wise_vectors_3d,
Expand Down Expand Up @@ -150,4 +151,6 @@
"causality",
"comb_logic",
"hessian",
"sctransform_plot_fit",
"plot_residual_var",
]
149 changes: 149 additions & 0 deletions dynamo/plot/sctransform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import Optional

from anndata import AnnData
from matplotlib.axes import Axes
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

def sctransform_plot_fit(
adata: AnnData,
xaxis: str = "gmean",
fig: Optional[Figure] = None,
) -> Figure:
"""Plot the fitting of model parameters in sctransform.
Args:
adata: annotated data matrix after sctransform.
xaxis: the gene expression metric is plotted on the x-axis.
fig: Matplotlib figure object to use for the plot. If not provided, a new figure is created.
Returns:
The matplotlib figure object containing the plot.
"""
if fig is None:
fig = plt.figure(figsize=(12, 3))
gene_names = adata.var['genes_step1_sct'][
~adata.var['genes_step1_sct'].isna()].index

genes_log10_mean = adata.var["log10_gmean_sct"]
genes_log_gmean = genes_log10_mean[~genes_log10_mean.isna()]

model_params_fit = pd.concat(
[adata.var["log_umi_sct"], adata.var["Intercept_sct"], adata.var["theta_sct"]], axis=1)
model_params = pd.concat(
[adata.var["log_umi_step1_sct"], adata.var["Intercept_step1_sct"], adata.var["model_pars_theta_step1"]],
axis=1)
model_params_fit = model_params_fit.rename(
columns={"log_umi_sct": "log_umi", "Intercept_sct": "Intercept", "theta_sct": "theta"})
model_params = model_params.rename(
columns={"log_umi_step1_sct": "log_umi",
"Intercept_step1_sct": "Intercept",
"model_pars_theta_step1": "theta"})

model_params = model_params.loc[gene_names]

total_params = model_params_fit.shape[1]

for index, column in enumerate(model_params_fit.columns):
ax = fig.add_subplot(1, total_params, index + 1)
model_param_col = model_params[column]

# model_param_outliers = is_outlier(model_param_col)
if column != "theta":
ax.scatter(
genes_log_gmean, # [~model_param_outliers],
model_param_col, # [~model_param_outliers],
s=1,
label="single gene estimate",
color="#2b8cbe",
)
ax.scatter(
genes_log10_mean,
model_params_fit[column],
s=2,
label="regularized",
color="#de2d26",
)
ax.set_ylabel(column)
else:
ax.scatter(
genes_log_gmean, # [~model_param_outliers],
np.log10(model_param_col), # [~model_param_outliers],
s=1,
label="single gene estimate",
color="#2b8cbe",
)
ax.scatter(
genes_log10_mean,
np.log10(model_params_fit[column]),
s=2,
label="regularized",
color="#de2d26",
)
ax.set_ylabel("log10(" + column + ")")
if column == "od_factor":
ax.set_ylabel("log10(od_factor)")

ax.set_xlabel("log10(gene_{})".format(xaxis))
ax.set_title(column)
ax.legend(frameon=False)
_ = fig.tight_layout()
return fig

def plot_residual_var(
adata: AnnData,
topngenes: int = 10,
label_genes: bool = True,
ax: Optional[Axes] = None,
) -> Figure:
"""Plot the relationship between the mean and variance of gene expression across cells, highlighting the genes with
the highest residual variance.
Args:
adata: annotated data matrix after sctransform.
topngenes: the number of genes with the highest residual variance to highlight in the plot.
label_genes: whether to label the highlighted genes in the plot. If `topngenes` is large, labeling genes may
lead to plotting error because of the space limitation.
ax: the axes on which to draw the plot. If None, a new figure and axes are created.
Returns:
The Figure object if `ax` is not given, else None.
"""
def vars(a, axis=None):
"""Helper function to calculate variance of sparse matrix by equation: var = mean(a**2) - mean(a)**2"""
a_squared = a.copy()
a_squared.data **= 2
return a_squared.mean(axis) - np.square(a.mean(axis))

if ax is None:
fig, ax = plt.subplots(figsize=(8, 5))
else:
fig = None

gene_attr = pd.DataFrame(adata.var['log10_gmean_sct'])
# gene_attr = gene_attr.loc[gene_names]
gene_attr["var"] = vars(adata.X, axis=0).tolist()[0]
gene_attr["mean"] = adata.X.mean(axis=0).tolist()[0]
gene_attr_sorted = gene_attr.sort_values(
"var", ascending=False
).reset_index()
topn = gene_attr_sorted.iloc[:topngenes]
gene_attr = gene_attr_sorted.iloc[topngenes:]
ax.set_xscale("log")

ax.scatter(
gene_attr["mean"], gene_attr["var"], s=1.5, color="black"
)
ax.scatter(topn["mean"], topn["var"], s=1.5, color="deeppink")
ax.axhline(1, linestyle="dashed", color="red")
ax.set_xlabel("mean")
ax.set_ylabel("var")
if label_genes:
texts = [
plt.text(row["mean"], row["var"], row["index"])
for index, row in topn.iterrows()
]
fig.tight_layout()
return fig
2 changes: 1 addition & 1 deletion dynamo/preprocessing/Preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
main_info_insert_adata,
main_warning,
)
from ..external import (
from .external import (
normalize_layers_pearson_residuals,
sctransform,
select_genes_by_pearson_residuals,
Expand Down
12 changes: 12 additions & 0 deletions dynamo/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@

from .cell_cycle import cell_cycle_scores
from .dynast import lambda_correction
from .external import (
harmony_debatch,
integrate,
normalize_layers_pearson_residuals,
sctransform,
select_genes_by_pearson_residuals,
)
from .preprocess import (
calc_sz_factor_legacy,
filter_cells_by_outliers,
Expand Down Expand Up @@ -45,12 +52,14 @@
"normalize_cells",
"lambda_correction",
"calc_sz_factor_legacy",
"normalize_layers_pearson_residuals",
"normalize",
"recipe_monocle",
"recipe_velocyto",
"calc_Gini",
"filter_cells_by_outliers",
"select_genes_monocle",
"select_genes_by_pearson_residuals",
"filter_genes",
"filter_genes_by_outliers",
"filter_genes_by_clusters_",
Expand All @@ -64,6 +73,7 @@
"top_pca_genes",
"relative2abs",
"scale",
"sctransform",
"convert2symbol",
"filter_genes_by_pattern",
"decode",
Expand All @@ -72,4 +82,6 @@
"log1p",
"log1p",
"log1p_adata_layer",
"harmony_debatch",
"integrate",
]
14 changes: 14 additions & 0 deletions dynamo/preprocessing/external/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .integration import harmony_debatch, integrate
from .pearson_residual_recipe import (
normalize_layers_pearson_residuals,
select_genes_by_pearson_residuals,
)
from .sctransform import sctransform

__all__ = [
"normalize_layers_pearson_residuals",
"sctransform",
"select_genes_by_pearson_residuals",
"harmony_debatch",
"integrate",
]
128 changes: 128 additions & 0 deletions dynamo/preprocessing/external/integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import List, Optional, Union

import numpy as np
from anndata import AnnData
from scipy.sparse import csr_matrix, isspmatrix

# Convert sparse matrix to dense matrix.
to_dense_matrix = lambda X: np.array(X.todense()) if isspmatrix(X) else np.asarray(X)

def integrate(
adatas: List[AnnData],
batch_key: str = "slices",
fill_value: Union[int, float] = 0,
) -> AnnData:
"""Concatenating all anndata objects.
Args:
adatas: AnnData matrices to concatenate with.
batch_key: the key to add the batch annotation to :attr:`obs`.
fill_value: Scalar value to fill newly missing values in arrays with.
Returns:
The concatenated AnnData, where adata.obs[batch_key] stores a categorical variable labeling the batch.
"""

batch_ca = [adata.obs[batch_key][0] for adata in adatas]

# Merge the obsm, varm and uns data of all anndata objcets separately.
obsm_dict, varm_dict, uns_dict = {}, {}, {}
obsm_keys, varm_keys, uns_keys = [], [], []
for adata in adatas:
obsm_keys.extend(list(adata.obsm.keys()))
varm_keys.extend(list(adata.varm.keys()))
uns_keys.extend(list(adata.uns_keys()))

obsm_keys, varm_keys, uns_keys = list(set(obsm_keys)), list(set(varm_keys)), list(set(uns_keys))
n_obsm_keys, n_varm_keys, n_uns_keys = len(obsm_keys), len(varm_keys), len(uns_keys)

if n_obsm_keys > 0:
for key in obsm_keys:
obsm_dict[key] = np.concatenate([to_dense_matrix(adata.obsm[key]) for adata in adatas], axis=0)
if n_varm_keys > 0:
for key in varm_keys:
varm_dict[key] = np.concatenate([to_dense_matrix(adata.varm[key]) for adata in adatas], axis=0)
if n_uns_keys > 0:
for key in uns_keys:
if "__type" in uns_keys and key == "__type":
uns_dict["__type"] = adatas[0].uns["__type"]
else:
uns_dict[key] = {
ca: adata.uns[key] if key in adata.uns_keys() else None for ca, adata in zip(batch_ca, adatas)
}

# Delete obsm, varm and uns data.
for adata in adatas:
del adata.obsm, adata.varm, adata.uns

# Concatenating obs and var data which will ignore the uns, obsm, varm attributes.
integrated_adata = adatas[0].concatenate(
*adatas[1:],
batch_key=batch_key,
batch_categories=batch_ca,
join="outer",
fill_value=fill_value,
uns_merge=None,
)

# Add Concatenated obsm data and varm data to integrated anndata object.
if n_obsm_keys > 0:
for key, value in obsm_dict.items():
integrated_adata.obsm[key] = value
if n_varm_keys > 0:
for key, value in varm_dict.items():
integrated_adata.varm[key] = value
if n_uns_keys > 0:
for key, value in uns_dict.items():
integrated_adata.uns[key] = value

return integrated_adata

def harmony_debatch(
adata: AnnData,
key: str,
basis: str = "X_pca",
adjusted_basis: str = "X_pca_harmony",
max_iter_harmony: int = 10,
copy: bool = False,
) -> Optional[AnnData]:
"""Use harmonypy [Korunsky19]_ to remove batch effects.
This function should be run after performing PCA but before computing the neighbor graph. Original Code Repository
is https://github.com/slowkow/harmonypy. Interesting example: https://slowkow.com/notes/harmony-animation/
Args:
adata: An Anndata object.
key: The name of the column in ``adata.obs`` that differentiates among experiments/batches.
basis: The name of the field in ``adata.obsm`` where the PCA table is stored.
adjusted_basis: The name of the field in ``adata.obsm`` where the adjusted PCA table will be stored after
running this function.
max_iter_harmony: Maximum number of rounds to run Harmony. One round of Harmony involves one clustering and one
correction step.
copy: Whether to copy `adata` or modify it inplace.
Returns:
Updates adata with the field ``adata.obsm[adjusted_basis]``, containing principal components adjusted by
Harmony.
"""
try:
import harmonypy
except ImportError:
raise ImportError("\nplease install harmonypy:\n\n\tpip install harmonypy")

adata = adata.copy() if copy else adata

# Convert sparse matrix to dense matrix.
matrix = to_dense_matrix(adata.obsm[basis])

# Use Harmony to adjust the PCs.
harmony_out = harmonypy.run_harmony(matrix, adata.obs, key, max_iter_harmony=max_iter_harmony)
adjusted_matrix = harmony_out.Z_corr.T

# Convert dense matrix to sparse matrix.
if isspmatrix(adata.obsm[basis]):
adjusted_matrix = csr_matrix(adjusted_matrix)

adata.obsm[adjusted_basis] = adjusted_matrix

return adata if copy else None
Loading

0 comments on commit b8aecad

Please sign in to comment.