Skip to content

Commit

Permalink
Merge pull request #13 from mortonjt/birdman
Browse files Browse the repository at this point in the history
WIP : Birdman dependency
  • Loading branch information
mortonjt authored May 26, 2021
2 parents 161e1d4 + 7a47000 commit f44a193
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 188 deletions.
156 changes: 114 additions & 42 deletions q2_batch/_batch.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import argparse
from biom import load_table
import biom
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
import pickle
import os
from skbio.stats.composition import ilr_inv
import matplotlib.pyplot as plt
import pickle
from cmdstanpy import CmdStanModel
from birdman import BaseModel
import tempfile
import json
import time


<<<<<<< HEAD
def _extract_replicates(replicates, batches):
=======
def _batch_func(counts : np.array, replicates : np.array,
batches : np.array, depth : int,
mc_samples : int=1000, chains=4) -> dict:
>>>>>>> 161e1d4b899815d62d3a640467f68f46e8621004
replicate_encoder = LabelEncoder()
replicate_encoder.fit(replicates)
replicate_ids = replicate_encoder.transform(replicates)
Expand All @@ -31,44 +32,115 @@ def _batch_func(counts : np.array, replicates : np.array,
batch_encoder = LabelEncoder()
batch_encoder.fit(batches)
batch_ids = batch_encoder.transform(batches)
# Actual stan modeling
code = os.path.join(os.path.dirname(__file__),
'assets/batch_pln_single.stan')

batch_ids = batch_ids.astype(np.int64) + 1
ref_ids = ref_ids.astype(np.int64) + 1
sm = CmdStanModel(stan_file=code)
dat = {
'N' : counts.shape[0],
'R' : int(max(ref_ids) + 1),
'B' : int(max(batch_ids) + 1),
'depth' : list(np.log(depth)),
'y' : list(map(int, counts.astype(np.int64))),
'ref_ids' : list(map(int, ref_ids )),
'batch_ids' : list(map(int, batch_ids))
}
with tempfile.TemporaryDirectory() as temp_dir_name:
data_path = os.path.join(temp_dir_name, 'data.json')
with open(data_path, 'w') as f:
json.dump(dat, f)
# Obtain an initial guess with MLE
# guess = sm.optimize(data=data_path, inits=0)
# see https://mattocci27.github.io/assets/poilog.html
# for recommended parameters for poisson log normal
fit = sm.sample(data=data_path, iter_sampling=mc_samples,
# inits=guess.optimized_params_dict,
chains=chains, iter_warmup=mc_samples // 2,
adapt_delta = 0.9, max_treedepth = 20)
fit.diagnose()
mu = fit.stan_variable('mu')
sigma = fit.stan_variable('sigma')
disp = fit.stan_variable('disp')
res = pd.DataFrame({
'mu': mu,
'sigma': sigma,
'disp': disp})
# TODO: this doesn't seem to work atm, but its fixed upstream
# res = fit.summary()
return res
return ref_ids, replicate_ids, batch_ids


class PoissonLogNormalBatch(BaseModel):
"""Fit Batch effects estimator with Poisson Log Normal
Parameters:
-----------
table: biom.table.Table
Feature table (features x samples)
batch_column : str
Column that specifies `batches` of interest that
cause technical artifacts.
replicate_column : str
Column that specifies technical replicates that
are spread across batches.
metadata: pd.DataFrame
Metadata for matching and status covariates.
num_iter: int
Number of posterior sample draws, defaults to 1000
num_warmup: int
Number of posterior draws used for warmup, defaults to 500
chains: int
Number of chains to use in MCMC, defaults to 4
seed: float
Random seed to use for sampling, defaults to 42
mu_scale : float
Standard deviation for prior distribution for mu
sigma_scale : float
Standard deviation for prior distribution for sigma
disp_scale : float
Standard deviation for prior distribution for disp
reference_scale : float
Mean for prior distribution for reference samples
reference_scale : float
Standard deviation for prior distribution for reference samples
Notes
-----
The priors for the reference are defaults for amplicon sequences
See https://github.com/mortonjt/q2-matchmaker/issues/24
"""
def __init__(self,
table: biom.table.Table,
batch_column: str,
replicate_column: str,
metadata: pd.DataFrame,
num_iter: int = 1000,
num_warmup: int = 500,
adapt_delta: float = 0.9,
max_treedepth: float = 20,
chains: int = 4,
seed: float = 42,
mu_scale: float = 1,
sigma_scale: float = 1,
disp_scale: float = 1,
reference_loc: float = -5,
reference_scale: float = 3):
model_path = os.path.join(os.path.dirname(__file__),
'assets/batch_pln_single.stan')
super(PoissonLogNormalBatch, self).__init__(
table, metadata, model_path,
num_iter, num_warmup, chains, seed,
parallelize_across="features")
# assemble replicate and batch ids
metadata = metadata.loc[table.ids()]
depth = table.sum(axis='sample')
replicates = metadata[replicate_column]
batches = metadata[batch_column]
ref_ids, replicate_ids, batch_ids = _extract_replicates(
replicates, batches)
self.dat = {
'D': table.shape[0], # number of features
'N': table.shape[1], # number of samples
'R': int(max(replicate_ids) + 1),
'B': int(max(batch_ids) + 1),
'depth': list(np.log(depth)),
"y": table.matrix_data.todense().T.astype(int),
'ref_ids': list(map(int, ref_ids)),
'batch_ids': list(map(int, batch_ids))
}
param_dict = {
"mu_scale": mu_scale,
"sigma_scale": sigma_scale,
"disp_scale": disp_scale,
"reference_loc": reference_loc,
"reference_scale": reference_scale
}
self.add_parameters(param_dict)

self.specify_model(
params=["mu", "sigma", "disp", "batch", "reference"],
dims={
"beta": ["covariate", "feature"],
"phi": ["feature"],
"log_lhood": ["tbl_sample", "feature"],
"y_predict": ["tbl_sample", "feature"]
},
coords={
"feature": self.feature_names,
"tbl_sample": self.sample_names
},
include_observed_data=True,
posterior_predictive="y_predict",
log_likelihood="log_lhood"
)


def _simulate(n=100, d=10, depth=50):
Expand Down
150 changes: 32 additions & 118 deletions q2_batch/_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,134 +5,48 @@
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
import pickle
import pystan
import dask
from q2_batch._batch import _batch_func
from q2_batch._batch import PoissonLogNormalBatch
from dask.distributed import Client, LocalCluster
from gneiss.util import match
import xarray as xr
import qiime2


# slow estimator
def estimate(counts : pd.DataFrame,
def _poisson_log_normal_estimate(counts,
replicates,
batches,
monte_carlo_samples,
cores,
**sampler_args):
metadata = pd.DataFrame({'batch': batches, 'reps': replicates})
table, metadata = match(table, metadata)
pln = PoissonLogNormalBatch(
table=table,
replicate_column="reps",
batch_column="batch",
metadata=metadata,
**sampler_args)
pln.compile_model()
pln.fit_model(dask_args={'n_workers': cores, 'threads_per_worker': 1})
samples = pln.to_inference_object()
return samples


def estimate(counts : biom.Table,
replicates : qiime2.CategoricalMetadataColumn,
batches : qiime2.CategoricalMetadataColumn,
monte_carlo_samples : int = 100,
cores : int = 1) -> xr.Dataset:
# match everything up
replicates = replicates.to_series()
batches = batches.to_series()
idx = list(set(counts.index) & set(replicates.index) & set(batches.index))
counts, replicates, batches = [x.loc[idx] for x in
(counts, replicates, batches)]
replicates, batches = replicates.values, batches.values
depth = counts.sum(axis=1)
pfunc = lambda x: _batch_func(np.array(x.values), replicates, batches,
depth, monte_carlo_samples)
if cores > 1:
try:
import dask.dataframe as dd
dcounts = dd.from_pandas(counts.T, npartitions=cores)
res = dcounts.apply(pfunc, axis=1)
resdf = res.compute(scheduler='processes')
data_df = list(resdf.values)
except:
data_df = list(counts.T.apply(pfunc, axis=1).values)
else:
data_df = list(counts.T.apply(pfunc, axis=1).values)
samples = xr.concat([df.to_xarray() for df in data_df], dim="features")
samples = samples.assign_coords(coords={
'features' : counts.columns,
'monte_carlo_samples' : np.arange(monte_carlo_samples)
})
return samples


# Parallel estimation of batch effects
def slurm_estimate(counts : pd.DataFrame,
batches : qiime2.CategoricalMetadataColumn,
replicates : qiime2.CategoricalMetadataColumn,
monte_carlo_samples : int,
cores : int = 4,
processes : int = 4,
nodes : int = 2,
memory : str = '16GB',
walltime : str = '01:00:00',
queue : str = '') -> xr.Dataset:
from dask_jobqueue import SLURMCluster
from dask.distributed import Client
import dask.dataframe as dd
import logging
logging.basicConfig(format='%(levelname)s:%(message)s',
level=logging.DEBUG)
cluster = SLURMCluster(cores=cores,
processes=processes,
memory=memory,
walltime=walltime,
interface='ib0',
env_extra=["export TBB_CXX_TYPE=gcc"],
queue=queue)
cluster.scale(jobs=nodes)
print(cluster.job_script())
# Build me a cluster!
dask_args={'n_workers': cores, 'threads_per_worker': 1}
cluster = LocalCluster(**dask_args)
cluster.scale(dask_args['n_workers'])
client = Client(cluster)
# match everything up
replicates = replicates.to_series()
batches = batches.to_series()
idx = list(set(counts.index) & set(replicates.index) & set(batches.index))
counts, replicates, batches = [x.loc[idx] for x in
(counts, replicates, batches)]
replicates, batches = replicates.values, batches.values
depth = counts.sum(axis=1)
pfunc = lambda x: _batch_func(np.array(x.values), replicates, batches,
depth, monte_carlo_samples)
print('Partitions', cores * nodes * processes * 4)
dcounts = dd.from_pandas(
counts.T, npartitions=cores * nodes * processes * 4)
dcounts = client.persist(dcounts)
res = dcounts.apply(pfunc, axis=1)
resdf = res.compute(scheduler='processes')
data_df = list(resdf.values)
samples = xr.concat([df.to_xarray() for df in data_df], dim="features")
samples = samples.assign_coords(coords={
'features' : counts.columns,
'monte_carlo_samples' : np.arange(monte_carlo_samples)
})
return samples


# Parallel estimation of batch effects
def parallel_estimate(counts : pd.DataFrame,
batches : qiime2.CategoricalMetadataColumn,
replicates : qiime2.CategoricalMetadataColumn,
monte_carlo_samples : int,
scheduler_json : str,
partitions : int = 100) -> xr.Dataset:
from dask.distributed import Client
import dask.dataframe as dd
import logging
logging.basicConfig(format='%(levelname)s:%(message)s',
level=logging.DEBUG)
client = Client(scheduler_file=scheduler_json)

# match everything up
replicates = replicates.to_series()
batches = batches.to_series()
idx = list(set(counts.index) & set(replicates.index) & set(batches.index))
counts, replicates, batches = [x.loc[idx] for x in
(counts, replicates, batches)]
replicates, batches = replicates.values, batches.values
depth = counts.sum(axis=1)
pfunc = lambda x: _batch_func(np.array(x.values), replicates, batches,
depth, monte_carlo_samples, chains=1)
dcounts = dd.from_pandas(counts.T, npartitions=partitions)
# dcounts = client.persist(dcounts)
res = dcounts.apply(pfunc, axis=1)
# resdf = client.compute(res)
resdf = res.compute(scheduler='processes')
data_df = list(resdf.values)

samples = xr.concat([df.to_xarray() for df in data_df], dim="features")
samples = samples.assign_coords(coords={
'features' : counts.columns,
'monte_carlo_samples' : np.arange(monte_carlo_samples)
})
samples = _poisson_log_normal_estimate(
counts, replicates, batches,
monte_carlo_samples, cores,
**sampler_args)
return samples
Loading

0 comments on commit f44a193

Please sign in to comment.