From 595065241580c88772cd22a99d900c6885f040ff Mon Sep 17 00:00:00 2001 From: Austin Gregg-Smith Date: Fri, 21 Jul 2023 14:51:30 +0000 Subject: [PATCH 1/9] surfaces work --- bencher/plotting/plot_library.py | 4 ++ bencher/plotting/plot_types.py | 1 + bencher/plotting/plots/surface.py | 75 +++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+) create mode 100644 bencher/plotting/plots/surface.py diff --git a/bencher/plotting/plot_library.py b/bencher/plotting/plot_library.py index 1e6e94eb..eb1fe9c8 100644 --- a/bencher/plotting/plot_library.py +++ b/bencher/plotting/plot_library.py @@ -8,6 +8,7 @@ from bencher.plotting.plots.scatterplot import Scatter from bencher.plotting.plots.tables import Tables from bencher.plotting.plots.volume import VolumePlot +from bencher.plotting.plots.surface import SurfacePlot from bencher.plotting.plots.hv_interactive import HvInteractive @@ -30,6 +31,8 @@ def setup_sources() -> PlotCollection: plt_col.add_plotter_source(Scatter()) plt_col.add_plotter_source(HvInteractive()) plt_col.add_plotter_source(VolumePlot()) + plt_col.add_plotter_source(SurfacePlot()) + return plt_col @staticmethod @@ -54,6 +57,7 @@ def default() -> PlotCollection: plt_col.add(PlotTypes.cone_plotly) # plt_col.add(PlotTypes.lineplot_hv_subplot) plt_col.add(PlotTypes.scatter2D_sns) + plt_col.add(PlotTypes.surface_hv) # plt_col.add(PlotTypes.hv_interactive) return plt_col diff --git a/bencher/plotting/plot_types.py b/bencher/plotting/plot_types.py index d630ddc0..6c3c504f 100644 --- a/bencher/plotting/plot_types.py +++ b/bencher/plotting/plot_types.py @@ -35,5 +35,6 @@ class PlotTypes(StrEnum): volume_plotly = auto() cone_plotly = auto() + surface_hv = auto() # hv_interactive = auto() diff --git a/bencher/plotting/plots/surface.py b/bencher/plotting/plots/surface.py new file mode 100644 index 00000000..57df1b3b --- /dev/null +++ b/bencher/plotting/plots/surface.py @@ -0,0 +1,75 @@ +from typing import Optional + +import panel as pn +import plotly.graph_objs as go +import logging +import xarray as xr + +from bencher.bench_cfg import BenchCfg +from bencher.bench_vars import ParametrizedSweep +from bencher.plotting.plot_filter import PlotFilter, PlotInput, VarRange, PltCntCfg +from bencher.plt_cfg import PltCfgBase, BenchPlotter +from bencher.plotting.plot_types import PlotTypes + +from bencher.plotting_functions import wrap_long_time_labels, plot_surface_holo +import holoviews as hv +from holoviews import opts + + +class SurfacePlot: + def surface_hv(self, pl_in: PlotInput) -> Optional[pn.panel]: + """Given a benchCfg generate a 2D surface plot + + Args: + bench_cfg (BenchCfg): description of benchmark + rv (ParametrizedSweep): result variable to plot + xr_cfg (PltCfgBase): config of x,y variables + + Returns: + pn.pane.holoview: A 2d surface plot as a holoview in a pane + """ + if PlotFilter( + float_range=VarRange(2, 2), + cat_range=VarRange(-1, None), + vector_len=VarRange(1, 1), + result_vars=VarRange(1, 1), + ).matches(pl_in.plt_cnt_cfg): + xr_cfg = BenchPlotter.plot_float_cnt_2( + pl_in.plt_cnt_cfg, pl_in.rv, pl_in.bench_cfg.debug + ) + hv.extension("plotly") + bench_cfg = pl_in.bench_cfg + rv = pl_in.rv + + bench_cfg = wrap_long_time_labels(bench_cfg) + + alpha = 0.3 + + da = bench_cfg.ds[rv.name] + + mean = da.mean("repeat") + + opts.defaults( + opts.Surface( + colorbar=True, + width=800, + height=800, + zlabel=xr_cfg.zlabel, + title=xr_cfg.title, + # image_rtol=0.002, + ) + ) + # TODO a warning suggests setting this parameter, but it does not seem to help as expected, leaving here to fix in the future + # hv.config.image_rtol = 1.0 + + ds = hv.Dataset(mean) + surface = ds.to(hv.Surface) + + if bench_cfg.repeats > 1: + std_dev = da.std("repeat") + upper = hv.Dataset(mean + std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False) + lower = hv.Dataset(mean - std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False) + return surface * upper * lower + return pn.panel(surface, name=PlotTypes.surface_hv) + + return None From 7230b893cc81cc19d004d3e26012dbd83035a78d Mon Sep 17 00:00:00 2001 From: Austin Gregg-Smith Date: Fri, 21 Jul 2023 15:00:30 +0000 Subject: [PATCH 2/9] finalise surface plots --- bencher/plotting/plots/surface.py | 40 ++++++++++-- bencher/plotting_functions.py | 101 ------------------------------ bencher/plt_cfg.py | 48 +------------- 3 files changed, 37 insertions(+), 152 deletions(-) diff --git a/bencher/plotting/plots/surface.py b/bencher/plotting/plots/surface.py index 57df1b3b..c3870b53 100644 --- a/bencher/plotting/plots/surface.py +++ b/bencher/plotting/plots/surface.py @@ -6,16 +6,48 @@ import xarray as xr from bencher.bench_cfg import BenchCfg -from bencher.bench_vars import ParametrizedSweep +from bencher.bench_vars import ParametrizedSweep, ResultVar from bencher.plotting.plot_filter import PlotFilter, PlotInput, VarRange, PltCntCfg from bencher.plt_cfg import PltCfgBase, BenchPlotter from bencher.plotting.plot_types import PlotTypes -from bencher.plotting_functions import wrap_long_time_labels, plot_surface_holo +from bencher.plotting_functions import wrap_long_time_labels import holoviews as hv from holoviews import opts +def plot_float_cnt_2(plt_cnt_cfg: PltCntCfg, rv: ResultVar, debug: bool) -> PltCfgBase: + """A function for determining the plot settings if there are 2 float variable and updates the PltCfgBase + + Args: + sns_cfg (PltCfgBase): See PltCfgBase definition + plt_cnt_cfg (PltCntCfg): See PltCntCfg definition + + Returns: + PltCfgBase: See PltCfgBase definition + """ + xr_cfg = PltCfgBase() + if plt_cnt_cfg.float_cnt == 2: + logging.info(f"surface plot: {rv.name}") + xr_cfg.plot_callback_xra = xr.plot.plot + xr_cfg.x = plt_cnt_cfg.float_vars[0].name + xr_cfg.y = plt_cnt_cfg.float_vars[1].name + xr_cfg.xlabel = f"{xr_cfg.x} [{plt_cnt_cfg.float_vars[0].units}]" + xr_cfg.ylabel = f"{xr_cfg.y} [{plt_cnt_cfg.float_vars[1].units}]" + xr_cfg.zlabel = f"{rv.name} [{rv.units}]" + xr_cfg.title = f"{rv.name} vs ({xr_cfg.x} and {xr_cfg.y})" + + if plt_cnt_cfg.cat_cnt >= 1: + logging.info("surface plot with 1 categorical") + xr_cfg.row = plt_cnt_cfg.cat_vars[0].name + xr_cfg.num_rows = len(plt_cnt_cfg.cat_vars[0].values(debug)) + if plt_cnt_cfg.cat_cnt >= 2: + logging.info("surface plot with 2> categorical") + xr_cfg.col = plt_cnt_cfg.cat_vars[1].name + xr_cfg.num_cols = len(plt_cnt_cfg.cat_vars[1].values(debug)) + return xr_cfg + + class SurfacePlot: def surface_hv(self, pl_in: PlotInput) -> Optional[pn.panel]: """Given a benchCfg generate a 2D surface plot @@ -34,9 +66,7 @@ def surface_hv(self, pl_in: PlotInput) -> Optional[pn.panel]: vector_len=VarRange(1, 1), result_vars=VarRange(1, 1), ).matches(pl_in.plt_cnt_cfg): - xr_cfg = BenchPlotter.plot_float_cnt_2( - pl_in.plt_cnt_cfg, pl_in.rv, pl_in.bench_cfg.debug - ) + xr_cfg = plot_float_cnt_2(pl_in.plt_cnt_cfg, pl_in.rv, pl_in.bench_cfg.debug) hv.extension("plotly") bench_cfg = pl_in.bench_cfg rv = pl_in.rv diff --git a/bencher/plotting_functions.py b/bencher/plotting_functions.py index 5c9919be..c64ff722 100644 --- a/bencher/plotting_functions.py +++ b/bencher/plotting_functions.py @@ -246,104 +246,3 @@ def save_fig( f"This figname {figname} already exists, please define a unique benchmark name or don't run the same benchmark twice" ) plt.savefig(figpath) - - -def plot_surface_plotly( - bench_cfg: BenchCfg, rv: ParametrizedSweep, xr_cfg: PltCfgBase -) -> pn.pane.Plotly: - """Given a benchCfg generate a 2D surface plot - - Args: - bench_cfg (BenchCfg): description of benchmark - rv (ParametrizedSweep): result variable to plot - xr_cfg (PltCfgBase): config of x,y variables - - Returns: - pn.pane.Plotly: A 2d surface plot as a holoview in a pane - """ - - if type(rv) == ResultVec: - return plot_scatter3D_px(bench_cfg, rv) - - bench_cfg = wrap_long_time_labels(bench_cfg) - - da = bench_cfg.ds[rv.name].transpose() - - mean = da.mean("repeat") - - x = da.coords[xr_cfg.x] - y = da.coords[xr_cfg.y] - - opacity = 0.3 - - surfaces = [go.Surface(x=x, y=y, z=mean)] - - if bench_cfg.repeats > 1: - std_dev = da.std("repeat") - surfaces.append(go.Surface(x=x, y=y, z=mean + std_dev, showscale=False, opacity=opacity)) - surfaces.append(go.Surface(x=x, y=y, z=mean - std_dev, showscale=False, opacity=opacity)) - - eye_dis = 1.7 - layout = go.Layout( - title=xr_cfg.title, - width=700, - height=700, - scene=dict( - xaxis_title=xr_cfg.xlabel, - yaxis_title=xr_cfg.ylabel, - zaxis_title=xr_cfg.zlabel, - camera={"eye": {"x": eye_dis, "y": eye_dis, "z": eye_dis}}, - ), - ) - - fig = {"data": surfaces, "layout": layout} - - return pn.pane.Plotly(fig) - - -def plot_surface_holo( - bench_cfg: BenchCfg, rv: ParametrizedSweep, xr_cfg: PltCfgBase -) -> pn.pane.Plotly: - """Given a benchCfg generate a 2D surface plot - - Args: - bench_cfg (BenchCfg): description of benchmark - rv (ParametrizedSweep): result variable to plot - xr_cfg (PltCfgBase): config of x,y variables - - Returns: - pn.pane.holoview: A 2d surface plot as a holoview in a pane - """ - - hv.extension("plotly") - - bench_cfg = wrap_long_time_labels(bench_cfg) - - alpha = 0.3 - - da = bench_cfg.ds[rv.name] - - mean = da.mean("repeat") - - opts.defaults( - opts.Surface( - colorbar=True, - width=800, - height=800, - zlabel=xr_cfg.zlabel, - title=xr_cfg.title, - # image_rtol=0.002, - ) - ) - # TODO a warning suggests setting this parameter, but it does not seem to help as expected, leaving here to fix in the future - # hv.config.image_rtol = 1.0 - - ds = hv.Dataset(mean) - surface = ds.to(hv.Surface) - - if bench_cfg.repeats > 1: - std_dev = da.std("repeat") - upper = hv.Dataset(mean + std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False) - lower = hv.Dataset(mean - std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False) - return surface * upper * lower - return surface diff --git a/bencher/plt_cfg.py b/bencher/plt_cfg.py index 20f59d71..1b23bcdb 100644 --- a/bencher/plt_cfg.py +++ b/bencher/plt_cfg.py @@ -68,6 +68,8 @@ def plot(bench_cfg: BenchCfg, main_tab=None, append_cols=None) -> pn.pane: if append_cols is not None: plot_cols.extend(append_cols) + # plot_cols.append(pn.Column(pn.Row()))#attempt to add spacer to stop overlapping but does not work todo + plot_cols.append(pn.pane.Markdown(f"{bench_cfg.post_description}")) tabs = pn.Tabs(plot_cols, name=bench_cfg.title) @@ -212,20 +214,6 @@ def plot_result_variable( sns_cfg = BenchPlotter.plot_float_cnt_1(sns_cfg, plt_cnt_cfg) sns_cfg = BenchPlotter.get_axes_and_title(rv, sns_cfg, plt_cnt_cfg) surf_col.append(plt_func.plot_sns(bench_cfg, rv, sns_cfg)) - else: - if plt_cnt_cfg.float_cnt == 2: - xr_cfg = BenchPlotter.plot_float_cnt_2(plt_cnt_cfg, rv, bench_cfg.debug) - if plt_cnt_cfg.cat_cnt == 0: - surf_col.append(plt_func.plot_surface_plotly(bench_cfg, rv, xr_cfg)) - else: - try: - surf_col.append(plt_func.plot_surface_holo(bench_cfg, rv, xr_cfg)) - except (TypeError, KeyError) as e: - surf_col.append( - pn.pane.Markdown( - f"3D (cat,float,cat) inputs -> (float) output plots are not supported yet, error:{e}" - ) - ) return surf_col @staticmethod @@ -318,35 +306,3 @@ def plot_float_cnt_1(sns_cfg: PltCfgBase, plt_cnt_cfg: PltCntCfg) -> PltCfgBase: sns_cfg = BenchPlotter.axis_mapping(cat_axis_order, sns_cfg, plt_cnt_cfg) return sns_cfg - - @staticmethod - def plot_float_cnt_2(plt_cnt_cfg: PltCntCfg, rv: ResultVar, debug: bool) -> PltCfgBase: - """A function for determining the plot settings if there are 2 float variable and updates the PltCfgBase - - Args: - sns_cfg (PltCfgBase): See PltCfgBase definition - plt_cnt_cfg (PltCntCfg): See PltCntCfg definition - - Returns: - PltCfgBase: See PltCfgBase definition - """ - xr_cfg = PltCfgBase() - if plt_cnt_cfg.float_cnt == 2: - logging.info(f"surface plot: {rv.name}") - xr_cfg.plot_callback_xra = xr.plot.plot - xr_cfg.x = plt_cnt_cfg.float_vars[0].name - xr_cfg.y = plt_cnt_cfg.float_vars[1].name - xr_cfg.xlabel = f"{xr_cfg.x} [{plt_cnt_cfg.float_vars[0].units}]" - xr_cfg.ylabel = f"{xr_cfg.y} [{plt_cnt_cfg.float_vars[1].units}]" - xr_cfg.zlabel = f"{rv.name} [{rv.units}]" - xr_cfg.title = f"{rv.name} vs ({xr_cfg.x} and {xr_cfg.y})" - - if plt_cnt_cfg.cat_cnt >= 1: - logging.info("surface plot with 1 categorical") - xr_cfg.row = plt_cnt_cfg.cat_vars[0].name - xr_cfg.num_rows = len(plt_cnt_cfg.cat_vars[0].values(debug)) - if plt_cnt_cfg.cat_cnt >= 2: - logging.info("surface plot with 2> categorical") - xr_cfg.col = plt_cnt_cfg.cat_vars[1].name - xr_cfg.num_cols = len(plt_cnt_cfg.cat_vars[1].values(debug)) - return xr_cfg From 8b825c60ff8c75a662b509a9ed9df67a5a1879e0 Mon Sep 17 00:00:00 2001 From: Austin Gregg-Smith Date: Fri, 21 Jul 2023 15:01:39 +0000 Subject: [PATCH 3/9] fix linting --- bencher/plotting/plots/surface.py | 6 ++---- bencher/plotting_functions.py | 1 - bencher/plt_cfg.py | 3 +-- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/bencher/plotting/plots/surface.py b/bencher/plotting/plots/surface.py index c3870b53..6698630f 100644 --- a/bencher/plotting/plots/surface.py +++ b/bencher/plotting/plots/surface.py @@ -1,14 +1,12 @@ from typing import Optional import panel as pn -import plotly.graph_objs as go import logging import xarray as xr -from bencher.bench_cfg import BenchCfg -from bencher.bench_vars import ParametrizedSweep, ResultVar +from bencher.bench_vars import ResultVar from bencher.plotting.plot_filter import PlotFilter, PlotInput, VarRange, PltCntCfg -from bencher.plt_cfg import PltCfgBase, BenchPlotter +from bencher.plt_cfg import PltCfgBase from bencher.plotting.plot_types import PlotTypes from bencher.plotting_functions import wrap_long_time_labels diff --git a/bencher/plotting_functions.py b/bencher/plotting_functions.py index c64ff722..58bb7108 100644 --- a/bencher/plotting_functions.py +++ b/bencher/plotting_functions.py @@ -9,7 +9,6 @@ import pandas as pd import panel as pn import plotly.express as px -import plotly.graph_objs as go import seaborn as sns from holoviews import opts diff --git a/bencher/plt_cfg.py b/bencher/plt_cfg.py index 1b23bcdb..43f75666 100644 --- a/bencher/plt_cfg.py +++ b/bencher/plt_cfg.py @@ -3,11 +3,10 @@ import panel as pn import seaborn as sns -import xarray as xr import bencher.plotting_functions as plt_func from bencher.bench_cfg import BenchCfg, PltCfgBase, PltCntCfg, describe_benchmark -from bencher.bench_vars import ParametrizedSweep, ResultVar, ResultVec +from bencher.bench_vars import ParametrizedSweep, ResultVec from bencher.optuna_conversions import collect_optuna_plots From a3aa0d7bfecbb80335864179d2778ea0ce3492bd Mon Sep 17 00:00:00 2001 From: Austin Gregg-Smith Date: Fri, 21 Jul 2023 15:05:46 +0000 Subject: [PATCH 4/9] minor fixes --- bencher/plotting/plots/surface.py | 2 +- bencher/plt_cfg.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/bencher/plotting/plots/surface.py b/bencher/plotting/plots/surface.py index 6698630f..18035cdc 100644 --- a/bencher/plotting/plots/surface.py +++ b/bencher/plotting/plots/surface.py @@ -98,6 +98,6 @@ def surface_hv(self, pl_in: PlotInput) -> Optional[pn.panel]: upper = hv.Dataset(mean + std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False) lower = hv.Dataset(mean - std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False) return surface * upper * lower - return pn.panel(surface, name=PlotTypes.surface_hv) + return pn.Column(surface, name=PlotTypes.surface_hv) return None diff --git a/bencher/plt_cfg.py b/bencher/plt_cfg.py index 6c2f6589..aa7eb2ea 100644 --- a/bencher/plt_cfg.py +++ b/bencher/plt_cfg.py @@ -122,8 +122,7 @@ def plot_results_row(bench_cfg: BenchCfg) -> pn.Row: plot_rows = pn.Row( name=bench_cfg.bench_name, scroll=True, height=1000 ) # todo remove the scroll and make it resize dynamically - name=bench_cfg.bench_name, scroll=True, height=1000 - ) # todo remove the scroll and make it resize dynamically + plt_cnt_cfg = BenchPlotter.generate_plt_cnt_cfg(bench_cfg) for rv in bench_cfg.result_vars: From a57136b8523cb2091ad9ed51a1996dac1b88bde7 Mon Sep 17 00:00:00 2001 From: Austin Gregg-Smith Date: Fri, 21 Jul 2023 15:13:34 +0000 Subject: [PATCH 5/9] add long time labels to bench_cfg --- bencher/bench_cfg.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/bencher/bench_cfg.py b/bencher/bench_cfg.py index b1f8ca0f..08715521 100644 --- a/bencher/bench_cfg.py +++ b/bencher/bench_cfg.py @@ -11,6 +11,7 @@ from str2bool import str2bool import holoviews as hv import numpy as np +import pandas as pd import bencher as bch from bencher.bench_vars import OptDir, TimeEvent, TimeSnapshot, describe_variable, hash_sha1 @@ -496,6 +497,29 @@ def get_best_trial_params(self): def get_pareto_front_params(self): return [p.params for p in self.studies[0].trials] + def wrap_long_time_labels(self) -> BenchCfg: + """Takes a benchCfg and wraps any index labels that are too long to be plotted easily + + Args: + bench_cfg (BenchCfg): + + Returns: + BenchCfg: updated config with wrapped labels + """ + if self.over_time: + if self.ds.coords["over_time"].dtype == np.datetime64: + # plotly catastrophically fails to plot anything with the default long string representation of time, so convert to a shorter time representation + self.ds.coords["over_time"] = [ + pd.to_datetime(t).strftime("%d-%m-%y %H-%M-%S") + for t in self.ds.coords.coords["over_time"].values + ] + # wrap very long time event labels because otherwise the graphs are unreadable + if self.time_event is not None: + self.ds.coords["over_time"] = [ + "\n".join(wrap(t, 20)) for t in self.ds.coords["over_time"].values + ] + return self + def get_hv_dataset(self, reduce=None): ds = convert_dataset_bool_dims_to_str(self.ds) if reduce is None: From b5097cb1ee4da1da482e1356c18bc710e394abdf Mon Sep 17 00:00:00 2001 From: Austin Gregg-Smith Date: Fri, 21 Jul 2023 15:16:29 +0000 Subject: [PATCH 6/9] Revert "add long time labels to bench_cfg" This reverts commit a57136b8523cb2091ad9ed51a1996dac1b88bde7. --- bencher/bench_cfg.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/bencher/bench_cfg.py b/bencher/bench_cfg.py index 08715521..b1f8ca0f 100644 --- a/bencher/bench_cfg.py +++ b/bencher/bench_cfg.py @@ -11,7 +11,6 @@ from str2bool import str2bool import holoviews as hv import numpy as np -import pandas as pd import bencher as bch from bencher.bench_vars import OptDir, TimeEvent, TimeSnapshot, describe_variable, hash_sha1 @@ -497,29 +496,6 @@ def get_best_trial_params(self): def get_pareto_front_params(self): return [p.params for p in self.studies[0].trials] - def wrap_long_time_labels(self) -> BenchCfg: - """Takes a benchCfg and wraps any index labels that are too long to be plotted easily - - Args: - bench_cfg (BenchCfg): - - Returns: - BenchCfg: updated config with wrapped labels - """ - if self.over_time: - if self.ds.coords["over_time"].dtype == np.datetime64: - # plotly catastrophically fails to plot anything with the default long string representation of time, so convert to a shorter time representation - self.ds.coords["over_time"] = [ - pd.to_datetime(t).strftime("%d-%m-%y %H-%M-%S") - for t in self.ds.coords.coords["over_time"].values - ] - # wrap very long time event labels because otherwise the graphs are unreadable - if self.time_event is not None: - self.ds.coords["over_time"] = [ - "\n".join(wrap(t, 20)) for t in self.ds.coords["over_time"].values - ] - return self - def get_hv_dataset(self, reduce=None): ds = convert_dataset_bool_dims_to_str(self.ds) if reduce is None: From 76dd8f3d4058f1761c2db4c08f9ad491ce2b8cdb Mon Sep 17 00:00:00 2001 From: Austin Gregg-Smith Date: Fri, 21 Jul 2023 15:35:23 +0000 Subject: [PATCH 7/9] add methods back --- bencher/plotting/plots/hv_interactive.py | 3 + bencher/plotting_functions.py | 97 ++++++++++++++++++++++++ bencher/plt_cfg.py | 46 ++++++++++- 3 files changed, 145 insertions(+), 1 deletion(-) diff --git a/bencher/plotting/plots/hv_interactive.py b/bencher/plotting/plots/hv_interactive.py index af5478b9..6d9fd9c7 100644 --- a/bencher/plotting/plots/hv_interactive.py +++ b/bencher/plotting/plots/hv_interactive.py @@ -81,6 +81,9 @@ def scatter_hv(self, pl_in: PlotInput) -> Optional[pn.panel]: def lineplot_hv(self, pl_in: PlotInput) -> Optional[pn.panel]: if self.lineplot_filter.matches(pl_in.plt_cnt_cfg): print(pl_in.bench_cfg.get_hv_dataset()) + print(pl_in.bench_cfg.get_dataframe(False)) + return pn.Column(pl_in.bench_cfg.get_hv_dataset().to(hv.Table)) + return pn.Column(pl_in.bench_cfg.to_curve(), name=PlotTypes.lineplot_hv) return None diff --git a/bencher/plotting_functions.py b/bencher/plotting_functions.py index 58bb7108..b0818a76 100644 --- a/bencher/plotting_functions.py +++ b/bencher/plotting_functions.py @@ -245,3 +245,100 @@ def save_fig( f"This figname {figname} already exists, please define a unique benchmark name or don't run the same benchmark twice" ) plt.savefig(figpath) + + +def plot_surface_plotly( + bench_cfg: BenchCfg, rv: ParametrizedSweep, xr_cfg: PltCfgBase +) -> pn.pane.Plotly: + """Given a benchCfg generate a 2D surface plot + Args: + bench_cfg (BenchCfg): description of benchmark + rv (ParametrizedSweep): result variable to plot + xr_cfg (PltCfgBase): config of x,y variables + Returns: + pn.pane.Plotly: A 2d surface plot as a holoview in a pane + """ + + if type(rv) == ResultVec: + return plot_scatter3D_px(bench_cfg, rv) + + bench_cfg = wrap_long_time_labels(bench_cfg) + + da = bench_cfg.ds[rv.name].transpose() + + mean = da.mean("repeat") + + x = da.coords[xr_cfg.x] + y = da.coords[xr_cfg.y] + + opacity = 0.3 + + surfaces = [go.Surface(x=x, y=y, z=mean)] + + if bench_cfg.repeats > 1: + std_dev = da.std("repeat") + surfaces.append(go.Surface(x=x, y=y, z=mean + std_dev, showscale=False, opacity=opacity)) + surfaces.append(go.Surface(x=x, y=y, z=mean - std_dev, showscale=False, opacity=opacity)) + + eye_dis = 1.7 + layout = go.Layout( + title=xr_cfg.title, + width=700, + height=700, + scene=dict( + xaxis_title=xr_cfg.xlabel, + yaxis_title=xr_cfg.ylabel, + zaxis_title=xr_cfg.zlabel, + camera={"eye": {"x": eye_dis, "y": eye_dis, "z": eye_dis}}, + ), + ) + + fig = {"data": surfaces, "layout": layout} + + return pn.pane.Plotly(fig) + + +def plot_surface_holo( + bench_cfg: BenchCfg, rv: ParametrizedSweep, xr_cfg: PltCfgBase +) -> pn.pane.Plotly: + """Given a benchCfg generate a 2D surface plot + Args: + bench_cfg (BenchCfg): description of benchmark + rv (ParametrizedSweep): result variable to plot + xr_cfg (PltCfgBase): config of x,y variables + Returns: + pn.pane.holoview: A 2d surface plot as a holoview in a pane + """ + + hv.extension("plotly") + + bench_cfg = wrap_long_time_labels(bench_cfg) + + alpha = 0.3 + + da = bench_cfg.ds[rv.name] + + mean = da.mean("repeat") + + opts.defaults( + opts.Surface( + colorbar=True, + width=800, + height=800, + zlabel=xr_cfg.zlabel, + title=xr_cfg.title, + # image_rtol=0.002, + ) + ) + # TODO a warning suggests setting this parameter, but it does not seem to help as expected, leaving here to fix in the future + # hv.config.image_rtol = 1.0 + + ds = hv.Dataset(mean) + surface = ds.to(hv.Surface) + + if bench_cfg.repeats > 1: + std_dev = da.std("repeat") + upper = hv.Dataset(mean + std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False) + lower = hv.Dataset(mean - std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False) + return surface * upper * lower + return surface \ No newline at end of file diff --git a/bencher/plt_cfg.py b/bencher/plt_cfg.py index aa7eb2ea..7a55bf79 100644 --- a/bencher/plt_cfg.py +++ b/bencher/plt_cfg.py @@ -6,7 +6,7 @@ import bencher.plotting_functions as plt_func from bencher.bench_cfg import BenchCfg, PltCfgBase, PltCntCfg, describe_benchmark -from bencher.bench_vars import ParametrizedSweep, ResultVec +from bencher.bench_vars import ParametrizedSweep, ResultVec,ResultVar from bencher.optuna_conversions import collect_optuna_plots @@ -214,6 +214,20 @@ def plot_result_variable( sns_cfg = BenchPlotter.plot_float_cnt_1(sns_cfg, plt_cnt_cfg) sns_cfg = BenchPlotter.get_axes_and_title(rv, sns_cfg, plt_cnt_cfg) surf_col.append(plt_func.plot_sns(bench_cfg, rv, sns_cfg)) + else: + if plt_cnt_cfg.float_cnt == 2: + xr_cfg = plot_float_cnt_2(plt_cnt_cfg, rv, bench_cfg.debug) + if plt_cnt_cfg.cat_cnt == 0: + surf_col.append(plt_func.plot_surface_plotly(bench_cfg, rv, xr_cfg)) + else: + try: + surf_col.append(plt_func.plot_surface_holo(bench_cfg, rv, xr_cfg)) + except (TypeError, KeyError) as e: + surf_col.append( + pn.pane.Markdown( + f"3D (cat,float,cat) inputs -> (float) output plots are not supported yet, error:{e}" + ) + ) return surf_col @staticmethod @@ -306,3 +320,33 @@ def plot_float_cnt_1(sns_cfg: PltCfgBase, plt_cnt_cfg: PltCntCfg) -> PltCfgBase: sns_cfg = BenchPlotter.axis_mapping(cat_axis_order, sns_cfg, plt_cnt_cfg) return sns_cfg + + @staticmethod + def plot_float_cnt_2(plt_cnt_cfg: PltCntCfg, rv: ResultVar, debug: bool) -> PltCfgBase: + """A function for determining the plot settings if there are 2 float variable and updates the PltCfgBase + Args: + sns_cfg (PltCfgBase): See PltCfgBase definition + plt_cnt_cfg (PltCntCfg): See PltCntCfg definition + Returns: + PltCfgBase: See PltCfgBase definition + """ + xr_cfg = PltCfgBase() + if plt_cnt_cfg.float_cnt == 2: + logging.info(f"surface plot: {rv.name}") + xr_cfg.plot_callback_xra = xr.plot.plot + xr_cfg.x = plt_cnt_cfg.float_vars[0].name + xr_cfg.y = plt_cnt_cfg.float_vars[1].name + xr_cfg.xlabel = f"{xr_cfg.x} [{plt_cnt_cfg.float_vars[0].units}]" + xr_cfg.ylabel = f"{xr_cfg.y} [{plt_cnt_cfg.float_vars[1].units}]" + xr_cfg.zlabel = f"{rv.name} [{rv.units}]" + xr_cfg.title = f"{rv.name} vs ({xr_cfg.x} and {xr_cfg.y})" + + if plt_cnt_cfg.cat_cnt >= 1: + logging.info("surface plot with 1 categorical") + xr_cfg.row = plt_cnt_cfg.cat_vars[0].name + xr_cfg.num_rows = len(plt_cnt_cfg.cat_vars[0].values(debug)) + if plt_cnt_cfg.cat_cnt >= 2: + logging.info("surface plot with 2> categorical") + xr_cfg.col = plt_cnt_cfg.cat_vars[1].name + xr_cfg.num_cols = len(plt_cnt_cfg.cat_vars[1].values(debug)) + return xr_cfg \ No newline at end of file From 0bba9d5836372264260060fdd74d1ed4bbf94a90 Mon Sep 17 00:00:00 2001 From: Austin Gregg-Smith Date: Fri, 21 Jul 2023 15:37:21 +0000 Subject: [PATCH 8/9] fix lint --- bencher/plotting_functions.py | 4 +++- bencher/plt_cfg.py | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/bencher/plotting_functions.py b/bencher/plotting_functions.py index b0818a76..a21e1f15 100644 --- a/bencher/plotting_functions.py +++ b/bencher/plotting_functions.py @@ -14,6 +14,8 @@ from bencher.bench_cfg import BenchCfg, PltCfgBase from bencher.bench_vars import ParametrizedSweep, ResultList, ResultVar, ResultVec +import plotly.graph_objs as go + hv.extension("plotly") @@ -341,4 +343,4 @@ def plot_surface_holo( upper = hv.Dataset(mean + std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False) lower = hv.Dataset(mean - std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False) return surface * upper * lower - return surface \ No newline at end of file + return surface diff --git a/bencher/plt_cfg.py b/bencher/plt_cfg.py index 7a55bf79..439bf650 100644 --- a/bencher/plt_cfg.py +++ b/bencher/plt_cfg.py @@ -6,8 +6,9 @@ import bencher.plotting_functions as plt_func from bencher.bench_cfg import BenchCfg, PltCfgBase, PltCntCfg, describe_benchmark -from bencher.bench_vars import ParametrizedSweep, ResultVec,ResultVar +from bencher.bench_vars import ParametrizedSweep, ResultVec, ResultVar from bencher.optuna_conversions import collect_optuna_plots +import xarray as xr class BenchPlotter: @@ -216,7 +217,7 @@ def plot_result_variable( surf_col.append(plt_func.plot_sns(bench_cfg, rv, sns_cfg)) else: if plt_cnt_cfg.float_cnt == 2: - xr_cfg = plot_float_cnt_2(plt_cnt_cfg, rv, bench_cfg.debug) + xr_cfg = BenchPlotter.plot_float_cnt_2(plt_cnt_cfg, rv, bench_cfg.debug) if plt_cnt_cfg.cat_cnt == 0: surf_col.append(plt_func.plot_surface_plotly(bench_cfg, rv, xr_cfg)) else: @@ -349,4 +350,4 @@ def plot_float_cnt_2(plt_cnt_cfg: PltCntCfg, rv: ResultVar, debug: bool) -> PltC logging.info("surface plot with 2> categorical") xr_cfg.col = plt_cnt_cfg.cat_vars[1].name xr_cfg.num_cols = len(plt_cnt_cfg.cat_vars[1].values(debug)) - return xr_cfg \ No newline at end of file + return xr_cfg From 862b9113fb666de882eb121cb5b53ad2b9adfaf7 Mon Sep 17 00:00:00 2001 From: Austin Gregg-Smith Date: Fri, 21 Jul 2023 15:41:25 +0000 Subject: [PATCH 9/9] restore lineplot --- bencher/plotting/plots/hv_interactive.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bencher/plotting/plots/hv_interactive.py b/bencher/plotting/plots/hv_interactive.py index 6d9fd9c7..a9f2f2bf 100644 --- a/bencher/plotting/plots/hv_interactive.py +++ b/bencher/plotting/plots/hv_interactive.py @@ -80,9 +80,9 @@ def scatter_hv(self, pl_in: PlotInput) -> Optional[pn.panel]: def lineplot_hv(self, pl_in: PlotInput) -> Optional[pn.panel]: if self.lineplot_filter.matches(pl_in.plt_cnt_cfg): - print(pl_in.bench_cfg.get_hv_dataset()) - print(pl_in.bench_cfg.get_dataframe(False)) - return pn.Column(pl_in.bench_cfg.get_hv_dataset().to(hv.Table)) + # print(pl_in.bench_cfg.get_hv_dataset()) + # print(pl_in.bench_cfg.get_dataframe(False)) + # return pn.Column(pl_in.bench_cfg.get_hv_dataset().to(hv.Table)) return pn.Column(pl_in.bench_cfg.to_curve(), name=PlotTypes.lineplot_hv) return None