Skip to content

Commit

Permalink
Merge pull request #319 from widdowquinn/issue_175
Browse files Browse the repository at this point in the history
Issue 175: Plot identity vs coverage scatter plot
  • Loading branch information
widdowquinn authored Sep 8, 2021
2 parents 9435a5e + 30e5c5f commit c4aa68b
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 3 deletions.
48 changes: 48 additions & 0 deletions pyani/pyani_graphics/mpl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,51 @@ def heatmap(dfr, outfilename=None, title=None, params=None):
if outfilename:
fig.savefig(outfilename)
return fig


def scatter(
dfr1,
dfr2,
outfilename=None,
matname1="identity",
matname2="coverage",
title=None,
params=None,
):
"""Return matplotlib scatterplot.
:param dfr1: pandas DataFrame with x-axis data
:param dfr2: pandas DataFrame with y-axis data
:param outfilename: path to output file (indicates output format)
:param matname1: name of x-axis data
:param matname2: name of y-axis data
:param title: title for the plot
:param params: a list of parameters for plotting: [colormap, vmin, vmax]
"""
# Make an empty dataframe to collect the input data in
combined = pd.DataFrame()

# Add data
combined[matname1] = dfr1.values.flatten()
combined[matname2] = dfr2.values.flatten()

# Add lable information, if available
# if params.labels:
# hue = "labels"
# combined['labels'] = # add labels to dataframe; unsure of their configuration at this point
# else:
hue = None

fig, ax = plt.subplots(figsize=(8, 8))
fig.suptitle(title)
ax.set_xlabel(f"{matname1.title()}")
ax.set_ylabel(f"{matname2.title()}")

plt.scatter(matname1, matname2, data=combined, c=hue, s=2)

# Return figure output, and write, if required
plt.subplots_adjust(top=0.85) # Leave room for title
fig.set_tight_layout(True)
if outfilename:
fig.savefig(outfilename)
return fig
53 changes: 53 additions & 0 deletions pyani/pyani_graphics/sns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,56 @@ def distribution(dfr, outfilename, matname, title=None):
fig.savefig(outfilename)

return fig


def scatter(
dfr1,
dfr2,
outfilename=None,
matname1="identity",
matname2="coverage",
title=None,
params=None,
):
"""Return seaborn scatterplot.
:param dfr1: pandas DataFrame with x-axis data
:param dfr2: pandas DataFrame with y-axis data
:param outfilename: path to output file (indicates output format)
:param matname1: name of x-axis data
:param matname2: name of y-axis data
:param title: title for the plot
:param params: a list of parameters for plotting: [colormap, vmin, vmax]
"""
# Make an empty dataframe to collect the input data in
combined = pd.DataFrame()

# Add data
combined[matname1] = dfr1.values.flatten()
combined[matname2] = dfr2.values.flatten()

# Add lable information, if available
# if params.labels:
# hue = "labels"
# combined['labels'] = # add labels to dataframe; unsure of their configuration at this point
# else:
hue = None

# Create the plot
fig = sns.lmplot(
x=matname1,
y=matname2,
data=combined,
hue=hue,
fit_reg=False,
scatter_kws={"s": 2},
)
fig.set(xlabel=matname1.title(), ylabel=matname2.title())
plt.title(title)

# Save to file
if outfilename:
fig.savefig(outfilename)

# Return clustermap
return fig
65 changes: 62 additions & 3 deletions pyani/scripts/subcommands/subcmd_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
from pyani import pyani_config, pyani_orm, pyani_graphics
from pyani.pyani_tools import termcolor, MatrixData


# Distribution dictionary of matrix graphics methods
GMETHODS = {"mpl": pyani_graphics.mpl.heatmap, "seaborn": pyani_graphics.sns.heatmap}
SMETHODS = {"mpl": pyani_graphics.mpl.scatter, "seaborn": pyani_graphics.sns.scatter}
# Distribution dictionary of distribution graphics methods
DISTMETHODS = {
"mpl": pyani_graphics.mpl.distribution,
Expand Down Expand Up @@ -114,7 +114,9 @@ def write_run_heatmaps(
)
result_label_dict = pyani_orm.get_matrix_labels_for_run(session, args.run_id)
result_class_dict = pyani_orm.get_matrix_classes_for_run(session, args.run_id)
logger.debug(f"Have {len(result_label_dict)} labels and {len(result_class_dict)} classes")
logger.debug(
f"Have {len(result_label_dict)} labels and {len(result_class_dict)} classes"
)

# Write heatmap for each results matrix
for matdata in [
Expand All @@ -131,10 +133,22 @@ def write_run_heatmaps(
run_id, matdata, result_label_dict, result_class_dict, outfmts, args
)
write_distribution(run_id, matdata, outfmts, args)
write_scatter(
run_id,
MatrixData("identity", pd.read_json(results.df_identity), {}),
MatrixData("coverage", pd.read_json(results.df_coverage), {}),
result_label_dict,
result_class_dict,
outfmts,
args,
)


def write_distribution(
run_id: int, matdata: MatrixData, outfmts: List[str], args: Namespace,
run_id: int,
matdata: MatrixData,
outfmts: List[str],
args: Namespace,
) -> None:
"""Write distribution plots for each matrix type.
Expand Down Expand Up @@ -192,3 +206,48 @@ def write_heatmap(

# Be tidy with matplotlib caches
plt.close("all")


def write_scatter(
run_id: int,
matdata1: MatrixData,
matdata2: MatrixData,
result_labels: Dict,
result_classes: Dict,
outfmts: List[str],
args: Namespace,
) -> None:
"""Write a single scatterplot for a pyani run.
:param run_id: int, run_id for this run
:param matdata1: MatrixData object for this scatterplot
:param matdata2: MatrixData object for this scatterplot
:param result_labels: dict of result labels
:param result_classes: dict of result classes
:param args: Namespace for command-line arguments
:param outfmts: list of output formats for files
"""
logger = logging.getLogger(__name__)

logger.info("Writing %s vs %s scatterplot", matdata1.name, matdata2.name)
cmap = pyani_config.get_colormap(matdata1.data, matdata1.name)
for fmt in outfmts:
outfname = (
Path(args.outdir)
/ f"scatter_{matdata1.name}_vs_{matdata2.name}_run{run_id}.{fmt}"
)
logger.debug("\tWriting graphics to %s", outfname)
params = pyani_graphics.Params(cmap, result_labels, result_classes)
# Draw scatterplot
SMETHODS[args.method](
matdata1.data,
matdata2.data,
outfname,
matdata1.name,
matdata2.name,
title=f"{matdata1.name.title()} vs {matdata2.name.title()}",
params=params,
)

# Be tidy with matplotlib caches
plt.close("all")
10 changes: 10 additions & 0 deletions tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def draw_format_method(fmt, mth, graphics_inputs, tmp_path):
"""Render graphics format and method output."""
df = pd.read_csv(graphics_inputs.filename, index_col=0, sep="\t")
fn = {"mpl": pyani_graphics.mpl.heatmap, "seaborn": pyani_graphics.sns.heatmap}
sc = {"mpl": pyani_graphics.mpl.scatter, "seaborn": pyani_graphics.sns.scatter}
params = {"mpl": pyani_config.params_mpl, "seaborn": pyani_config.params_mpl}
method_params = pyani_graphics.Params(
params[mth](df)["ANIm_percentage_identity"],
Expand All @@ -90,6 +91,15 @@ def draw_format_method(fmt, mth, graphics_inputs, tmp_path):
fn[mth](
df, tmp_path / f"{mth}.{fmt}", title=f"{mth}:{fmt} test", params=method_params
)
sc[mth](
df,
df,
tmp_path / f"{mth}.{fmt}",
"matrix1",
"matrix2",
title=f"{mth}:{fmt} test",
params=method_params,
)


def test_png_mpl(graphics_inputs, tmp_path):
Expand Down

0 comments on commit c4aa68b

Please sign in to comment.