Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ImageStack #693

Merged
merged 10 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion geoviews/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .annotators import annotate # noqa (API import)
from .element import ( # noqa (API import)
_Element, Feature, Tiles, WMTS, LineContours, FilledContours,
Text, Image, Points, Path, Polygons, Shape, Dataset, RGB,
Text, Image, ImageStack, Points, Path, Polygons, Shape, Dataset, RGB,
Contours, Graph, TriMesh, Nodes, EdgePaths, QuadMesh, VectorField,
HexTiles, Labels, Rectangles, Segments, WindBarbs
)
Expand Down
5 changes: 4 additions & 1 deletion geoviews/element/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
)

from .geo import (_Element, Feature, Tiles, is_geographic, # noqa (API import)
WMTS, Points, Image, Text, LineContours, RGB,
WMTS, Points, Image, ImageStack, Text, LineContours, RGB,
FilledContours, Path, Polygons, Shape, Dataset,
Contours, TriMesh, Graph, Nodes, EdgePaths, QuadMesh,
VectorField, Labels, HexTiles, Rectangles, Segments, WindBarbs)
Expand Down Expand Up @@ -51,6 +51,9 @@ def filledcontours(self, kdims=None, vdims=None, mdims=None, **kwargs):
def image(self, kdims=None, vdims=None, mdims=None, **kwargs):
return self(Image, kdims, vdims, mdims, **kwargs)

def image_stack(self, kdims=None, vdims=None, mdims=None, **kwargs):
return self(ImageStack, kdims, vdims, mdims, **kwargs)

def points(self, kdims=None, vdims=None, mdims=None, **kwargs):
if kdims is None: kdims = self._element.kdims
el_type = Points if is_geographic(self._element, kdims) else HvPoints
Expand Down
3 changes: 2 additions & 1 deletion geoviews/element/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from holoviews.element.comparison import Comparison as HvComparison

from .geo import Image, Points, LineContours, FilledContours, WindBarbs
from .geo import Image, ImageStack, Points, LineContours, FilledContours, WindBarbs

class Comparison(HvComparison):

@classmethod
def register(cls):
super().register()
cls.equality_type_funcs[Image] = cls.compare_dataset
cls.equality_type_funcs[ImageStack] = cls.compare_dataset
cls.equality_type_funcs[Points] = cls.compare_dataset
cls.equality_type_funcs[LineContours] = cls.compare_dataset
cls.equality_type_funcs[FilledContours] = cls.compare_dataset
Expand Down
46 changes: 45 additions & 1 deletion geoviews/element/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
Labels as HvLabels, Rectangles as HvRectangles,
Segments as HvSegments, Geometry as HvGeometry,
)
from holoviews import __version__ as _hv_version
try:
from holoviews import ImageStack as HvImageStack
except ImportError:
class HvImageStack:
# will check version below
pass
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
from holoviews.element.selection import Selection2DExpr


from shapely.geometry.base import BaseGeometry
from shapely.geometry import (
box, GeometryCollection, MultiPolygon, LineString, MultiLineString,
Expand Down Expand Up @@ -401,6 +407,44 @@ def from_xarray(cls, da, crs=None, apply_transform=False,
return from_xarray(da, crs, apply_transform, **kwargs)


class ImageStack(_Element, HvImageStack):
"""
ImageStack expands the capabilities of Image to by supporting
multiple layers of images.

As there is many ways to represent multiple layers of images,
the following options are supported:

1) A 3D Numpy array with the shape (y, x, level)
2) A list of 2D Numpy arrays with identical shape (y, x)
3) A dictionary where the keys will be set as the vdims and the
values are 2D Numpy arrays with identical shapes (y, x).
If the dictionary's keys matches the kdims of the element,
they need to be 1D arrays.
4) A tuple containing (x, y, level_0, level_1, ...),
where the level is a 2D Numpy array in the shape of (y, x).
5) An xarray DataArray or Dataset where its `coords` contain the kdims.

If no kdims are supplied, x and y are used.

If no vdims are supplied, and the naming can be inferred like with a dictionary
the levels will be named level_0, level_1, etc.
"""

def __init__(self, data, kdims=None, vdims=None, **params):
if _hv_version < '1.18':
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
raise ImportError('ImageStack requires HoloViews 1.18 or greater.')
super().__init__(data, kdims=kdims, vdims=vdims, **params)

vdims = param.List(doc="""
The dimension description of the data held in the matrix.""")

group = param.String(default='ImageStack', constant=True)

_ndim = 3

_vdim_reductions = {1: Image}


class QuadMesh(_Element, HvQuadMesh):
"""
Expand Down
15 changes: 13 additions & 2 deletions geoviews/plotting/bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from holoviews.plotting.bokeh.hex_tiles import hex_binning, HexTilesPlot
from holoviews.plotting.bokeh.path import PolygonPlot, PathPlot, ContourPlot
from holoviews.plotting.bokeh.raster import RasterPlot, RGBPlot, QuadMeshPlot

from ...element import (
WMTS, Points, Polygons, Path, Contours, Shape, Image, Feature,
WMTS, Points, Polygons, Path, Contours, Shape, Image, ImageStack, Feature,
Text, RGB, Nodes, EdgePaths, Graph, TriMesh, QuadMesh, VectorField,
Labels, HexTiles, LineContours, FilledContours, Rectangles, Segments
)
Expand All @@ -29,7 +28,13 @@
from ...util import poly_types, line_types
from .plot import GeoPlot, GeoOverlayPlot
from . import callbacks # noqa
try:
from holoviews.plotting.bokeh.raster import ImageStackPlot
except ImportError:
class ImageStackPlot:

def __init__(self, *args, **kwargs):
raise ImportError('ImageStackPlot requires HoloViews>=1.18.0')
hoxbro marked this conversation as resolved.
Show resolved Hide resolved

class TilePlot(GeoPlot):

Expand Down Expand Up @@ -140,6 +145,11 @@ class GeoRGBPlot(GeoPlot, RGBPlot):
_project_operation = project_image.instance(fast=False)


class GeoImageStackPlot(GeoPlot, ImageStackPlot):

_project_operation = project_image.instance(fast=False)


class GeoPolygonPlot(GeoPlot, PolygonPlot):

_project_operation = project_path
Expand Down Expand Up @@ -293,6 +303,7 @@ def _process(self, element, key=None):
Path: GeoPathPlot,
Shape: GeoShapePlot,
Image: GeoRasterPlot,
ImageStack: GeoImageStackPlot,
RGB: GeoRGBPlot,
LineContours: LineContourPlot,
FilledContours: FilledContourPlot,
Expand Down
8 changes: 4 additions & 4 deletions geoviews/plotting/mpl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


from ...element import (
Image, Points, Feature, WMTS, Tiles, Text, LineContours,
Image, ImageStack, Points, Feature, WMTS, Tiles, Text, LineContours,
FilledContours, is_geographic, Path, Polygons, Shape, RGB,
Contours, Nodes, EdgePaths, Graph, TriMesh, QuadMesh, VectorField,
HexTiles, Labels, Rectangles, Segments, WindBarbs
Expand Down Expand Up @@ -246,7 +246,6 @@ def update_handles(self, *args):
return GeoPlot.update_handles(self, *args)



class GeoQuadMeshPlot(GeoPlot, QuadMeshPlot):

_project_operation = project_quadmesh
Expand All @@ -268,9 +267,10 @@ class GeoRGBPlot(GeoImagePlot):
def get_data(self, element, ranges, style):
self._norm_kwargs(element, ranges, style, element.vdims[0])
style.pop('interpolation', None)
zs = get_raster_array(element)[::-1]
zs = get_raster_array(element)
hoxbro marked this conversation as resolved.
Show resolved Hide resolved
l, b, r, t = element.bounds.lbrt()
style['extent'] = [l, r, b, t]
style['origin'] = 'upper'
if self.geographic:
style['transform'] = element.crs
return (zs,), style, {}
Expand Down Expand Up @@ -587,6 +587,7 @@ def draw_annotation(self, axis, data, crs, opts):
Path: GeoPathPlot,
Contours: GeoContourPlot,
RGB: GeoRGBPlot,
ImageStack: GeoRGBPlot,
Shape: GeoShapePlot,
Graph: GeoGraphPlot,
TriMesh: GeoTriMeshPlot,
Expand All @@ -595,7 +596,6 @@ def draw_annotation(self, axis, data, crs, opts):
HexTiles: GeoHexTilesPlot,
QuadMesh: GeoQuadMeshPlot}, 'matplotlib')


# Define plot and style options
options = Store.options(backend='matplotlib')

Expand Down
24 changes: 24 additions & 0 deletions geoviews/tests/plotting/bokeh/test_bokeh_chart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

import numpy as np
import geoviews as gv

class TestImageStackPlot:

def test_image_stack_crs(self):
x = np.arange(-120, -115)
y = np.arange(40, 43)
a = np.random.rand(len(y), len(x))
b = np.random.rand(len(y), len(x))

img_stack = gv.ImageStack(
(x, y, a, b), kdims=["x", "y"], vdims=["a", "b"],
)
data = img_stack.data
np.testing.assert_almost_equal(data["x"], x)
np.testing.assert_almost_equal(data["y"], y)
np.testing.assert_almost_equal(data["a"], a)
np.testing.assert_almost_equal(data["b"], b)

fig = gv.render(img_stack, backend="bokeh")
assert fig.x_range
assert fig.y_range
23 changes: 23 additions & 0 deletions geoviews/tests/plotting/mpl/test_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,26 @@
"'flagcolor' and 'barbcolor'; ignoring 'flagcolor' and 'barbcolor'.\n"
)
self.assertEqual(log_msg, warning)


class TestImageStackPlot(TestMPLPlot):

def test_image_stack_crs(self):
x = np.arange(-120, -115)
y = np.arange(40, 43)
a = np.random.rand(len(y), len(x))
b = np.random.rand(len(y), len(x))

img_stack = gv.ImageStack(
(x, y, a, b), kdims=["x", "y"], vdims=["a", "b"],
)
data = img_stack.data
np.testing.assert_almost_equal(data["x"], x)
np.testing.assert_almost_equal(data["y"], y)
np.testing.assert_almost_equal(data["a"], a)
np.testing.assert_almost_equal(data["b"], b)

fig = gv.render(img_stack)

Check failure on line 217 in geoviews/tests/plotting/mpl/test_chart.py

View workflow job for this annotation

GitHub Actions / Core tests on Python 3.12, ubuntu-latest

TestImageStackPlot.test_image_stack_crs ImportError: Flattening ImageStacks requires datashader.
hoxbro marked this conversation as resolved.
Show resolved Hide resolved
mpl_img = fig.axes[0].get_children()[0]
np.testing.assert_almost_equal(mpl_img.get_extent(), (-120.5, -115.5, 39.5, 42.5))
assert np.sum(mpl_img.get_array()) > 0
2 changes: 2 additions & 0 deletions geoviews/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,8 @@ def from_xarray(da, crs=None, apply_transform=False, nan_nodata=False, **kwargs)
from .element.geo import RGB, HvRGB
el = RGB if 'crs' in kwargs else HvRGB
vdims = el.vdims[:bands]
if bands == 4:
vdims.append("A")
el = el(data, [x, y], vdims, **kwargs)
if hasattr(el.data, 'attrs'):
el.data.attrs = da.attrs
Expand Down
Loading