Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/surface fx #66

Merged
merged 11 commits into from
Jul 21, 2023
Merged
4 changes: 4 additions & 0 deletions bencher/plotting/plot_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bencher.plotting.plots.scatterplot import Scatter
from bencher.plotting.plots.tables import Tables
from bencher.plotting.plots.volume import VolumePlot
from bencher.plotting.plots.surface import SurfacePlot

from bencher.plotting.plots.hv_interactive import HvInteractive

Expand All @@ -30,6 +31,8 @@ def setup_sources() -> PlotCollection:
plt_col.add_plotter_source(Scatter())
plt_col.add_plotter_source(HvInteractive())
plt_col.add_plotter_source(VolumePlot())
plt_col.add_plotter_source(SurfacePlot())

return plt_col

@staticmethod
Expand All @@ -54,6 +57,7 @@ def default() -> PlotCollection:
plt_col.add(PlotTypes.cone_plotly)
# plt_col.add(PlotTypes.lineplot_hv_subplot)
plt_col.add(PlotTypes.scatter2D_sns)
plt_col.add(PlotTypes.surface_hv)
# plt_col.add(PlotTypes.hv_interactive)

return plt_col
Expand Down
1 change: 1 addition & 0 deletions bencher/plotting/plot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@ class PlotTypes(StrEnum):

volume_plotly = auto()
cone_plotly = auto()
surface_hv = auto()

# hv_interactive = auto()
3 changes: 3 additions & 0 deletions bencher/plotting/plots/hv_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def scatter_hv(self, pl_in: PlotInput) -> Optional[pn.panel]:

def lineplot_hv(self, pl_in: PlotInput) -> Optional[pn.panel]:
if False & self.lineplot_filter.matches(pl_in.plt_cnt_cfg):
# print(pl_in.bench_cfg.get_hv_dataset())
# print(pl_in.bench_cfg.get_dataframe(False))
# return pn.Column(pl_in.bench_cfg.get_hv_dataset().to(hv.Table))
print(pl_in.bench_cfg.get_hv_dataset())
return pn.Column(pl_in.bench_cfg.to_curve(), name=PlotTypes.lineplot_hv)
return None
Expand Down
106 changes: 106 additions & 0 deletions bencher/plotting/plots/surface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Optional

import panel as pn
import logging
import xarray as xr

from bencher.bench_vars import ResultVar
from bencher.plotting.plot_filter import PlotFilter, PlotInput, VarRange, PltCntCfg
from bencher.plt_cfg import PltCfgBase
from bencher.plotting.plot_types import PlotTypes

from bencher.plotting_functions import wrap_long_time_labels
import holoviews as hv
from holoviews import opts


def plot_float_cnt_2(plt_cnt_cfg: PltCntCfg, rv: ResultVar, debug: bool) -> PltCfgBase:
"""A function for determining the plot settings if there are 2 float variable and updates the PltCfgBase

Args:
sns_cfg (PltCfgBase): See PltCfgBase definition
plt_cnt_cfg (PltCntCfg): See PltCntCfg definition

Returns:
PltCfgBase: See PltCfgBase definition
"""
xr_cfg = PltCfgBase()
if plt_cnt_cfg.float_cnt == 2:
logging.info(f"surface plot: {rv.name}")
xr_cfg.plot_callback_xra = xr.plot.plot
xr_cfg.x = plt_cnt_cfg.float_vars[0].name
xr_cfg.y = plt_cnt_cfg.float_vars[1].name
xr_cfg.xlabel = f"{xr_cfg.x} [{plt_cnt_cfg.float_vars[0].units}]"
xr_cfg.ylabel = f"{xr_cfg.y} [{plt_cnt_cfg.float_vars[1].units}]"
xr_cfg.zlabel = f"{rv.name} [{rv.units}]"
xr_cfg.title = f"{rv.name} vs ({xr_cfg.x} and {xr_cfg.y})"

if plt_cnt_cfg.cat_cnt >= 1:
logging.info("surface plot with 1 categorical")
xr_cfg.row = plt_cnt_cfg.cat_vars[0].name
xr_cfg.num_rows = len(plt_cnt_cfg.cat_vars[0].values(debug))
if plt_cnt_cfg.cat_cnt >= 2:
logging.info("surface plot with 2> categorical")
xr_cfg.col = plt_cnt_cfg.cat_vars[1].name
xr_cfg.num_cols = len(plt_cnt_cfg.cat_vars[1].values(debug))
return xr_cfg


class SurfacePlot:
def surface_hv(self, pl_in: PlotInput) -> Optional[pn.panel]:
"""Given a benchCfg generate a 2D surface plot

Args:
bench_cfg (BenchCfg): description of benchmark
rv (ParametrizedSweep): result variable to plot
xr_cfg (PltCfgBase): config of x,y variables

Returns:
pn.pane.holoview: A 2d surface plot as a holoview in a pane
"""
if False & PlotFilter(
float_range=VarRange(2, 2),
cat_range=VarRange(-1, None),
vector_len=VarRange(1, 1),
result_vars=VarRange(1, 1),
).matches(pl_in.plt_cnt_cfg):
xr_cfg = plot_float_cnt_2(pl_in.plt_cnt_cfg, pl_in.rv, pl_in.bench_cfg.debug)
hv.extension("plotly")
bench_cfg = pl_in.bench_cfg
rv = pl_in.rv

bench_cfg = wrap_long_time_labels(bench_cfg)

alpha = 0.3

da = bench_cfg.ds[rv.name]

mean = da.mean("repeat")

opts.defaults(
opts.Surface(
colorbar=True,
width=800,
height=800,
zlabel=xr_cfg.zlabel,
title=xr_cfg.title,
# image_rtol=0.002,
)
)
# TODO a warning suggests setting this parameter, but it does not seem to help as expected, leaving here to fix in the future
# hv.config.image_rtol = 1.0

ds = hv.Dataset(mean)
surface = ds.to(hv.Surface)

if bench_cfg.repeats > 1:
std_dev = da.std("repeat")
surface *= (
hv.Dataset(mean + std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False)
)
surface *= (
hv.Dataset(mean - std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False)
)
return pn.Column(surface, name=PlotTypes.surface_hv)

return None
7 changes: 2 additions & 5 deletions bencher/plotting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import pandas as pd
import panel as pn
import plotly.express as px
import plotly.graph_objs as go
import seaborn as sns
from holoviews import opts

from bencher.bench_cfg import BenchCfg, PltCfgBase
from bencher.bench_vars import ParametrizedSweep, ResultList, ResultVar, ResultVec
import plotly.graph_objs as go


hv.extension("plotly")

Expand Down Expand Up @@ -252,12 +253,10 @@ def plot_surface_plotly(
bench_cfg: BenchCfg, rv: ParametrizedSweep, xr_cfg: PltCfgBase
) -> pn.pane.Plotly:
"""Given a benchCfg generate a 2D surface plot

Args:
bench_cfg (BenchCfg): description of benchmark
rv (ParametrizedSweep): result variable to plot
xr_cfg (PltCfgBase): config of x,y variables

Returns:
pn.pane.Plotly: A 2d surface plot as a holoview in a pane
"""
Expand Down Expand Up @@ -305,12 +304,10 @@ def plot_surface_holo(
bench_cfg: BenchCfg, rv: ParametrizedSweep, xr_cfg: PltCfgBase
) -> pn.pane.Plotly:
"""Given a benchCfg generate a 2D surface plot

Args:
bench_cfg (BenchCfg): description of benchmark
rv (ParametrizedSweep): result variable to plot
xr_cfg (PltCfgBase): config of x,y variables

Returns:
pn.pane.holoview: A 2d surface plot as a holoview in a pane
"""
Expand Down
14 changes: 7 additions & 7 deletions bencher/plt_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import panel as pn
import seaborn as sns
import xarray as xr

import bencher.plotting_functions as plt_func
from bencher.bench_cfg import BenchCfg, PltCfgBase, PltCntCfg, describe_benchmark
from bencher.bench_vars import ParametrizedSweep, ResultVar, ResultVec
from bencher.bench_vars import ParametrizedSweep, ResultVec, ResultVar
from bencher.optuna_conversions import collect_optuna_plots
import xarray as xr


class BenchPlotter:
Expand Down Expand Up @@ -68,6 +68,8 @@ def plot(bench_cfg: BenchCfg, main_tab=None, append_cols=None) -> pn.pane:

if append_cols is not None:
plot_cols.extend(append_cols)
# plot_cols.append(pn.Column(pn.Row()))#attempt to add spacer to stop overlapping but does not work todo

plot_cols.append(pn.pane.Markdown(f"{bench_cfg.post_description}"))

tabs = pn.Tabs(plot_cols, name=bench_cfg.title)
Expand Down Expand Up @@ -118,9 +120,9 @@ def plot_results_row(bench_cfg: BenchCfg) -> pn.Row:
Returns:
pn.Row: A panel row with plots in it
"""
plot_rows = pn.Row(
name=bench_cfg.bench_name, scroll=True, height=1000
) # todo remove the scroll and make it resize dynamically
# todo remove the scroll and make it resize dynamically
plot_rows = pn.Row(name=bench_cfg.bench_name, scroll=True)

plt_cnt_cfg = BenchPlotter.generate_plt_cnt_cfg(bench_cfg)

for rv in bench_cfg.result_vars:
Expand Down Expand Up @@ -322,11 +324,9 @@ def plot_float_cnt_1(sns_cfg: PltCfgBase, plt_cnt_cfg: PltCntCfg) -> PltCfgBase:
@staticmethod
def plot_float_cnt_2(plt_cnt_cfg: PltCntCfg, rv: ResultVar, debug: bool) -> PltCfgBase:
"""A function for determining the plot settings if there are 2 float variable and updates the PltCfgBase

Args:
sns_cfg (PltCfgBase): See PltCfgBase definition
plt_cnt_cfg (PltCntCfg): See PltCntCfg definition

Returns:
PltCfgBase: See PltCfgBase definition
"""
Expand Down