diff --git a/holoviews/core/data/grid.py b/holoviews/core/data/grid.py index d7d9f5a25b..f61b0aa087 100644 --- a/holoviews/core/data/grid.py +++ b/holoviews/core/data/grid.py @@ -188,7 +188,7 @@ def groupby(cls, dataset, dim_names, container_type, group_type, **kwargs): group_kwargs['kdims'] = kdims group_kwargs.update(kwargs) - drop_dim = len(group_kwargs['kdims']) != len(kdims) + drop_dim = any(d not in group_kwargs['kdims'] for d in kdims) # Find all the keys along supplied dimensions keys = [dataset.data[d.name] for d in dimensions] @@ -206,7 +206,7 @@ def groupby(cls, dataset, dim_names, container_type, group_type, **kwargs): group_data = {dataset.vdims[0].name: np.atleast_1d(group_data)} for dim, v in zip(dim_names, unique_key): group_data[dim] = np.atleast_1d(v) - else: + elif not drop_dim: for vdim in dataset.vdims: group_data[vdim.name] = np.squeeze(group_data[vdim.name]) group_data = group_type(group_data, **group_kwargs) diff --git a/holoviews/core/data/iris.py b/holoviews/core/data/iris.py index 1a0f409789..376b12dde2 100644 --- a/holoviews/core/data/iris.py +++ b/holoviews/core/data/iris.py @@ -117,7 +117,8 @@ def init(cls, eltype, data, kdims, vdims): @classmethod def validate(cls, dataset): - pass + if len(dataset.vdims) > 1: + raise ValueError("Iris cubes do not support more than one value dimension") @classmethod @@ -187,7 +188,7 @@ def groupby(cls, dataset, dims, container_type=HoloMap, group_type=None, **kwarg group_kwargs['kdims'] = slice_dims group_kwargs.update(kwargs) - drop_dim = len(group_kwargs['kdims']) != len(slice_dims) + drop_dim = any(d not in group_kwargs['kdims'] for d in slice_dims) unique_coords = product(*[cls.values(dataset, d, expanded=False) for d in dims]) diff --git a/holoviews/core/data/xarray.py b/holoviews/core/data/xarray.py index 4dbba497c3..90f5001e06 100644 --- a/holoviews/core/data/xarray.py +++ b/holoviews/core/data/xarray.py @@ -105,7 +105,7 @@ def groupby(cls, dataset, dimensions, container_type, group_type, **kwargs): kdims=element_dims) group_kwargs.update(kwargs) - drop_dim = len(group_kwargs['kdims']) != len(element_dims) + drop_dim = any(d not in group_kwargs['kdims'] for d in element_dims) # XArray 0.7.2 does not support multi-dimensional groupby # Replace custom implementation when diff --git a/tests/testdataset.py b/tests/testdataset.py index 48e902041f..c5ea095d89 100644 --- a/tests/testdataset.py +++ b/tests/testdataset.py @@ -877,6 +877,22 @@ def test_dataset_groupby_drop_dims_dynamic(self): partial = ds.to(Dataset, kdims=['x'], vdims=['Val'], groupby='y', dynamic=True) self.assertEqual(partial[19]['Val'], array[:, -1, :].T.flatten()) + def test_dataset_groupby_drop_dims_with_vdim(self): + array = np.random.rand(3, 20, 10) + ds = Dataset({'x': range(10), 'y': range(20), 'z': range(3), 'Val': array, 'Val2': array*2}, + kdims=['x', 'y', 'z'], vdims=['Val', 'Val2']) + with DatatypeContext([self.datatype, 'columns', 'dataframe']): + partial = ds.to(Dataset, kdims=['Val'], vdims=['Val2'], groupby='y') + self.assertEqual(partial.last['Val'], array[:, -1, :].T.flatten()) + + def test_dataset_groupby_drop_dims_dynamic_with_vdim(self): + array = np.random.rand(3, 20, 10) + ds = Dataset({'x': range(10), 'y': range(20), 'z': range(3), 'Val': array, 'Val2': array*2}, + kdims=['x', 'y', 'z'], vdims=['Val', 'Val2']) + with DatatypeContext([self.datatype, 'columns', 'dataframe']): + partial = ds.to(Dataset, kdims=['Val'], vdims=['Val2'], groupby='y', dynamic=True) + self.assertEqual(partial[19]['Val'], array[:, -1, :].T.flatten()) + class IrisDatasetTest(GridDatasetTest): """ @@ -929,6 +945,12 @@ def test_dataset_sample_hm(self): def test_dataset_sample_hm_alias(self): raise SkipTest("Not supported") + def test_dataset_groupby_drop_dims_with_vdim(self): + raise SkipTest("Not supported") + + def test_dataset_groupby_drop_dims_dynamic_with_vdim(self): + raise SkipTest("Not supported") + class XArrayDatasetTest(GridDatasetTest): """