Skip to content

Commit

Permalink
Merge pull request #11 from mortonjt/distributed
Browse files Browse the repository at this point in the history
WIP: Distributed
  • Loading branch information
mortonjt authored Apr 15, 2021
2 parents c821ee1 + 5e3e525 commit 161e1d4
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 23 deletions.
4 changes: 2 additions & 2 deletions q2_batch/_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def _batch_func(counts : np.array, replicates : np.array,
batches : np.array, depth : int,
mc_samples : int=1000) -> dict:
mc_samples : int=1000, chains=4) -> dict:
replicate_encoder = LabelEncoder()
replicate_encoder.fit(replicates)
replicate_ids = replicate_encoder.transform(replicates)
Expand Down Expand Up @@ -56,7 +56,7 @@ def _batch_func(counts : np.array, replicates : np.array,
# for recommended parameters for poisson log normal
fit = sm.sample(data=data_path, iter_sampling=mc_samples,
# inits=guess.optimized_params_dict,
chains=4, iter_warmup=mc_samples // 2,
chains=chains, iter_warmup=mc_samples // 2,
adapt_delta = 0.9, max_treedepth = 20)
fit.diagnose()
mu = fit.stan_variable('mu')
Expand Down
92 changes: 86 additions & 6 deletions q2_batch/_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,92 @@ def estimate(counts : pd.DataFrame,
return samples


# Parallel estimation of batch effects
def slurm_estimate(counts : pd.DataFrame,
batches : qiime2.CategoricalMetadataColumn,
replicates : qiime2.CategoricalMetadataColumn,
monte_carlo_samples : int,
cores : int = 4,
processes : int = 4,
nodes : int = 2,
memory : str = '16GB',
walltime : str = '01:00:00',
queue : str = '') -> xr.Dataset:
from dask_jobqueue import SLURMCluster
from dask.distributed import Client
import dask.dataframe as dd
import logging
logging.basicConfig(format='%(levelname)s:%(message)s',
level=logging.DEBUG)
cluster = SLURMCluster(cores=cores,
processes=processes,
memory=memory,
walltime=walltime,
interface='ib0',
env_extra=["export TBB_CXX_TYPE=gcc"],
queue=queue)
cluster.scale(jobs=nodes)
print(cluster.job_script())
client = Client(cluster)
# match everything up
replicates = replicates.to_series()
batches = batches.to_series()
idx = list(set(counts.index) & set(replicates.index) & set(batches.index))
counts, replicates, batches = [x.loc[idx] for x in
(counts, replicates, batches)]
replicates, batches = replicates.values, batches.values
depth = counts.sum(axis=1)
pfunc = lambda x: _batch_func(np.array(x.values), replicates, batches,
depth, monte_carlo_samples)
print('Partitions', cores * nodes * processes * 4)
dcounts = dd.from_pandas(
counts.T, npartitions=cores * nodes * processes * 4)
dcounts = client.persist(dcounts)
res = dcounts.apply(pfunc, axis=1)
resdf = res.compute(scheduler='processes')
data_df = list(resdf.values)
samples = xr.concat([df.to_xarray() for df in data_df], dim="features")
samples = samples.assign_coords(coords={
'features' : counts.columns,
'monte_carlo_samples' : np.arange(monte_carlo_samples)
})
return samples


# Parallel estimation of batch effects
def parallel_estimate(counts : pd.DataFrame,
replicate_column : qiime2.CategoricalMetadataColumn,
batch_column : qiime2.CategoricalMetadataColumn,
batches : qiime2.CategoricalMetadataColumn,
replicates : qiime2.CategoricalMetadataColumn,
monte_carlo_samples : int,
cores=16,
memory='16 GB',
processes=4):
from dask_jobqueue import SLURMCluster
scheduler_json : str,
partitions : int = 100) -> xr.Dataset:
from dask.distributed import Client
import dask.dataframe as dd
import logging
logging.basicConfig(format='%(levelname)s:%(message)s',
level=logging.DEBUG)
client = Client(scheduler_file=scheduler_json)

# match everything up
replicates = replicates.to_series()
batches = batches.to_series()
idx = list(set(counts.index) & set(replicates.index) & set(batches.index))
counts, replicates, batches = [x.loc[idx] for x in
(counts, replicates, batches)]
replicates, batches = replicates.values, batches.values
depth = counts.sum(axis=1)
pfunc = lambda x: _batch_func(np.array(x.values), replicates, batches,
depth, monte_carlo_samples, chains=1)
dcounts = dd.from_pandas(counts.T, npartitions=partitions)
# dcounts = client.persist(dcounts)
res = dcounts.apply(pfunc, axis=1)
# resdf = client.compute(res)
resdf = res.compute(scheduler='processes')
data_df = list(resdf.values)

samples = xr.concat([df.to_xarray() for df in data_df], dim="features")
samples = samples.assign_coords(coords={
'features' : counts.columns,
'monte_carlo_samples' : np.arange(monte_carlo_samples)
})
return samples
4 changes: 2 additions & 2 deletions q2_batch/assets/batch_pln_single.stan
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ model {
vector[N] eta;
// setting priors ...
disp ~ normal(0., 5); // weak overdispersion prior
mu ~ normal(0., 10.); // uninformed batch effects mean prior
sigma ~ normal(0., 5); // weak batch effects variance prior
mu ~ normal(0., 1.); // strong batch effects mean prior
sigma ~ normal(0., 1.); // strong batch effects variance prior
batch ~ normal(mu, sigma); // random effects
reference ~ normal(0., 10.); // uninformed reference prior
// generating counts
Expand Down
81 changes: 79 additions & 2 deletions q2_batch/plugin_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
MetadataColumn, Categorical)

from q2_batch import __version__
from q2_batch._method import estimate
from q2_batch._method import estimate, slurm_estimate, parallel_estimate
from q2_differential._type import FeatureTensor
from q2_differential._format import FeatureTensorNetCDFFormat, FeatureTensorNetCDFDirFmt
from q2_types.feature_table import FeatureTable, Frequency



plugin = qiime2.plugin.Plugin(
name='batch',
version=__version__,
Expand Down Expand Up @@ -53,3 +52,81 @@
description=("Computes batch effects from technical replicates"),
citations=[]
)


plugin.methods.register_function(
function=slurm_estimate,
inputs={'counts': FeatureTable[Frequency]},
parameters={
'batches': MetadataColumn[Categorical],
'replicates': MetadataColumn[Categorical],
'monte_carlo_samples': Int,
'cores': Int,
'processes': Int,
'nodes': Int,
'memory': Str,
'walltime': Str,
'queue': Str
},
outputs=[
('posterior', FeatureTensor)
],
input_descriptions={
"counts": "Input table of counts.",
},
output_descriptions={
'posterior': ('Output posterior distribution of batch effect'),
},
parameter_descriptions={
'batches': ('Specifies the batch ids'),
'replicates': ('Specifies the technical replicates.'),
'monte_carlo_samples': (
'Number of monte carlo samples to draw from '
'posterior distribution.'
),
'cores' : 'Number of cpu cores per process',
'processes' : 'Number of processes',
'nodes' : 'Number of nodes',
'memory' : "Amount of memory per process (default: '16GB'",
'walltime' : "Amount of time to spend on each worker (default : '01:00:00')",
'queue' : "Processing queue"
},
name='parallel estimation on slurm',
description=("Computes batch effects from technical replicates on a slurm cluster"),
citations=[]
)


plugin.methods.register_function(
function=parallel_estimate,
inputs={'counts': FeatureTable[Frequency]},
parameters={
'batches': MetadataColumn[Categorical],
'replicates': MetadataColumn[Categorical],
'monte_carlo_samples': Int,
'scheduler_json': Str,
'partitions': Int,
},
outputs=[
('posterior', FeatureTensor)
],
input_descriptions={
"counts": "Input table of counts.",
},
output_descriptions={
'posterior': ('Output posterior distribution of batch effect'),
},
parameter_descriptions={
'batches': ('Specifies the batch ids'),
'replicates': ('Specifies the technical replicates.'),
'monte_carlo_samples': (
'Number of monte carlo samples to draw from '
'posterior distribution.'
),
'scheduler_json' : 'Scheduler details in json format.',
'partitions' : 'Number of partitions to segment data.'
},
name='parallel estimation',
description=("Computes batch effects from technical replicates on a cluster"),
citations=[]
)
11 changes: 0 additions & 11 deletions q2_batch/tests/test_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,6 @@ def test_batch(self):
self.assertTrue(res is not None)
self.assertTrue(isinstance(res, xr.Dataset))

def test_batch_dask(self):
res = estimate(
self.table,
replicates=qiime2.CategoricalMetadataColumn(self.metadata['reps']),
batches=qiime2.CategoricalMetadataColumn(self.metadata['batch']),
monte_carlo_samples=100,
cores=4
)
self.assertTrue(res is not None)
self.assertTrue(isinstance(res, xr.Dataset))


if __name__ == '__main__':
unittest.main()
96 changes: 96 additions & 0 deletions scripts/batch_estimate_slurm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import qiime2
import argparse
from dask_jobqueue import SLURMCluster
from dask.distributed import Client
import dask
import dask.dataframe as dd
import dask.array as da
from biom import load_table
import pandas as pd
import numpy as np
import xarray as xr
from q2_batch._batch import _batch_func
import time
import logging
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)

parser = argparse.ArgumentParser()
parser.add_argument(
'--biom-table', help='Biom table of counts.', required=True)
parser.add_argument(
'--metadata-file', help='Sample metadata file.', required=True)
parser.add_argument(
'--batches', help='Column specifying batches.', required=True)
parser.add_argument(
'--replicates', help='Column specifying replicates.', required=True)
parser.add_argument(
'--monte-carlo-samples', help='Number of monte carlo samples.',
type=int, required=False, default=1000)
parser.add_argument(
'--cores', help='Number of cores per process.', type=int, required=False, default=1)
parser.add_argument(
'--processes', help='Number of processes per node.', type=int, required=False, default=1)
parser.add_argument(
'--nodes', help='Number of nodes.', type=int, required=False, default=1)
parser.add_argument(
'--memory', help='Memory allocation size.', type=str, required=False, default='16GB')
parser.add_argument(
'--walltime', help='Walltime.', type=str, required=False, default='01:00:00')
parser.add_argument(
'--interface', help='Interface for communication', type=str, required=False, default='eth0')
parser.add_argument(
'--queue', help='Queue to submit job to.', type=str, required=True)
parser.add_argument(
'--output-tensor', help='Output tensor.', type=str, required=True)

args = parser.parse_args()
print(args)
cluster = SLURMCluster(cores=args.cores,
processes=args.processes,
memory=args.memory,
walltime=args.walltime,
interface=args.interface,
nanny=True,
death_timeout='15s',
local_directory='/tmp',
shebang='#!/usr/bin/env bash',
env_extra=["export TBB_CXX_TYPE=gcc"],
queue=args.queue)
print(cluster.job_script())
cluster.scale(jobs=args.nodes)
client = Client(cluster)
print(client)
client.wait_for_workers(args.nodes)
time.sleep(15)
print(cluster.scheduler.workers)
table = load_table(args.biom_table)
counts = pd.DataFrame(np.array(table.matrix_data.todense()).T,
index=table.ids(),
columns=table.ids(axis='observation'))
metadata = pd.read_table(args.metadata_file, index_col=0)
replicates = metadata[args.replicates]
batches = metadata[args.batches]
# match everything up
idx = list(set(counts.index) & set(replicates.index) & set(batches.index))
counts, replicates, batches = [x.loc[idx] for x in
(counts, replicates, batches)]
replicates, batches = replicates.values, batches.values
depth = counts.sum(axis=1)
pfunc = lambda x: _batch_func(x, replicates, batches,
depth, args.monte_carlo_samples, chains=1)
dcounts = da.from_array(counts.values.T, chunks=(counts.T.shape))
print('Dimensions', counts.shape, dcounts.shape, len(counts.columns))

res = []
for d in range(dcounts.shape[0]):
r = dask.delayed(pfunc)(dcounts[d])
res.append(r)
print('Res length', len(res))
futures = dask.persist(*res)
resdf = dask.compute(futures)
data_df = list(resdf[0])
samples = xr.concat([df.to_xarray() for df in data_df], dim="features")
samples = samples.assign_coords(coords={
'features' : table.ids(axis='observation'),
'monte_carlo_samples' : np.arange(args.monte_carlo_samples)})
samples.to_netcdf(args.output_tensor)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,6 @@
package_data={
"q2_batch": ['assets/batch_nb_single.stan'],
},
scripts=glob('scripts/*'),
classifiers=classifiers,
)

0 comments on commit 161e1d4

Please sign in to comment.