Skip to content

Commit

Permalink
Merge pull request #667 from Sichao25/root_folder
Browse files Browse the repository at this point in the history
Docstring and type hints for root folder python files
  • Loading branch information
Xiaojieqiu authored Feb 26, 2024
2 parents f5756a3 + d6d75d3 commit 6f52dc6
Show file tree
Hide file tree
Showing 6 changed files with 400 additions and 385 deletions.
186 changes: 89 additions & 97 deletions dynamo/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import List, Generator, Optional, Tuple, Union
from typing import Any, List, Generator, Optional, Tuple, Union

import colorcet
import matplotlib
Expand All @@ -14,6 +14,7 @@


class DynamoAdataKeyManager:
"""A class to manage the keys used in anndata object for dynamo."""
VAR_GENE_MEAN_KEY = "pp_gene_mean"
VAR_GENE_VAR_KEY = "pp_gene_variance"
VAR_GENE_HIGHLY_VARIABLE_KEY = "gene_highly_variable"
Expand Down Expand Up @@ -67,8 +68,9 @@ def _select_layer_gene_chunked_data(
if start < n:
yield (mat[:, start:n], start, n)

def gen_new_layer_key(layer_name, key, sep="_") -> str:
"""utility function for returning a new key name for a specific layer. By convention layer_name should not have the separator as the last character."""
def gen_new_layer_key(layer_name: str, key: str, sep: str = "_") -> str:
"""Utility function for returning a new key name for a specific layer. By convention layer_name should not have
the separator as the last character."""
if layer_name == "":
return key
if layer_name[-1] == sep:
Expand All @@ -79,14 +81,15 @@ def gen_layer_pp_key(*keys):
"""Generate dynamo style keys for adata.uns[pp][key0_key1_key2...]"""
return "_".join(keys)

def gen_layer_X_key(key):
"""Generate dynamo style keys for adata.layer[X_*], used later in dynamics"""
def gen_layer_X_key(key: str) -> str:
"""Generate dynamo style keys for adata.layer[X_*], used later in dynamics."""
return DynamoAdataKeyManager.gen_new_layer_key("X", key)

def is_layer_X_key(key):
def is_layer_X_key(key: str) -> bool:
"""Check if the key is a layer key for X layer."""
return key[:2] == "X_"

def gen_layer_pearson_residual_key(layer: str):
def gen_layer_pearson_residual_key(layer: str) -> str:
"""Generate dynamo style keys for adata.uns[pp][key0_key1_key2...]"""
return DynamoAdataKeyManager.gen_layer_pp_key(
layer, DynamoAdataKeyManager.UNS_PP_PEARSON_RESIDUAL_NORMALIZATION
Expand Down Expand Up @@ -145,7 +148,8 @@ def select_layer_chunked_data(
else:
raise NotImplementedError("chunk_mode %s not implemented." % (chunk_mode))

def set_layer_data(adata: AnnData, layer: str, vals: np.array, var_indices: np.array = None):
def set_layer_data(adata: AnnData, layer: str, vals: np.array, var_indices: np.array = None) -> None:
"""This utility provides a unified interface for setting data to layers."""
if var_indices is None:
var_indices = slice(None)
if layer == DynamoAdataKeyManager.X_LAYER:
Expand All @@ -158,6 +162,7 @@ def set_layer_data(adata: AnnData, layer: str, vals: np.array, var_indices: np.a
adata.layers[layer] = vals

def check_if_layer_exist(adata: AnnData, layer: str) -> bool:
"""Check if the layer exists in adata."""
if layer == DynamoAdataKeyManager.X_LAYER:
# assume always exist
return True
Expand All @@ -166,8 +171,11 @@ def check_if_layer_exist(adata: AnnData, layer: str) -> bool:

return layer in adata.layers

def get_available_layer_keys(adata, layers="all", remove_pp_layers=True, include_protein=True):
"""Get the list of available layers' keys. If `layers` is set to all, return a list of all available layers; if `layers` is set to a list, then the intersetion of available layers and `layers` will be returned."""
def get_available_layer_keys(
adata: AnnData, layers: str = "all", remove_pp_layers: bool = True, include_protein: bool = True,
) -> List[str]:
"""Get the list of available layers' keys. If `layers` is set to all, return a list of all available layers; if
`layers` is set to a list, then the intersetion of available layers and `layers` will be returned."""
layer_keys = list(adata.layers.keys())
if layers is None: # layers=adata.uns["pp"]["experiment_layers"], in calc_sz_factor
layers = "X"
Expand All @@ -182,28 +190,32 @@ def get_available_layer_keys(adata, layers="all", remove_pp_layers=True, include
res_layers = list(set(res_layers).difference(["matrix", "ambiguous", "spanning"]))
return res_layers

def allowed_layer_raw_names():
def allowed_layer_raw_names() -> Tuple[List[str], List[str], List[str]]:
"""Return a list of allowed layer names in raw data."""
only_splicing = ["spliced", "unspliced"]
only_labeling = ["new", "total"]
splicing_and_labeling = ["uu", "ul", "su", "sl"]
return only_splicing, only_labeling, splicing_and_labeling

def get_raw_data_layers(adata: AnnData) -> str:
"""Get the list of raw data layers names in adata."""
only_splicing, only_labeling, splicing_and_labeling = DKM.allowed_layer_raw_names()
# select layers in adata to be normalized
res = only_splicing + only_labeling + splicing_and_labeling
res = set(res).intersection(adata.layers.keys()).union("X")
res = list(res)
return res

def allowed_X_layer_names():
def allowed_X_layer_names() -> Tuple[List[str], List[str], List[str]]:
"""Return a list of allowed layer names in X layers data."""
only_splicing = ["X_spliced", "X_unspliced"]
only_labeling = ["X_new", "X_total"]
splicing_and_labeling = ["X_uu", "X_ul", "X_su", "X_sl"]

return only_splicing, only_labeling, splicing_and_labeling

def init_uns_pp_namespace(adata: AnnData):
def init_uns_pp_namespace(adata: AnnData) -> None:
"""Initialize the uns[pp] namespace in adata."""
adata.uns[DynamoAdataKeyManager.UNS_PP_KEY] = {}

def get_excluded_layers(X_total_layers: bool = False, splicing_total_layers: bool = False) -> List:
Expand Down Expand Up @@ -265,12 +277,14 @@ def aggregate_layers_into_total(


class DynamoVisConfig:
"""Dynamo visualization config class holding static variables to change behaviors of functions globally."""
def set_default_mode(background="white"):
"""Set the default mode for dynamo visualization."""
set_figure_params("dynamo", background=background)


class DynamoAdataConfig:
"""dynamo anndata object config class holding static variables to change behaviors of functions globally."""
"""Dynamo anndata object config class holding static variables to change behaviors of functions globally."""

# set the adata store mode.
# saving memory or storing more results
Expand Down Expand Up @@ -314,22 +328,17 @@ class DynamoAdataConfig:
# config_key_to_values contains _key to values for config values
config_key_to_values = None

def use_default_var_if_none(val, key, replace_val=None):
"""if `val` is equal to `replace_val`, then a config value will be returned according to `key` stored in dynamo configuration. Otherwise return the original `val` value.
def use_default_var_if_none(val: Any, key: str, replace_val: Optional[Any] = None) -> Any:
"""If `val` is equal to `replace_val`, then a config value will be returned according to `key` stored in dynamo
configuration. Otherwise return the original `val` value.
Parameters
----------
val :
The input value to check against.
key :
`key` stored in the dynamo configuration. E.g DynamoAdataConfig.RECIPE_MONOCLE_KEEP_RAW_LAYERS_KEY
replace_val :
the target value to replace, by default None
Args:
val: The input value to check against.
key: `key` stored in the dynamo configuration. E.g DynamoAdataConfig.RECIPE_MONOCLE_KEEP_RAW_LAYERS_KEY.
replace_val: The target value to replace, by default None.
Returns
-------
Returns:
`val` or config value set in DynamoAdataConfig according to the method description above.
"""
if not key in DynamoAdataConfig.config_key_to_values:
assert KeyError("Config %s not exist in DynamoAdataConfig." % (key))
Expand All @@ -339,7 +348,8 @@ def use_default_var_if_none(val, key, replace_val=None):
return config_val
return val

def update_data_store_mode(mode):
def update_data_store_mode(mode: str) -> None:
"""Update the data store mode for dynamo anndata object."""
DynamoAdataConfig.data_store_mode = mode

# default succinct for recipe*, except for recipe_monocle
Expand Down Expand Up @@ -373,7 +383,8 @@ def update_data_store_mode(mode):
}


def update_data_store_mode(mode):
def update_data_store_mode(mode: str) -> None:
"""Update the data store mode for dynamo anndata object."""
DynamoAdataConfig.update_data_store_mode(mode)


Expand Down Expand Up @@ -557,7 +568,8 @@ def update_data_store_mode(mode):
# }


def dyn_theme(background="white"):
def dyn_theme(background: str = "white") -> None:
"""Set the dynamo theme for matplotlib.rcParams."""
# https://github.com/matplotlib/matplotlib/blob/master/lib/matplotlib/mpl-data/stylelib/dark_background.mplstyle

if background == "black":
Expand Down Expand Up @@ -601,30 +613,23 @@ def dyn_theme(background="white"):


def config_dynamo_rcParams(
background="white",
prop_cycle=zebrafish_256,
fontsize=8,
color_map=None,
frameon=None,
):
background: str = "white",
prop_cycle: List[str] = zebrafish_256,
fontsize: float = 8,
color_map: Optional[str] = None,
frameon: Optional[bool] = None,
) -> None:
"""Configure matplotlib.rcParams to dynamo defaults (based on ggplot style and scanpy).
Parameters
----------
background: `str` (default: `white`)
The background color of the plot. By default we use the white ground
which is suitable for producing figures for publication. Setting it to `black` background will
be great for presentation.
prop_cycle: `list` (default: zebrafish_256)
A list with hex color codes
fontsize: float (default: 6)
Size of font
color_map: `plt.cm` or None (default: None)
Color map
frameon: `bool` or None (default: None)
Whether to have frame for the figure.
Returns
-------
Args:
background: The background color of the plot. By default we use the white ground which is suitable for producing
figures for publication. Setting it to `black` background will be great for presentation.
prop_cycle: A list with hex color codes.
fontsize: Size of font.
color_map: Color map.
frameon: Whether to have frame for the figure.
Returns:
Nothing but configure the rcParams globally.
"""

Expand Down Expand Up @@ -739,50 +744,37 @@ def config_dynamo_rcParams(


def set_figure_params(
dynamo=True,
background="white",
fontsize=8,
figsize=(6, 4),
dpi=None,
dpi_save=None,
frameon=None,
vector_friendly=True,
color_map=None,
format="pdf",
transparent=False,
ipython_format="png2x",
dynamo: bool = True,
background: str = "white",
fontsize: float = 8,
figsize: Tuple[float, float] = (6, 4),
dpi: Optional[int] = None,
dpi_save: Optional[int] = None,
frameon: Optional[bool] = None,
vector_friendly: bool = True,
color_map: str = None,
format: str = "pdf",
transparent: bool = False,
ipython_format: str = "png2x",
):
"""Set resolution/size, styling and format of figures.
This function is adapted from: https://github.com/theislab/scanpy/blob/f539870d7484675876281eb1c475595bf4a69bdb/scanpy/_settings.py
Arguments
---------
dynamo: `bool` (default: `True`)
Init default values for :obj:`matplotlib.rcParams` suited for dynamo.
background: `str` (default: `white`)
The background color of the plot. By default we use the white ground
which is suitable for producing figures for publication. Setting it to `black` background will
be great for presentation.
fontsize: `[float, float]` or None (default: `6`)
figsize: `(float, float)` (default: `(6.5, 5)`)
Width and height for default figure size.
dpi: `int` or None (default: `None`)
Resolution of rendered figures - this influences the size of figures in notebooks.
dpi_save: `int` or None (default: `None`)
Resolution of saved figures. This should typically be higher to achieve
publication quality.
frameon: `bool` or None (default: `None`)
Add frames and axes labels to scatter plots.
vector_friendly: `bool` (default: `True`)
Plot scatter plots using `png` backend even when exporting as `pdf` or `svg`.
color_map: `str` (default: `None`)
Convenience method for setting the default color map.
format: {'png', 'pdf', 'svg', etc.} (default: 'pdf')
This sets the default format for saving figures: `file_format_figs`.
transparent: `bool` (default: `False`)
Save figures with transparent back ground. Sets `rcParams['savefig.transparent']`.
ipython_format : list of `str` (default: 'png2x')
Only concerns the notebook/IPython environment; see
`IPython.core.display.set_matplotlib_formats` for more details.
Args:
dynamo: Init default values for :obj:`matplotlib.rcParams` suited for dynamo.
background: The background color of the plot. By default we use the white ground which is suitable for producing
figures for publication. Setting it to `black` background will be great for presentation.
fontsize: Size of font.
figsize: Width and height for default figure size.
dpi: Resolution of rendered figures - this influences the size of figures in notebooks.
dpi_save: Resolution of saved figures. This should typically be higher to achieve publication quality.
frameon: Add frames and axes labels to scatter plots.
vector_friendly: Plot scatter plots using `png` backend even when exporting as `pdf` or `svg`.
color_map: Convenience method for setting the default color map.
format: This sets the default format for saving figures: `file_format_figs`. This can be `png`, `pdf`, `svg`, etc.
transparent: Save figures with transparent background. Sets `rcParams['savefig.transparent']`.
ipython_format: Only concerns the notebook/IPython environment; see `IPython.core.display.set_matplotlib_formats`
for more details.
"""

try:
Expand Down Expand Up @@ -824,8 +816,8 @@ def reset_rcParams():
rcParams.update(rcParamsDefault)


def set_pub_style(scaler=1):
"""formatting helper function that can be used to save publishable figures"""
def set_pub_style(scaler: float = 1) -> None:
"""Formatting helper function that can be used to save publishable figures."""
set_figure_params("dynamo", background="white")
matplotlib.use("cairo")
matplotlib.rcParams.update({"font.size": 4 * scaler})
Expand All @@ -843,8 +835,8 @@ def set_pub_style(scaler=1):
matplotlib.rcParams.update(params)


def set_pub_style_mpltex():
"""formatting helper function based on mpltex package that can be used to save publishable figures"""
def set_pub_style_mpltex() -> None:
"""Formatting helper function based on mpltex package that can be used to save publishable figures."""
set_figure_params("dynamo", background="white")
matplotlib.use("cairo")
# the following code is adapted from https://github.com/liuyxpp/mpltex
Expand Down
Loading

0 comments on commit 6f52dc6

Please sign in to comment.