Skip to content

Commit

Permalink
Merge pull request #469 from aristoteleo/optimize_PCA
Browse files Browse the repository at this point in the history
Optimize pca
  • Loading branch information
Xiaojieqiu authored Apr 5, 2023
2 parents ba43701 + 7801822 commit 9c38bc2
Show file tree
Hide file tree
Showing 14 changed files with 413 additions and 149 deletions.
2 changes: 1 addition & 1 deletion dynamo/external/pearson_residual_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
is_nonnegative_integer_arr,
seurat_get_mean_var,
)
from ..preprocessing.utils import pca_monocle
from ..preprocessing.utils import pca

main_logger = LoggerManager.main_logger

Expand Down
243 changes: 159 additions & 84 deletions dynamo/preprocessing/Preprocessor.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions dynamo/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
cook_dist,
decode,
filter_genes_by_pattern,
pca_monocle,
pca,
relative2abs,
scale,
top_pca_genes,
Expand Down Expand Up @@ -65,7 +65,7 @@
"cell_cycle_scores",
"basic_stats",
"cook_dist",
"pca_monocle",
"pca",
"top_pca_genes",
"relative2abs",
"scale",
Expand Down
6 changes: 3 additions & 3 deletions dynamo/preprocessing/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
get_sz_exprs,
merge_adata_attrs,
normalize_mat_monocle,
pca_monocle,
pca,
sz_util,
unique_var_obs_adata,
)
Expand Down Expand Up @@ -1285,7 +1285,7 @@ def recipe_monocle(
logger.info("applying %s ..." % (method.upper()))

if method == "pca":
adata = pca_monocle(adata, pca_input, num_dim, "X_" + method.lower())
adata = pca(adata, pca_input, num_dim, "X_" + method.lower())
# TODO remove adata.obsm["X"] in future, use adata.obsm.X_pca instead
adata.obsm["X"] = adata.obsm["X_" + method.lower()]

Expand Down Expand Up @@ -1438,7 +1438,7 @@ def recipe_velocyto(
CM = CM[:, valid_ind]

if method == "pca":
adata, fit, _ = pca_monocle(adata, CM, num_dim, "X_" + method.lower(), return_all=True)
adata, fit, _ = pca(adata, CM, num_dim, "X_" + method.lower(), return_all=True)
# adata.obsm['X_' + method.lower()] = reduce_dim

elif method == "ica":
Expand Down
8 changes: 4 additions & 4 deletions dynamo/preprocessing/preprocessor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
get_sz_exprs,
merge_adata_attrs,
normalize_mat_monocle,
pca_monocle,
pca,
sz_util,
unique_var_obs_adata,
)
Expand All @@ -65,7 +65,7 @@ def is_log1p_transformed_adata(adata: anndata.AnnData) -> bool:
A flag shows whether the adata object is log transformed.
"""

chosen_gene_indices = np.random.choice(adata.n_obs, 10)
chosen_gene_indices = np.random.choice(adata.n_vars, 10)
_has_log1p_transformed = not np.allclose(
np.array(adata.X[:, chosen_gene_indices].sum(1)),
np.array(adata.layers["spliced"][:, chosen_gene_indices].sum(1)),
Expand Down Expand Up @@ -1528,7 +1528,7 @@ def is_nonnegative_integer_arr(mat: Union[np.ndarray, spmatrix, list]) -> bool:
def pca_selected_genes_wrapper(
adata: AnnData, pca_input: Union[np.ndarray, None] = None, n_pca_components: int = 30, key: str = "X_pca"
):
"""A wrapper for pca_monocle function to reduce dimensions of the Adata with PCA.
"""A wrapper for pca function to reduce dimensions of the Adata with PCA.
Args:
adata: an AnnData object.
Expand All @@ -1537,4 +1537,4 @@ def pca_selected_genes_wrapper(
key: the key to store the calculation result. Defaults to "X_pca".
"""

adata = pca_monocle(adata, pca_input, n_pca_components=n_pca_components, pca_key=key)
adata = pca(adata, pca_input, n_pca_components=n_pca_components, pca_key=key)
236 changes: 204 additions & 32 deletions dynamo/preprocessing/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Callable, Iterable, List, Tuple, Union
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

try:
from typing import Literal
Expand All @@ -14,8 +14,12 @@
import scipy.sparse
import statsmodels.api as sm
from anndata import AnnData
from scipy.sparse import csr_matrix, issparse
from scipy.sparse.linalg import LinearOperator, svds
from scipy.sparse import csc_matrix, csr_matrix, issparse
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.utils import check_random_state
from sklearn.utils.extmath import svd_flip
from sklearn.utils.sparsefuncs import mean_variance_axis

from ..configuration import DKM, DynamoAdataKeyManager
from ..dynamo_logger import (
Expand Down Expand Up @@ -774,38 +778,186 @@ def decode(adata: anndata.AnnData) -> None:
# ---------------------------------------------------------------------------------------------------
# pca

def _truncatedSVD_with_center(
X: Union[csc_matrix, csr_matrix],
n_components: int = 30,
random_state: int = 0,
) -> Dict:
"""Center a sparse matrix and perform truncated SVD on it.
def pca_monocle(
It uses `scipy.sparse.linalg.LinearOperator` to express the centered sparse
input by given matrix-vector and matrix-matrix products. Then truncated
singular value decomposition (SVD) can be solved without calculating the
individual entries of the centered matrix. The right singular vectors after
decomposition represent the principal components. This function is inspired
by the implementation of scanpy (https://github.com/scverse/scanpy).
Args:
X: The input sparse matrix to perform truncated SVD on.
n_components: The number of components to keep. Default is 30.
random_state: Seed for the random number generator. Default is 0.
Returns:
The transformed input matrix and a sklearn PCA object containing the
right singular vectors and amount of variance explained by each
principal component.
"""
random_state = check_random_state(random_state)
np.random.set_state(random_state.get_state())
v0 = random_state.uniform(-1, 1, np.min(X.shape))
n_components = min(n_components, X.shape[1] - 1)

mean = X.mean(0)
X_H = X.T.conj()
mean_H = mean.T.conj()
ones = np.ones(X.shape[0])[None, :].dot

# Following callables implements different type of matrix calculation.
def matvec(x):
"""Matrix-vector multiplication. Performs the operation X_centered*x
where x is a column vector or an 1-D array."""
return X.dot(x) - mean.dot(x)

def matmat(x):
"""Matrix-matrix multiplication. Performs the operation X_centered*x
where x is a matrix or ndarray."""
return X.dot(x) - mean.dot(x)

def rmatvec(x):
"""Adjoint matrix-vector multiplication. Performs the operation
X_centered^H * x where x is a column vector or an 1-d array."""
return X_H.dot(x) - mean_H.dot(ones(x))

def rmatmat(x):
"""Adjoint matrix-matrix multiplication. Performs the operation
X_centered^H * x where x is a matrix or ndarray."""
return X_H.dot(x) - mean_H.dot(ones(x))

# Construct the LinearOperator with callables above.
X_centered = LinearOperator(
shape=X.shape,
matvec=matvec,
matmat=matmat,
rmatvec=rmatvec,
rmatmat=rmatmat,
dtype=X.dtype,
)

# Solve SVD without calculating individuals entries in LinearOperator.
U, Sigma, VT = svds(X_centered, solver='arpack', k=n_components, v0=v0)
Sigma = Sigma[::-1]
U, VT = svd_flip(U[:, ::-1], VT[::-1])
X_transformed = U * Sigma
components_ = VT
exp_var = np.var(X_transformed, axis=0)
_, full_var = mean_variance_axis(X, axis=0)
full_var = full_var.sum()

result_dict = {
"X_pca": X_transformed,
"components_": components_,
"explained_variance_ratio_": exp_var / full_var,
}

fit = PCA(
n_components=n_components,
random_state=random_state,
)
X_pca = result_dict["X_pca"]
fit.components_ = result_dict["components_"]
fit.explained_variance_ratio_ = result_dict[
"explained_variance_ratio_"]

return fit, X_pca

def _pca_fit(
X: np.ndarray,
pca_func: Callable,
n_components: int = 30,
**kwargs,
) -> Tuple:
"""Apply PCA to the input data array X using the specified PCA function.
Args:
X: the input data array of shape (n_samples, n_features).
pca_func: the PCA function to use, which should have a 'fit' and
'transform' method, such as the PCA class or the IncrementalPCA
class from sklearn.decomposition.
n_components: the number of principal components to compute. If
n_components is greater than or equal to the number of features in
X, it will be set to n_features - 1 to avoid overfitting.
**kwargs: any additional keyword arguments that will be passed to the
PCA function.
Returns:
A tuple containing two elements:
- The fitted PCA object, which has a 'fit' and 'transform' method.
- The transformed array X_pca of shape (n_samples, n_components).
"""
fit = pca_func(
n_components=min(n_components, X.shape[1] - 1),
**kwargs,
).fit(X)
X_pca = fit.transform(X)
return fit, X_pca


def pca(
adata: AnnData,
X_data: np.ndarray = None,
n_pca_components: int = 30,
pca_key: str = "X",
pcs_key: str = "PCs",
genes_to_append: Union[List[str], None] = None,
layer: Union[List[str], str, None] = None,
svd_solver: Literal["randomized", "arpack"] = "randomized",
random_state: int = 0,
use_truncated_SVD_threshold: int = 500000,
use_incremental_PCA: bool = False,
incremental_batch_size: Optional[int] = None,
return_all: bool = False,
) -> Union[AnnData, Tuple[AnnData, Union[PCA, TruncatedSVD], np.ndarray]]:
"""Perform PCA reduction for monocle recipe.
When large dataset is used (e.g. 1 million cells are used), Incremental PCA
is recommended to avoid the memory issue. When cell number is less than half
a million, by default PCA or _truncatedSVD_with_center (use sparse matrix
that doesn't explicitly perform centering) will be used. TruncatedSVD is the
fastest method. Unlike other methods which will center the data first, it
performs SVD decomposition on raw input. Only use this when dataset is too
large for other methods.
Args:
adata: an AnnData object.
X_data: the data to perform dimension reduction on. Defaults to None.
n_pca_components: number of PCA components reduced to. Defaults to 30.
pca_key: the key to store the reduced data. Defaults to "X".
pcs_key: the key to store the principle axes in feature space. Defaults to "PCs".
pcs_key: the key to store the principle axes in feature space. Defaults
to "PCs".
genes_to_append: a list of genes should be inspected. Defaults to None.
layer: the layer(s) to perform dimension reduction on. Would be overrided by X_data. Defaults to None.
return_all: whether to return the PCA fit model and the reduced array together with the updated AnnData object.
Defaults to False.
layer: the layer(s) to perform dimension reduction on. Would be
overrided by X_data. Defaults to None.
svd_solver: the svd_solver to solve svd decomposition in PCA.
random_state: the seed used to initialize the random state for PCA.
use_truncated_SVD_threshold: the threshold of observations to use
truncated SVD instead of standard PCA for efficiency.
use_incremental_PCA: whether to use Incremental PCA. Recommend enabling
incremental PCA when dataset is too large to fit in memory.
incremental_batch_size: The number of samples to use for each batch when
performing incremental PCA. If batch_size is None, then batch_size
is inferred from the data and set to 5 * n_features.
return_all: whether to return the PCA fit model and the reduced array
together with the updated AnnData object. Defaults to False.
Raises:
ValueError: layer provided is not invalid.
ValueError: list of genes to append is invalid.
Returns:
The the updated AnnData object with reduced data if `return_all` is False. Otherwise, a tuple (adata, fit,
X_pca), where adata is the updated AnnData object, fit is the fit model for dimension reduction, and X_pca is
the reduced array, will be returned.
The updated AnnData object with reduced data if `return_all` is False.
Otherwise, a tuple (adata, fit, X_pca), where adata is the updated
AnnData object, fit is the fit model for dimension reduction, and X_pca
is the reduced array, will be returned.
"""

# only use genes pass filter (based on use_for_pca) to perform dimension reduction.
Expand Down Expand Up @@ -847,33 +999,53 @@ def pca_monocle(
adata.var.iloc[bad_genes, adata.var.columns.tolist().index("use_for_pca")] = False
X_data = X_data[:, valid_ind]

USE_TRUNCATED_SVD_THRESHOLD = 100000
if adata.n_obs < USE_TRUNCATED_SVD_THRESHOLD:
pca = PCA(
n_components=min(n_pca_components, X_data.shape[1] - 1),
svd_solver="arpack",
random_state=0,
if use_incremental_PCA:
from sklearn.decomposition import IncrementalPCA
fit, X_pca = _pca_fit(
X_data,
pca_func=IncrementalPCA,
n_components=n_pca_components,
batch_size=incremental_batch_size,
)
fit = pca.fit(X_data.toarray()) if issparse(X_data) else pca.fit(X_data)
X_pca = fit.transform(X_data.toarray()) if issparse(X_data) else fit.transform(X_data)
adata.obsm[pca_key] = X_pca
adata.uns[pcs_key] = fit.components_.T
else:
if adata.n_obs < use_truncated_SVD_threshold:
if not issparse(X_data):
fit, X_pca = _pca_fit(
X_data,
pca_func=PCA,
n_components=n_pca_components,
svd_solver=svd_solver,
random_state=random_state,
)
else:
fit, X_pca = _truncatedSVD_with_center(
X_data,
n_components=n_pca_components,
random_state=random_state,
)
else:
# TruncatedSVD is the fastest method we have. It doesn't center the
# data. It only performs SVD decomposition, which is the second part
# in our _truncatedSVD_with_center function.
fit, X_pca = _pca_fit(
X_data,
pca_func=TruncatedSVD,
n_components=n_pca_components + 1,
random_state=random_state
)
# first columns is related to the total UMI (or library size)
X_pca = X_pca[:, 1:]

adata.uns["explained_variance_ratio_"] = fit.explained_variance_ratio_
adata.obsm[pca_key] = X_pca
if use_incremental_PCA or adata.n_obs < use_truncated_SVD_threshold:
adata.uns[pcs_key] = fit.components_.T
adata.uns[
"explained_variance_ratio_"] = fit.explained_variance_ratio_
else:
# unscaled PCA
fit = TruncatedSVD(
n_components=min(n_pca_components + 1, X_data.shape[1] - 1),
random_state=0,
)
# first columns is related to the total UMI (or library size)
X_pca = fit.fit_transform(X_data)[:, 1:]
adata.obsm[pca_key] = X_pca
adata.uns[pcs_key] = fit.components_.T[:, 1:]

adata.uns["explained_variance_ratio_"] = fit.explained_variance_ratio_[1:]

adata.uns["explained_variance_ratio_"] = fit.explained_variance_ratio_[1:]
adata.uns[
"explained_variance_ratio_"] = fit.explained_variance_ratio_[1:]
adata.uns["pca_mean"] = fit.mean_ if hasattr(fit, "mean_") else None

if return_all:
Expand Down
4 changes: 2 additions & 2 deletions dynamo/tools/cell_velocities.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,9 @@ def cell_velocities(

if "pca_fit" not in adata.uns_keys() or type(adata.uns["pca_fit"]) == str:
CM = adata.X[:, adata.var.use_for_dynamics.values]
from ..preprocessing.utils import pca_monocle
from ..preprocessing.utils import pca

adata, pca_fit, X_pca = pca_monocle(adata, CM, n_pca_components, "X", return_all=True)
adata, pca_fit, X_pca = pca(adata, CM, n_pca_components, "X", return_all=True)
adata.uns["pca_fit"] = pca_fit

X_pca, pca_fit = adata.obsm["X"], adata.uns["pca_fit"]
Expand Down
Loading

0 comments on commit 9c38bc2

Please sign in to comment.