From 754eeaee0caa6b1c5054cf18ab1bd1d66c163a41 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 24 May 2021 14:25:13 -0600 Subject: [PATCH 1/6] some flake8 and stuff --- q2_batch/_batch.py | 122 ++++++++++++++++++++++---- q2_batch/_method.py | 18 ++-- q2_batch/assets/batch_pln_single.stan | 36 +++++--- q2_batch/plugin_setup.py | 7 +- q2_batch/tests/test_batch.py | 18 +++- setup.py | 6 +- 6 files changed, 160 insertions(+), 47 deletions(-) diff --git a/q2_batch/_batch.py b/q2_batch/_batch.py index 81a294b..7a840f6 100644 --- a/q2_batch/_batch.py +++ b/q2_batch/_batch.py @@ -1,22 +1,16 @@ -import argparse -from biom import load_table +import biom import numpy as np import pandas as pd -import seaborn as sns from sklearn.preprocessing import LabelEncoder -import pickle import os from skbio.stats.composition import ilr_inv -import matplotlib.pyplot as plt -import pickle from cmdstanpy import CmdStanModel +from birdman import BaseModel import tempfile import json -def _batch_func(counts : np.array, replicates : np.array, - batches : np.array, depth : int, - mc_samples : int=1000) -> dict: +def _extract_replicates(replicates, batches): replicate_encoder = LabelEncoder() replicate_encoder.fit(replicates) replicate_ids = replicate_encoder.transform(replicates) @@ -31,20 +25,30 @@ def _batch_func(counts : np.array, replicates : np.array, batch_encoder = LabelEncoder() batch_encoder.fit(batches) batch_ids = batch_encoder.transform(batches) + + batch_ids = batch_ids.astype(np.int64) + 1 + ref_ids = ref_ids.astype(np.int64) + 1 + return ref_ids, replicate_ids, batch_ids + + +def _batch_func(counts: np.array, replicates: np.array, + batches: np.array, depth: int, + mc_samples: int = 1000) -> dict: + + ref_ids, replicate_ids, batch_ids = _extract_replicates( + replicates, batches) # Actual stan modeling code = os.path.join(os.path.dirname(__file__), 'assets/batch_pln_single.stan') - batch_ids = batch_ids.astype(np.int64) + 1 - ref_ids = ref_ids.astype(np.int64) + 1 sm = CmdStanModel(stan_file=code) dat = { - 'N' : len(counts), - 'R' : int(max(replicate_ids) + 1), - 'B' : int(max(batch_ids) + 1), - 'depth' : list(np.log(depth)), - 'y' : list(map(int, counts.astype(np.int64))), - 'ref_ids' : list(map(int, ref_ids )), - 'batch_ids' : list(map(int, batch_ids)) + 'N': len(counts), + 'R': int(max(replicate_ids) + 1), + 'B': int(max(batch_ids) + 1), + 'depth': list(np.log(depth)), + 'y': list(map(int, counts.astype(np.int64))), + 'ref_ids': list(map(int, ref_ids)), + 'batch_ids': list(map(int, batch_ids)) } with tempfile.TemporaryDirectory() as temp_dir_name: data_path = os.path.join(temp_dir_name, 'data.json') @@ -57,7 +61,7 @@ def _batch_func(counts : np.array, replicates : np.array, fit = sm.sample(data=data_path, iter_sampling=mc_samples, # inits=guess.optimized_params_dict, chains=4, iter_warmup=mc_samples // 2, - adapt_delta = 0.9, max_treedepth = 20) + adapt_delta=0.9, max_treedepth=20) fit.diagnose() mu = fit.stan_variable('mu') sigma = fit.stan_variable('sigma') @@ -71,6 +75,86 @@ def _batch_func(counts : np.array, replicates : np.array, return res +class PoissonLogNormalBatch(BaseModel): + """Fit Batch effects estimator with Poisson Log Normal + + Parameters: + ----------- + table: biom.table.Table + Feature table (features x samples) + batch_column : str + Column that specifies `batches` of interest that + cause technical artifacts. + replicate_column : str + Column that specifies technical replicates that + are spread across batches. + metadata: pd.DataFrame + Metadata for matching and status covariates. + num_iter: int + Number of posterior sample draws, defaults to 1000 + num_warmup: int + Number of posterior draws used for warmup, defaults to 500 + chains: int + Number of chains to use in MCMC, defaults to 4 + seed: float + Random seed to use for sampling, defaults to 42 + mu_scale : float + Standard deviation for prior distribution for mu + sigma_scale : float + Standard deviation for prior distribution for sigma + disp_scale : float + Standard deviation for prior distribution for disp + reference_scale : float + Standard deviation for prior distribution for reference + """ + def __init__(self, + table: biom.table.Table, + batch_column: str, + replicate_column: str, + metadata: pd.DataFrame, + num_iter: int = 1000, + num_warmup: int = 500, + adapt_delta: float = 0.9, + max_treedepth: float = 20, + chains: int = 4, + seed: float = 42, + mu_scale: float = 1, + sigma_scale: float = 1, + disp_scale: float = 1, + reference_scale: float = 10): + model_path = os.path.join(os.path.dirname(__file__), + 'assets/batch_pln_single.stan') + super(PoissonLogNormalBatch, self).__init__( + table, metadata, model_path, + num_iter, num_warmup, chains, seed, + parallelize_across="features") + # assemble replicate and batch ids + metadata = metadata.loc[table.ids()] + depth = table.sum(axis='sample') + replicates = metadata[replicate_column] + batches = metadata[batch_column] + ref_ids, replicate_ids, batch_ids = _extract_replicates( + replicates, batches) + self.dat = { + 'D': table.shape[0], # number of features + 'N': table.shape[1], # number of samples + 'R': int(max(replicate_ids) + 1), + 'B': int(max(batch_ids) + 1), + 'depth': list(np.log(depth)), + "y": table.matrix_data.todense().T.astype(int), + 'ref_ids': list(map(int, ref_ids)), + 'batch_ids': list(map(int, batch_ids)) + } + self.param_names = ["mu", "sigma", "disp", "batch", "reference"] + param_dict = { + "mu_scale": mu_scale, + "sigma_scale": sigma_scale, + "disp_scale": disp_scale, + "reference_scale": reference_scale + } + self.add_parameters(param_dict) + + def _simulate(n=100, d=10, depth=50): """ Simulate batch effects from Multinomial distribution diff --git a/q2_batch/_method.py b/q2_batch/_method.py index e8aef57..81f925f 100644 --- a/q2_batch/_method.py +++ b/q2_batch/_method.py @@ -5,14 +5,12 @@ import seaborn as sns from sklearn.preprocessing import LabelEncoder import pickle -import pystan import dask from q2_batch._batch import _batch_func import xarray as xr import qiime2 -# slow estimator def estimate(counts : pd.DataFrame, replicates : qiime2.CategoricalMetadataColumn, batches : qiime2.CategoricalMetadataColumn, @@ -44,12 +42,12 @@ def estimate(counts : pd.DataFrame, return samples -# Parallel estimation of batch effects def parallel_estimate(counts : pd.DataFrame, - replicate_column : qiime2.CategoricalMetadataColumn, - batch_column : qiime2.CategoricalMetadataColumn, - monte_carlo_samples : int, - cores=16, - memory='16 GB', - processes=4): - from dask_jobqueue import SLURMCluster + replicates : pd.Series, + batches : pd.Series, + monte_carlo_samples : int = 100, + cores : int = 1) -> xr.Dataset: + pass + +def phylogenetic_impute(): + pass diff --git a/q2_batch/assets/batch_pln_single.stan b/q2_batch/assets/batch_pln_single.stan index 11a05f3..e01e0c5 100644 --- a/q2_batch/assets/batch_pln_single.stan +++ b/q2_batch/assets/batch_pln_single.stan @@ -1,11 +1,16 @@ data { - int N; // number of samples - int R; // number of replicates - int B; // number of batchs - real depth[N]; // sequencing depths of microbes - int y[N]; // observed microbe abundances + int N; // number of samples + int R; // number of replicates + int B; // number of batchs + real depth[N]; // sequencing depths of microbes + int y[N]; // observed microbe abundances int ref_ids[N]; // locations of reference replicates int batch_ids[N]; // batch ids + // Priors + real mu_scale; + real sigma_scale; + real disp_scale; + real reference_scale; } parameters { @@ -20,11 +25,11 @@ parameters { model { vector[N] eta; // setting priors ... - disp ~ normal(0., 5); // weak overdispersion prior - mu ~ normal(0., 10.); // uninformed batch effects mean prior - sigma ~ normal(0., 5); // weak batch effects variance prior - batch ~ normal(mu, sigma); // random effects - reference ~ normal(0., 10.); // uninformed reference prior + disp ~ normal(0., disp_scale); // weak overdispersion prior + mu ~ normal(0., mu_scale); // uninformed batch effects mean prior + sigma ~ normal(0., sigma_scale); // weak batch effects variance prior + batch ~ normal(mu, sigma); // random effects + reference ~ normal(0., reference_scale); // uninformed reference prior // generating counts for (n in 1:N){ eta[n] = batch[batch_ids[n]] + reference[ref_ids[n]]; @@ -32,3 +37,14 @@ model { lam ~ normal(eta, disp); y ~ poisson_log(lam + to_vector(depth)); } + +generated quantities { + vector[N] y_predict; + vector[N] log_lhood; + for (n in 1:N){ + real eta = batch[batch_ids[n]] + reference[ref_ids[n]]; + real lam = normal_rng(eta, disp) + y_predict[n] = poisson_log_rng(lam + depth[n]); + log_lhood[n] = poisson_log_lpmf(y[n] | lam + depth[n]); + } +} diff --git a/q2_batch/plugin_setup.py b/q2_batch/plugin_setup.py index 94be9e8..5690d98 100644 --- a/q2_batch/plugin_setup.py +++ b/q2_batch/plugin_setup.py @@ -6,12 +6,10 @@ from q2_batch import __version__ from q2_batch._method import estimate -from q2_differential._type import FeatureTensor -from q2_differential._format import FeatureTensorNetCDFFormat, FeatureTensorNetCDFDirFmt +from q2_types.feature_data import MonteCarloTensor from q2_types.feature_table import FeatureTable, Frequency - plugin = qiime2.plugin.Plugin( name='batch', version=__version__, @@ -22,6 +20,7 @@ ' for downstream plugins'), package='q2-batch') + plugin.methods.register_function( function=estimate, inputs={'counts': FeatureTable[Frequency]}, @@ -32,7 +31,7 @@ 'cores': Int }, outputs=[ - ('posterior', FeatureTensor) + ('posterior', MonteCarloTensor) ], input_descriptions={ "counts": "Input table of counts.", diff --git a/q2_batch/tests/test_batch.py b/q2_batch/tests/test_batch.py index 209f9ea..ac8dc09 100644 --- a/q2_batch/tests/test_batch.py +++ b/q2_batch/tests/test_batch.py @@ -1,5 +1,5 @@ import unittest -from q2_batch._batch import _batch_func, _simulate +from q2_batch._batch import _batch_func, _simulate, PoissonLogNormalBatch class TestBatch(unittest.TestCase): @@ -16,5 +16,21 @@ def test_batch(self): self.assertEqual(res.shape, (2000 * 4, 3)) +class TestPoissonLogNormalBatch(TestBatch): + def test_batch(self): + pln = PoissonLogNormalBatch( + table=biom_table, + replicate_column="reps", + batch_column="batch", + metadata=metadata, + num_warmup=1000, + mu_scale=1, + reference_scale=5, + chains=1, + seed=42) + pln.compile_model() + pln.fit_model(jobs=4) + + if __name__ == '__main__': unittest.main() diff --git a/setup.py b/setup.py index 64d3426..e32e265 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ """ classifiers = [s.strip() for s in classes.split('\n') if s] -description = ('QIIME2 plugin Quick and dirty batch effect correction') +description = ('QIIME2 plugin quick and dirty batch effect correction') setup(name='q2-batch', @@ -36,10 +36,10 @@ 'pandas', 'xarray', 'matplotlib', - 'q2-differential', + 'q2-types', 'cmdstanpy==0.9.68', + 'dask', # 'dask_jobsqueue', # optional - # 'dask' ], entry_points={ 'qiime2.plugins': ['q2-batch=q2_batch.plugin_setup:plugin'] From 85da3b338cbb39a76c1150d15cc1bae2181df180 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 24 May 2021 17:33:23 -0600 Subject: [PATCH 2/6] adding unittest for PLN batch correction --- q2_batch/_batch.py | 38 ++++++++++++++++++++++++++- q2_batch/assets/batch_pln_single.stan | 8 +++--- q2_batch/tests/test_batch.py | 12 ++++++--- 3 files changed, 49 insertions(+), 9 deletions(-) diff --git a/q2_batch/_batch.py b/q2_batch/_batch.py index 7a840f6..5b2259c 100644 --- a/q2_batch/_batch.py +++ b/q2_batch/_batch.py @@ -5,9 +5,11 @@ import os from skbio.stats.composition import ilr_inv from cmdstanpy import CmdStanModel +from dask.distributed import Client, LocalCluster from birdman import BaseModel import tempfile import json +import time def _extract_replicates(replicates, batches): @@ -145,7 +147,6 @@ def __init__(self, 'ref_ids': list(map(int, ref_ids)), 'batch_ids': list(map(int, batch_ids)) } - self.param_names = ["mu", "sigma", "disp", "batch", "reference"] param_dict = { "mu_scale": mu_scale, "sigma_scale": sigma_scale, @@ -154,6 +155,41 @@ def __init__(self, } self.add_parameters(param_dict) + self.specify_model( + params=["mu", "sigma", "disp", "batch", "reference"], + dims={ + "beta": ["covariate", "feature"], + "phi": ["feature"], + "log_lhood": ["tbl_sample", "feature"], + "y_predict": ["tbl_sample", "feature"] + }, + coords={ + "feature": self.feature_names, + "tbl_sample": self.sample_names + }, + include_observed_data=True, + posterior_predictive="y_predict", + log_likelihood="log_lhood" + ) + + def fit_model(self, cluster_type: str = 'local', + sampler_args: dict = {}, + dask_args: dict = {}, + convert_to_inference: bool = False): + if cluster_type == 'local': + cluster = LocalCluster(**dask_args) + cluster.scale(dask_args['n_workers']) + client = Client(cluster) + elif cluster_type == 'slurm': + from dask_jobqueue import SLURMCluster + cluster = SLURMCluster(**dask_args) + cluster.scale(dask_args['n_workers']) + client = Client(cluster) + client.wait_for_workers(dask_args['n_workers']) + time.sleep(60) + super().fit_model(**sampler_args, + convert_to_inference=convert_to_inference) + def _simulate(n=100, d=10, depth=50): """ Simulate batch effects from Multinomial distribution diff --git a/q2_batch/assets/batch_pln_single.stan b/q2_batch/assets/batch_pln_single.stan index e01e0c5..0465e57 100644 --- a/q2_batch/assets/batch_pln_single.stan +++ b/q2_batch/assets/batch_pln_single.stan @@ -42,9 +42,9 @@ generated quantities { vector[N] y_predict; vector[N] log_lhood; for (n in 1:N){ - real eta = batch[batch_ids[n]] + reference[ref_ids[n]]; - real lam = normal_rng(eta, disp) - y_predict[n] = poisson_log_rng(lam + depth[n]); - log_lhood[n] = poisson_log_lpmf(y[n] | lam + depth[n]); + real eta_ = batch[batch_ids[n]] + reference[ref_ids[n]]; + real lam_ = normal_rng(eta_, disp); + y_predict[n] = poisson_log_rng(lam_ + depth[n]); + log_lhood[n] = poisson_log_lpmf(y[n] | lam_ + depth[n]); } } diff --git a/q2_batch/tests/test_batch.py b/q2_batch/tests/test_batch.py index ac8dc09..e147a3d 100644 --- a/q2_batch/tests/test_batch.py +++ b/q2_batch/tests/test_batch.py @@ -1,6 +1,6 @@ import unittest from q2_batch._batch import _batch_func, _simulate, PoissonLogNormalBatch - +import biom class TestBatch(unittest.TestCase): @@ -18,18 +18,22 @@ def test_batch(self): class TestPoissonLogNormalBatch(TestBatch): def test_batch(self): + table = biom.Table(self.table.values.T, + self.table.columns, self.table.index) pln = PoissonLogNormalBatch( - table=biom_table, + table=table, replicate_column="reps", batch_column="batch", - metadata=metadata, + metadata=self.metadata, num_warmup=1000, mu_scale=1, reference_scale=5, chains=1, seed=42) pln.compile_model() - pln.fit_model(jobs=4) + pln.fit_model(dask_args={'n_workers': 1, 'threads_per_worker': 1}) + inf = pln.to_inference_object() + self.assertEqual(inf['posterior']['mu'].shape, (10, 1, 1000)) if __name__ == '__main__': From fd4256f8981c21d3dd600b0d634343820bf6fc81 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 24 May 2021 17:38:55 -0600 Subject: [PATCH 3/6] adding quick test script --- q2_batch/tests/quick_test.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 q2_batch/tests/quick_test.py diff --git a/q2_batch/tests/quick_test.py b/q2_batch/tests/quick_test.py new file mode 100644 index 0000000..efd239d --- /dev/null +++ b/q2_batch/tests/quick_test.py @@ -0,0 +1,33 @@ +from q2_batch._batch import _batch_func, _simulate, PoissonLogNormalBatch +import biom + +table, metadata = _simulate(n=100, d=100, depth=100) +table = biom.Table(table.values.T, + table.columns, table.index) +pln = PoissonLogNormalBatch( + table=table, + replicate_column="reps", + batch_column="batch", + metadata=metadata, + num_warmup=1000, + mu_scale=1, + reference_scale=5, + chains=1, + seed=42) +pln.compile_model() +# pln.fit_model(dask_args={'n_workers': 1, 'threads_per_worker': 1}) +pln.fit_model(dask_args=dict( + cores=4, + processes=4, + memory='16GB', + walltime='10:00:00', + interface='ib0', + nanny=True, + death_timeout='300s', + local_directory='/scratch', + shebang='#!/usr/bin/env bash', + env_extra=["export TBB_CXX_TYPE=gcc"], + queue='ccb' +) +) +pln.to_inference_object() From 483812bb6c590b85f3e236ee32b38d2f4eb77417 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 24 May 2021 20:38:51 -0400 Subject: [PATCH 4/6] reshuffling around cluster support --- q2_batch/_batch.py | 19 ------------------- q2_batch/tests/quick_test.py | 36 +++++++++++++++++++++--------------- q2_batch/tests/test_batch.py | 7 ++++++- 3 files changed, 27 insertions(+), 35 deletions(-) diff --git a/q2_batch/_batch.py b/q2_batch/_batch.py index 5b2259c..19ccf05 100644 --- a/q2_batch/_batch.py +++ b/q2_batch/_batch.py @@ -5,7 +5,6 @@ import os from skbio.stats.composition import ilr_inv from cmdstanpy import CmdStanModel -from dask.distributed import Client, LocalCluster from birdman import BaseModel import tempfile import json @@ -172,24 +171,6 @@ def __init__(self, log_likelihood="log_lhood" ) - def fit_model(self, cluster_type: str = 'local', - sampler_args: dict = {}, - dask_args: dict = {}, - convert_to_inference: bool = False): - if cluster_type == 'local': - cluster = LocalCluster(**dask_args) - cluster.scale(dask_args['n_workers']) - client = Client(cluster) - elif cluster_type == 'slurm': - from dask_jobqueue import SLURMCluster - cluster = SLURMCluster(**dask_args) - cluster.scale(dask_args['n_workers']) - client = Client(cluster) - client.wait_for_workers(dask_args['n_workers']) - time.sleep(60) - super().fit_model(**sampler_args, - convert_to_inference=convert_to_inference) - def _simulate(n=100, d=10, depth=50): """ Simulate batch effects from Multinomial distribution diff --git a/q2_batch/tests/quick_test.py b/q2_batch/tests/quick_test.py index efd239d..0ed6e8c 100644 --- a/q2_batch/tests/quick_test.py +++ b/q2_batch/tests/quick_test.py @@ -1,7 +1,9 @@ +from dask_jobqueue import SLURMCluster +from dask.distributed import Client from q2_batch._batch import _batch_func, _simulate, PoissonLogNormalBatch import biom -table, metadata = _simulate(n=100, d=100, depth=100) +table, metadata = _simulate(n=100, d=20, depth=100) table = biom.Table(table.values.T, table.columns, table.index) pln = PoissonLogNormalBatch( @@ -14,20 +16,24 @@ reference_scale=5, chains=1, seed=42) + pln.compile_model() # pln.fit_model(dask_args={'n_workers': 1, 'threads_per_worker': 1}) -pln.fit_model(dask_args=dict( - cores=4, - processes=4, - memory='16GB', - walltime='10:00:00', - interface='ib0', - nanny=True, - death_timeout='300s', - local_directory='/scratch', - shebang='#!/usr/bin/env bash', - env_extra=["export TBB_CXX_TYPE=gcc"], - queue='ccb' -) -) +jobs=4 +cluster = SLURMCluster(cores=4, + processes=jobs, + memory='16GB', + walltime='01:00:00', + interface='ib0', + nanny=True, + death_timeout='300s', + local_directory='/scratch', + shebang='#!/usr/bin/env bash', + env_extra=["export TBB_CXX_TYPE=gcc"], + queue='ccb') +client = Client(cluster) +cluster.scale(jobs=jobs) +client.wait_for_workers(jobs) + +pln.fit_model(convert_to_inference=True) pln.to_inference_object() diff --git a/q2_batch/tests/test_batch.py b/q2_batch/tests/test_batch.py index e147a3d..d2cff9d 100644 --- a/q2_batch/tests/test_batch.py +++ b/q2_batch/tests/test_batch.py @@ -1,5 +1,6 @@ import unittest from q2_batch._batch import _batch_func, _simulate, PoissonLogNormalBatch +from dask.distributed import Client, LocalCluster import biom class TestBatch(unittest.TestCase): @@ -31,7 +32,11 @@ def test_batch(self): chains=1, seed=42) pln.compile_model() - pln.fit_model(dask_args={'n_workers': 1, 'threads_per_worker': 1}) + dask_args={'n_workers': 1, 'threads_per_worker': 1} + cluster = LocalCluster(**dask_args) + cluster.scale(dask_args['n_workers']) + client = Client(cluster) + pln.fit_model() inf = pln.to_inference_object() self.assertEqual(inf['posterior']['mu'].shape, (10, 1, 1000)) From 22f0a31c7005ea0213f084d39a3a738b2527fe70 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Tue, 25 May 2021 22:36:57 -0600 Subject: [PATCH 5/6] basic tests passing --- q2_batch/_batch.py | 75 ++++----------------------- q2_batch/_method.py | 70 ++++++++++++------------- q2_batch/assets/batch_pln_single.stan | 4 +- q2_batch/tests/test_batch.py | 22 ++++---- q2_batch/tests/test_method.py | 17 +++--- 5 files changed, 64 insertions(+), 124 deletions(-) diff --git a/q2_batch/_batch.py b/q2_batch/_batch.py index 5b2259c..c016a27 100644 --- a/q2_batch/_batch.py +++ b/q2_batch/_batch.py @@ -33,50 +33,6 @@ def _extract_replicates(replicates, batches): return ref_ids, replicate_ids, batch_ids -def _batch_func(counts: np.array, replicates: np.array, - batches: np.array, depth: int, - mc_samples: int = 1000) -> dict: - - ref_ids, replicate_ids, batch_ids = _extract_replicates( - replicates, batches) - # Actual stan modeling - code = os.path.join(os.path.dirname(__file__), - 'assets/batch_pln_single.stan') - sm = CmdStanModel(stan_file=code) - dat = { - 'N': len(counts), - 'R': int(max(replicate_ids) + 1), - 'B': int(max(batch_ids) + 1), - 'depth': list(np.log(depth)), - 'y': list(map(int, counts.astype(np.int64))), - 'ref_ids': list(map(int, ref_ids)), - 'batch_ids': list(map(int, batch_ids)) - } - with tempfile.TemporaryDirectory() as temp_dir_name: - data_path = os.path.join(temp_dir_name, 'data.json') - with open(data_path, 'w') as f: - json.dump(dat, f) - # Obtain an initial guess with MLE - # guess = sm.optimize(data=data_path, inits=0) - # see https://mattocci27.github.io/assets/poilog.html - # for recommended parameters for poisson log normal - fit = sm.sample(data=data_path, iter_sampling=mc_samples, - # inits=guess.optimized_params_dict, - chains=4, iter_warmup=mc_samples // 2, - adapt_delta=0.9, max_treedepth=20) - fit.diagnose() - mu = fit.stan_variable('mu') - sigma = fit.stan_variable('sigma') - disp = fit.stan_variable('disp') - res = pd.DataFrame({ - 'mu': mu, - 'sigma': sigma, - 'disp': disp}) - # TODO: this doesn't seem to work atm, but its fixed upstream - # res = fit.summary() - return res - - class PoissonLogNormalBatch(BaseModel): """Fit Batch effects estimator with Poisson Log Normal @@ -107,7 +63,14 @@ class PoissonLogNormalBatch(BaseModel): disp_scale : float Standard deviation for prior distribution for disp reference_scale : float - Standard deviation for prior distribution for reference + Mean for prior distribution for reference samples + reference_scale : float + Standard deviation for prior distribution for reference samples + + Notes + ----- + The priors for the reference are defaults for amplicon sequences + See https://github.com/mortonjt/q2-matchmaker/issues/24 """ def __init__(self, table: biom.table.Table, @@ -123,7 +86,8 @@ def __init__(self, mu_scale: float = 1, sigma_scale: float = 1, disp_scale: float = 1, - reference_scale: float = 10): + reference_loc: float = -5, + reference_scale: float = 3): model_path = os.path.join(os.path.dirname(__file__), 'assets/batch_pln_single.stan') super(PoissonLogNormalBatch, self).__init__( @@ -151,6 +115,7 @@ def __init__(self, "mu_scale": mu_scale, "sigma_scale": sigma_scale, "disp_scale": disp_scale, + "reference_loc": reference_loc, "reference_scale": reference_scale } self.add_parameters(param_dict) @@ -172,24 +137,6 @@ def __init__(self, log_likelihood="log_lhood" ) - def fit_model(self, cluster_type: str = 'local', - sampler_args: dict = {}, - dask_args: dict = {}, - convert_to_inference: bool = False): - if cluster_type == 'local': - cluster = LocalCluster(**dask_args) - cluster.scale(dask_args['n_workers']) - client = Client(cluster) - elif cluster_type == 'slurm': - from dask_jobqueue import SLURMCluster - cluster = SLURMCluster(**dask_args) - cluster.scale(dask_args['n_workers']) - client = Client(cluster) - client.wait_for_workers(dask_args['n_workers']) - time.sleep(60) - super().fit_model(**sampler_args, - convert_to_inference=convert_to_inference) - def _simulate(n=100, d=10, depth=50): """ Simulate batch effects from Multinomial distribution diff --git a/q2_batch/_method.py b/q2_batch/_method.py index 81f925f..ce706f4 100644 --- a/q2_batch/_method.py +++ b/q2_batch/_method.py @@ -6,48 +6,48 @@ from sklearn.preprocessing import LabelEncoder import pickle import dask -from q2_batch._batch import _batch_func +from q2_batch._batch import PoissonLogNormalBatch +from dask.distributed import Client, LocalCluster +from gneiss.util import match import xarray as xr import qiime2 -def estimate(counts : pd.DataFrame, +def _poisson_log_normal_estimate(counts, + replicates, + batches, + monte_carlo_samples, + cores, + **sampler_args): + metadata = pd.DataFrame({'batch': batches, 'reps': replicates}) + table, metadata = match(table, metadata) + pln = PoissonLogNormalBatch( + table=table, + replicate_column="reps", + batch_column="batch", + metadata=metadata, + **sampler_args) + pln.compile_model() + pln.fit_model(dask_args={'n_workers': cores, 'threads_per_worker': 1}) + samples = pln.to_inference_object() + return samples + + +def estimate(counts : biom.Table, replicates : qiime2.CategoricalMetadataColumn, batches : qiime2.CategoricalMetadataColumn, monte_carlo_samples : int = 100, cores : int = 1) -> xr.Dataset: - replicates = replicates.to_series().values - batches = batches.to_series().values - # TODO: need to speed this up with either joblib or something - depth = counts.sum(axis=1) - pfunc = lambda x: _batch_func(np.array(x.values), replicates, batches, - depth, monte_carlo_samples) - if cores > 1: - try: - import dask.dataframe as dd - dcounts = dd.from_pandas(counts.T, npartitions=cores) - res = dcounts.apply(pfunc, axis=1) - resdf = res.compute(scheduler='processes') - data_df = list(resdf.values) - except: - data_df = list(counts.T.apply(pfunc, axis=1).values) - else: - data_df = list(counts.T.apply(pfunc, axis=1).values) - samples = xr.concat([df.to_xarray() for df in data_df], dim="features") - samples = samples.assign_coords(coords={ - 'features' : counts.columns, - 'monte_carlo_samples' : np.arange(monte_carlo_samples) - }) + replicates = replicates.to_series() + batches = batches.to_series() + # Build me a cluster! + dask_args={'n_workers': cores, 'threads_per_worker': 1} + cluster = LocalCluster(**dask_args) + cluster.scale(dask_args['n_workers']) + client = Client(cluster) + samples = _poisson_log_normal_estimate( + counts, replicates, batches, + monte_carlo_samples, cores, + **sampler_args) return samples - - -def parallel_estimate(counts : pd.DataFrame, - replicates : pd.Series, - batches : pd.Series, - monte_carlo_samples : int = 100, - cores : int = 1) -> xr.Dataset: - pass - -def phylogenetic_impute(): - pass diff --git a/q2_batch/assets/batch_pln_single.stan b/q2_batch/assets/batch_pln_single.stan index 0465e57..e6765c5 100644 --- a/q2_batch/assets/batch_pln_single.stan +++ b/q2_batch/assets/batch_pln_single.stan @@ -10,6 +10,7 @@ data { real mu_scale; real sigma_scale; real disp_scale; + real reference_loc; real reference_scale; } @@ -29,7 +30,8 @@ model { mu ~ normal(0., mu_scale); // uninformed batch effects mean prior sigma ~ normal(0., sigma_scale); // weak batch effects variance prior batch ~ normal(mu, sigma); // random effects - reference ~ normal(0., reference_scale); // uninformed reference prior + // uninformed reference prior + reference ~ normal(reference_loc, reference_scale); // generating counts for (n in 1:N){ eta[n] = batch[batch_ids[n]] + reference[ref_ids[n]]; diff --git a/q2_batch/tests/test_batch.py b/q2_batch/tests/test_batch.py index e147a3d..9943d21 100644 --- a/q2_batch/tests/test_batch.py +++ b/q2_batch/tests/test_batch.py @@ -1,23 +1,19 @@ import unittest -from q2_batch._batch import _batch_func, _simulate, PoissonLogNormalBatch +from q2_batch._batch import _simulate, PoissonLogNormalBatch +from dask.distributed import Client, LocalCluster import biom -class TestBatch(unittest.TestCase): + +class TestPoissonLogNormalBatch(unittest.TestCase): def setUp(self): self.table, self.metadata = _simulate(n=100, d=10, depth=50) def test_batch(self): - res = _batch_func(self.table.values[:, 0], - replicates=self.metadata['reps'].values, - batches=self.metadata['batch'].values, - depth=self.table.sum(axis=1), - mc_samples=2000) - self.assertEqual(res.shape, (2000 * 4, 3)) - - -class TestPoissonLogNormalBatch(TestBatch): - def test_batch(self): + dask_args={'n_workers': 1, 'threads_per_worker': 1} + cluster = LocalCluster(**dask_args) + cluster.scale(dask_args['n_workers']) + client = Client(cluster) table = biom.Table(self.table.values.T, self.table.columns, self.table.index) pln = PoissonLogNormalBatch( @@ -31,7 +27,7 @@ def test_batch(self): chains=1, seed=42) pln.compile_model() - pln.fit_model(dask_args={'n_workers': 1, 'threads_per_worker': 1}) + pln.fit_model() inf = pln.to_inference_object() self.assertEqual(inf['posterior']['mu'].shape, (10, 1, 1000)) diff --git a/q2_batch/tests/test_method.py b/q2_batch/tests/test_method.py index d4e1f97..1d8d975 100644 --- a/q2_batch/tests/test_method.py +++ b/q2_batch/tests/test_method.py @@ -3,12 +3,16 @@ from q2_batch._batch import _simulate import qiime2 import xarray as xr +from birdman.diagnostics import r2_score class TestBatchEstimation(unittest.TestCase): def setUp(self): self.table, self.metadata = _simulate(n=50, d=4, depth=30) + self.table = Table(self.table.values.T, + list(self.table.columns), + list(self.table.index)) def test_batch(self): res = estimate( @@ -19,18 +23,9 @@ def test_batch(self): cores=1 ) self.assertTrue(res is not None) - self.assertTrue(isinstance(res, xr.Dataset)) + res = r2_score(inf) + self.assertGreater(res['r2'], 0.3) - def test_batch_dask(self): - res = estimate( - self.table, - replicates=qiime2.CategoricalMetadataColumn(self.metadata['reps']), - batches=qiime2.CategoricalMetadataColumn(self.metadata['batch']), - monte_carlo_samples=100, - cores=4 - ) - self.assertTrue(res is not None) - self.assertTrue(isinstance(res, xr.Dataset)) if __name__ == '__main__': From aeb2fdd9770c463866968d79bda4e12bbec5348c Mon Sep 17 00:00:00 2001 From: mortonjt Date: Tue, 25 May 2021 22:37:30 -0600 Subject: [PATCH 6/6] some doc updates --- q2_batch/plugin_setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/q2_batch/plugin_setup.py b/q2_batch/plugin_setup.py index 5690d98..1ec069f 100644 --- a/q2_batch/plugin_setup.py +++ b/q2_batch/plugin_setup.py @@ -48,7 +48,7 @@ ), 'cores' : 'Number of cpu cores' }, - name='estimation', - description=("Computes batch effects from technical replicates"), + name='Batch effect estimation', + description=("Computes batch effects from technical replicates."), citations=[] )