Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ImageStack #693

Merged
merged 10 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion geoviews/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .annotators import annotate # noqa (API import)
from .element import ( # noqa (API import)
_Element, Feature, Tiles, WMTS, LineContours, FilledContours,
Text, Image, Points, Path, Polygons, Shape, Dataset, RGB,
Text, Image, ImageStack, Points, Path, Polygons, Shape, Dataset, RGB,
Contours, Graph, TriMesh, Nodes, EdgePaths, QuadMesh, VectorField,
HexTiles, Labels, Rectangles, Segments, WindBarbs
)
Expand Down
5 changes: 4 additions & 1 deletion geoviews/element/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
)

from .geo import (_Element, Feature, Tiles, is_geographic, # noqa (API import)
WMTS, Points, Image, Text, LineContours, RGB,
WMTS, Points, Image, ImageStack, Text, LineContours, RGB,
FilledContours, Path, Polygons, Shape, Dataset,
Contours, TriMesh, Graph, Nodes, EdgePaths, QuadMesh,
VectorField, Labels, HexTiles, Rectangles, Segments, WindBarbs)
Expand Down Expand Up @@ -51,6 +51,9 @@ def filledcontours(self, kdims=None, vdims=None, mdims=None, **kwargs):
def image(self, kdims=None, vdims=None, mdims=None, **kwargs):
return self(Image, kdims, vdims, mdims, **kwargs)

def image_stack(self, kdims=None, vdims=None, mdims=None, **kwargs):
return self(ImageStack, kdims, vdims, mdims, **kwargs)

def points(self, kdims=None, vdims=None, mdims=None, **kwargs):
if kdims is None: kdims = self._element.kdims
el_type = Points if is_geographic(self._element, kdims) else HvPoints
Expand Down
3 changes: 2 additions & 1 deletion geoviews/element/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from holoviews.element.comparison import Comparison as HvComparison

from .geo import Image, Points, LineContours, FilledContours, WindBarbs
from .geo import Image, ImageStack, Points, LineContours, FilledContours, WindBarbs

class Comparison(HvComparison):

@classmethod
def register(cls):
super().register()
cls.equality_type_funcs[Image] = cls.compare_dataset
cls.equality_type_funcs[ImageStack] = cls.compare_dataset
cls.equality_type_funcs[Points] = cls.compare_dataset
cls.equality_type_funcs[LineContours] = cls.compare_dataset
cls.equality_type_funcs[FilledContours] = cls.compare_dataset
Expand Down
39 changes: 36 additions & 3 deletions geoviews/element/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from holoviews.core import util
from holoviews.element import (
Contours as HvContours, Graph as HvGraph, Image as HvImage,
Nodes as HvNodes, Path as HvPath, Polygons as HvPolygons,
RGB as HvRGB, Text as HvText, TriMesh as HvTriMesh,
QuadMesh as HvQuadMesh, Points as HvPoints,
ImageStack as HvImageStack, Nodes as HvNodes, Path as HvPath,
hoxbro marked this conversation as resolved.
Show resolved Hide resolved
Polygons as HvPolygons, RGB as HvRGB, Text as HvText,
TriMesh as HvTriMesh, QuadMesh as HvQuadMesh, Points as HvPoints,
VectorField as HvVectorField, HexTiles as HvHexTiles,
Labels as HvLabels, Rectangles as HvRectangles,
Segments as HvSegments, Geometry as HvGeometry,
Expand Down Expand Up @@ -401,6 +401,39 @@ def from_xarray(cls, da, crs=None, apply_transform=False,
return from_xarray(da, crs, apply_transform, **kwargs)


class ImageStack(_Element, HvImageStack):
"""
ImageStack expands the capabilities of Image to by supporting
multiple layers of images.

As there is many ways to represent multiple layers of images,
the following options are supported:

1) A 3D Numpy array with the shape (y, x, level)
2) A list of 2D Numpy arrays with identical shape (y, x)
3) A dictionary where the keys will be set as the vdims and the
values are 2D Numpy arrays with identical shapes (y, x).
If the dictionary's keys matches the kdims of the element,
they need to be 1D arrays.
4) A tuple containing (x, y, level_0, level_1, ...),
where the level is a 2D Numpy array in the shape of (y, x).
5) An xarray DataArray or Dataset where its `coords` contain the kdims.

If no kdims are supplied, x and y are used.

If no vdims are supplied, and the naming can be inferred like with a dictionary
the levels will be named level_0, level_1, etc.
"""

vdims = param.List(doc="""
The dimension description of the data held in the matrix.""")

group = param.String(default='ImageStack', constant=True)

_ndim = 3

_vdim_reductions = {1: Image}


class QuadMesh(_Element, HvQuadMesh):
"""
Expand Down
10 changes: 8 additions & 2 deletions geoviews/plotting/bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from holoviews.plotting.bokeh.graphs import TriMeshPlot, GraphPlot
from holoviews.plotting.bokeh.hex_tiles import hex_binning, HexTilesPlot
from holoviews.plotting.bokeh.path import PolygonPlot, PathPlot, ContourPlot
from holoviews.plotting.bokeh.raster import RasterPlot, RGBPlot, QuadMeshPlot
from holoviews.plotting.bokeh.raster import RasterPlot, RGBPlot, QuadMeshPlot, ImageStackPlot

from ...element import (
WMTS, Points, Polygons, Path, Contours, Shape, Image, Feature,
WMTS, Points, Polygons, Path, Contours, Shape, Image, ImageStack, Feature,
Text, RGB, Nodes, EdgePaths, Graph, TriMesh, QuadMesh, VectorField,
Labels, HexTiles, LineContours, FilledContours, Rectangles, Segments
)
Expand Down Expand Up @@ -140,6 +140,11 @@ class GeoRGBPlot(GeoPlot, RGBPlot):
_project_operation = project_image.instance(fast=False)


class GeoImageStackPlot(GeoPlot, ImageStackPlot):

_project_operation = project_image.instance(fast=False)


class GeoPolygonPlot(GeoPlot, PolygonPlot):

_project_operation = project_path
Expand Down Expand Up @@ -293,6 +298,7 @@ def _process(self, element, key=None):
Path: GeoPathPlot,
Shape: GeoShapePlot,
Image: GeoRasterPlot,
ImageStack: GeoImageStackPlot,
RGB: GeoRGBPlot,
LineContours: LineContourPlot,
FilledContours: FilledContourPlot,
Expand Down
22 changes: 20 additions & 2 deletions geoviews/plotting/mpl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


from ...element import (
Image, Points, Feature, WMTS, Tiles, Text, LineContours,
Image, ImageStack, Points, Feature, WMTS, Tiles, Text, LineContours,
FilledContours, is_geographic, Path, Polygons, Shape, RGB,
Contours, Nodes, EdgePaths, Graph, TriMesh, QuadMesh, VectorField,
HexTiles, Labels, Rectangles, Segments, WindBarbs
Expand Down Expand Up @@ -246,7 +246,6 @@ def update_handles(self, *args):
return GeoPlot.update_handles(self, *args)



class GeoQuadMeshPlot(GeoPlot, QuadMeshPlot):

_project_operation = project_quadmesh
Expand Down Expand Up @@ -288,6 +287,24 @@ def update_handles(self, *args):
return GeoPlot.update_handles(self, *args)


class GeoImageStackPlot(GeoImagePlot):

style_opts = ['alpha', 'cmap', 'visible', 'filterrad', 'clims', 'norm']

def __init__(self, element, **params):
super().__init__(element, **params)

def get_data(self, element, ranges, style):
self._norm_kwargs(element, ranges, style, element.vdims[0])
style.pop('interpolation', None)
xs, ys, zs = geo_mesh(element)
xs = GridInterface._infer_interval_breaks(xs)
ys = GridInterface._infer_interval_breaks(ys)
if self.geographic:
style['transform'] = element.crs
return (xs, ys, zs), style, {}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea why MPL ImageStack is not going through this code; if I add print/raise, it doesn't do anything.



class GeoPointPlot(GeoPlot, PointPlot):
"""
Draws a scatter plot from the data in a Points Element.
Expand Down Expand Up @@ -587,6 +604,7 @@ def draw_annotation(self, axis, data, crs, opts):
Path: GeoPathPlot,
Contours: GeoContourPlot,
RGB: GeoRGBPlot,
ImageStack: GeoImageStackPlot,
Shape: GeoShapePlot,
Graph: GeoGraphPlot,
TriMesh: GeoTriMeshPlot,
Expand Down
2 changes: 2 additions & 0 deletions geoviews/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,8 @@ def from_xarray(da, crs=None, apply_transform=False, nan_nodata=False, **kwargs)
from .element.geo import RGB, HvRGB
el = RGB if 'crs' in kwargs else HvRGB
vdims = el.vdims[:bands]
if bands == 4:
vdims.append("A")
el = el(data, [x, y], vdims, **kwargs)
if hasattr(el.data, 'attrs'):
el.data.attrs = da.attrs
Expand Down
Loading