diff --git a/bencher/plotting/plot_collection.py b/bencher/plotting/plot_collection.py index 77fa5596..0b621b83 100644 --- a/bencher/plotting/plot_collection.py +++ b/bencher/plotting/plot_collection.py @@ -1,13 +1,14 @@ from __future__ import annotations -import panel as pn + import inspect -from typing import List, Callable import logging +from typing import Callable, List +import panel as pn -from bencher.bench_cfg import PltCntCfg, BenchCfg +from bencher.bench_cfg import BenchCfg, PltCntCfg from bencher.bench_vars import ParametrizedSweep -from bencher.plotting.plot_filter import PlotProvider, PlotInput +from bencher.plotting.plot_filter import PlotInput, PlotProvider class PlotCollection: @@ -86,14 +87,24 @@ def gather_plots( Args: bench_cfg (BenchCfg): A config of the input vars rv (ParametrizedSweep): a config of the result variable - plt_cnt_cfg (PltCntCfg): A config of how many input types there are""" + plt_cnt_cfg (PltCntCfg): A config of how many input types there are + Raises: + ValueError: If the plot does not inherit from a pn.viewable.Viewable type + """ tabs = pn.Accordion() for plt_fn in self.plotters.values(): plots = plt_fn(PlotInput(bench_cfg, rv, plt_cnt_cfg)) - for plt_instance in plots: - logging.info(f"plotting: {plt_instance.name}") - tabs.append(plt_instance) + if plots is not None: + if type(plots) != list: + plots = [plots] + for plt_instance in plots: + if not isinstance(plt_instance, pn.viewable.Viewable): + raise ValueError( + "The plot must be a viewable type (pn.viewable.Viewable or pn.panel)" + ) + logging.info(f"plotting: {plt_instance.name}") + tabs.append(plt_instance) if len(tabs) > 0: tabs.active = [0] # set the first plot as active return tabs diff --git a/bencher/plotting/plots/catplot.py b/bencher/plotting/plots/catplot.py index 205d1c43..276d7bc3 100644 --- a/bencher/plotting/plots/catplot.py +++ b/bencher/plotting/plots/catplot.py @@ -1,11 +1,13 @@ -from typing import List, Tuple -import seaborn as sns -import panel as pn +from typing import Optional, Tuple + import matplotlib.pyplot as plt import pandas as pd -from bencher.plotting.plot_filter import PlotFilter, VarRange, PlotInput -from bencher.plt_cfg import PltCfgBase +import panel as pn +import seaborn as sns + +from bencher.plotting.plot_filter import PlotFilter, PlotInput, VarRange from bencher.plotting.plot_types import PlotTypes +from bencher.plt_cfg import PltCfgBase class Catplot: @@ -31,32 +33,32 @@ def plot_setup(pl_in: PlotInput) -> Tuple[pd.DataFrame, PltCfgBase]: return df, sns_cfg @staticmethod - def plot_postprocess(fg: plt.figure, sns_cfg: PltCfgBase, name: str) -> List[pn.panel]: + def plot_postprocess(fg: plt.figure, sns_cfg: PltCfgBase, name: str) -> Optional[pn.panel]: fg.fig.suptitle(sns_cfg.title) fg.set_xlabels(label=sns_cfg.xlabel, clear_inner=True) fg.set_ylabels(label=sns_cfg.ylabel, clear_inner=True) plt.tight_layout() - return [pn.panel(fg.fig, name=name)] + return pn.panel(fg.fig, name=name) - def catplot_common(self, pl_in: PlotInput, kind: str, name: str) -> List[pn.panel]: + def catplot_common(self, pl_in: PlotInput, kind: str, name: str) -> Optional[pn.panel]: if self.float_0_cat_at_least1_vec_1_res_1.matches(pl_in.plt_cnt_cfg): df, sns_cfg = self.plot_setup(pl_in) sns_cfg.kind = kind fg = sns.catplot(df, **sns_cfg.as_sns_args()) return self.plot_postprocess(fg, sns_cfg, name) - return [] + return None - def swarmplot(self, pl_in: PlotInput) -> List[pn.panel]: + def swarmplot(self, pl_in: PlotInput) -> Optional[pn.panel]: return self.catplot_common(pl_in, "swarm", PlotTypes.swarmplot) - def violinplot(self, pl_in: PlotInput) -> List[pn.panel]: + def violinplot(self, pl_in: PlotInput) -> Optional[pn.panel]: return self.catplot_common(pl_in, "violin", PlotTypes.violinplot) - def boxplot(self, pl_in: PlotInput) -> List[pn.panel]: + def boxplot(self, pl_in: PlotInput) -> Optional[pn.panel]: return self.catplot_common(pl_in, "box", PlotTypes.boxplot) - def barplot(self, pl_in: PlotInput) -> List[pn.panel]: + def barplot(self, pl_in: PlotInput) -> Optional[pn.panel]: return self.catplot_common(pl_in, "bar", PlotTypes.barplot) - def boxenplot(self, pl_in: PlotInput) -> List[pn.panel]: + def boxenplot(self, pl_in: PlotInput) -> Optional[pn.panel]: return self.catplot_common(pl_in, "boxen", PlotTypes.boxenplot) diff --git a/bencher/plotting/plots/plots_2D.py b/bencher/plotting/plots/plots_2D.py index ba0b6c17..528735d4 100644 --- a/bencher/plotting/plots/plots_2D.py +++ b/bencher/plotting/plots/plots_2D.py @@ -1,8 +1,9 @@ -from typing import List +from typing import Optional + import panel as pn import plotly.express as px -from bencher.plotting.plot_filter import PlotFilter, VarRange, PlotInput +from bencher.plotting.plot_filter import PlotFilter, PlotInput, VarRange from bencher.plotting.plot_types import PlotTypes @@ -15,7 +16,7 @@ class Plots2D: result_vars=VarRange(1, 1), ) - def imshow(self, pl_in: PlotInput) -> List[pn.panel]: + def imshow(self, pl_in: PlotInput) -> Optional[pn.panel]: """use the imshow plotting method to display 2D data Args: @@ -33,14 +34,12 @@ def imshow(self, pl_in: PlotInput) -> List[pn.panel]: xlabel = f"{fv_x.name} [{fv_x.units}]" ylabel = f"{fv_y.name} [{fv_y.units}]" color_label = f"{pl_in.rv.name} [{pl_in.rv.units}]" - return [ - pn.panel( - px.imshow( - mean, - title=title, - labels={"x": xlabel, "y": ylabel, "color": color_label}, - ), - name=PlotTypes.imshow, - ) - ] - return [] + return pn.panel( + px.imshow( + mean, + title=title, + labels={"x": xlabel, "y": ylabel, "color": color_label}, + ), + name=PlotTypes.imshow, + ) + return None diff --git a/bencher/plotting/plots/tables.py b/bencher/plotting/plots/tables.py index 397671ac..6c8fdc40 100644 --- a/bencher/plotting/plots/tables.py +++ b/bencher/plotting/plots/tables.py @@ -1,27 +1,27 @@ -from typing import List import panel as pn -from bencher.plotting.plot_types import PlotTypes + from bencher.plotting.plot_filter import PlotInput +from bencher.plotting.plot_types import PlotTypes class Tables: """A class to display the result data in tabular form""" - def dataframe_flat(self, pl_in: PlotInput) -> List[pn.panel]: + def dataframe_flat(self, pl_in: PlotInput) -> pn.panel: """Returns a list of panel objects containing a flat dataframe.""" df = pl_in.bench_cfg.get_dataframe() - return [pn.pane.DataFrame(df, name=PlotTypes.dataframe_flat)] + return pn.pane.DataFrame(df, name=PlotTypes.dataframe_flat) - def dataframe_multi_index(self, pl_in: PlotInput) -> List[pn.panel]: + def dataframe_multi_index(self, pl_in: PlotInput) -> pn.panel: """Returns a list of panel objects containing a multi-index dataframe.""" df = pl_in.bench_cfg.ds.to_dataframe() - return [pn.pane.DataFrame(df, name=PlotTypes.dataframe_multi_index)] + return pn.pane.DataFrame(df, name=PlotTypes.dataframe_multi_index) - def dataframe_mean(self, pl_in: PlotInput) -> List[pn.panel]: + def dataframe_mean(self, pl_in: PlotInput) -> pn.panel: """Returns a list of panel objects containing a mean dataframe.""" df = pl_in.bench_cfg.ds.mean("repeat").to_dataframe().reset_index() - return [pn.pane.DataFrame(df, name=PlotTypes.dataframe_mean)] + return pn.pane.DataFrame(df, name=PlotTypes.dataframe_mean) - def xarray(self, pl_in: PlotInput) -> List[pn.panel]: + def xarray(self, pl_in: PlotInput) -> pn.panel: """Returns a list of panel objects containing an xarray object.""" - return [pn.panel(pl_in.bench_cfg.ds, name=PlotTypes.xarray)] + return pn.panel(pl_in.bench_cfg.ds, name=PlotTypes.xarray) diff --git a/test/plots/test_plots_common.py b/test/plots/test_plots_common.py index de87f022..4d4609d9 100644 --- a/test/plots/test_plots_common.py +++ b/test/plots/test_plots_common.py @@ -1,4 +1,5 @@ import unittest + import panel as pn import bencher as bch @@ -33,6 +34,5 @@ def basic_plot_asserts(self, result: bch.BenchCfg, plot_name: str) -> None: plot_name (str): expected name of the plot """ - self.assertIsInstance(result, list) - self.assertIsInstance(result[0], pn.viewable.Viewable) - self.assertEqual(result[0].name, plot_name) + self.assertIsInstance(result, pn.viewable.Viewable) + self.assertEqual(result.name, plot_name) diff --git a/test/plots/test_tables.py b/test/plots/test_tables.py index f86fe21f..e5bcfa96 100644 --- a/test/plots/test_tables.py +++ b/test/plots/test_tables.py @@ -1,14 +1,16 @@ +from hypothesis import given, settings +from hypothesis import strategies as st + +from bencher.bench_cfg import PltCntCfg +from bencher.example.benchmark_data import ExampleBenchCfgOut from bencher.plotting.plot_collection import PlotInput from bencher.plotting.plot_types import PlotTypes from bencher.plotting.plots.tables import Tables -from bencher.bench_cfg import PltCntCfg -from bencher.example.benchmark_data import ExampleBenchCfgOut -from hypothesis import given, settings, strategies as st from .test_plots_common import TestPlotsCommon -class TestCatPlot(TestPlotsCommon): +class TestTables(TestPlotsCommon): @settings(deadline=10000) @given( st.sampled_from( diff --git a/test/test_plot_collection.py b/test/test_plot_collection.py index 00e92cd3..585c12a4 100644 --- a/test/test_plot_collection.py +++ b/test/test_plot_collection.py @@ -1,15 +1,17 @@ # (Mostly) Generated by CodiumAI, sanity checked and fixed by a human import unittest +from typing import Optional + import panel as pn -from typing import List + from bencher.plotting.plot_collection import PlotCollection, PlotProvider class TestPlotProvider: - def plot_1(self) -> List[pn.panel]: + def plot_1(self) -> Optional[pn.panel]: return [pn.pane.Markdown("Test plot 1", name="plot_1")] - def plot_2(self) -> List[pn.panel]: + def plot_2(self) -> Optional[pn.panel]: return [pn.pane.Markdown("Test plot 2", name="plot_2")]