Skip to content

Commit

Permalink
Added support for RGB plots (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Jan 6, 2019
1 parent c900550 commit 55c643a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
22 changes: 22 additions & 0 deletions hvplot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,28 @@ def image(self, x=None, y=None, z=None, colorbar=True, **kwds):
"""
return self(x, y, z=z, kind='image', colorbar=colorbar, **kwds)

def rgb(self, x=None, y=None, z=None, bands=None, **kwds):
"""
RGB plot
Parameters
----------
x, y : string, optional
The coordinate variable along the x- and y-axis
bands : string, optional
The coordinate variable to draw the RGB channels from
z : string, optional
The data variable to plot
**kwds : optional
Keyword arguments to pass on to
:py:meth:`hvplot.converter.HoloViewsConverter`.
Returns
-------
obj : HoloViews object
The HoloViews representation of the plot.
"""
return self(x, y, z=z, bands=bands, kind='rgb', **kwds)

def quadmesh(self, x=None, y=None, z=None, colorbar=True, **kwds):
"""
QuadMesh plot
Expand Down
43 changes: 38 additions & 5 deletions hvplot/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from holoviews.element import (
Curve, Scatter, Area, Bars, BoxWhisker, Dataset, Distribution,
Table, HeatMap, Image, HexTiles, QuadMesh, Bivariate, Histogram,
Violin, Contours, Polygons, Points, Path, Labels
Violin, Contours, Polygons, Points, Path, Labels, RGB
)
from holoviews.plotting.util import process_cmap
from holoviews.operation import histogram
Expand Down Expand Up @@ -147,7 +147,7 @@ class HoloViewsConverter(object):
Declares a minimum sampling density beyond.
"""

_gridded_types = ['image', 'contour', 'contourf', 'quadmesh']
_gridded_types = ['image', 'contour', 'contourf', 'quadmesh', 'rgb']

_geom_types = ['paths', 'polygons']

Expand Down Expand Up @@ -178,6 +178,7 @@ class HoloViewsConverter(object):
'dataset' : ['columns'],
'table' : ['columns'],
'image' : ['z', 'logz'],
'rgb' : ['z', 'bands'],
'quadmesh' : ['z', 'logz'],
'contour' : ['z', 'levels', 'logz'],
'contourf' : ['z', 'levels', 'logz'],
Expand All @@ -193,7 +194,7 @@ class HoloViewsConverter(object):
'kde': Distribution, 'area': Area, 'box': BoxWhisker, 'violin': Violin,
'bar': Bars, 'barh': Bars, 'contour': Contours, 'contourf': Polygons,
'points': Points, 'polygons': Polygons, 'paths': Path, 'step': Curve,
'labels': Labels
'labels': Labels, 'rgb': RGB
}

_colorbar_types = ['image', 'hexbin', 'heatmap', 'quadmesh', 'bivariate',
Expand Down Expand Up @@ -452,7 +453,10 @@ def _process_data(self, kind, data, x, y, by, groupby, row, col,
dims = [c for c in data.coords if data[c].shape != ()
and c not in ignore]
if kind is None and (not (x or y) or all(c in data.coords for c in (x, y))):
if len(dims) == 1:
if list(data.coords) == ['band', 'y', 'x']:
kind = 'rgb'
gridded = True
elif len(dims) == 1:
kind = 'line'
elif len(dims) == 2 or (x and y):
kind = 'image'
Expand All @@ -468,7 +472,7 @@ def _process_data(self, kind, data, x, y, by, groupby, row, col,
use_dask, persist, gridded,
label, value_label)

if kind not in self._stats_types:
if kind not in self._stats_types and kind != 'rgb':
if by is None: by = by_new
if groupby is None: groupby = groupby_new

Expand Down Expand Up @@ -1106,6 +1110,35 @@ def image(self, x=None, y=None, z=None, data=None):
if self.geo: params['crs'] = self.crs
return element(data, [x, y], z, **params).redim(**self._redim).redim.range(**ranges).opts(**opts)

def rgb(self, x=None, y=None, bands=None, data=None):
data = self.data if data is None else data

coords = list(data.coords)
if len(coords) < 3:
raise ValueError('Data must be 3D array to be converted to RGB.')
x = x or coords[2]
y = y or coords[1]
bands = bands or coords[0]
z = self.kwds.get('z', list(data.data_vars)[0])
nbands = len(data.coords[bands])
if nbands < 3:
raise ValueError('Selected bands coordinate (%s) has only %d channels,'
'expected at least three channels to convert to RGB.' %
(bands, nbands))
data = data[z]

params = dict(self._relabel)
opts = dict(plot=self._plot_opts, style=self._style_opts, norm=self._norm_opts)
if self.geo: params['crs'] = self.crs
xres, yres = data.attrs['res'] if 'res' in data.attrs else (1, 1)
xs = data.coords[x][::-1] if xres < 0 else data.coords[x]
ys = data.coords[y][::-1] if yres < 0 else data.coords[y]
eldata = (xs, ys)
for b in range(nbands):
eldata += (data[b].values,)
rgb = RGB(eldata, [x, y], RGB.vdims[:nbands], **params)
return rgb.redim(**self._redim).opts(**opts)

def quadmesh(self, x=None, y=None, z=None, data=None):
import xarray as xr
data = self.data if data is None else data
Expand Down

0 comments on commit 55c643a

Please sign in to comment.