Skip to content

Commit

Permalink
Merge pull request #66 from dyson-ai/feature/surface_fx
Browse files Browse the repository at this point in the history
Feature/surface fx
  • Loading branch information
blooop authored Jul 21, 2023
2 parents 0030653 + aea8d27 commit ca71897
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 12 deletions.
4 changes: 4 additions & 0 deletions bencher/plotting/plot_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bencher.plotting.plots.scatterplot import Scatter
from bencher.plotting.plots.tables import Tables
from bencher.plotting.plots.volume import VolumePlot
from bencher.plotting.plots.surface import SurfacePlot

from bencher.plotting.plots.hv_interactive import HvInteractive

Expand All @@ -30,6 +31,8 @@ def setup_sources() -> PlotCollection:
plt_col.add_plotter_source(Scatter())
plt_col.add_plotter_source(HvInteractive())
plt_col.add_plotter_source(VolumePlot())
plt_col.add_plotter_source(SurfacePlot())

return plt_col

@staticmethod
Expand All @@ -54,6 +57,7 @@ def default() -> PlotCollection:
plt_col.add(PlotTypes.cone_plotly)
# plt_col.add(PlotTypes.lineplot_hv_subplot)
plt_col.add(PlotTypes.scatter2D_sns)
plt_col.add(PlotTypes.surface_hv)
# plt_col.add(PlotTypes.hv_interactive)

return plt_col
Expand Down
1 change: 1 addition & 0 deletions bencher/plotting/plot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@ class PlotTypes(StrEnum):

volume_plotly = auto()
cone_plotly = auto()
surface_hv = auto()

# hv_interactive = auto()
3 changes: 3 additions & 0 deletions bencher/plotting/plots/hv_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def scatter_hv(self, pl_in: PlotInput) -> Optional[pn.panel]:

def lineplot_hv(self, pl_in: PlotInput) -> Optional[pn.panel]:
if False & self.lineplot_filter.matches(pl_in.plt_cnt_cfg):
# print(pl_in.bench_cfg.get_hv_dataset())
# print(pl_in.bench_cfg.get_dataframe(False))
# return pn.Column(pl_in.bench_cfg.get_hv_dataset().to(hv.Table))
print(pl_in.bench_cfg.get_hv_dataset())
return pn.Column(pl_in.bench_cfg.to_curve(), name=PlotTypes.lineplot_hv)
return None
Expand Down
106 changes: 106 additions & 0 deletions bencher/plotting/plots/surface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Optional

import panel as pn
import logging
import xarray as xr

from bencher.bench_vars import ResultVar
from bencher.plotting.plot_filter import PlotFilter, PlotInput, VarRange, PltCntCfg
from bencher.plt_cfg import PltCfgBase
from bencher.plotting.plot_types import PlotTypes

from bencher.plotting_functions import wrap_long_time_labels
import holoviews as hv
from holoviews import opts


def plot_float_cnt_2(plt_cnt_cfg: PltCntCfg, rv: ResultVar, debug: bool) -> PltCfgBase:
"""A function for determining the plot settings if there are 2 float variable and updates the PltCfgBase
Args:
sns_cfg (PltCfgBase): See PltCfgBase definition
plt_cnt_cfg (PltCntCfg): See PltCntCfg definition
Returns:
PltCfgBase: See PltCfgBase definition
"""
xr_cfg = PltCfgBase()
if plt_cnt_cfg.float_cnt == 2:
logging.info(f"surface plot: {rv.name}")
xr_cfg.plot_callback_xra = xr.plot.plot
xr_cfg.x = plt_cnt_cfg.float_vars[0].name
xr_cfg.y = plt_cnt_cfg.float_vars[1].name
xr_cfg.xlabel = f"{xr_cfg.x} [{plt_cnt_cfg.float_vars[0].units}]"
xr_cfg.ylabel = f"{xr_cfg.y} [{plt_cnt_cfg.float_vars[1].units}]"
xr_cfg.zlabel = f"{rv.name} [{rv.units}]"
xr_cfg.title = f"{rv.name} vs ({xr_cfg.x} and {xr_cfg.y})"

if plt_cnt_cfg.cat_cnt >= 1:
logging.info("surface plot with 1 categorical")
xr_cfg.row = plt_cnt_cfg.cat_vars[0].name
xr_cfg.num_rows = len(plt_cnt_cfg.cat_vars[0].values(debug))
if plt_cnt_cfg.cat_cnt >= 2:
logging.info("surface plot with 2> categorical")
xr_cfg.col = plt_cnt_cfg.cat_vars[1].name
xr_cfg.num_cols = len(plt_cnt_cfg.cat_vars[1].values(debug))
return xr_cfg


class SurfacePlot:
def surface_hv(self, pl_in: PlotInput) -> Optional[pn.panel]:
"""Given a benchCfg generate a 2D surface plot
Args:
bench_cfg (BenchCfg): description of benchmark
rv (ParametrizedSweep): result variable to plot
xr_cfg (PltCfgBase): config of x,y variables
Returns:
pn.pane.holoview: A 2d surface plot as a holoview in a pane
"""
if False & PlotFilter(
float_range=VarRange(2, 2),
cat_range=VarRange(-1, None),
vector_len=VarRange(1, 1),
result_vars=VarRange(1, 1),
).matches(pl_in.plt_cnt_cfg):
xr_cfg = plot_float_cnt_2(pl_in.plt_cnt_cfg, pl_in.rv, pl_in.bench_cfg.debug)
hv.extension("plotly")
bench_cfg = pl_in.bench_cfg
rv = pl_in.rv

bench_cfg = wrap_long_time_labels(bench_cfg)

alpha = 0.3

da = bench_cfg.ds[rv.name]

mean = da.mean("repeat")

opts.defaults(
opts.Surface(
colorbar=True,
width=800,
height=800,
zlabel=xr_cfg.zlabel,
title=xr_cfg.title,
# image_rtol=0.002,
)
)
# TODO a warning suggests setting this parameter, but it does not seem to help as expected, leaving here to fix in the future
# hv.config.image_rtol = 1.0

ds = hv.Dataset(mean)
surface = ds.to(hv.Surface)

if bench_cfg.repeats > 1:
std_dev = da.std("repeat")
surface *= (
hv.Dataset(mean + std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False)
)
surface *= (
hv.Dataset(mean - std_dev).to(hv.Surface).opts(alpha=alpha, colorbar=False)
)
return pn.Column(surface, name=PlotTypes.surface_hv)

return None
7 changes: 2 additions & 5 deletions bencher/plotting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import pandas as pd
import panel as pn
import plotly.express as px
import plotly.graph_objs as go
import seaborn as sns
from holoviews import opts

from bencher.bench_cfg import BenchCfg, PltCfgBase
from bencher.bench_vars import ParametrizedSweep, ResultList, ResultVar, ResultVec
import plotly.graph_objs as go


hv.extension("plotly")

Expand Down Expand Up @@ -252,12 +253,10 @@ def plot_surface_plotly(
bench_cfg: BenchCfg, rv: ParametrizedSweep, xr_cfg: PltCfgBase
) -> pn.pane.Plotly:
"""Given a benchCfg generate a 2D surface plot
Args:
bench_cfg (BenchCfg): description of benchmark
rv (ParametrizedSweep): result variable to plot
xr_cfg (PltCfgBase): config of x,y variables
Returns:
pn.pane.Plotly: A 2d surface plot as a holoview in a pane
"""
Expand Down Expand Up @@ -305,12 +304,10 @@ def plot_surface_holo(
bench_cfg: BenchCfg, rv: ParametrizedSweep, xr_cfg: PltCfgBase
) -> pn.pane.Plotly:
"""Given a benchCfg generate a 2D surface plot
Args:
bench_cfg (BenchCfg): description of benchmark
rv (ParametrizedSweep): result variable to plot
xr_cfg (PltCfgBase): config of x,y variables
Returns:
pn.pane.holoview: A 2d surface plot as a holoview in a pane
"""
Expand Down
14 changes: 7 additions & 7 deletions bencher/plt_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import panel as pn
import seaborn as sns
import xarray as xr

import bencher.plotting_functions as plt_func
from bencher.bench_cfg import BenchCfg, PltCfgBase, PltCntCfg, describe_benchmark
from bencher.bench_vars import ParametrizedSweep, ResultVar, ResultVec
from bencher.bench_vars import ParametrizedSweep, ResultVec, ResultVar
from bencher.optuna_conversions import collect_optuna_plots
import xarray as xr


class BenchPlotter:
Expand Down Expand Up @@ -68,6 +68,8 @@ def plot(bench_cfg: BenchCfg, main_tab=None, append_cols=None) -> pn.pane:

if append_cols is not None:
plot_cols.extend(append_cols)
# plot_cols.append(pn.Column(pn.Row()))#attempt to add spacer to stop overlapping but does not work todo

plot_cols.append(pn.pane.Markdown(f"{bench_cfg.post_description}"))

tabs = pn.Tabs(plot_cols, name=bench_cfg.title)
Expand Down Expand Up @@ -118,9 +120,9 @@ def plot_results_row(bench_cfg: BenchCfg) -> pn.Row:
Returns:
pn.Row: A panel row with plots in it
"""
plot_rows = pn.Row(
name=bench_cfg.bench_name, scroll=True, height=1000
) # todo remove the scroll and make it resize dynamically
# todo remove the scroll and make it resize dynamically
plot_rows = pn.Row(name=bench_cfg.bench_name, scroll=True)

plt_cnt_cfg = BenchPlotter.generate_plt_cnt_cfg(bench_cfg)

for rv in bench_cfg.result_vars:
Expand Down Expand Up @@ -322,11 +324,9 @@ def plot_float_cnt_1(sns_cfg: PltCfgBase, plt_cnt_cfg: PltCntCfg) -> PltCfgBase:
@staticmethod
def plot_float_cnt_2(plt_cnt_cfg: PltCntCfg, rv: ResultVar, debug: bool) -> PltCfgBase:
"""A function for determining the plot settings if there are 2 float variable and updates the PltCfgBase
Args:
sns_cfg (PltCfgBase): See PltCfgBase definition
plt_cnt_cfg (PltCntCfg): See PltCntCfg definition
Returns:
PltCfgBase: See PltCfgBase definition
"""
Expand Down

0 comments on commit ca71897

Please sign in to comment.