diff --git a/pyani/scripts/parsers/plot_parser.py b/pyani/scripts/parsers/plot_parser.py index 1de228e3..c5a4da6c 100644 --- a/pyani/scripts/parsers/plot_parser.py +++ b/pyani/scripts/parsers/plot_parser.py @@ -104,4 +104,13 @@ def build( help="graphics method to use for plotting", choices=["seaborn", "mpl", "plotly"], ) + parser.add_argument( + "--workers", + dest="workers", + action="store", + default=None, + type=int, + help="Number of worker processes for multiprocessing " + "(default zero, meaning use all available cores)", + ) parser.set_defaults(func=subcommands.subcmd_plot) diff --git a/pyani/scripts/subcommands/subcmd_plot.py b/pyani/scripts/subcommands/subcmd_plot.py index 8c3d9887..dbd156cd 100644 --- a/pyani/scripts/subcommands/subcmd_plot.py +++ b/pyani/scripts/subcommands/subcmd_plot.py @@ -41,6 +41,7 @@ import logging import os +import multiprocessing from argparse import Namespace from pathlib import Path @@ -90,14 +91,12 @@ def subcmd_plot(args: Namespace) -> int: run_ids = [int(run) for run in args.run_id.split(",")] logger.debug("Generating graphics for runs: %s", run_ids) for run_id in run_ids: - write_run_heatmaps(run_id, session, outfmts, args) + write_run_plots(run_id, session, outfmts, args) return 0 -def write_run_heatmaps( - run_id: int, session, outfmts: List[str], args: Namespace -) -> None: +def write_run_plots(run_id: int, session, outfmts: List[str], args: Namespace) -> None: """Write all heatmaps for a specified run to file. :param run_id: int, run identifier in database session @@ -118,7 +117,13 @@ def write_run_heatmaps( f"Have {len(result_label_dict)} labels and {len(result_class_dict)} classes" ) - # Write heatmap for each results matrix + # Write heatmap and distribution plot for each results matrix + + # Create worker pool and empty command list + pool = multiprocessing.Pool(processes=args.workers) + plotting_commands = [] + + # Build and collect the plotting commands for matdata in [ MatrixData(*_) for _ in [ @@ -129,20 +134,39 @@ def write_run_heatmaps( ("hadamard", pd.read_json(results.df_hadamard), {}), ] ]: - write_heatmap( - run_id, matdata, result_label_dict, result_class_dict, outfmts, args + plotting_commands.append( + ( + write_heatmap, + [run_id, matdata, result_label_dict, result_class_dict, outfmts, args], + ) + ) + plotting_commands.append((write_distribution, [run_id, matdata, outfmts, args])) + + id_matrix = MatrixData("identity", pd.read_json(results.df_identity), {}) + cov_matrix = MatrixData("coverage", pd.read_json(results.df_coverage), {}) + plotting_commands.append( + ( + write_scatter, + [ + run_id, + id_matrix, + cov_matrix, + result_label_dict, + result_class_dict, + outfmts, + args, + ], ) - write_distribution(run_id, matdata, outfmts, args) - write_scatter( - run_id, - MatrixData("identity", pd.read_json(results.df_identity), {}), - MatrixData("coverage", pd.read_json(results.df_coverage), {}), - result_label_dict, - result_class_dict, - outfmts, - args, ) + # Run the plotting commands + for func, options in plotting_commands: + pool.apply_async(func, options, {}) + + # Close worker pool + pool.close() + pool.join() + def write_distribution( run_id: int, diff --git a/requirements.txt b/requirements.txt index 79f24435..15d97e59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ biopython +ete3 matplotlib namedlist networkx @@ -6,6 +7,7 @@ numpy openpyxl pandas Pillow +PyQt5 scipy seaborn sqlalchemy==1.3.10 diff --git a/tests/test_subcmd_06_plot.py b/tests/test_subcmd_06_plot.py index 16b96bbe..989a1e7f 100644 --- a/tests/test_subcmd_06_plot.py +++ b/tests/test_subcmd_06_plot.py @@ -80,6 +80,7 @@ def setUp(self): dbpath=self.dbpath, formats="pdf", method="mpl", + workers=None, ), "mpl_png": Namespace( outdir=self.outdir / "mpl", @@ -87,6 +88,7 @@ def setUp(self): dbpath=self.dbpath, formats="png", method="mpl", + workers=None, ), "mpl_svg": Namespace( outdir=self.outdir / "mpl", @@ -94,6 +96,7 @@ def setUp(self): dbpath=self.dbpath, formats="svg", method="mpl", + workers=None, ), "mpl_jpg": Namespace( outdir=self.outdir / "mpl", @@ -101,6 +104,7 @@ def setUp(self): dbpath=self.dbpath, formats="jpg", method="mpl", + workers=None, ), "seaborn_pdf": Namespace( outdir=self.outdir / "seaborn", @@ -108,6 +112,7 @@ def setUp(self): dbpath=self.dbpath, formats="pdf", method="seaborn", + workers=None, ), "seaborn_png": Namespace( outdir=self.outdir / "seaborn", @@ -115,6 +120,7 @@ def setUp(self): dbpath=self.dbpath, formats="png", method="seaborn", + workers=None, ), "seaborn_svg": Namespace( outdir=self.outdir / "seaborn", @@ -122,6 +128,7 @@ def setUp(self): dbpath=self.dbpath, formats="svg", method="seaborn", + workers=None, ), "seaborn_jpg": Namespace( outdir=self.outdir / "seaborn", @@ -129,6 +136,7 @@ def setUp(self): dbpath=self.dbpath, formats="jpg", method="seaborn", + workers=None, ), }