Skip to content

Commit

Permalink
Merge pull request #474 from Ukyeon/fano
Browse files Browse the repository at this point in the history
Fano
  • Loading branch information
Xiaojieqiu authored May 4, 2023
2 parents 5f0aaf2 + b37ef81 commit 9a079f5
Show file tree
Hide file tree
Showing 18 changed files with 1,613 additions and 1,320 deletions.
4 changes: 2 additions & 2 deletions dynamo/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cycler import cycler
from matplotlib import cm, colors, rcParams

from .dynamo_logger import main_info, main_warning
from .dynamo_logger import main_debug, main_info


class DynamoAdataKeyManager:
Expand Down Expand Up @@ -847,5 +847,5 @@ def set_pub_style_mpltex():

# initialize DynamoSaveConfig and DynamoVisConfig mode defaults
DynamoAdataConfig.update_data_store_mode("full")
main_info("setting visualization default mode in dynamo. Your customized matplotlib settings might be overritten.")
main_debug("setting visualization default mode in dynamo. Your customized matplotlib settings might be overwritten.")
DynamoVisConfig.set_default_mode()
11 changes: 5 additions & 6 deletions dynamo/dynamo_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def error(self, message, indent_level=1, *args, **kwargs):

def info_insert_adata(self, key, adata_attr="obsm", indent_level=1, *args, **kwargs):
message = "<insert> %s to %s in AnnData Object." % (key, adata_attr)
message = format_logging_message(message, logging.INFO, indent_level=indent_level)
return self.logger.error(message, *args, **kwargs)
message = format_logging_message(message, logging.DEBUG, indent_level=indent_level)
return self.logger.debug(message, *args, **kwargs)

def info_insert_adata_var(self, key, indent_level=1, *args, **kwargs):
return self.info_insert_adata(self, key, adata_attr="var", indent_level=1, *args, **kwargs)
Expand Down Expand Up @@ -189,18 +189,17 @@ def report_progress(self, percent=None, count=None, total=None, progress_name=""

def finish_progress(self, progress_name="", time_unit="s", indent_level=1):
self.log_time()
self.report_progress(percent=100, progress_name=progress_name)
# self.report_progress(percent=100, progress_name=progress_name)

saved_terminator = self.logger_stream_handler.terminator
self.logger_stream_handler.terminator = ""
self.logger.info("\n")
self.logger_stream_handler.flush()
self.logger_stream_handler.terminator = saved_terminator

if time_unit == "s":
self.info("[%s] finished [%.4fs]" % (progress_name, self.time_passed), indent_level=indent_level)
self.info("[%s] completed [%.4fs]" % (progress_name, self.time_passed), indent_level=indent_level)
elif time_unit == "ms":
self.info("[%s] finished [%.4fms]" % (progress_name, self.time_passed * 1e3), indent_level=indent_level)
self.info("[%s] completed [%.4fms]" % (progress_name, self.time_passed * 1e3), indent_level=indent_level)
else:
raise NotImplementedError
# self.logger.info("|")
Expand Down
25 changes: 16 additions & 9 deletions dynamo/external/sctransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# =================================================================

import os
from multiprocessing import Manager, Pool

import numpy as np
import pandas as pd
Expand All @@ -22,6 +21,7 @@

from ..configuration import DKM
from ..dynamo_logger import main_info, main_info_insert_adata_layer
from ..preprocessing.utils import get_gene_selection_filter

_EPS = np.finfo(float).eps

Expand Down Expand Up @@ -128,6 +128,8 @@ def sctransform_core(
"""
A re-implementation of SCTransform from the Satija lab.
"""
import multiprocessing

main_info("sctransform adata on layer: %s" % (layer))
X = DKM.select_layer_data(adata, layer).copy()
X = sp.sparse.csr_matrix(X)
Expand All @@ -139,10 +141,8 @@ def sctransform_core(
genes_ix = genes.copy()

X = X[:, genes]
Xraw = X.copy()
gene_names = gene_names[genes]
genes = np.arange(X.shape[1])
genes_cell_count = X.sum(0).A.flatten()

genes_log_gmean = np.log10(gmean(X, axis=0, eps=gmean_eps))

Expand Down Expand Up @@ -188,7 +188,10 @@ def sctransform_core(
bin_ind = np.ceil(np.arange(1, genes_step1.size + 1) / bin_size)
max_bin = max(bin_ind)

ps = Manager().dict()
ps = multiprocessing.Manager().dict()

# create a process context of fork that copy a Python process from an existing process.
ctx = multiprocessing.get_context("fork")

for i in range(1, int(max_bin) + 1):
genes_bin_regress = genes_step1[bin_ind == i]
Expand All @@ -197,7 +200,9 @@ def sctransform_core(
mm = np.vstack((np.ones(data_step1.shape[0]), data_step1["log_umi"].values.flatten())).T

pc_chunksize = umi_bin.shape[1] // os.cpu_count() + 1
pool = Pool(os.cpu_count(), _parallel_init, [genes_bin_regress, umi_bin, gene_names, mm, ps])

pool = ctx.Pool(os.cpu_count(), _parallel_init, [genes_bin_regress, umi_bin, gene_names, mm, ps])

try:
pool.map(_parallel_wrapper, range(umi_bin.shape[1]), chunksize=pc_chunksize)
finally:
Expand Down Expand Up @@ -254,10 +259,6 @@ def sctransform_core(
full_model_pars["theta"] = theta
del full_model_pars["dispersion"]

model_pars_outliers = outliers

regressor_data = np.vstack((np.ones(cell_attrs.shape[0]), cell_attrs["log_umi"].values)).T

d = X.data
x, y = X.nonzero()
mud = np.exp(full_model_pars.values[:, 0][y] + full_model_pars.values[:, 1][y] * cell_attrs["log_umi"].values[x])
Expand Down Expand Up @@ -331,3 +332,9 @@ def sctransform(adata: AnnData, layers: str = [DKM.X_LAYER], output_layer: str =
"""a wrapper calls sctransform_core and set dynamo style keys in adata"""
for layer in layers:
sctransform_core(adata, layer=layer, n_genes=n_top_genes, **kwargs)
if adata.X.shape[1] > n_top_genes:
X_squared = adata.X.copy()
X_squared.data **= 2
variance = X_squared.mean(0) - np.square(adata.X.mean(0))
adata.var["sct_score"] = variance.A1
adata.var["use_for_pca"] = get_gene_selection_filter(adata.var["sct_score"], n_top_genes=n_top_genes)
61 changes: 16 additions & 45 deletions dynamo/plot/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..configuration import DynamoAdataKeyManager
from ..dynamo_logger import main_warning
from ..preprocessing import preprocess as pp
from ..preprocessing.preprocess_monocle_utils import top_table
from ..preprocessing.gene_selection import get_prediction_by_svr
from ..preprocessing.utils import detect_experiment_datatype
from ..tools.utils import get_mapper, update_dict
from .utils import save_fig
Expand Down Expand Up @@ -649,47 +649,36 @@ def feature_genes(
save_show_or_return: str = "show",
save_kwargs: dict = {},
):
"""Plot selected feature genes on top of the mean vs. dispersion scatterplot.
"""Plot selected feature genes on top of the mean vs. dispersion scatter plot.
Parameters
----------
Args:
adata: :class:`~anndata.AnnData`
AnnData object
layer: `str` (default: `X`)
The data from a particular layer (include X) used for making the feature gene plot.
mode: None or `str` (default: `None`)
The method to select the feature genes (can be either `dispersion`, `gini` or `SVR`).
The method to select the feature genes (can be either `cv_dispersion`, `fano_dispersion` or `gini`).
figsize: `string` (default: (4, 3))
Figure size of each facet.
save_show_or_return: {'show', 'save', 'return'} (default: `show`)
Whether to save, show or return the figure.
save_kwargs: `dict` (default: `{}`)
A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the
A dictionary that will be passed to the save_fig function. By default, it is an empty dictionary and the
save_fig function will use the {"path": None, "prefix": 'feature_genes', "dpi": None, "ext": 'pdf',
"transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a
"transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a
dictionary that properly modify those keys according to your needs.
Returns
-------
Returns:
Nothing but plots the selected feature genes via the mean, CV plot.
"""

import matplotlib.pyplot as plt

mode = adata.uns["feature_selection"] if mode is None else mode

layer = DynamoAdataKeyManager.get_available_layer_keys(adata, layer, include_protein=False)[0]

uns_store_key = None
if mode == "dispersion":
uns_store_key = "dispFitInfo" if layer in ["raw", "X"] else layer + "_dispFitInfo"

table = top_table(adata, layer)
x_min, x_max = (
np.nanmin(table["mean_expression"]),
np.nanmax(table["mean_expression"]),
)
elif mode == "SVR":
if "_dispersion" in mode: # "cv_dispersion", "fano_dispersion"
prefix = "" if layer == "X" else layer + "_"
uns_store_key = "velocyto_SVR" if layer == "raw" or layer == "X" else layer + "_velocyto_SVR"

Expand All @@ -709,11 +698,12 @@ def feature_genes(
ordering_genes = adata.var["use_for_pca"] if "use_for_pca" in adata.var.columns else None

mu_linspace = np.linspace(x_min, x_max, num=1000)
fit = (
adata.uns[uns_store_key]["disp_func"](mu_linspace)
if mode == "dispersion"
else adata.uns[uns_store_key]["SVR"](mu_linspace.reshape(-1, 1))
)
if "_dispersion" in mode:
mean = adata.uns[uns_store_key]["mean"]
cv = adata.uns[uns_store_key]["cv"]
svr_gamma = adata.uns[uns_store_key]["svr_gamma"]
fit, _ = get_prediction_by_svr(mean, cv, svr_gamma)
fit = fit(mu_linspace.reshape(-1, 1))

plt.figure(figsize=figsize)
plt.plot(mu_linspace, fit, alpha=0.4, color="r")
Expand All @@ -724,15 +714,7 @@ def feature_genes(
)

valid_disp_table = table.iloc[valid_ind, :]
if mode == "dispersion":
ax = plt.scatter(
valid_disp_table["mean_expression"],
valid_disp_table["dispersion_empirical"],
s=3,
alpha=1,
color="xkcd:red",
)
elif mode == "SVR":
if "_dispersion" in mode:
ax = plt.scatter(
valid_disp_table[prefix + "log_m"],
valid_disp_table[prefix + "log_cv"],
Expand All @@ -743,15 +725,7 @@ def feature_genes(

neg_disp_table = table.iloc[~valid_ind, :]

if mode == "dispersion":
ax = plt.scatter(
neg_disp_table["mean_expression"],
neg_disp_table["dispersion_empirical"],
s=3,
alpha=0.5,
color="xkcd:grey",
)
elif mode == "SVR":
if "_dispersion" in mode:
ax = plt.scatter(
neg_disp_table[prefix + "log_m"],
neg_disp_table[prefix + "log_cv"],
Expand All @@ -760,9 +734,6 @@ def feature_genes(
color="xkcd:grey",
)

# plt.xlim((0, 100))
if mode == "dispersion":
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Mean (log)")
plt.ylabel("Dispersion (log)") if mode == "dispersion" else plt.ylabel("CV (log)")
Expand Down
Loading

0 comments on commit 9a079f5

Please sign in to comment.