Skip to content

Commit

Permalink
Use SciPy's KDTree instead of deprecated cKDTree (#733)
Browse files Browse the repository at this point in the history
Functionality is equivalent since SciPy 1.6 (see note [here](https://docs.scipy.org/doc/scipy-1.13.0/reference/generated/scipy.spatial.cKDTree.html)).

I pinned SciPy to >= 1.6 (released Dec 31, 2020)

This MR also adds some missing test cases for the `_ensure_spacing` helper function used by `cucim.skimage.feature.peak_local_max`. The new tests revealed a bug in that function in the case of non-scalar `spacing` which is now fixed.

Also, CuPy recently added KDTree so we can hopefully improve performance by moving to that in the future. I opened issue #732 as a reminder to investigate that.

Authors:
  - Gregory Lee (https://github.com/grlee77)

Approvers:
  - Ray Douglass (https://github.com/raydouglass)
  - https://github.com/jakirkham

URL: #733
  • Loading branch information
grlee77 authored Apr 30, 2024
1 parent edd1da3 commit fb94e16
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 11 deletions.
2 changes: 1 addition & 1 deletion conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies:
- pywavelets>=1.0
- recommonmark
- scikit-image>=0.19.0,<0.23.0a0
- scipy
- scipy>=1.6.0
- sphinx<6
- sysroot_linux-64==2.17
- tifffile>=2022.7.28
Expand Down
2 changes: 1 addition & 1 deletion conda/environments/all_cuda-122_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ dependencies:
- pywavelets>=1.0
- recommonmark
- scikit-image>=0.19.0,<0.23.0a0
- scipy
- scipy>=1.6.0
- sphinx<6
- sysroot_linux-64==2.17
- tifffile>=2022.7.28
Expand Down
4 changes: 2 additions & 2 deletions conda/recipes/cucim/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ requirements:
- numpy 1.23
- python
- scikit-image >=0.19.0,<0.23.0a0
- scipy
- scipy >=1.6
run:
- {{ pin_compatible('cuda-version', max_pin='x', min_pin='x') }}
{% if cuda_major != "11" %}
Expand All @@ -79,7 +79,7 @@ requirements:
- libcucim ={{ version }}
- python
- scikit-image >=0.19.0,<0.23.0a0
- scipy
- scipy >=1.6
run_constrained:
- openslide-python >=1.3.0

Expand Down
2 changes: 1 addition & 1 deletion dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ dependencies:
- lazy_loader>=0.1
- numpy>=1.23.4,<2.0a0
- scikit-image>=0.19.0,<0.23.0a0
- scipy
- scipy>=1.6.0
- output_types: conda
packages:
- cupy>=12.0.0
Expand Down
2 changes: 1 addition & 1 deletion python/cucim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"lazy_loader>=0.1",
"numpy>=1.23.4,<2.0a0",
"scikit-image>=0.19.0,<0.23.0a0",
"scipy",
"scipy>=1.6.0",
] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
classifiers = [
"Development Status :: 4 - Beta",
Expand Down
8 changes: 5 additions & 3 deletions python/cucim/src/cucim/skimage/_shared/coord.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import cupy as cp
import numpy as np
from scipy.spatial import cKDTree, distance
from scipy.spatial import KDTree, distance


# TODO: avoid host/device transfers (currently needed for cKDTree)
# TODO: avoid host/device transfers (currently needed for KDTree)
def _ensure_spacing(coord, spacing, p_norm, max_out):
"""Returns a subset of coord where a minimum spacing is guaranteed.
Expand All @@ -30,7 +30,7 @@ def _ensure_spacing(coord, spacing, p_norm, max_out):
"""

# Use KDtree to find the peaks that are too close to each other
tree = cKDTree(coord)
tree = KDTree(coord)

indices = tree.query_ball_point(coord, r=spacing, p=p_norm)
rejected_peaks_indices = set()
Expand Down Expand Up @@ -106,6 +106,8 @@ def ensure_spacing(
if len(coords):
coords = cp.atleast_2d(coords)
coords = cp.asnumpy(coords)
if not np.isscalar(spacing):
spacing = cp.asnumpy(spacing)
if min_split_size is None:
batch_list = [coords]
else:
Expand Down
98 changes: 98 additions & 0 deletions python/cucim/src/cucim/skimage/_shared/tests/test_coord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import time

import cupy as cp
import numpy as np
import pytest
from scipy.spatial.distance import minkowski, pdist

from cucim.skimage._shared.coord import ensure_spacing


@pytest.mark.parametrize("p", [1, 2, np.inf])
@pytest.mark.parametrize("size", [30, 50, None])
def test_ensure_spacing_trivial(p, size):
# --- Empty input
assert ensure_spacing(cp.asarray([]), p_norm=p).size == 0

# --- A unique point
coord = cp.random.randn(1, 2)
assert cp.array_equal(
coord, ensure_spacing(coord, p_norm=p, min_split_size=size)
)

# --- Verified spacing
coord = cp.random.randn(100, 2)

# --- 0 spacing
assert cp.array_equal(
coord, ensure_spacing(coord, spacing=0, p_norm=p, min_split_size=size)
)

# Spacing is chosen to be half the minimum distance
coord_cpu = cp.asnumpy(coord)
spacing = cp.asarray(pdist(coord_cpu, metric=minkowski, p=p).min() * 0.5)

out = ensure_spacing(coord, spacing=spacing, p_norm=p, min_split_size=size)

assert cp.array_equal(coord, out)


@pytest.mark.parametrize("ndim", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("size", [2, 10, None])
def test_ensure_spacing_nD(ndim, size):
coord = cp.ones((5, ndim))

expected = cp.ones((1, ndim))

assert cp.array_equal(ensure_spacing(coord, min_split_size=size), expected)


@pytest.mark.parametrize("p", [1, 2, np.inf])
@pytest.mark.parametrize("size", [50, 100, None])
def test_ensure_spacing_batch_processing(p, size):
coord_cpu = np.random.randn(100, 2)

# --- Consider the average distance btween the point as spacing
spacing = cp.asarray(np.median(pdist(coord_cpu, metric=minkowski, p=p)))
coord = cp.asarray(coord_cpu)

expected = ensure_spacing(coord, spacing=spacing, p_norm=p)

cp.testing.assert_array_equal(
ensure_spacing(coord, spacing=spacing, p_norm=p, min_split_size=size),
expected,
)


def test_max_batch_size():
"""Small batches are slow, large batches -> large allocations -> also slow.
https://github.com/scikit-image/scikit-image/pull/6035#discussion_r751518691
"""
coords = cp.random.randint(low=0, high=1848, size=(40000, 2))
tstart = time.time()
ensure_spacing(coords, spacing=100, min_split_size=50, max_split_size=2000)
dur1 = time.time() - tstart

tstart = time.time()
ensure_spacing(coords, spacing=100, min_split_size=50, max_split_size=20000)
dur2 = time.time() - tstart

# Originally checked dur1 < dur2 to assert that the default batch size was
# faster than a much larger batch size. However, on rare occasion a CI test
# case would fail with dur1 ~5% larger than dur2. To be more robust to
# variable load or differences across architectures, we relax this here.
assert dur1 < 1.33 * dur2


@pytest.mark.parametrize("p", [1, 2, np.inf])
@pytest.mark.parametrize("size", [30, 50, None])
def test_ensure_spacing_p_norm(p, size):
coord_cpu = np.random.randn(100, 2)

# --- Consider the average distance btween the point as spacing
spacing = cp.asarray(np.median(pdist(coord_cpu, metric=minkowski, p=p)))
coord = cp.asarray(coord_cpu)
out = ensure_spacing(coord, spacing=spacing, p_norm=p, min_split_size=size)

assert pdist(cp.asnumpy(out), metric=minkowski, p=p).min() > spacing
4 changes: 2 additions & 2 deletions python/cucim/src/cucim/skimage/feature/corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import cupy as cp
import numpy as np
from scipy import spatial # TODO: use cuSpatial if cKDTree becomes available
from scipy import spatial # TODO: use cuSpatial if KDTree becomes available

import cucim.skimage._vendored.ndimage as ndi
from cucim.skimage.util import img_as_float
Expand Down Expand Up @@ -1370,7 +1370,7 @@ def corner_peaks(
coords = cp.asnumpy(coords)

# Use KDtree to find the peaks that are too close to each other
tree = spatial.cKDTree(coords)
tree = spatial.KDTree(coords)

rejected_peaks_indices = set()
for idx, point in enumerate(coords):
Expand Down

0 comments on commit fb94e16

Please sign in to comment.