Skip to content

Commit

Permalink
COMPAT: follow geopandas in unary_union -> union_all change (#291)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche authored May 15, 2024
1 parent 2607381 commit aa1b52f
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 6 deletions.
34 changes: 33 additions & 1 deletion dask_geopandas/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from packaging.version import Version

import numpy as np
Expand Down Expand Up @@ -25,6 +27,7 @@

DASK_2022_8_1 = Version(dask.__version__) >= Version("2022.8.1")
GEOPANDAS_0_12 = Version(geopandas.__version__) >= Version("0.12.0")
GEOPANDAS_1_0 = Version(geopandas.__version__) >= Version("1.0.0a0")
PANDAS_2_0_0 = Version(pd.__version__) >= Version("2.0.0")


Expand Down Expand Up @@ -265,6 +268,18 @@ def sindex(self):
@property
@derived_from(geopandas.base.GeoPandasBase)
def unary_union(self):
warnings.warn(
"The 'unary_union' attribute is deprecated, "
"use the 'union_all()' method instead.",
FutureWarning,
stacklevel=2,
)
if GEOPANDAS_1_0:
return self.union_all()
else:
return self._unary_union()

def _unary_union(self):
attr = "unary_union"
meta = BaseGeometry()

Expand All @@ -275,6 +290,20 @@ def unary_union(self):
meta=meta,
)

def union_all(self):
if not GEOPANDAS_1_0:
return self._unary_union()

attr = "union_all"
meta = BaseGeometry()

return self.reduction(
lambda x: x.union_all(),
token=attr,
aggregate=lambda x: geopandas.GeoSeries(x).union_all(),
meta=meta,
)

@derived_from(geopandas.base.GeoPandasBase)
def representative_point(self):
return self.map_partitions(
Expand Down Expand Up @@ -659,7 +688,10 @@ def dissolve(self, by=None, aggfunc="first", split_out=1, **kwargs):
drop = [by, self.geometry.name]

def union(block):
merged_geom = block.unary_union
if GEOPANDAS_1_0:
merged_geom = block.union_all()
else:
merged_geom = block.unary_union
return merged_geom

merge_geometries = dd.Aggregation(
Expand Down
34 changes: 32 additions & 2 deletions dask_geopandas/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

DASK_2022_8_1 = Version(dask.__version__) >= Version("2022.8.1")
GEOPANDAS_0_12 = Version(geopandas.__version__) >= Version("0.12.0")
GEOPANDAS_1_0 = Version(geopandas.__version__) >= Version("1.0.0a0")
PANDAS_2_0_0 = Version(pd.__version__) >= Version("2.0.0")


Expand All @@ -56,6 +57,17 @@ def aggregate(cls, inputs, **kwargs):
return s.unary_union


class UnionAll(ApplyConcatApply):
@classmethod
def chunk(cls, df, **kwargs):
return df.union_all()

@classmethod
def aggregate(cls, inputs, **kwargs):
s = geopandas.GeoSeries(inputs)
return s.union_all()


class TotalBounds(ApplyConcatApply):
@classmethod
def chunk(cls, df, **kwargs):
Expand Down Expand Up @@ -302,7 +314,22 @@ def sindex(self):
@property
@derived_from(geopandas.base.GeoPandasBase)
def unary_union(self):
return new_collection(UnaryUnion(self.expr))
warnings.warn(
"The 'unary_union' attribute is deprecated, "
"use the 'union_all()' method instead.",
FutureWarning,
stacklevel=2,
)
if GEOPANDAS_1_0:
return new_collection(UnionAll(self.expr))
else:
return new_collection(UnaryUnion(self.expr))

def union_all(self):
if GEOPANDAS_1_0:
return new_collection(UnionAll(self.expr))
else:
return new_collection(UnaryUnion(self.expr))

@derived_from(geopandas.base.GeoPandasBase)
def representative_point(self):
Expand Down Expand Up @@ -693,7 +720,10 @@ def dissolve(self, by=None, aggfunc="first", split_out=1, **kwargs):

def union(block):
block = geopandas.GeoSeries(block)
merged_geom = block.unary_union
if GEOPANDAS_1_0:
merged_geom = block.union_all()
else:
merged_geom = block.unary_union
return merged_geom

merge_geometries = dd.Aggregation(
Expand Down
22 changes: 19 additions & 3 deletions dask_geopandas/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dask_geopandas.hilbert_distance import _hilbert_distance
from dask_geopandas.morton_distance import _morton_distance
from dask_geopandas.geohash import _geohash
from dask_geopandas.core import PANDAS_2_0_0
from dask_geopandas.core import PANDAS_2_0_0, GEOPANDAS_1_0


@pytest.fixture
Expand Down Expand Up @@ -107,10 +107,26 @@ def test_geoseries_properties(geoseries_polygons, attr):


def test_geoseries_unary_union(geoseries_points):
original = getattr(geoseries_points, "unary_union")
if GEOPANDAS_1_0:
original = geoseries_points.union_all()
else:
original = geoseries_points.unary_union

dask_obj = dask_geopandas.from_geopandas(geoseries_points, npartitions=2)
with pytest.warns(FutureWarning):
daskified = dask_obj.unary_union
assert isinstance(daskified, Scalar)
assert original.equals(daskified.compute())


def test_geoseries_union_all(geoseries_points):
if GEOPANDAS_1_0:
original = geoseries_points.union_all()
else:
original = geoseries_points.unary_union

dask_obj = dask_geopandas.from_geopandas(geoseries_points, npartitions=2)
daskified = dask_obj.unary_union
daskified = dask_obj.union_all()
assert isinstance(daskified, Scalar)
assert original.equals(daskified.compute())

Expand Down

0 comments on commit aa1b52f

Please sign in to comment.