Skip to content

Commit

Permalink
Merge pull request #44 from dyson-ai/feature/refactor_plot_return_types
Browse files Browse the repository at this point in the history
Feature/refactor plot return types
  • Loading branch information
mfinean authored Jun 29, 2023
2 parents 0a6d815 + 9f712de commit 537a2a7
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 56 deletions.
27 changes: 19 additions & 8 deletions bencher/plotting/plot_collection.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations
import panel as pn

import inspect
from typing import List, Callable
import logging
from typing import Callable, List

import panel as pn

from bencher.bench_cfg import PltCntCfg, BenchCfg
from bencher.bench_cfg import BenchCfg, PltCntCfg
from bencher.bench_vars import ParametrizedSweep
from bencher.plotting.plot_filter import PlotProvider, PlotInput
from bencher.plotting.plot_filter import PlotInput, PlotProvider


class PlotCollection:
Expand Down Expand Up @@ -86,14 +87,24 @@ def gather_plots(
Args:
bench_cfg (BenchCfg): A config of the input vars
rv (ParametrizedSweep): a config of the result variable
plt_cnt_cfg (PltCntCfg): A config of how many input types there are"""
plt_cnt_cfg (PltCntCfg): A config of how many input types there are
Raises:
ValueError: If the plot does not inherit from a pn.viewable.Viewable type
"""

tabs = pn.Accordion()
for plt_fn in self.plotters.values():
plots = plt_fn(PlotInput(bench_cfg, rv, plt_cnt_cfg))
for plt_instance in plots:
logging.info(f"plotting: {plt_instance.name}")
tabs.append(plt_instance)
if plots is not None:
if type(plots) != list:
plots = [plots]
for plt_instance in plots:
if not isinstance(plt_instance, pn.viewable.Viewable):
raise ValueError(
"The plot must be a viewable type (pn.viewable.Viewable or pn.panel)"
)
logging.info(f"plotting: {plt_instance.name}")
tabs.append(plt_instance)
if len(tabs) > 0:
tabs.active = [0] # set the first plot as active
return tabs
30 changes: 16 additions & 14 deletions bencher/plotting/plots/catplot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import List, Tuple
import seaborn as sns
import panel as pn
from typing import Optional, Tuple

import matplotlib.pyplot as plt
import pandas as pd
from bencher.plotting.plot_filter import PlotFilter, VarRange, PlotInput
from bencher.plt_cfg import PltCfgBase
import panel as pn
import seaborn as sns

from bencher.plotting.plot_filter import PlotFilter, PlotInput, VarRange
from bencher.plotting.plot_types import PlotTypes
from bencher.plt_cfg import PltCfgBase


class Catplot:
Expand All @@ -31,32 +33,32 @@ def plot_setup(pl_in: PlotInput) -> Tuple[pd.DataFrame, PltCfgBase]:
return df, sns_cfg

@staticmethod
def plot_postprocess(fg: plt.figure, sns_cfg: PltCfgBase, name: str) -> List[pn.panel]:
def plot_postprocess(fg: plt.figure, sns_cfg: PltCfgBase, name: str) -> Optional[pn.panel]:
fg.fig.suptitle(sns_cfg.title)
fg.set_xlabels(label=sns_cfg.xlabel, clear_inner=True)
fg.set_ylabels(label=sns_cfg.ylabel, clear_inner=True)
plt.tight_layout()
return [pn.panel(fg.fig, name=name)]
return pn.panel(fg.fig, name=name)

def catplot_common(self, pl_in: PlotInput, kind: str, name: str) -> List[pn.panel]:
def catplot_common(self, pl_in: PlotInput, kind: str, name: str) -> Optional[pn.panel]:
if self.float_0_cat_at_least1_vec_1_res_1.matches(pl_in.plt_cnt_cfg):
df, sns_cfg = self.plot_setup(pl_in)
sns_cfg.kind = kind
fg = sns.catplot(df, **sns_cfg.as_sns_args())
return self.plot_postprocess(fg, sns_cfg, name)
return []
return None

def swarmplot(self, pl_in: PlotInput) -> List[pn.panel]:
def swarmplot(self, pl_in: PlotInput) -> Optional[pn.panel]:
return self.catplot_common(pl_in, "swarm", PlotTypes.swarmplot)

def violinplot(self, pl_in: PlotInput) -> List[pn.panel]:
def violinplot(self, pl_in: PlotInput) -> Optional[pn.panel]:
return self.catplot_common(pl_in, "violin", PlotTypes.violinplot)

def boxplot(self, pl_in: PlotInput) -> List[pn.panel]:
def boxplot(self, pl_in: PlotInput) -> Optional[pn.panel]:
return self.catplot_common(pl_in, "box", PlotTypes.boxplot)

def barplot(self, pl_in: PlotInput) -> List[pn.panel]:
def barplot(self, pl_in: PlotInput) -> Optional[pn.panel]:
return self.catplot_common(pl_in, "bar", PlotTypes.barplot)

def boxenplot(self, pl_in: PlotInput) -> List[pn.panel]:
def boxenplot(self, pl_in: PlotInput) -> Optional[pn.panel]:
return self.catplot_common(pl_in, "boxen", PlotTypes.boxenplot)
27 changes: 13 additions & 14 deletions bencher/plotting/plots/plots_2D.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List
from typing import Optional

import panel as pn
import plotly.express as px

from bencher.plotting.plot_filter import PlotFilter, VarRange, PlotInput
from bencher.plotting.plot_filter import PlotFilter, PlotInput, VarRange
from bencher.plotting.plot_types import PlotTypes


Expand All @@ -15,7 +16,7 @@ class Plots2D:
result_vars=VarRange(1, 1),
)

def imshow(self, pl_in: PlotInput) -> List[pn.panel]:
def imshow(self, pl_in: PlotInput) -> Optional[pn.panel]:
"""use the imshow plotting method to display 2D data
Args:
Expand All @@ -33,14 +34,12 @@ def imshow(self, pl_in: PlotInput) -> List[pn.panel]:
xlabel = f"{fv_x.name} [{fv_x.units}]"
ylabel = f"{fv_y.name} [{fv_y.units}]"
color_label = f"{pl_in.rv.name} [{pl_in.rv.units}]"
return [
pn.panel(
px.imshow(
mean,
title=title,
labels={"x": xlabel, "y": ylabel, "color": color_label},
),
name=PlotTypes.imshow,
)
]
return []
return pn.panel(
px.imshow(
mean,
title=title,
labels={"x": xlabel, "y": ylabel, "color": color_label},
),
name=PlotTypes.imshow,
)
return None
20 changes: 10 additions & 10 deletions bencher/plotting/plots/tables.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
from typing import List
import panel as pn
from bencher.plotting.plot_types import PlotTypes

from bencher.plotting.plot_filter import PlotInput
from bencher.plotting.plot_types import PlotTypes


class Tables:
"""A class to display the result data in tabular form"""

def dataframe_flat(self, pl_in: PlotInput) -> List[pn.panel]:
def dataframe_flat(self, pl_in: PlotInput) -> pn.panel:
"""Returns a list of panel objects containing a flat dataframe."""
df = pl_in.bench_cfg.get_dataframe()
return [pn.pane.DataFrame(df, name=PlotTypes.dataframe_flat)]
return pn.pane.DataFrame(df, name=PlotTypes.dataframe_flat)

def dataframe_multi_index(self, pl_in: PlotInput) -> List[pn.panel]:
def dataframe_multi_index(self, pl_in: PlotInput) -> pn.panel:
"""Returns a list of panel objects containing a multi-index dataframe."""
df = pl_in.bench_cfg.ds.to_dataframe()
return [pn.pane.DataFrame(df, name=PlotTypes.dataframe_multi_index)]
return pn.pane.DataFrame(df, name=PlotTypes.dataframe_multi_index)

def dataframe_mean(self, pl_in: PlotInput) -> List[pn.panel]:
def dataframe_mean(self, pl_in: PlotInput) -> pn.panel:
"""Returns a list of panel objects containing a mean dataframe."""
df = pl_in.bench_cfg.ds.mean("repeat").to_dataframe().reset_index()
return [pn.pane.DataFrame(df, name=PlotTypes.dataframe_mean)]
return pn.pane.DataFrame(df, name=PlotTypes.dataframe_mean)

def xarray(self, pl_in: PlotInput) -> List[pn.panel]:
def xarray(self, pl_in: PlotInput) -> pn.panel:
"""Returns a list of panel objects containing an xarray object."""
return [pn.panel(pl_in.bench_cfg.ds, name=PlotTypes.xarray)]
return pn.panel(pl_in.bench_cfg.ds, name=PlotTypes.xarray)
6 changes: 3 additions & 3 deletions test/plots/test_plots_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest

import panel as pn

import bencher as bch
Expand Down Expand Up @@ -33,6 +34,5 @@ def basic_plot_asserts(self, result: bch.BenchCfg, plot_name: str) -> None:
plot_name (str): expected name of the plot
"""

self.assertIsInstance(result, list)
self.assertIsInstance(result[0], pn.viewable.Viewable)
self.assertEqual(result[0].name, plot_name)
self.assertIsInstance(result, pn.viewable.Viewable)
self.assertEqual(result.name, plot_name)
10 changes: 6 additions & 4 deletions test/plots/test_tables.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from hypothesis import given, settings
from hypothesis import strategies as st

from bencher.bench_cfg import PltCntCfg
from bencher.example.benchmark_data import ExampleBenchCfgOut
from bencher.plotting.plot_collection import PlotInput
from bencher.plotting.plot_types import PlotTypes
from bencher.plotting.plots.tables import Tables
from bencher.bench_cfg import PltCntCfg
from bencher.example.benchmark_data import ExampleBenchCfgOut
from hypothesis import given, settings, strategies as st

from .test_plots_common import TestPlotsCommon


class TestCatPlot(TestPlotsCommon):
class TestTables(TestPlotsCommon):
@settings(deadline=10000)
@given(
st.sampled_from(
Expand Down
8 changes: 5 additions & 3 deletions test/test_plot_collection.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# (Mostly) Generated by CodiumAI, sanity checked and fixed by a human
import unittest
from typing import Optional

import panel as pn
from typing import List

from bencher.plotting.plot_collection import PlotCollection, PlotProvider


class TestPlotProvider:
def plot_1(self) -> List[pn.panel]:
def plot_1(self) -> Optional[pn.panel]:
return [pn.pane.Markdown("Test plot 1", name="plot_1")]

def plot_2(self) -> List[pn.panel]:
def plot_2(self) -> Optional[pn.panel]:
return [pn.pane.Markdown("Test plot 2", name="plot_2")]


Expand Down

0 comments on commit 537a2a7

Please sign in to comment.