diff --git a/q2_batch/_batch.py b/q2_batch/_batch.py index 7fd76ba..5334cd0 100644 --- a/q2_batch/_batch.py +++ b/q2_batch/_batch.py @@ -1,22 +1,23 @@ -import argparse -from biom import load_table +import biom import numpy as np import pandas as pd -import seaborn as sns from sklearn.preprocessing import LabelEncoder -import pickle import os from skbio.stats.composition import ilr_inv -import matplotlib.pyplot as plt -import pickle from cmdstanpy import CmdStanModel +from birdman import BaseModel import tempfile import json +import time +<<<<<<< HEAD +def _extract_replicates(replicates, batches): +======= def _batch_func(counts : np.array, replicates : np.array, batches : np.array, depth : int, mc_samples : int=1000, chains=4) -> dict: +>>>>>>> 161e1d4b899815d62d3a640467f68f46e8621004 replicate_encoder = LabelEncoder() replicate_encoder.fit(replicates) replicate_ids = replicate_encoder.transform(replicates) @@ -31,44 +32,115 @@ def _batch_func(counts : np.array, replicates : np.array, batch_encoder = LabelEncoder() batch_encoder.fit(batches) batch_ids = batch_encoder.transform(batches) - # Actual stan modeling - code = os.path.join(os.path.dirname(__file__), - 'assets/batch_pln_single.stan') + batch_ids = batch_ids.astype(np.int64) + 1 ref_ids = ref_ids.astype(np.int64) + 1 - sm = CmdStanModel(stan_file=code) - dat = { - 'N' : counts.shape[0], - 'R' : int(max(ref_ids) + 1), - 'B' : int(max(batch_ids) + 1), - 'depth' : list(np.log(depth)), - 'y' : list(map(int, counts.astype(np.int64))), - 'ref_ids' : list(map(int, ref_ids )), - 'batch_ids' : list(map(int, batch_ids)) - } - with tempfile.TemporaryDirectory() as temp_dir_name: - data_path = os.path.join(temp_dir_name, 'data.json') - with open(data_path, 'w') as f: - json.dump(dat, f) - # Obtain an initial guess with MLE - # guess = sm.optimize(data=data_path, inits=0) - # see https://mattocci27.github.io/assets/poilog.html - # for recommended parameters for poisson log normal - fit = sm.sample(data=data_path, iter_sampling=mc_samples, - # inits=guess.optimized_params_dict, - chains=chains, iter_warmup=mc_samples // 2, - adapt_delta = 0.9, max_treedepth = 20) - fit.diagnose() - mu = fit.stan_variable('mu') - sigma = fit.stan_variable('sigma') - disp = fit.stan_variable('disp') - res = pd.DataFrame({ - 'mu': mu, - 'sigma': sigma, - 'disp': disp}) - # TODO: this doesn't seem to work atm, but its fixed upstream - # res = fit.summary() - return res + return ref_ids, replicate_ids, batch_ids + + +class PoissonLogNormalBatch(BaseModel): + """Fit Batch effects estimator with Poisson Log Normal + + Parameters: + ----------- + table: biom.table.Table + Feature table (features x samples) + batch_column : str + Column that specifies `batches` of interest that + cause technical artifacts. + replicate_column : str + Column that specifies technical replicates that + are spread across batches. + metadata: pd.DataFrame + Metadata for matching and status covariates. + num_iter: int + Number of posterior sample draws, defaults to 1000 + num_warmup: int + Number of posterior draws used for warmup, defaults to 500 + chains: int + Number of chains to use in MCMC, defaults to 4 + seed: float + Random seed to use for sampling, defaults to 42 + mu_scale : float + Standard deviation for prior distribution for mu + sigma_scale : float + Standard deviation for prior distribution for sigma + disp_scale : float + Standard deviation for prior distribution for disp + reference_scale : float + Mean for prior distribution for reference samples + reference_scale : float + Standard deviation for prior distribution for reference samples + + Notes + ----- + The priors for the reference are defaults for amplicon sequences + See https://github.com/mortonjt/q2-matchmaker/issues/24 + """ + def __init__(self, + table: biom.table.Table, + batch_column: str, + replicate_column: str, + metadata: pd.DataFrame, + num_iter: int = 1000, + num_warmup: int = 500, + adapt_delta: float = 0.9, + max_treedepth: float = 20, + chains: int = 4, + seed: float = 42, + mu_scale: float = 1, + sigma_scale: float = 1, + disp_scale: float = 1, + reference_loc: float = -5, + reference_scale: float = 3): + model_path = os.path.join(os.path.dirname(__file__), + 'assets/batch_pln_single.stan') + super(PoissonLogNormalBatch, self).__init__( + table, metadata, model_path, + num_iter, num_warmup, chains, seed, + parallelize_across="features") + # assemble replicate and batch ids + metadata = metadata.loc[table.ids()] + depth = table.sum(axis='sample') + replicates = metadata[replicate_column] + batches = metadata[batch_column] + ref_ids, replicate_ids, batch_ids = _extract_replicates( + replicates, batches) + self.dat = { + 'D': table.shape[0], # number of features + 'N': table.shape[1], # number of samples + 'R': int(max(replicate_ids) + 1), + 'B': int(max(batch_ids) + 1), + 'depth': list(np.log(depth)), + "y": table.matrix_data.todense().T.astype(int), + 'ref_ids': list(map(int, ref_ids)), + 'batch_ids': list(map(int, batch_ids)) + } + param_dict = { + "mu_scale": mu_scale, + "sigma_scale": sigma_scale, + "disp_scale": disp_scale, + "reference_loc": reference_loc, + "reference_scale": reference_scale + } + self.add_parameters(param_dict) + + self.specify_model( + params=["mu", "sigma", "disp", "batch", "reference"], + dims={ + "beta": ["covariate", "feature"], + "phi": ["feature"], + "log_lhood": ["tbl_sample", "feature"], + "y_predict": ["tbl_sample", "feature"] + }, + coords={ + "feature": self.feature_names, + "tbl_sample": self.sample_names + }, + include_observed_data=True, + posterior_predictive="y_predict", + log_likelihood="log_lhood" + ) def _simulate(n=100, d=10, depth=50): diff --git a/q2_batch/_method.py b/q2_batch/_method.py index cc5b6a6..fedee0b 100644 --- a/q2_batch/_method.py +++ b/q2_batch/_method.py @@ -5,134 +5,48 @@ import seaborn as sns from sklearn.preprocessing import LabelEncoder import pickle -import pystan import dask -from q2_batch._batch import _batch_func +from q2_batch._batch import PoissonLogNormalBatch +from dask.distributed import Client, LocalCluster +from gneiss.util import match import xarray as xr import qiime2 -# slow estimator -def estimate(counts : pd.DataFrame, +def _poisson_log_normal_estimate(counts, + replicates, + batches, + monte_carlo_samples, + cores, + **sampler_args): + metadata = pd.DataFrame({'batch': batches, 'reps': replicates}) + table, metadata = match(table, metadata) + pln = PoissonLogNormalBatch( + table=table, + replicate_column="reps", + batch_column="batch", + metadata=metadata, + **sampler_args) + pln.compile_model() + pln.fit_model(dask_args={'n_workers': cores, 'threads_per_worker': 1}) + samples = pln.to_inference_object() + return samples + + +def estimate(counts : biom.Table, replicates : qiime2.CategoricalMetadataColumn, batches : qiime2.CategoricalMetadataColumn, monte_carlo_samples : int = 100, cores : int = 1) -> xr.Dataset: - # match everything up replicates = replicates.to_series() batches = batches.to_series() - idx = list(set(counts.index) & set(replicates.index) & set(batches.index)) - counts, replicates, batches = [x.loc[idx] for x in - (counts, replicates, batches)] - replicates, batches = replicates.values, batches.values - depth = counts.sum(axis=1) - pfunc = lambda x: _batch_func(np.array(x.values), replicates, batches, - depth, monte_carlo_samples) - if cores > 1: - try: - import dask.dataframe as dd - dcounts = dd.from_pandas(counts.T, npartitions=cores) - res = dcounts.apply(pfunc, axis=1) - resdf = res.compute(scheduler='processes') - data_df = list(resdf.values) - except: - data_df = list(counts.T.apply(pfunc, axis=1).values) - else: - data_df = list(counts.T.apply(pfunc, axis=1).values) - samples = xr.concat([df.to_xarray() for df in data_df], dim="features") - samples = samples.assign_coords(coords={ - 'features' : counts.columns, - 'monte_carlo_samples' : np.arange(monte_carlo_samples) - }) - return samples - - -# Parallel estimation of batch effects -def slurm_estimate(counts : pd.DataFrame, - batches : qiime2.CategoricalMetadataColumn, - replicates : qiime2.CategoricalMetadataColumn, - monte_carlo_samples : int, - cores : int = 4, - processes : int = 4, - nodes : int = 2, - memory : str = '16GB', - walltime : str = '01:00:00', - queue : str = '') -> xr.Dataset: - from dask_jobqueue import SLURMCluster - from dask.distributed import Client - import dask.dataframe as dd - import logging - logging.basicConfig(format='%(levelname)s:%(message)s', - level=logging.DEBUG) - cluster = SLURMCluster(cores=cores, - processes=processes, - memory=memory, - walltime=walltime, - interface='ib0', - env_extra=["export TBB_CXX_TYPE=gcc"], - queue=queue) - cluster.scale(jobs=nodes) - print(cluster.job_script()) + # Build me a cluster! + dask_args={'n_workers': cores, 'threads_per_worker': 1} + cluster = LocalCluster(**dask_args) + cluster.scale(dask_args['n_workers']) client = Client(cluster) - # match everything up - replicates = replicates.to_series() - batches = batches.to_series() - idx = list(set(counts.index) & set(replicates.index) & set(batches.index)) - counts, replicates, batches = [x.loc[idx] for x in - (counts, replicates, batches)] - replicates, batches = replicates.values, batches.values - depth = counts.sum(axis=1) - pfunc = lambda x: _batch_func(np.array(x.values), replicates, batches, - depth, monte_carlo_samples) - print('Partitions', cores * nodes * processes * 4) - dcounts = dd.from_pandas( - counts.T, npartitions=cores * nodes * processes * 4) - dcounts = client.persist(dcounts) - res = dcounts.apply(pfunc, axis=1) - resdf = res.compute(scheduler='processes') - data_df = list(resdf.values) - samples = xr.concat([df.to_xarray() for df in data_df], dim="features") - samples = samples.assign_coords(coords={ - 'features' : counts.columns, - 'monte_carlo_samples' : np.arange(monte_carlo_samples) - }) - return samples - - -# Parallel estimation of batch effects -def parallel_estimate(counts : pd.DataFrame, - batches : qiime2.CategoricalMetadataColumn, - replicates : qiime2.CategoricalMetadataColumn, - monte_carlo_samples : int, - scheduler_json : str, - partitions : int = 100) -> xr.Dataset: - from dask.distributed import Client - import dask.dataframe as dd - import logging - logging.basicConfig(format='%(levelname)s:%(message)s', - level=logging.DEBUG) - client = Client(scheduler_file=scheduler_json) - - # match everything up - replicates = replicates.to_series() - batches = batches.to_series() - idx = list(set(counts.index) & set(replicates.index) & set(batches.index)) - counts, replicates, batches = [x.loc[idx] for x in - (counts, replicates, batches)] - replicates, batches = replicates.values, batches.values - depth = counts.sum(axis=1) - pfunc = lambda x: _batch_func(np.array(x.values), replicates, batches, - depth, monte_carlo_samples, chains=1) - dcounts = dd.from_pandas(counts.T, npartitions=partitions) - # dcounts = client.persist(dcounts) - res = dcounts.apply(pfunc, axis=1) - # resdf = client.compute(res) - resdf = res.compute(scheduler='processes') - data_df = list(resdf.values) - - samples = xr.concat([df.to_xarray() for df in data_df], dim="features") - samples = samples.assign_coords(coords={ - 'features' : counts.columns, - 'monte_carlo_samples' : np.arange(monte_carlo_samples) - }) + samples = _poisson_log_normal_estimate( + counts, replicates, batches, + monte_carlo_samples, cores, + **sampler_args) return samples diff --git a/q2_batch/assets/batch_pln_single.stan b/q2_batch/assets/batch_pln_single.stan index 418b8d8..e6765c5 100644 --- a/q2_batch/assets/batch_pln_single.stan +++ b/q2_batch/assets/batch_pln_single.stan @@ -1,11 +1,17 @@ data { - int N; // number of samples - int R; // number of replicates - int B; // number of batchs - real depth[N]; // sequencing depths of microbes - int y[N]; // observed microbe abundances + int N; // number of samples + int R; // number of replicates + int B; // number of batchs + real depth[N]; // sequencing depths of microbes + int y[N]; // observed microbe abundances int ref_ids[N]; // locations of reference replicates int batch_ids[N]; // batch ids + // Priors + real mu_scale; + real sigma_scale; + real disp_scale; + real reference_loc; + real reference_scale; } parameters { @@ -20,11 +26,12 @@ parameters { model { vector[N] eta; // setting priors ... - disp ~ normal(0., 5); // weak overdispersion prior - mu ~ normal(0., 1.); // strong batch effects mean prior - sigma ~ normal(0., 1.); // strong batch effects variance prior - batch ~ normal(mu, sigma); // random effects - reference ~ normal(0., 10.); // uninformed reference prior + disp ~ normal(0., disp_scale); // weak overdispersion prior + mu ~ normal(0., mu_scale); // uninformed batch effects mean prior + sigma ~ normal(0., sigma_scale); // weak batch effects variance prior + batch ~ normal(mu, sigma); // random effects + // uninformed reference prior + reference ~ normal(reference_loc, reference_scale); // generating counts for (n in 1:N){ eta[n] = batch[batch_ids[n]] + reference[ref_ids[n]]; @@ -32,3 +39,14 @@ model { lam ~ normal(eta, disp); y ~ poisson_log(lam + to_vector(depth)); } + +generated quantities { + vector[N] y_predict; + vector[N] log_lhood; + for (n in 1:N){ + real eta_ = batch[batch_ids[n]] + reference[ref_ids[n]]; + real lam_ = normal_rng(eta_, disp); + y_predict[n] = poisson_log_rng(lam_ + depth[n]); + log_lhood[n] = poisson_log_lpmf(y[n] | lam_ + depth[n]); + } +} diff --git a/q2_batch/plugin_setup.py b/q2_batch/plugin_setup.py index 5dd4cd4..228185e 100644 --- a/q2_batch/plugin_setup.py +++ b/q2_batch/plugin_setup.py @@ -5,9 +5,8 @@ MetadataColumn, Categorical) from q2_batch import __version__ -from q2_batch._method import estimate, slurm_estimate, parallel_estimate -from q2_differential._type import FeatureTensor -from q2_differential._format import FeatureTensorNetCDFFormat, FeatureTensorNetCDFDirFmt +from q2_batch._method import estimate +from q2_types.feature_data import MonteCarloTensor from q2_types.feature_table import FeatureTable, Frequency @@ -21,6 +20,7 @@ ' for downstream plugins'), package='q2-batch') + plugin.methods.register_function( function=estimate, inputs={'counts': FeatureTable[Frequency]}, @@ -31,7 +31,7 @@ 'cores': Int }, outputs=[ - ('posterior', FeatureTensor) + ('posterior', MonteCarloTensor) ], input_descriptions={ "counts": "Input table of counts.", @@ -48,8 +48,8 @@ ), 'cores' : 'Number of cpu cores' }, - name='estimation', - description=("Computes batch effects from technical replicates"), + name='Batch effect estimation', + description=("Computes batch effects from technical replicates."), citations=[] ) diff --git a/q2_batch/tests/quick_test.py b/q2_batch/tests/quick_test.py new file mode 100644 index 0000000..0ed6e8c --- /dev/null +++ b/q2_batch/tests/quick_test.py @@ -0,0 +1,39 @@ +from dask_jobqueue import SLURMCluster +from dask.distributed import Client +from q2_batch._batch import _batch_func, _simulate, PoissonLogNormalBatch +import biom + +table, metadata = _simulate(n=100, d=20, depth=100) +table = biom.Table(table.values.T, + table.columns, table.index) +pln = PoissonLogNormalBatch( + table=table, + replicate_column="reps", + batch_column="batch", + metadata=metadata, + num_warmup=1000, + mu_scale=1, + reference_scale=5, + chains=1, + seed=42) + +pln.compile_model() +# pln.fit_model(dask_args={'n_workers': 1, 'threads_per_worker': 1}) +jobs=4 +cluster = SLURMCluster(cores=4, + processes=jobs, + memory='16GB', + walltime='01:00:00', + interface='ib0', + nanny=True, + death_timeout='300s', + local_directory='/scratch', + shebang='#!/usr/bin/env bash', + env_extra=["export TBB_CXX_TYPE=gcc"], + queue='ccb') +client = Client(cluster) +cluster.scale(jobs=jobs) +client.wait_for_workers(jobs) + +pln.fit_model(convert_to_inference=True) +pln.to_inference_object() diff --git a/q2_batch/tests/test_batch.py b/q2_batch/tests/test_batch.py index 209f9ea..d432c96 100644 --- a/q2_batch/tests/test_batch.py +++ b/q2_batch/tests/test_batch.py @@ -1,19 +1,43 @@ import unittest -from q2_batch._batch import _batch_func, _simulate +<<<<<<< HEAD +from q2_batch._batch import _simulate, PoissonLogNormalBatch +======= +from q2_batch._batch import _batch_func, _simulate, PoissonLogNormalBatch +>>>>>>> 483812bb6c590b85f3e236ee32b38d2f4eb77417 +from dask.distributed import Client, LocalCluster +import biom -class TestBatch(unittest.TestCase): +class TestPoissonLogNormalBatch(unittest.TestCase): def setUp(self): self.table, self.metadata = _simulate(n=100, d=10, depth=50) def test_batch(self): - res = _batch_func(self.table.values[:, 0], - replicates=self.metadata['reps'].values, - batches=self.metadata['batch'].values, - depth=self.table.sum(axis=1), - mc_samples=2000) - self.assertEqual(res.shape, (2000 * 4, 3)) + dask_args={'n_workers': 1, 'threads_per_worker': 1} + cluster = LocalCluster(**dask_args) + cluster.scale(dask_args['n_workers']) + client = Client(cluster) + table = biom.Table(self.table.values.T, + self.table.columns, self.table.index) + pln = PoissonLogNormalBatch( + table=table, + replicate_column="reps", + batch_column="batch", + metadata=self.metadata, + num_warmup=1000, + mu_scale=1, + reference_scale=5, + chains=1, + seed=42) + pln.compile_model() + dask_args={'n_workers': 1, 'threads_per_worker': 1} + cluster = LocalCluster(**dask_args) + cluster.scale(dask_args['n_workers']) + client = Client(cluster) + pln.fit_model() + inf = pln.to_inference_object() + self.assertEqual(inf['posterior']['mu'].shape, (10, 1, 1000)) if __name__ == '__main__': diff --git a/q2_batch/tests/test_method.py b/q2_batch/tests/test_method.py index cfb161c..9c5451e 100644 --- a/q2_batch/tests/test_method.py +++ b/q2_batch/tests/test_method.py @@ -3,12 +3,16 @@ from q2_batch._batch import _simulate import qiime2 import xarray as xr +from birdman.diagnostics import r2_score class TestBatchEstimation(unittest.TestCase): def setUp(self): self.table, self.metadata = _simulate(n=50, d=4, depth=30) + self.table = Table(self.table.values.T, + list(self.table.columns), + list(self.table.index)) def test_batch(self): res = estimate( @@ -19,7 +23,8 @@ def test_batch(self): cores=1 ) self.assertTrue(res is not None) - self.assertTrue(isinstance(res, xr.Dataset)) + res = r2_score(inf) + self.assertGreater(res['r2'], 0.3) if __name__ == '__main__': diff --git a/setup.py b/setup.py index c8bd7bc..dd9482c 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ """ classifiers = [s.strip() for s in classes.split('\n') if s] -description = ('QIIME2 plugin Quick and dirty batch effect correction') +description = ('QIIME2 plugin quick and dirty batch effect correction') setup(name='q2-batch', @@ -36,10 +36,10 @@ 'pandas', 'xarray', 'matplotlib', - 'q2-differential', + 'q2-types', 'cmdstanpy==0.9.68', + 'dask', # 'dask_jobsqueue', # optional - # 'dask' ], entry_points={ 'qiime2.plugins': ['q2-batch=q2_batch.plugin_setup:plugin']