Skip to content

Commit

Permalink
Merge pull request #598 from Sichao25/new
Browse files Browse the repository at this point in the history
Create chunk option for normalization and gene selection
  • Loading branch information
Xiaojieqiu authored Feb 20, 2024
2 parents d48664e + d2fed88 commit c12f6dd
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 48 deletions.
64 changes: 63 additions & 1 deletion dynamo/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import List, Optional, Tuple, Union
from typing import List, Generator, Optional, Tuple, Union

import colorcet
import matplotlib
Expand Down Expand Up @@ -37,6 +37,35 @@ class DynamoAdataKeyManager:
X_LAYER = "X"
PROTEIN_LAYER = "protein"
X_PCA = "X_pca"
RAW = "raw"

def _select_layer_cell_chunked_data(
mat: np.ndarray,
chunk_size: int,
) -> Generator:
"""Select layer data in cell chunks based on chunk_size."""
start = 0
n = mat.shape[0]
for _ in range(int(n // chunk_size)):
end = start + chunk_size
yield (mat[start:end, :], start, end)
start = end
if start < n:
yield (mat[start:n, :], start, n)

def _select_layer_gene_chunked_data(
mat: np.ndarray,
chunk_size: int,
) -> Generator:
"""Select layer data in gene chunks based on chunk_size."""
start = 0
n = mat.shape[1]
for _ in range(int(n // chunk_size)):
end = start + chunk_size
yield (mat[:, start:end], start, end)
start = end
if start < n:
yield (mat[:, start:n], start, n)

def gen_new_layer_key(layer_name, key, sep="_") -> str:
"""utility function for returning a new key name for a specific layer. By convention layer_name should not have the separator as the last character."""
Expand Down Expand Up @@ -83,6 +112,39 @@ def select_layer_data(adata: AnnData, layer: str, copy=False) -> pd.DataFrame:
return res_data.copy()
return res_data

def select_layer_chunked_data(
adata: AnnData,
layer: str,
chunk_size: int,
chunk_mode: str = "cell",
) -> Generator:
"""This utility provides a unified interface for selecting chunked layer data."""
if layer is None:
layer = DynamoAdataKeyManager.X_LAYER

if chunk_mode == "cell":
if layer == DynamoAdataKeyManager.X_LAYER:
return DynamoAdataKeyManager._select_layer_cell_chunked_data(adata.X, chunk_size=chunk_size)
elif layer == DynamoAdataKeyManager.RAW:
return DynamoAdataKeyManager._select_layer_cell_chunked_data(adata.raw.X, chunk_size=chunk_size)
elif layer == DynamoAdataKeyManager.PROTEIN_LAYER:
return DynamoAdataKeyManager._select_layer_cell_chunked_data(
adata.obsm["protein"], chunk_size=chunk_size) if "protein" in adata.obsm_keys() else None
else:
return DynamoAdataKeyManager._select_layer_cell_chunked_data(adata.layers[layer], chunk_size=chunk_size)
elif chunk_mode == "gene":
if layer == DynamoAdataKeyManager.X_LAYER:
return DynamoAdataKeyManager._select_layer_gene_chunked_data(adata.X, chunk_size=chunk_size)
elif layer == DynamoAdataKeyManager.RAW:
return DynamoAdataKeyManager._select_layer_gene_chunked_data(adata.raw.X, chunk_size=chunk_size)
elif layer == DynamoAdataKeyManager.PROTEIN_LAYER:
return DynamoAdataKeyManager._select_layer_gene_chunked_data(
adata.obsm["protein"], chunk_size=chunk_size) if "protein" in adata.obsm_keys() else None
else:
return DynamoAdataKeyManager._select_layer_gene_chunked_data(adata.layers[layer], chunk_size=chunk_size)
else:
raise NotImplementedError("chunk_mode %s not implemented." % (chunk_mode))

def set_layer_data(adata: AnnData, layer: str, vals: np.array, var_indices: np.array = None):
if var_indices is None:
var_indices = slice(None)
Expand Down
55 changes: 34 additions & 21 deletions dynamo/preprocessing/gene_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def select_genes_by_seurat_recipe(
nan_replace_val: Union[float, None] = None,
n_top_genes: int = 2000,
algorithm: Literal["seurat_dispersion", "fano_dispersion"] = "seurat_dispersion",
chunk_size: Optional[int] = None,
seurat_min_disp: Union[float, None] = None,
seurat_max_disp: Union[float, None] = None,
seurat_min_mean: Union[float, None] = None,
Expand All @@ -545,7 +546,8 @@ def select_genes_by_seurat_recipe(
layer: the key of a sparse matrix in adata. Defaults to DKM.X_LAYER.
nan_replace_val: your choice of value to replace values in layer. Defaults to None.
n_top_genes: number of genes to select as highly variable genes. Defaults to 2000.
algorithm: a method for selecting genes; must be one of "seurat_dispersion" or "fano".
algorithm: a method for selecting genes; Only support "seurat_dispersion" for now.
chunk_size: the size of chunked data. Defaults to None.
seurat_min_disp: seurat dispersion min cutoff. Defaults to None.
seurat_max_disp: seurat dispersion max cutoff. Defaults to None.
seurat_min_mean: seurat mean min cutoff. Defaults to None.
Expand All @@ -569,25 +571,45 @@ def select_genes_by_seurat_recipe(

if len(pass_filter_genes) != len(set(pass_filter_genes)):
main_warning("gene names are not unique, please check your preprocessing procedure.")
subset_adata = adata[:, pass_filter_genes]
if n_top_genes is None:
main_info("n_top_genes is None, reserve all genes and add filter gene information")
n_top_genes = adata.n_vars
layer_mat = DKM.select_layer_data(subset_adata, layer)
if nan_replace_val:
main_info("replacing nan values with: %s" % (nan_replace_val))
_mask = get_nan_or_inf_data_bool_mask(layer_mat)
layer_mat[_mask] = nan_replace_val

chunk_size = chunk_size if chunk_size is not None else adata.n_vars

if algorithm == "seurat_dispersion":
chunked_layer_mats = DKM.select_layer_chunked_data(
adata[:, pass_filter_genes],
layer,
chunk_size=chunk_size,
chunk_mode="gene",
)
mean = np.zeros(len(pass_filter_genes), dtype=adata.X.dtype)
variance = np.zeros(len(pass_filter_genes), dtype=adata.X.dtype)

for mat_data in chunked_layer_mats:
layer_mat = mat_data[0]

if nan_replace_val:
main_info("replacing nan values with: %s" % (nan_replace_val))
_mask = get_nan_or_inf_data_bool_mask(layer_mat)
layer_mat[_mask] = nan_replace_val

chunked_mean, chunked_var = seurat_get_mean_var(layer_mat)

mean[mat_data[1]:mat_data[2]] = chunked_mean
variance[mat_data[1]:mat_data[2]] = chunked_var

mean, variance, highly_variable_mask = select_genes_by_seurat_dispersion(
layer_mat,
mean=mean,
variance=variance,
min_disp=seurat_min_disp,
max_disp=seurat_max_disp,
min_mean=seurat_min_mean,
max_mean=seurat_max_mean,
n_top_genes=n_top_genes,
)

main_info_insert_adata_var(DKM.VAR_GENE_MEAN_KEY)
main_info_insert_adata_var(DKM.VAR_GENE_VAR_KEY)
main_info_insert_adata_var(DKM.VAR_GENE_HIGHLY_VARIABLE_KEY)
Expand All @@ -603,14 +625,6 @@ def select_genes_by_seurat_recipe(
adata.var[DKM.VAR_GENE_HIGHLY_VARIABLE_KEY][pass_filter_genes] = highly_variable_mask
adata.var[DKM.VAR_USE_FOR_PCA][pass_filter_genes] = highly_variable_mask

elif algorithm == "fano_dispersion":
select_genes_monocle(adata, layer=layer, sort_by=algorithm)
# adata = select_genes_by_svr(
# adata,
# layers=layer,
# algorithm=algorithm,
# )
# filter_bool = get_svr_filter(adata, layer=layer, n_top_genes=n_top_genes, return_adata=False)
else:
raise ValueError(f"The algorithm {algorithm} is not existed")

Expand All @@ -621,7 +635,8 @@ def select_genes_by_seurat_recipe(


def select_genes_by_seurat_dispersion(
sparse_layer_mat: csr_matrix,
mean: np.ndarray,
variance: np.ndarray,
n_bins: int = 20,
log_mean_and_dispersion: bool = True,
min_disp: float = None,
Expand All @@ -633,7 +648,8 @@ def select_genes_by_seurat_dispersion(
"""Apply seurat's gene selection recipe by cutoffs.
Args:
sparse_layer_mat: the sparse matrix used for gene selection.
mean: mean of the columns for each gene.
variance: variance of the columns for each gene.
n_bins: the number of bins for normalization. Defaults to 20.
log_mean_and_dispersion: whether log the gene expression values before calculating the dispersion values.
Defaults to True.
Expand Down Expand Up @@ -661,9 +677,6 @@ def select_genes_by_seurat_dispersion(
if max_mean is None:
max_mean = 3

# mean, variance, dispersion = calc_mean_var_dispersion_sparse(sparse_layer_mat) # Dead
sc_mean, sc_var = seurat_get_mean_var(sparse_layer_mat)
mean, variance = sc_mean, sc_var
dispersion = variance / mean

if log_mean_and_dispersion:
Expand Down
Loading

0 comments on commit c12f6dd

Please sign in to comment.