-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
627 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
#!/usr/bin/env python | ||
# coding: utf-8 | ||
|
||
import sys | ||
import numpy | ||
import pandas | ||
import argparse | ||
|
||
|
||
|
||
def denoiseTable(data, threshold_ratio = 0.1, min_low_samples = 3, min_sample_ratio = 0.15, min_otu_counts = 1000, min_candidates = 10, max_cross_index = 0.02): | ||
""" | ||
Denoise an OTU table | ||
Parameters: OTU Table, threshold_ratio = 0.1, min_low_samples = 3, min_sample_ratio = 0.15, min_otu_counts = 1000, min_candidates = 10, max_cross_index = 0.02 | ||
Output: Cleaned OTU table | ||
""" | ||
numOtus, numSamples = data.shape | ||
|
||
verbose(' - Total OTUs: {}\n - Total samples: {}'.format(numOtus, numSamples)) | ||
|
||
if numSamples * min_sample_ratio > min_low_samples: | ||
min_low_samples = int(numSamples * min_sample_ratio) | ||
verbose(' - Min cross talk samples: {}'.format(min_low_samples)) | ||
|
||
verbose(' - Scanning OTU table to estimate cross talk index') | ||
|
||
otu_means = data.mean(axis=1) | ||
low_cells = (0 < data) & data.le(threshold_ratio * otu_means, axis=0) | ||
num_samples_low = low_cells.sum(axis=1) | ||
tot_samples_low = (data * low_cells).sum(axis=1) | ||
row_tot = data.sum(axis=1) | ||
|
||
cross_index = (tot_samples_low / row_tot) * (numSamples / tot_samples_low) | ||
candidates = (tot_samples_low > min_low_samples) & (row_tot > min_otu_counts) & (cross_index <= max_cross_index) | ||
if candidates.sum() < min_candidates: | ||
verbose('Not enough OTU candidates to estimate cross talk') | ||
return data | ||
|
||
cross_talk_median = cross_index.loc[candidates].median() | ||
|
||
verbose('Median cross talk: {}'.format(cross_talk_median)) | ||
|
||
Zi = cross_talk_median * row_tot / data.shape[1] | ||
|
||
t = 0 | ||
dividedData = data.divide(Zi, axis=0) | ||
t = 2 / (1 + numpy.exp(dividedData.clip(upper=100))) | ||
|
||
|
||
denoised_data = data.where(t < threshold_ratio, 0) | ||
|
||
if opt.output is None: | ||
opt.output = opt.input + '.cleaned' | ||
|
||
return denoised_data | ||
|
||
|
||
def compareTables(firstTable,secondTable): | ||
"""Compare two data frames and return the sum of squares of the differences""" | ||
deltaTable = firstTable - secondTable | ||
return deltaTable.abs().sum().sum() * (deltaTable != 0).sum().sum() | ||
|
||
def denoiseTableWithReference(x): | ||
"""Receive a vector of parameters (threshold_ratio = 0.1, min_low_samples = 3, min_sample_ratio = 0.15, max_cross_index = 0.02) | ||
and return a cleaned matrix""" | ||
#threshold_ratio = 0.1, min_low_samples = 3, min_sample_ratio = 0.15, max_cross_index = 0.02 | ||
cleanedTable = denoiseTable(data, threshold_ratio=x[0], min_low_samples=int(x[1]), min_sample_ratio=x[2], max_cross_index=x[3]) | ||
return compareTables(cleanedTable,referenceTable) | ||
|
||
|
||
def eprint(*args, **kwargs): | ||
"""print to STDERR""" | ||
print(*args, file=sys.stderr, **kwargs) | ||
|
||
|
||
def verbose(message): | ||
"""Print a verbose message (if --verbose is enabled)""" | ||
if opt.verbose: | ||
eprint(message) | ||
|
||
def debug(message): | ||
"""Print a debug message prepending # (requires --debug enabled)""" | ||
if opt.debug: | ||
eprint('#{}'.format(message)) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
|
||
opt_parser = argparse.ArgumentParser(description='Denoise Illumina cross-talk from OTU tables') | ||
|
||
opt_parser.add_argument('-i', '--input', | ||
help='OTU table filename', | ||
required=True) | ||
|
||
opt_parser.add_argument('-o', '--output', | ||
help='Cleaned OTU table filename', | ||
) | ||
opt_parser.add_argument('-v', '--verbose', | ||
help='Print extra information', | ||
action='store_true') | ||
|
||
opt_parser.add_argument('-d', '--debug', | ||
help='Print debug information', | ||
action='store_true') | ||
opt_parser.add_argument('--version', action='version', version='%(prog)s 1.0') | ||
|
||
|
||
opt = opt_parser.parse_args() | ||
|
||
|
||
# Import OTU table in "Qiime Classic format" | ||
try: | ||
data = pandas.read_csv(opt.input, sep='\t', header=0, index_col=0) | ||
except Exception as e: | ||
eprint("FATAL ERROR: Unable to open {}. {}".format(opt.input, e)) | ||
exit(1) | ||
|
||
if False: | ||
referenceTable = pandas.read_csv(opt.reference, sep='\t', header=0, index_col=0) | ||
x0 = numpy.array([0.1, 3, 0.15, 0.02]) | ||
bounds = ([0, 2], [1.0, 6.0], [0.05, 0.3], [0.01, 0.05]) | ||
#res = minimize(denoiseTableWithReference, x0, method='COBYLA', bounds=bounds,options={ 'disp': True}) | ||
#print(res) | ||
c=0 | ||
min=None | ||
max=None | ||
for threshold in numpy.arange(0, 2, 0.1): | ||
for lowsamples in numpy.arange(1.0, 5.0, 0.5): | ||
for min_sample_ratio in numpy.arange(0.01, 0.30, 0.025): | ||
for max_cross_index in numpy.arange(0.1, 0.5, 0.05): | ||
c += 1 | ||
x = numpy.array([threshold, lowsamples, min_sample_ratio, max_cross_index]) | ||
r = denoiseTableWithReference(x) | ||
if min == None: | ||
min = r | ||
max = r | ||
print('*{} {}: thr={} lowsamples={} minsampleratio={} maxcross={}'.format(c, r, threshold, | ||
lowsamples, | ||
min_sample_ratio, | ||
max_cross_index)) | ||
|
||
if min > r: | ||
min = r | ||
print('<{} {}: thr={} lowsamples={} minsampleratio={} maxcross={}'.format(c, r, threshold, | ||
lowsamples, | ||
min_sample_ratio, | ||
max_cross_index)) | ||
|
||
if max < r: | ||
max = r | ||
print('>{} {}: thr={} lowsamples={} minsampleratio={} maxcross={}'.format(c, r, threshold, | ||
lowsamples, | ||
min_sample_ratio, | ||
max_cross_index)) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
else: | ||
cleanedTable = denoiseTable(data) | ||
cleanedTable.to_csv(opt.input + '.cleaned' if opt.output is None else opt.output, sep='\t') | ||
|
||
exit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
#!/usr/bin/env python | ||
""" | ||
USAGE: transpose.py [options] Table.tsv | ||
INPUT (percentages, but this is not assumed): | ||
Family NHP6 NHP18 NHP15 NHP11 All | ||
"Prevotellaceae" 19.5 0.0106 58.4 41.4 24.4 | ||
(Unassigned) 14.1 3.69 5.64 11.4 21.5 | ||
Ruminococcaceae 14.8 0 7.33 12 15.2 | ||
Lachnospiraceae 31 0.201 11.7 14.5 14.6 | ||
Veillonellaceae 5.16 0.0053 3.77 7.02 4.0 | ||
"Porphyromonadaceae" 0.469 18.6 0.587 0.6 2.6 | ||
Import a TSV table (taxonomy summary by USEARCH), having samples in columns and OTU/ASVs in rows | ||
- use first column as index key (feature name), | ||
- remove rows where the sum is below a threshold, | ||
- transpose (samples as rows) | ||
- sort rows by names (sample name) | ||
""" | ||
import numpy as np | ||
import pandas | ||
import sys | ||
import argparse | ||
#from IPython import embed | ||
import matplotlib.pyplot as plt | ||
import matplotlib | ||
from scipy.cluster.hierarchy import dendrogram, linkage | ||
|
||
def heatmap(data, row_labels, col_labels, ax=None, | ||
cbar_kw={}, cbarlabel="", **kwargs): | ||
""" | ||
Create a heatmap from a numpy array and two lists of labels. | ||
Parameters | ||
---------- | ||
data | ||
A 2D numpy array of shape (N, M). | ||
row_labels | ||
A list or array of length N with the labels for the rows. | ||
col_labels | ||
A list or array of length M with the labels for the columns. | ||
ax | ||
A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If | ||
not provided, use current axes or create a new one. Optional. | ||
cbar_kw | ||
A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. | ||
cbarlabel | ||
The label for the colorbar. Optional. | ||
**kwargs | ||
All other arguments are forwarded to `imshow`. | ||
""" | ||
|
||
if not ax: | ||
ax = plt.gca() | ||
|
||
# Plot the heatmap | ||
im = ax.imshow(data, **kwargs) | ||
|
||
# Create colorbar | ||
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) | ||
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") | ||
|
||
# We want to show all ticks... | ||
ax.set_xticks(np.arange(data.shape[1])) | ||
ax.set_yticks(np.arange(data.shape[0])) | ||
# ... and label them with the respective list entries. | ||
ax.set_xticklabels(col_labels) | ||
ax.set_yticklabels(row_labels) | ||
|
||
# Let the horizontal axes labeling appear on top. | ||
ax.tick_params(top=True, bottom=False, | ||
labeltop=True, labelbottom=False) | ||
|
||
# Rotate the tick labels and set their alignment. | ||
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", | ||
rotation_mode="anchor") | ||
|
||
# Turn spines off and create white grid. | ||
for edge, spine in ax.spines.items(): | ||
spine.set_visible(False) | ||
|
||
ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) | ||
ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) | ||
ax.grid(which="minor", color="w", linestyle='-', linewidth=3) | ||
ax.tick_params(which="minor", bottom=False, left=False) | ||
|
||
return im, cbar | ||
|
||
|
||
def annotate_heatmap(im, data=None, valfmt="{x:.2f}", | ||
textcolors=["black", "white"], | ||
threshold=None, **textkw): | ||
""" | ||
A function to annotate a heatmap. | ||
Parameters | ||
---------- | ||
im | ||
The AxesImage to be labeled. | ||
data | ||
Data used to annotate. If None, the image's data is used. Optional. | ||
valfmt | ||
The format of the annotations inside the heatmap. This should either | ||
use the string format method, e.g. "$ {x:.2f}", or be a | ||
`matplotlib.ticker.Formatter`. Optional. | ||
textcolors | ||
A list or array of two color specifications. The first is used for | ||
values below a threshold, the second for those above. Optional. | ||
threshold | ||
Value in data units according to which the colors from textcolors are | ||
applied. If None (the default) uses the middle of the colormap as | ||
separation. Optional. | ||
**kwargs | ||
All other arguments are forwarded to each call to `text` used to create | ||
the text labels. | ||
""" | ||
|
||
if not isinstance(data, (list, np.ndarray)): | ||
data = im.get_array() | ||
|
||
# Normalize the threshold to the images color range. | ||
if threshold is not None: | ||
threshold = im.norm(threshold) | ||
else: | ||
threshold = im.norm(data.max())/2. | ||
|
||
# Set default alignment to center, but allow it to be | ||
# overwritten by textkw. | ||
kw = dict(horizontalalignment="center", | ||
verticalalignment="center") | ||
kw.update(textkw) | ||
|
||
# Get the formatter in case a string is supplied | ||
if isinstance(valfmt, str): | ||
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) | ||
|
||
# Loop over the data and create a `Text` for each "pixel". | ||
# Change the text's color depending on the data. | ||
texts = [] | ||
for i in range(data.shape[0]): | ||
for j in range(data.shape[1]): | ||
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) | ||
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) | ||
texts.append(text) | ||
|
||
return texts | ||
|
||
|
||
def eprint(*args, **kwargs): | ||
print(*args, file=sys.stderr, **kwargs) | ||
|
||
if __name__ == "__main__": | ||
|
||
parser = argparse.ArgumentParser(description='Transpose table for MultiQC') | ||
parser.add_argument('TABLE', help='Input file name') | ||
parser.add_argument('-o', "--output", help='Output file name', default="plot.png") | ||
parser.add_argument('-w', "--width", help='Plot width (inches)', default=9) | ||
parser.add_argument("--height", help='Plot height (inches)', default=5) | ||
args = parser.parse_args() | ||
|
||
try: | ||
# Import TSV, use first column as index, remove column "All" | ||
# note: To set rownames posteriori: set_index(list(table)[0]), or by name set_index('Column_name') | ||
table = pandas.read_csv(args.TABLE,delimiter='\t',encoding='utf-8', index_col=0) | ||
eprint(f" * Imported {args.TABLE}: {table.shape}") | ||
except Exception as e: | ||
eprint(f"Error trying to import {args.TABLE}:\n{e}") | ||
exit() | ||
|
||
try: | ||
plt.figure(figsize=(float(args.width), float(args.height))) | ||
plt.xticks(rotation=90) | ||
linked = linkage(table.transpose(), 'single') | ||
labels = table.keys().tolist() | ||
dendrogram(linked, | ||
orientation='left', | ||
distance_sort='descending', | ||
show_leaf_counts=True, | ||
labels=labels, | ||
leaf_rotation=0, | ||
truncate_mode='level', | ||
leaf_font_size=8) | ||
plt.tight_layout() | ||
plt.savefig(args.output + '_dendrogram.png', bbox_inches='tight') | ||
except Exception as errorMessage: | ||
eprint(f"Error generating dendrogram:\n{errorMessage}") | ||
|
||
try: | ||
plt.clf() | ||
plt.figure(figsize=(40, 20)) | ||
fig, ax = plt.subplots() | ||
im, cbar = heatmap(table.transpose(), table.columns.values, table.index.values, ax=ax, | ||
cmap="YlGn", cbarlabel="xxx") | ||
|
||
texts = annotate_heatmap(im, valfmt="{x:.1f}") | ||
|
||
fig.tight_layout() | ||
plt.show() | ||
plt.savefig(args.output + '_heatmap.png') | ||
|
||
except Exception as errorMessage: | ||
eprint(f"Skipping heatmap:\n{errorMessage}") |
Oops, something went wrong.