Skip to content

Commit

Permalink
Merge pull request #346 from widdowquinn/issue_292
Browse files Browse the repository at this point in the history
Issue 292: Parallelise plot creation using `multiprocessing`
  • Loading branch information
baileythegreen authored Nov 30, 2021
2 parents 9ae767f + 122f6d4 commit 4dd4f83
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 16 deletions.
9 changes: 9 additions & 0 deletions pyani/scripts/parsers/plot_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,13 @@ def build(
help="graphics method to use for plotting",
choices=["seaborn", "mpl", "plotly"],
)
parser.add_argument(
"--workers",
dest="workers",
action="store",
default=None,
type=int,
help="Number of worker processes for multiprocessing "
"(default zero, meaning use all available cores)",
)
parser.set_defaults(func=subcommands.subcmd_plot)
56 changes: 40 additions & 16 deletions pyani/scripts/subcommands/subcmd_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

import logging
import os
import multiprocessing

from argparse import Namespace
from pathlib import Path
Expand Down Expand Up @@ -90,14 +91,12 @@ def subcmd_plot(args: Namespace) -> int:
run_ids = [int(run) for run in args.run_id.split(",")]
logger.debug("Generating graphics for runs: %s", run_ids)
for run_id in run_ids:
write_run_heatmaps(run_id, session, outfmts, args)
write_run_plots(run_id, session, outfmts, args)

return 0


def write_run_heatmaps(
run_id: int, session, outfmts: List[str], args: Namespace
) -> None:
def write_run_plots(run_id: int, session, outfmts: List[str], args: Namespace) -> None:
"""Write all heatmaps for a specified run to file.
:param run_id: int, run identifier in database session
Expand All @@ -118,7 +117,13 @@ def write_run_heatmaps(
f"Have {len(result_label_dict)} labels and {len(result_class_dict)} classes"
)

# Write heatmap for each results matrix
# Write heatmap and distribution plot for each results matrix

# Create worker pool and empty command list
pool = multiprocessing.Pool(processes=args.workers)
plotting_commands = []

# Build and collect the plotting commands
for matdata in [
MatrixData(*_)
for _ in [
Expand All @@ -129,20 +134,39 @@ def write_run_heatmaps(
("hadamard", pd.read_json(results.df_hadamard), {}),
]
]:
write_heatmap(
run_id, matdata, result_label_dict, result_class_dict, outfmts, args
plotting_commands.append(
(
write_heatmap,
[run_id, matdata, result_label_dict, result_class_dict, outfmts, args],
)
)
plotting_commands.append((write_distribution, [run_id, matdata, outfmts, args]))

id_matrix = MatrixData("identity", pd.read_json(results.df_identity), {})
cov_matrix = MatrixData("coverage", pd.read_json(results.df_coverage), {})
plotting_commands.append(
(
write_scatter,
[
run_id,
id_matrix,
cov_matrix,
result_label_dict,
result_class_dict,
outfmts,
args,
],
)
write_distribution(run_id, matdata, outfmts, args)
write_scatter(
run_id,
MatrixData("identity", pd.read_json(results.df_identity), {}),
MatrixData("coverage", pd.read_json(results.df_coverage), {}),
result_label_dict,
result_class_dict,
outfmts,
args,
)

# Run the plotting commands
for func, options in plotting_commands:
pool.apply_async(func, options, {})

# Close worker pool
pool.close()
pool.join()


def write_distribution(
run_id: int,
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
biopython
ete3
matplotlib
namedlist
networkx
numpy
openpyxl
pandas
Pillow
PyQt5
scipy
seaborn
sqlalchemy==1.3.10
Expand Down
8 changes: 8 additions & 0 deletions tests/test_subcmd_06_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,55 +80,63 @@ def setUp(self):
dbpath=self.dbpath,
formats="pdf",
method="mpl",
workers=None,
),
"mpl_png": Namespace(
outdir=self.outdir / "mpl",
run_id=self.run_id,
dbpath=self.dbpath,
formats="png",
method="mpl",
workers=None,
),
"mpl_svg": Namespace(
outdir=self.outdir / "mpl",
run_id=self.run_id,
dbpath=self.dbpath,
formats="svg",
method="mpl",
workers=None,
),
"mpl_jpg": Namespace(
outdir=self.outdir / "mpl",
run_id=self.run_id,
dbpath=self.dbpath,
formats="jpg",
method="mpl",
workers=None,
),
"seaborn_pdf": Namespace(
outdir=self.outdir / "seaborn",
run_id=self.run_id,
dbpath=self.dbpath,
formats="pdf",
method="seaborn",
workers=None,
),
"seaborn_png": Namespace(
outdir=self.outdir / "seaborn",
run_id=self.run_id,
dbpath=self.dbpath,
formats="png",
method="seaborn",
workers=None,
),
"seaborn_svg": Namespace(
outdir=self.outdir / "seaborn",
run_id=self.run_id,
dbpath=self.dbpath,
formats="svg",
method="seaborn",
workers=None,
),
"seaborn_jpg": Namespace(
outdir=self.outdir / "seaborn",
run_id=self.run_id,
dbpath=self.dbpath,
formats="jpg",
method="seaborn",
workers=None,
),
}

Expand Down

0 comments on commit 4dd4f83

Please sign in to comment.