Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

External to preprocess #473

Merged
merged 11 commits into from
May 4, 2023
Empty file added __init__.py
Empty file.
8 changes: 0 additions & 8 deletions dynamo/external/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,14 @@

from .gseapy import enrichr
from .hodge import ddhodge
from .pearson_residual_recipe import (
normalize_layers_pearson_residuals,
select_genes_by_pearson_residuals,
)
from .scifate import scifate_glmnet
from .scribe import coexp_measure, coexp_measure_mat, scribe
from .sctransform import sctransform

__all__ = [
"enrichr",
"ddhodge",
"normalize_layers_pearson_residuals",
"select_genes_by_pearson_residuals",
"scifate_glmnet",
"coexp_measure",
"coexp_measure_mat",
"scribe",
"sctransform",
]
3 changes: 3 additions & 0 deletions dynamo/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from .scatters import scatters
from .scPotential import show_landscape
from .sctransform import sctransform_plot_fit, plot_residual_var
from .scVectorField import ( # , plot_LIC_gray
cell_wise_vectors,
cell_wise_vectors_3d,
Expand Down Expand Up @@ -150,4 +151,6 @@
"causality",
"comb_logic",
"hessian",
"sctransform_plot_fit",
"plot_residual_var",
]
149 changes: 149 additions & 0 deletions dynamo/plot/sctransform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import Optional

from anndata import AnnData
from matplotlib.axes import Axes
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

def sctransform_plot_fit(
adata: AnnData,
xaxis: str = "gmean",
fig: Optional[Figure] = None,
) -> Figure:
"""Plot the fitting of model parameters in sctransform.

Args:
adata: annotated data matrix after sctransform.
xaxis: the gene expression metric is plotted on the x-axis.
fig: Matplotlib figure object to use for the plot. If not provided, a new figure is created.

Returns:
The matplotlib figure object containing the plot.
"""
if fig is None:
fig = plt.figure(figsize=(12, 3))
gene_names = adata.var['genes_step1_sct'][
~adata.var['genes_step1_sct'].isna()].index

genes_log10_mean = adata.var["log10_gmean_sct"]
genes_log_gmean = genes_log10_mean[~genes_log10_mean.isna()]

model_params_fit = pd.concat(
[adata.var["log_umi_sct"], adata.var["Intercept_sct"], adata.var["theta_sct"]], axis=1)
model_params = pd.concat(
[adata.var["log_umi_step1_sct"], adata.var["Intercept_step1_sct"], adata.var["model_pars_theta_step1"]],
axis=1)
model_params_fit = model_params_fit.rename(
columns={"log_umi_sct": "log_umi", "Intercept_sct": "Intercept", "theta_sct": "theta"})
model_params = model_params.rename(
columns={"log_umi_step1_sct": "log_umi",
"Intercept_step1_sct": "Intercept",
"model_pars_theta_step1": "theta"})

model_params = model_params.loc[gene_names]

total_params = model_params_fit.shape[1]

for index, column in enumerate(model_params_fit.columns):
ax = fig.add_subplot(1, total_params, index + 1)
model_param_col = model_params[column]

# model_param_outliers = is_outlier(model_param_col)
if column != "theta":
ax.scatter(
genes_log_gmean, # [~model_param_outliers],
model_param_col, # [~model_param_outliers],
s=1,
label="single gene estimate",
color="#2b8cbe",
)
ax.scatter(
genes_log10_mean,
model_params_fit[column],
s=2,
label="regularized",
color="#de2d26",
)
ax.set_ylabel(column)
else:
ax.scatter(
genes_log_gmean, # [~model_param_outliers],
np.log10(model_param_col), # [~model_param_outliers],
s=1,
label="single gene estimate",
color="#2b8cbe",
)
ax.scatter(
genes_log10_mean,
np.log10(model_params_fit[column]),
s=2,
label="regularized",
color="#de2d26",
)
ax.set_ylabel("log10(" + column + ")")
if column == "od_factor":
ax.set_ylabel("log10(od_factor)")

ax.set_xlabel("log10(gene_{})".format(xaxis))
ax.set_title(column)
ax.legend(frameon=False)
_ = fig.tight_layout()
return fig

def plot_residual_var(
adata: AnnData,
topngenes: int = 10,
label_genes: bool = True,
ax: Optional[Axes] = None,
) -> Figure:
"""Plot the relationship between the mean and variance of gene expression across cells, highlighting the genes with
the highest residual variance.

Args:
adata: annotated data matrix after sctransform.
topngenes: the number of genes with the highest residual variance to highlight in the plot.
label_genes: whether to label the highlighted genes in the plot. If `topngenes` is large, labeling genes may
lead to plotting error because of the space limitation.
ax: the axes on which to draw the plot. If None, a new figure and axes are created.

Returns:
The Figure object if `ax` is not given, else None.
"""
def vars(a, axis=None):
"""Helper function to calculate variance of sparse matrix by equation: var = mean(a**2) - mean(a)**2"""
a_squared = a.copy()
a_squared.data **= 2
return a_squared.mean(axis) - np.square(a.mean(axis))

if ax is None:
fig, ax = plt.subplots(figsize=(8, 5))
else:
fig = None

gene_attr = pd.DataFrame(adata.var['log10_gmean_sct'])
# gene_attr = gene_attr.loc[gene_names]
gene_attr["var"] = vars(adata.X, axis=0).tolist()[0]
gene_attr["mean"] = adata.X.mean(axis=0).tolist()[0]
gene_attr_sorted = gene_attr.sort_values(
"var", ascending=False
).reset_index()
topn = gene_attr_sorted.iloc[:topngenes]
gene_attr = gene_attr_sorted.iloc[topngenes:]
ax.set_xscale("log")

ax.scatter(
gene_attr["mean"], gene_attr["var"], s=1.5, color="black"
)
ax.scatter(topn["mean"], topn["var"], s=1.5, color="deeppink")
ax.axhline(1, linestyle="dashed", color="red")
ax.set_xlabel("mean")
ax.set_ylabel("var")
if label_genes:
texts = [
plt.text(row["mean"], row["var"], row["index"])
for index, row in topn.iterrows()
]
fig.tight_layout()
return fig
2 changes: 1 addition & 1 deletion dynamo/preprocessing/Preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
main_info_insert_adata,
main_warning,
)
from ..external import (
from .external import (
normalize_layers_pearson_residuals,
sctransform,
select_genes_by_pearson_residuals,
Expand Down
12 changes: 12 additions & 0 deletions dynamo/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@

from .cell_cycle import cell_cycle_scores
from .dynast import lambda_correction
from .external import (
harmony_debatch,
integrate,
normalize_layers_pearson_residuals,
sctransform,
select_genes_by_pearson_residuals,
)
from .preprocess import (
calc_sz_factor_legacy,
filter_cells_by_outliers,
Expand Down Expand Up @@ -45,12 +52,14 @@
"normalize_cells",
"lambda_correction",
"calc_sz_factor_legacy",
"normalize_layers_pearson_residuals",
"normalize",
"recipe_monocle",
"recipe_velocyto",
"calc_Gini",
"filter_cells_by_outliers",
"select_genes_monocle",
"select_genes_by_pearson_residuals",
"filter_genes",
"filter_genes_by_outliers",
"filter_genes_by_clusters_",
Expand All @@ -64,6 +73,7 @@
"top_pca_genes",
"relative2abs",
"scale",
"sctransform",
"convert2symbol",
"filter_genes_by_pattern",
"decode",
Expand All @@ -72,4 +82,6 @@
"log1p",
"log1p",
"log1p_adata_layer",
"harmony_debatch",
"integrate",
]
14 changes: 14 additions & 0 deletions dynamo/preprocessing/external/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .integration import harmony_debatch, integrate
from .pearson_residual_recipe import (
normalize_layers_pearson_residuals,
select_genes_by_pearson_residuals,
)
from .sctransform import sctransform

__all__ = [
"normalize_layers_pearson_residuals",
"sctransform",
"select_genes_by_pearson_residuals",
"harmony_debatch",
"integrate",
]
128 changes: 128 additions & 0 deletions dynamo/preprocessing/external/integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import List, Optional, Union

import numpy as np
from anndata import AnnData
from scipy.sparse import csr_matrix, isspmatrix

# Convert sparse matrix to dense matrix.
to_dense_matrix = lambda X: np.array(X.todense()) if isspmatrix(X) else np.asarray(X)

def integrate(
adatas: List[AnnData],
batch_key: str = "slices",
fill_value: Union[int, float] = 0,
) -> AnnData:
"""Concatenating all anndata objects.

Args:
adatas: AnnData matrices to concatenate with.
batch_key: the key to add the batch annotation to :attr:`obs`.
fill_value: Scalar value to fill newly missing values in arrays with.

Returns:
The concatenated AnnData, where adata.obs[batch_key] stores a categorical variable labeling the batch.
"""

batch_ca = [adata.obs[batch_key][0] for adata in adatas]

# Merge the obsm, varm and uns data of all anndata objcets separately.
obsm_dict, varm_dict, uns_dict = {}, {}, {}
obsm_keys, varm_keys, uns_keys = [], [], []
for adata in adatas:
obsm_keys.extend(list(adata.obsm.keys()))
varm_keys.extend(list(adata.varm.keys()))
uns_keys.extend(list(adata.uns_keys()))

obsm_keys, varm_keys, uns_keys = list(set(obsm_keys)), list(set(varm_keys)), list(set(uns_keys))
n_obsm_keys, n_varm_keys, n_uns_keys = len(obsm_keys), len(varm_keys), len(uns_keys)

if n_obsm_keys > 0:
for key in obsm_keys:
obsm_dict[key] = np.concatenate([to_dense_matrix(adata.obsm[key]) for adata in adatas], axis=0)
if n_varm_keys > 0:
for key in varm_keys:
varm_dict[key] = np.concatenate([to_dense_matrix(adata.varm[key]) for adata in adatas], axis=0)
if n_uns_keys > 0:
for key in uns_keys:
if "__type" in uns_keys and key == "__type":
uns_dict["__type"] = adatas[0].uns["__type"]
else:
uns_dict[key] = {
ca: adata.uns[key] if key in adata.uns_keys() else None for ca, adata in zip(batch_ca, adatas)
}

# Delete obsm, varm and uns data.
for adata in adatas:
del adata.obsm, adata.varm, adata.uns

# Concatenating obs and var data which will ignore the uns, obsm, varm attributes.
integrated_adata = adatas[0].concatenate(
*adatas[1:],
batch_key=batch_key,
batch_categories=batch_ca,
join="outer",
fill_value=fill_value,
uns_merge=None,
)

# Add Concatenated obsm data and varm data to integrated anndata object.
if n_obsm_keys > 0:
for key, value in obsm_dict.items():
integrated_adata.obsm[key] = value
if n_varm_keys > 0:
for key, value in varm_dict.items():
integrated_adata.varm[key] = value
if n_uns_keys > 0:
for key, value in uns_dict.items():
integrated_adata.uns[key] = value

return integrated_adata

def harmony_debatch(
adata: AnnData,
key: str,
basis: str = "X_pca",
adjusted_basis: str = "X_pca_harmony",
max_iter_harmony: int = 10,
copy: bool = False,
) -> Optional[AnnData]:
"""Use harmonypy [Korunsky19]_ to remove batch effects.

This function should be run after performing PCA but before computing the neighbor graph. Original Code Repository
is https://github.com/slowkow/harmonypy. Interesting example: https://slowkow.com/notes/harmony-animation/

Args:
adata: An Anndata object.
key: The name of the column in ``adata.obs`` that differentiates among experiments/batches.
basis: The name of the field in ``adata.obsm`` where the PCA table is stored.
adjusted_basis: The name of the field in ``adata.obsm`` where the adjusted PCA table will be stored after
running this function.
max_iter_harmony: Maximum number of rounds to run Harmony. One round of Harmony involves one clustering and one
correction step.
copy: Whether to copy `adata` or modify it inplace.

Returns:
Updates adata with the field ``adata.obsm[adjusted_basis]``, containing principal components adjusted by
Harmony.
"""
try:
import harmonypy
except ImportError:
raise ImportError("\nplease install harmonypy:\n\n\tpip install harmonypy")

adata = adata.copy() if copy else adata

# Convert sparse matrix to dense matrix.
matrix = to_dense_matrix(adata.obsm[basis])

# Use Harmony to adjust the PCs.
harmony_out = harmonypy.run_harmony(matrix, adata.obs, key, max_iter_harmony=max_iter_harmony)
adjusted_matrix = harmony_out.Z_corr.T

# Convert dense matrix to sparse matrix.
if isspmatrix(adata.obsm[basis]):
adjusted_matrix = csr_matrix(adjusted_matrix)

adata.obsm[adjusted_basis] = adjusted_matrix

return adata if copy else None
Loading