Skip to content

Commit

Permalink
Update k nearest neighbor, always return k results
Browse files Browse the repository at this point in the history
  • Loading branch information
sebhahn committed May 24, 2022
1 parent 75827c0 commit fdc80d2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 18 deletions.
17 changes: 8 additions & 9 deletions src/pygeogrids/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,26 +445,25 @@ def find_k_nearest_gpi(self, lon, lat, max_dist=np.Inf, k=1):
-------
gpi : np.ndarray
Grid point indices.
distance : np.ndarray
dist : np.ndarray
Distance of gpi(s) to given lon, lat.
At the moment not on a great circle but in spherical
cartesian coordinates.
"""
if self.kdTree is None:
self._setup_kdtree()

distance, ind = self.kdTree.find_nearest_index(lon, lat,
max_dist=max_dist, k=k)
mask = np.isinf(distance)
ind = ind[~mask]
distance = distance[~mask]
dist, ind = self.kdTree.find_nearest_index(lon, lat,
max_dist=max_dist, k=k)
mask = np.isinf(dist)
gpi = np.zeros(dist.shape, dtype=np.int32) + np.iinfo(np.int32).max

if self.gpidirect and self.allpoints or len(ind) == 0:
gpi = ind
gpi[~mask] = ind[~mask]
else:
gpi = self.activegpis[ind]
gpi[~mask] = self.activegpis[ind[~mask]]

return gpi, distance
return gpi, dist

def gpi2lonlat(self, gpi):
"""
Expand Down
15 changes: 6 additions & 9 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,9 @@ def test_k_nearest_neighbor(self):
assert lat == 18.5

with pytest.warns(UserWarning):
gpi, dist = self.grid.find_k_nearest_gpi(14.3, 18.5, k=2,
gpi, dist = self.grid.find_k_nearest_gpi(14.3, 18.5, k=2,
max_dist=25000)
assert len(gpi) == len(dist) == 1
assert np.all(np.isfinite(dist))
assert gpi == 25754
assert gpi.shape == dist.shape == (1, 2)


def test_k_nearest_neighbor_list(self):
Expand All @@ -161,15 +159,14 @@ def test_nearest_neighbor_max_dist(self):

# test with maxdist lower than nearest point
gpi, dist = self.grid.find_nearest_gpi(14.3, 18.5, max_dist=10000)
assert len(gpi) == 0
assert len(dist) == 0
assert gpi == np.iinfo(np.int32).max
assert dist == np.inf

# test with custom gpi, see issue #68
grid = grids.BasicGrid(lon=[16,17], lat=[45,46], gpis=[100,200])
gpi, dist = grid.find_nearest_gpi(0,0, max_dist=1000)
assert len(gpi) == 0
assert len(dist) == 0

assert gpi == np.iinfo(np.int32).max
assert dist == np.inf

class TestCellGridNotGpiDirect(unittest.TestCase):

Expand Down

0 comments on commit fdc80d2

Please sign in to comment.