Skip to content

Commit

Permalink
Fixed bug computing categorical datashader aggregates (#2295)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored and jlstevens committed Feb 5, 2018
1 parent 4f29a5f commit 05f4545
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 13 deletions.
4 changes: 2 additions & 2 deletions holoviews/core/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ def irregular(cls, dataset, dim):
@classmethod
def shape(cls, dataset, gridded=False):
if gridded:
return dataset.data.shape
return dataset.data.shape[:2]
else:
return cls.length(dataset), len(dataset.dimensions())


@classmethod
def length(cls, dataset):
return np.product(dataset.data.shape)
return np.product(dataset.data.shape[:2])


@classmethod
Expand Down
3 changes: 2 additions & 1 deletion holoviews/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def deephash(obj):
if sys.version_info.major == 3:
basestring = str
unicode = str
long = int
generator_types = (zip, range, types.GeneratorType)
else:
basestring = basestring
Expand Down Expand Up @@ -1570,7 +1571,7 @@ def dt_to_int(value, time_unit='us'):
value = value.to_pydatetime()
elif isinstance(value, np.datetime64):
value = value.tolist()
if isinstance(value, int):
if isinstance(value, (int, long)):
# Handle special case of nanosecond precision which cannot be
# represented by python datetime
return value * 10**-(np.log10(tscale)-3)
Expand Down
27 changes: 19 additions & 8 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ def get_agg_data(cls, obj, category=None):
df = df.copy()
for d in (x, y):
if df[d.name].dtype.kind == 'M':
df[d.name] = df[d.name].astype('datetime64[ns]').astype('int64') * 10e-4

df[d.name] = df[d.name].astype('datetime64[ns]').astype('int64') * 1000.
return x, y, Dataset(df, kdims=kdims, vdims=vdims), glyph


Expand Down Expand Up @@ -374,9 +373,8 @@ def _process(self, element, key=None):
raise ValueError("Aggregation column %s not found on %s element. "
"Ensure the aggregator references an existing "
"dimension." % (column,element))
if isinstance(agg_fn, ds.count_cat):
name = '%s Count' % agg_fn.column
vdims = [dims[0](column)]
name = '%s Count' % column if isinstance(agg_fn, ds.count_cat) else column
vdims = [dims[0](name)]
else:
vdims = Dimension('Count')
params = dict(get_param_values(element), kdims=[x, y],
Expand All @@ -400,7 +398,7 @@ def _process(self, element, key=None):
for c in agg.coords[column].data:
cagg = agg.sel(**{column: c})
eldata = cagg if ds_version > '0.5.0' else (xs, ys, cagg.data)
layers[c] = self.p.element_type(eldata, **params)
layers[c] = self.p.element_type(eldata, **dict(params, vdims=vdims))
return NdOverlay(layers, kdims=[data.get_dimension(column)])


Expand Down Expand Up @@ -725,12 +723,12 @@ def concatenate(cls, overlay):
"""
if not isinstance(overlay, NdOverlay):
raise ValueError('Only NdOverlays can be concatenated')
xarr = xr.concat([v.data.T for v in overlay.values()],
xarr = xr.concat([v.data.transpose() for v in overlay.values()],
pd.Index(overlay.keys(), name=overlay.kdims[0].name))
params = dict(get_param_values(overlay.last),
vdims=overlay.last.vdims,
kdims=overlay.kdims+overlay.last.kdims)
return Dataset(xarr.T, datatype=['xarray'], **params)
return Dataset(xarr.transpose(), datatype=['xarray'], **params)


@classmethod
Expand All @@ -751,7 +749,20 @@ def rgb2hex(cls, rgb):
return "#{0:02x}{1:02x}{2:02x}".format(*(int(v*255) for v in rgb))


@classmethod
def to_xarray(cls, element):
if issubclass(element.interface, XArrayInterface):
return element
data = tuple(element.dimension_values(kd, expanded=False)
for kd in element.kdims)
data += tuple(element.dimension_values(vd, flat=False)
for vd in element.vdims)
dtypes = [dt for dt in element.datatype if dt != 'xarray']
return element.clone(data, datatype=['xarray']+dtypes)


def _process(self, element, key=None):
element = element.map(self.to_xarray, Image)
if isinstance(element, NdOverlay):
bounds = element.last.bounds
element = self.concatenate(element)
Expand Down
Empty file added tests/operation/__init__.py
Empty file.
63 changes: 61 additions & 2 deletions tests/testdatashader.py → tests/operation/testdatashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from nose.plugins.attrib import attr

import numpy as np
from holoviews import Curve, Points, Image, Dataset, RGB, Path, Graph, TriMesh, QuadMesh
from holoviews import (Dimension, Curve, Points, Image, Dataset, RGB, Path,
Graph, TriMesh, QuadMesh, NdOverlay)
from holoviews.element.comparison import ComparisonTestCase
from holoviews.core.util import pd

try:
import datashader as ds
from holoviews.operation.datashader import (
aggregate, regrid, ds_version, stack, directly_connect_edges, rasterize
aggregate, regrid, ds_version, stack, directly_connect_edges,
shade, rasterize
)
except:
ds_version = None
Expand Down Expand Up @@ -43,6 +46,17 @@ def test_aggregate_points_sampling(self):
x_sampling=0.5, y_sampling=0.5)
self.assertEqual(img, expected)

def test_aggregate_points_categorical(self):
points = Points([(0.2, 0.3, 'A'), (0.4, 0.7, 'B'), (0, 0.99, 'C')], vdims='z')
img = aggregate(points, dynamic=False, x_range=(0, 1), y_range=(0, 1),
width=2, height=2, aggregator=ds.count_cat('z'))
xs, ys = [0.25, 0.75], [0.25, 0.75]
expected = NdOverlay({'A': Image((xs, ys, [[1, 0], [0, 0]]), vdims='z Count'),
'B': Image((xs, ys, [[0, 0], [1, 0]]), vdims='z Count'),
'C': Image((xs, ys, [[0, 0], [1, 0]]), vdims='z Count')},
kdims=['z'])
self.assertEqual(img, expected)

def test_aggregate_curve(self):
curve = Curve([(0.2, 0.3), (0.4, 0.7), (0.8, 0.99)])
expected = Image(([0.25, 0.75], [0.25, 0.75], [[1, 0], [1, 1]]),
Expand All @@ -51,6 +65,18 @@ def test_aggregate_curve(self):
width=2, height=2)
self.assertEqual(img, expected)

def test_aggregate_curve_datetimes(self):
dates = pd.date_range(start="2016-01-01", end="2016-01-03", freq='1D')
curve = Curve((dates, [1, 2, 3]))
img = aggregate(curve, width=2, height=2, dynamic=False)
bounds = (np.datetime64('2015-12-31T23:59:59.723518000'), 1.0,
np.datetime64('2016-01-03T00:00:00.276482000'), 3.0)
dates = [np.datetime64('2016-01-01T12:00:00.000000000'),
np.datetime64('2016-01-02T12:00:00.000000000')]
expected = Image((dates, [1.5, 2.5], [[1, 0], [0, 2]]),
datatype=['xarray'], bounds=bounds, vdims='Count')
self.assertEqual(img, expected)

def test_aggregate_ndoverlay(self):
ds = Dataset([(0.2, 0.3, 0), (0.4, 0.7, 1), (0, 0.99, 2)], kdims=['x', 'y', 'z'])
ndoverlay = ds.to(Points, ['x', 'y'], [], 'z').overlay()
Expand All @@ -77,6 +103,39 @@ def test_aggregate_dframe_nan_path(self):
self.assertEqual(img, expected)


@attr(optional=1)
class DatashaderShadeTests(ComparisonTestCase):

def test_shade_categorical_images_xarray(self):
xs, ys = [0.25, 0.75], [0.25, 0.75]
data = NdOverlay({'A': Image((xs, ys, [[1, 0], [0, 0]]), datatype=['xarray'], vdims='z Count'),
'B': Image((xs, ys, [[0, 0], [1, 0]]), datatype=['xarray'], vdims='z Count'),
'C': Image((xs, ys, [[0, 0], [1, 0]]), datatype=['xarray'], vdims='z Count')},
kdims=['z'])
shaded = shade(data)
r = [[228, 255], [66, 255]]
g = [[26, 255], [150, 255]]
b = [[28, 255], [129, 255]]
a = [[40, 0], [255, 0]]
expected = RGB((xs, ys, r, g, b, a), datatype=['grid'],
vdims=RGB.vdims+[Dimension('A', range=(0, 1))])
self.assertEqual(shaded, expected)

def test_shade_categorical_images_grid(self):
xs, ys = [0.25, 0.75], [0.25, 0.75]
data = NdOverlay({'A': Image((xs, ys, [[1, 0], [0, 0]]), datatype=['grid'], vdims='z Count'),
'B': Image((xs, ys, [[0, 0], [1, 0]]), datatype=['grid'], vdims='z Count'),
'C': Image((xs, ys, [[0, 0], [1, 0]]), datatype=['grid'], vdims='z Count')},
kdims=['z'])
shaded = shade(data)
r = [[228, 255], [66, 255]]
g = [[26, 255], [150, 255]]
b = [[28, 255], [129, 255]]
a = [[40, 0], [255, 0]]
expected = RGB((xs, ys, r, g, b, a), datatype=['grid'],
vdims=RGB.vdims+[Dimension('A', range=(0, 1))])
self.assertEqual(shaded, expected)



@attr(optional=1)
Expand Down

0 comments on commit 05f4545

Please sign in to comment.