From dadec84e03e70157f6a19ef7e10759d14d7c8229 Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Wed, 2 Oct 2024 16:42:29 -0700 Subject: [PATCH] Support datashade hover --- hvplot/converter.py | 79 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 75 insertions(+), 4 deletions(-) diff --git a/hvplot/converter.py b/hvplot/converter.py index 0f90fc800..3bb31c175 100644 --- a/hvplot/converter.py +++ b/hvplot/converter.py @@ -45,7 +45,7 @@ from holoviews.plotting.bokeh import OverlayPlot, colormap_generator from holoviews.plotting.util import process_cmap from holoviews.operation import histogram, apply_when -from holoviews.streams import Buffer, Pipe +from holoviews.streams import Buffer, Pipe, PointerXY from holoviews.util.transform import dim, lon_lat_to_easting_northing from pandas import DatetimeIndex, MultiIndex @@ -842,13 +842,22 @@ def __init__( if kind == 'errorbars': hover = False elif hover is None: - hover = not self.datashade + hover = True + if hover and not any( t for t in tools if isinstance(t, HoverTool) or t in ['hover', 'vline', 'hline'] ): if hover in {'vline', 'hline'}: plot_opts['hover_mode'] = hover - tools.append('hover') + self.hover_mode = hover + else: + self.hover_mode = 'mouse' + if not self.datashade: + tools.append('hover') + + self.hover = bool(hover) + self.hover_tooltips = hover_tooltips + self.hover_formatters = hover_formatters if 'hover' in tools: if hover_tooltips: plot_opts['hover_tooltips'] = hover_tooltips @@ -1760,7 +1769,7 @@ def method_wrapper(ds, x, y): return layers import_datashader() - from holoviews.operation.datashader import datashade, rasterize, dynspread + from holoviews.operation.datashader import datashade, rasterize, dynspread, inspect_points categorical, agg = self._process_categorical_datashader() if agg: @@ -1819,11 +1828,73 @@ def method_wrapper(ds, x, y): threshold=self.kwds.get('threshold', 0.5), ) + # a workaround to show hover info for datashaded points + if self.hover and self.datashade and self.kind == 'points': + inspector = inspect_points.instance( + streams=[PointerXY], transform=self._datashade_hover_transform + ) + processed *= inspector(processed).opts( + size=10, + alpha=0, + tools=['hover'], + hover_mode=self.hover_mode, + hover_tooltips=self.hover_tooltips, + hover_formatters=self.hover_formatters, + ) + opts = filter_opts(eltype, dict(self._plot_opts, **style), backend='bokeh') layers = self._apply_layers(processed).opts(eltype, **opts, backend='bokeh') layers = _transfer_opts_cur_backend(layers) return layers + def _datashade_hover_transform(self, df): + if not len(df): + return df + + # show at least the x and y columns + cols = self.hover_cols.copy() + if self.x not in cols: + cols.append(self.x) + if self.y not in cols: + cols.append(self.y) + + # handle aggregator, e.g. ds.sum('column') or ds.count_cat('column') + agg_col = None + agg_series_map = {} + if self.aggregator and not isinstance(self.aggregator, str) and self.aggregator.column: + agg_col = self.aggregator.column + agg_op = type(self.aggregator).__name__ + if hasattr(df, agg_op): # df.sum/df.count + agg_value = df.agg({agg_col: agg_op}) + elif agg_op == 'count_cat': + agg_value = df[agg_col].value_counts() + + if agg_col in cols: + cols.remove(agg_col) + + # take the mean of numeric columns + num_series = df[cols].select_dtypes(include=['number']).mean() + if len(num_series): + agg_series_map['number_cols'] = num_series + + # take the first value of object columns + obj_series = df[cols].select_dtypes(exclude=['number']).iloc[0] + if len(obj_series): + agg_series_map['object_cols'] = obj_series + + # to preserve order of other columns, add this last + if agg_col: + agg_series_map[agg_col] = agg_value + + # concat all series into a single dataframe which has one row + df_hover = pd.concat(agg_series_map.values()).to_frame().transpose() + + # remove index if it wasn't in the original dataset + if 'index' not in self.data.columns: + df_hover = df_hover.drop(columns=['index'], errors='ignore') + + return df_hover + def _resample_obj(self, operation, obj, opts): def exceeds_resample_when(plot): return len(plot) > self.resample_when