Skip to content

Commit

Permalink
Added MNN batch correction
Browse files Browse the repository at this point in the history
  • Loading branch information
pinin4fjords committed Aug 13, 2020
1 parent 2a564e7 commit 0faa369
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 1 deletion.
18 changes: 17 additions & 1 deletion scanpy-scripts-tests.bats
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ setup() {
noharmony_integrated_pca_pdf="${output_dir}/pca_${test_clustering}.pdf"
bbknn_obj="${output_dir}/bbknn.h5ad"
bbknn_opt="--batch-key ${test_clustering} --key-added bbknn"
mnn_obj="${output_dir}/mnn.h5ad"
mnn_opt="--batch-key ${test_clustering}"


if [ ! -d "$data_dir" ]; then
Expand Down Expand Up @@ -498,13 +500,27 @@ setup() {
skip "$bbknn_obj exists and resume is set to 'true'"
fi

run rm -f $bbknn_obj && echo "$scanpy integrate bbknn $bbknn_opt $louvain_obj $bbknn_obj" && eval "$scanpy integrate bbknn $bbknn_opt $louvain_obj $bbknn_obj"
run rm -f $bbknn_obj && eval "$scanpy integrate bbknn $bbknn_opt $louvain_obj $bbknn_obj"

[ "$status" -eq 0 ]
[ -f "$plt_rank_genes_groups_matrix_pdf" ]

}

# Do bbknn batch correction, using clustering as batch (just for test purposes)

@test "Run MNN batch integration using clustering as batch" {
if [ "$resume" = 'true' ] && [ -f "$mnn_obj" ]; then
skip "$mnn_obj exists and resume is set to 'true'"
fi

run rm -f $mnn_obj && eval "$scanpy integrate mnn_correct $mnn_opt $louvain_obj $mnn_obj"

[ "$status" -eq 0 ]
[ -f "$mnn_obj" ]

}

# Local Variables:
# mode: sh
# End:
2 changes: 2 additions & 0 deletions scanpy_scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
PLOT_HEATMAP_CMD,
HARMONY_INTEGRATE_CMD,
BBKNN_CMD,
MNN_CORRECT_CMD,
)


Expand Down Expand Up @@ -109,6 +110,7 @@ def integrate():

integrate.add_command(HARMONY_INTEGRATE_CMD)
integrate.add_command(BBKNN_CMD)
integrate.add_command(MNN_CORRECT_CMD)


@cli.group(cls=NaturalOrderGroup)
Expand Down
102 changes: 102 additions & 0 deletions scanpy_scripts/cmd_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,108 @@
COMMON_OPTIONS['random_state'],
],

'mnn_correct': [
*COMMON_OPTIONS['input'],
*COMMON_OPTIONS['output'],
click.option(
'--layer', '-l',
type=click.STRING,
default=None,
show_default=True,
help="Layer to batch correct. By default corrects the contents of .X."
),
click.option(
'--key-added',
type=click.STRING,
default=None,
show_default=True,
help="Key under which to add the computed results. By default a new "
"layer will be created called 'mnnnn', 'bbknn_{layer}' or "
"'bbknn_layer_{key_added}' where those parameters were specified. A value of 'X' "
"causes batch-corrected values to overwrite the original content of .X."
),
click.option(
'--var-subset',
type=(click.STRING, CommaSeparatedText()),
multiple=True,
help="The subset of vars (list of str) to be used when performing "
"MNN correction in the format of '--var-subset <name> <values>'. Typically, use "
"the highly variable genes (HVGs) like '--var-subset highly_variable True'. When "
"unset, uses all vars."
),
click.option(
'--batch-key', 'batch_key',
type=click.STRING,
required=True,
help='adata.obs column name discriminating between your batches.'
),
click.option(
'--n-neighbors', '-k',
type=CommaSeparatedText(click.INT, simplify=True),
default=20,
show_default=True,
help='Number of mutual nearest neighbors.'
),
click.option(
'--sigma',
type=click.FLOAT,
default=1.0,
show_default=True,
help='The bandwidth of the Gaussian smoothing kernel used to '
'compute the correction vectors.'
),
click.option(
'--no-cos_norm_in', 'cos_norm_in',
is_flag=True,
default=True,
help='Default behaviour is to perform cosine normalization on the '
'input data prior to calculating distances between cells. Use this '
'flag to disable that behaviour.'
),
click.option(
'--no-cos_norm_out', 'cos_norm_out',
is_flag=True,
default=True,
help='Default behaviour is to perform cosine normalization prior to '
'computing corrected expression values. Use this flag to disable that '
'behaviour.'
),
click.option(
'--svd-dim',
type=click.INT,
default=None,
show_default=True,
help='The number of dimensions to use for summarizing biological '
'substructure within each batch. If not set, biological components '
'will not be removed from the correction vectors.'
),
click.option(
'--no-var-adj',
is_flag=True,
default=True,
help='Default behaviour is to adjust variance of the correction '
'vectors. Use this flag to disable that behaviour. Note this step takes most '
'computing time.'
),
click.option(
'--compute-angle',
is_flag=True,
default=False,
help='When set, compute the angle between each cell’s correction '
'vector and the biological subspace of the reference batch.'
),
click.option(
'--svd-mode',
type=click.Choice(['svd', 'rsvd', 'irlb']),
default='rsvd',
show_default=True,
help="'svd' computes SVD using a non-randomized SVD-via-ID "
"algorithm, while 'rsvd' uses a randomized version. 'irlb' performs truncated "
"SVD by implicitly restarted Lanczos bidiagonalization (forked from "
"https://github.com/airysen/irlbpy)."
),
],

'bbknn': [
*COMMON_OPTIONS['input'],
*COMMON_OPTIONS['output'],
Expand Down
8 changes: 8 additions & 0 deletions scanpy_scripts/cmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .lib._diffmap import diffmap
from .lib._dpt import dpt
from .lib._bbknn import bbknn
from .lib._mnn import mnn_correct

LANG = os.environ.get('LANG', None)

Expand Down Expand Up @@ -226,3 +227,10 @@
cmd_desc='Batch balanced kNN [Polanski19].',
arg_desc=_IO_DESC,
)

MNN_CORRECT_CMD = make_subcmd(
'mnn_correct',
mnn_correct,
cmd_desc='Correct batch effects by matching mutual nearest neighbors [Haghverdi18] [Kang18].',
arg_desc=_IO_DESC,
)
85 changes: 85 additions & 0 deletions scanpy_scripts/lib/_mnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
scanpy external mnn
"""

import scanpy.external as sce
import numpy as np
import click

# Wrapper for mnn allowing use of non-standard slot

def mnn_correct(adata, batch_key=None, key_added=None, var_subset=None, layer=None, **kwargs):
"""
Wrapper function for sce.pp.mnn_correct(), for supporting non-standard neighbors slot
"""

# mnn will use .X, so we need to put other layers there for processing

if layer:
adata.layers['X_backup'] = adata.X
adata.X = adata.layers[layer]

# mnn_correct() wants batches in separate adatas

batches = np.unique(adata.obs[batch_key])
alldata = []
for batch in batches:
alldata.append( adata[adata.obs[batch_key] == batch,] )

# Process var_subset into a list of strings that can be provided to
# mnn_correct()

if var_subset is not None and len(var_subset) > 0 and var_subset[0] is not None:

subset = []

for name, values in var_subset :
if name in adata.var:
if adata.var[name].dtype == 'bool':
values = [ True if x.lower() == "true" else x for x in values ]
else:
raise click.ClickException(f'Var "{name}" unavailable')

ind = [ x in values for x in adata.var[name] ]
subset = subset + adata.var.index[ ind ].to_list()

var_subset = set(subset)
print('Will use %d selected genes for MNN' % len(var_subset))

else:
var_subset = None

# Here's the main bit

cdata = sce.pp.mnn_correct(*alldata, var_subset = var_subset, do_concatenate = True, index_unique = None, **kwargs)

# If user has specified key_added = X then they want us to overwrite .X,
# othwerwise copy the .X to a named layer of the original object. In either
# case make sure obs and var are the same as the original.

if key_added is None or key_added != 'X':

mnn_key = 'mnn'
if layer:
mnn_key = f"{mnn_key}_{layer}"

# Layers is set (so we're not storing computed results in the .X,
# and we had to overwrite .X to run mnn), and key_added shows we're
# not storing in the .X, so we need to restore from the backup.

adata.X = adata.layers['X_backup']

if key_added:
mnn_key = f"{mnn_key}_{key_added}"

adata.layers[mnn_key] = cdata[0][adata.obs.index, adata.var.index].X

else:
adata.X = cdata[0][adata.obs.index, adata.var.index].X

# Delete the backup of .X if we needed one

if layer:
del adata.layers['X_backup']

return adata

0 comments on commit 0faa369

Please sign in to comment.