Skip to content

Commit

Permalink
Merge pull request rapidsai#4683 from rapidsai/branch-22.04
Browse files Browse the repository at this point in the history
[gpuCI] Forward-merge branch-22.04 to branch-22.06 [skip gpuci]
  • Loading branch information
GPUtester authored Apr 4, 2022
2 parents 3ce72e6 + 9f73269 commit 2de9800
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/cuml/neighbors/kernel_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def sample(self, n_samples=1, random_state=None):

supported_kernels = ["gaussian", "tophat"]
if (self.kernel not in supported_kernels
or self.metric != "euclidian"):
or self.metric != "euclidean"):
raise NotImplementedError(
"Only {} kernels, and the euclidean"
" metric are supported.".format(supported_kernels))
Expand Down
12 changes: 9 additions & 3 deletions python/cuml/test/test_kernel_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,18 @@ def test_kernel_density(arrays, kernel, metric, bandwidth):
assert np.allclose(np.exp(as_type("numpy", cuml_prob_test)),
ref_prob_test, rtol=tol, atol=tol, equal_nan=True)

if kernel in ["gaussian", "tophat"] and metric == "euclidian":
if kernel in ["gaussian", "tophat"] and metric == "euclidean":
sample = kde.sample(100, random_state=32).get()
nearest = skl_pairwise_distances(sample, X, metric=metric)
nearest = skl_pairwise_distances(sample, X_np, metric=metric)
nearest = nearest.min(axis=1)
if kernel == "gaussian":
assert np.all(nearest < 5 * bandwidth)
from scipy.stats import chi
# The euclidean distance of each sample from its cluster
# follows a chi distribution (not squared) with DoF=dimension
# and scale = bandwidth
# Fail the test if the largest observed distance
# is vanishingly unlikely
assert chi.sf(nearest.max(), X.shape[1], scale=bandwidth) > 1e-8
elif kernel == "tophat":
assert np.all(nearest <= bandwidth)
else:
Expand Down

0 comments on commit 2de9800

Please sign in to comment.