Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 175: Plot identity vs coverage scatter plot #319

Merged
merged 9 commits into from
Sep 8, 2021
48 changes: 48 additions & 0 deletions pyani/pyani_graphics/mpl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,51 @@ def heatmap(dfr, outfilename=None, title=None, params=None):
if outfilename:
fig.savefig(outfilename)
return fig


def scatter(
dfr1,
dfr2,
outfilename=None,
matname1="identity",
matname2="coverage",
title=None,
params=None,
):
"""Return matplotlib scatterplot.

:param dfr1: pandas DataFrame with x-axis data
:param dfr2: pandas DataFrame with y-axis data
:param outfilename: path to output file (indicates output format)
:param matname1: name of x-axis data
:param matname2: name of y-axis data
:param title: title for the plot
:param params: a list of parameters for plotting: [colormap, vmin, vmax]
"""
# Make an empty dataframe to collect the input data in
combined = pd.DataFrame()

# Add data
combined[matname1] = dfr1.values.flatten()
combined[matname2] = dfr2.values.flatten()

# Add lable information, if available
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that label info will end up being incorporated here. Labels apply to single genomes, not pairwise comparisons. Could just delete this commented-out section for now. We can discuss ideas for colouring points on the plot elsewhere.

# if params.labels:
# hue = "labels"
# combined['labels'] = # add labels to dataframe; unsure of their configuration at this point
# else:
hue = None

fig, ax = plt.subplots(figsize=(8, 8))
fig.suptitle(title)
ax.set_xlabel(f"{matname1.title()}")
ax.set_ylabel(f"{matname2.title()}")

plt.scatter(matname1, matname2, data=combined, c=hue, s=2)

# Return figure output, and write, if required
plt.subplots_adjust(top=0.85) # Leave room for title
fig.set_tight_layout(True)
if outfilename:
fig.savefig(outfilename)
return fig
53 changes: 53 additions & 0 deletions pyani/pyani_graphics/sns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,56 @@ def distribution(dfr, outfilename, matname, title=None):
fig.savefig(outfilename)

return fig


def scatter(
dfr1,
dfr2,
outfilename=None,
matname1="identity",
matname2="coverage",
title=None,
params=None,
):
"""Return seaborn scatterplot.

:param dfr1: pandas DataFrame with x-axis data
:param dfr2: pandas DataFrame with y-axis data
:param outfilename: path to output file (indicates output format)
:param matname1: name of x-axis data
:param matname2: name of y-axis data
:param title: title for the plot
:param params: a list of parameters for plotting: [colormap, vmin, vmax]
"""
# Make an empty dataframe to collect the input data in
combined = pd.DataFrame()

# Add data
combined[matname1] = dfr1.values.flatten()
combined[matname2] = dfr2.values.flatten()

# Add lable information, if available
# if params.labels:
# hue = "labels"
# combined['labels'] = # add labels to dataframe; unsure of their configuration at this point
# else:
hue = None

# Create the plot
fig = sns.lmplot(
x=matname1,
y=matname2,
data=combined,
hue=hue,
fit_reg=False,
scatter_kws={"s": 2},
)
fig.set(xlabel=matname1.title(), ylabel=matname2.title())
plt.title(title)

# Save to file
if outfilename:
fig.savefig(outfilename)

# Return clustermap
return fig
65 changes: 62 additions & 3 deletions pyani/scripts/subcommands/subcmd_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
from pyani import pyani_config, pyani_orm, pyani_graphics
from pyani.pyani_tools import termcolor, MatrixData


# Distribution dictionary of matrix graphics methods
GMETHODS = {"mpl": pyani_graphics.mpl.heatmap, "seaborn": pyani_graphics.sns.heatmap}
SMETHODS = {"mpl": pyani_graphics.mpl.scatter, "seaborn": pyani_graphics.sns.scatter}
# Distribution dictionary of distribution graphics methods
DISTMETHODS = {
"mpl": pyani_graphics.mpl.distribution,
Expand Down Expand Up @@ -114,7 +114,9 @@ def write_run_heatmaps(
)
result_label_dict = pyani_orm.get_matrix_labels_for_run(session, args.run_id)
result_class_dict = pyani_orm.get_matrix_classes_for_run(session, args.run_id)
logger.debug(f"Have {len(result_label_dict)} labels and {len(result_class_dict)} classes")
logger.debug(
f"Have {len(result_label_dict)} labels and {len(result_class_dict)} classes"
)

# Write heatmap for each results matrix
for matdata in [
Expand All @@ -131,10 +133,22 @@ def write_run_heatmaps(
run_id, matdata, result_label_dict, result_class_dict, outfmts, args
)
write_distribution(run_id, matdata, outfmts, args)
write_scatter(
run_id,
MatrixData("identity", pd.read_json(results.df_identity), {}),
MatrixData("coverage", pd.read_json(results.df_coverage), {}),
result_label_dict,
result_class_dict,
outfmts,
args,
)


def write_distribution(
run_id: int, matdata: MatrixData, outfmts: List[str], args: Namespace,
run_id: int,
matdata: MatrixData,
outfmts: List[str],
args: Namespace,
) -> None:
"""Write distribution plots for each matrix type.

Expand Down Expand Up @@ -192,3 +206,48 @@ def write_heatmap(

# Be tidy with matplotlib caches
plt.close("all")


def write_scatter(
run_id: int,
matdata1: MatrixData,
matdata2: MatrixData,
result_labels: Dict,
result_classes: Dict,
outfmts: List[str],
args: Namespace,
) -> None:
"""Write a single scatterplot for a pyani run.

:param run_id: int, run_id for this run
:param matdata1: MatrixData object for this scatterplot
:param matdata2: MatrixData object for this scatterplot
:param result_labels: dict of result labels
:param result_classes: dict of result classes
:param args: Namespace for command-line arguments
:param outfmts: list of output formats for files
"""
logger = logging.getLogger(__name__)

logger.info("Writing %s vs %s scatterplot", matdata1.name, matdata2.name)
cmap = pyani_config.get_colormap(matdata1.data, matdata1.name)
for fmt in outfmts:
outfname = (
Path(args.outdir)
/ f"scatter_{matdata1.name}_vs_{matdata2.name}_run{run_id}.{fmt}"
)
logger.debug("\tWriting graphics to %s", outfname)
params = pyani_graphics.Params(cmap, result_labels, result_classes)
# Draw scatterplot
SMETHODS[args.method](
matdata1.data,
matdata2.data,
outfname,
matdata1.name,
matdata2.name,
title=f"{matdata1.name.title()} vs {matdata2.name.title()}",
params=params,
)

# Be tidy with matplotlib caches
plt.close("all")
10 changes: 10 additions & 0 deletions tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def draw_format_method(fmt, mth, graphics_inputs, tmp_path):
"""Render graphics format and method output."""
df = pd.read_csv(graphics_inputs.filename, index_col=0, sep="\t")
fn = {"mpl": pyani_graphics.mpl.heatmap, "seaborn": pyani_graphics.sns.heatmap}
sc = {"mpl": pyani_graphics.mpl.scatter, "seaborn": pyani_graphics.sns.scatter}
params = {"mpl": pyani_config.params_mpl, "seaborn": pyani_config.params_mpl}
method_params = pyani_graphics.Params(
params[mth](df)["ANIm_percentage_identity"],
Expand All @@ -90,6 +91,15 @@ def draw_format_method(fmt, mth, graphics_inputs, tmp_path):
fn[mth](
df, tmp_path / f"{mth}.{fmt}", title=f"{mth}:{fmt} test", params=method_params
)
sc[mth](
df,
df,
tmp_path / f"{mth}.{fmt}",
"matrix1",
"matrix2",
title=f"{mth}:{fmt} test",
params=method_params,
)


def test_png_mpl(graphics_inputs, tmp_path):
Expand Down