From 4f23154996c56b00f54002adda079cf27da3d3cd Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Jun 2022 10:13:12 -0700 Subject: [PATCH 1/7] Allow cosine distance metric in dbscan --- cpp/src/dbscan/dbscan.cuh | 10 +++- cpp/src/dbscan/vertexdeg/cosine.cuh | 90 +++++++++++++++++++++++++++++ cpp/src/dbscan/vertexdeg/runner.cuh | 4 ++ python/cuml/cluster/dbscan.pyx | 3 +- python/cuml/tests/test_dbscan.py | 35 +++++++++++ 5 files changed, 140 insertions(+), 2 deletions(-) create mode 100644 cpp/src/dbscan/vertexdeg/cosine.cuh diff --git a/cpp/src/dbscan/dbscan.cuh b/cpp/src/dbscan/dbscan.cuh index 8664fb0a65..804651d37f 100644 --- a/cpp/src/dbscan/dbscan.cuh +++ b/cpp/src/dbscan/dbscan.cuh @@ -110,7 +110,15 @@ void dbscanFitImpl(const raft::handle_t& handle, { raft::common::nvtx::range fun_scope("ML::Dbscan::Fit"); ML::Logger::get().setLevel(verbosity); - int algo_vd = (metric == raft::distance::Precomputed) ? 2 : 1; + // int algo_vd = (metric == raft::distance::Precomputed) ? 2 : 1; + int algo_vd; + if (metric == raft::distance::Precomputed) { + algo_vd = 2; + } else if (metric == raft::distance::CosineExpanded) { + algo_vd = 3; + } else { + algo_vd = 1; + } int algo_adj = 1; int algo_ccl = 2; diff --git a/cpp/src/dbscan/vertexdeg/cosine.cuh b/cpp/src/dbscan/vertexdeg/cosine.cuh new file mode 100644 index 0000000000..a21dceece6 --- /dev/null +++ b/cpp/src/dbscan/vertexdeg/cosine.cuh @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "pack.h" + +namespace ML { +namespace Dbscan { +namespace VertexDeg { +namespace Cosine { + +/** + * Calculates the vertex degree array and the epsilon neighborhood adjacency matrix for the batch. + */ +template +void launcher(const raft::handle_t& handle, + Pack data, + index_t start_vertex_id, + index_t batch_size, + cudaStream_t stream) +{ + data.resetArray(stream, batch_size + 1); + + ASSERT(sizeof(index_t) == 4 || sizeof(index_t) == 8, "index_t should be 4 or 8 bytes"); + + index_t m = data.N; + index_t n = min(data.N - start_vertex_id, batch_size); + index_t k = data.D; + value_t eps2 = 2 * data.eps; + + rmm::device_uvector rowNorms(m, stream); + rmm::device_uvector l2Normalized(m * n, stream); + + raft::linalg::rowNorm(rowNorms.data(), + data.x, + k, + m, + raft::linalg::NormType::L2Norm, + true, + stream, + [] __device__(value_t in) { return sqrtf(in); }); + + raft::linalg::matrixVectorOp( + l2Normalized.data(), + data.x, + rowNorms.data(), + k, + m, + true, + true, + [] __device__(value_t mat_in, value_t vec_in) { return mat_in / vec_in; }, + stream); + + raft::spatial::knn::epsUnexpL2SqNeighborhood( + data.adj, + data.vd, + l2Normalized.data(), + l2Normalized.data() + start_vertex_id * k, + m, + n, + k, + eps2, + stream); +} + +} // namespace Cosine +} // end namespace VertexDeg +} // end namespace Dbscan +} // namespace ML diff --git a/cpp/src/dbscan/vertexdeg/runner.cuh b/cpp/src/dbscan/vertexdeg/runner.cuh index 082a2ac46f..3a60da69ea 100644 --- a/cpp/src/dbscan/vertexdeg/runner.cuh +++ b/cpp/src/dbscan/vertexdeg/runner.cuh @@ -17,6 +17,7 @@ #pragma once #include "algo.cuh" +#include "cosine.cuh" #include "naive.cuh" #include "pack.h" #include "precomputed.cuh" @@ -47,6 +48,9 @@ void run(const raft::handle_t& handle, case 2: Precomputed::launcher(handle, data, start_vertex_id, batch_size, stream); break; + case 3: + Cosine::launcher(handle, data, start_vertex_id, batch_size, stream); + break; default: ASSERT(false, "Incorrect algo passed! '%d'", algo); } } diff --git a/python/cuml/cluster/dbscan.pyx b/python/cuml/cluster/dbscan.pyx index d00df0a822..26f8dd8db8 100644 --- a/python/cuml/cluster/dbscan.pyx +++ b/python/cuml/cluster/dbscan.pyx @@ -147,7 +147,7 @@ class DBSCAN(Base, min_samples : int (default = 5) The number of samples in a neighborhood such that this group can be considered as an important core point (including the point itself). - metric: {'euclidean', 'precomputed'}, default = 'euclidean' + metric: {'euclidean', 'precomputed', 'cosine'}, default = 'euclidean' The metric to use when calculating distances between points. If metric is 'precomputed', X is assumed to be a distance matrix and must be square. @@ -267,6 +267,7 @@ class DBSCAN(Base, "L2": DistanceType.L2SqrtUnexpanded, "euclidean": DistanceType.L2SqrtUnexpanded, "precomputed": DistanceType.Precomputed, + "cosine": DistanceType.CosineExpanded } if self.metric in metric_parsing: metric = metric_parsing[self.metric.lower()] diff --git a/python/cuml/tests/test_dbscan.py b/python/cuml/tests/test_dbscan.py index 8c8027d7ec..b23b2af67c 100644 --- a/python/cuml/tests/test_dbscan.py +++ b/python/cuml/tests/test_dbscan.py @@ -107,6 +107,41 @@ def test_dbscan_precomputed(datatype, nrows, max_mbytes_per_batch, out_dtype): algorithm="brute") sk_labels = sk_dbscan.fit_predict(X_dist) + print("cu_labels:", cu_labels) + print("sk_labels:", sk_labels) + + # Check the core points are equal + assert array_equal(cuml_dbscan.core_sample_indices_, + sk_dbscan.core_sample_indices_) + + # Check the labels are correct + assert_dbscan_equal(sk_labels, cu_labels, X, + cuml_dbscan.core_sample_indices_, eps) + + +@pytest.mark.parametrize('max_mbytes_per_batch', [unit_param(1), + quality_param(1e2), stress_param(None)]) +@pytest.mark.parametrize('nrows', [unit_param(500), quality_param(5000), + stress_param(10000)]) +@pytest.mark.parametrize('out_dtype', ["int32", "int64"]) +def test_dbscan_cosine(nrows, max_mbytes_per_batch, out_dtype): + # 2-dimensional dataset for easy distance matrix computation + X, y = make_blobs(n_samples=nrows, cluster_std=0.01, + n_features=2, random_state=0) + + eps = 0.1 + + cuml_dbscan = cuDBSCAN(eps=eps, min_samples=5, metric='cosine', + max_mbytes_per_batch=max_mbytes_per_batch, + output_type='numpy') + + cu_labels = cuml_dbscan.fit_predict(X, out_dtype=out_dtype) + + sk_dbscan = skDBSCAN(eps=eps, min_samples=5, metric='cosine', + algorithm='brute') + + sk_labels = sk_dbscan.fit_predict(X) + # Check the core points are equal assert array_equal(cuml_dbscan.core_sample_indices_, sk_dbscan.core_sample_indices_) From b41b02d8c83457d78ce57d1c2fe8dfd922949fe8 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Jun 2022 10:52:23 -0700 Subject: [PATCH 2/7] removed commented line --- cpp/src/dbscan/dbscan.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/dbscan/dbscan.cuh b/cpp/src/dbscan/dbscan.cuh index 804651d37f..478bf1a534 100644 --- a/cpp/src/dbscan/dbscan.cuh +++ b/cpp/src/dbscan/dbscan.cuh @@ -110,7 +110,6 @@ void dbscanFitImpl(const raft::handle_t& handle, { raft::common::nvtx::range fun_scope("ML::Dbscan::Fit"); ML::Logger::get().setLevel(verbosity); - // int algo_vd = (metric == raft::distance::Precomputed) ? 2 : 1; int algo_vd; if (metric == raft::distance::Precomputed) { algo_vd = 2; From eb8df1057ef2eb449f6f2b4688331790eb5f3c24 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 13 Jun 2022 11:12:39 -0700 Subject: [PATCH 3/7] fixed styling (copyright) --- cpp/src/dbscan/vertexdeg/runner.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/dbscan/vertexdeg/runner.cuh b/cpp/src/dbscan/vertexdeg/runner.cuh index 3a60da69ea..6a9b5024e4 100644 --- a/cpp/src/dbscan/vertexdeg/runner.cuh +++ b/cpp/src/dbscan/vertexdeg/runner.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 1ce61c746f2425c8993bb6f57ebb46b3aa6bf423 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 14 Jun 2022 09:47:01 -0700 Subject: [PATCH 4/7] design changes --- python/cuml/cluster/dbscan.pyx | 4 ++-- python/cuml/tests/test_dbscan.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/cuml/cluster/dbscan.pyx b/python/cuml/cluster/dbscan.pyx index 26f8dd8db8..212295bc64 100644 --- a/python/cuml/cluster/dbscan.pyx +++ b/python/cuml/cluster/dbscan.pyx @@ -266,8 +266,8 @@ class DBSCAN(Base, metric_parsing = { "L2": DistanceType.L2SqrtUnexpanded, "euclidean": DistanceType.L2SqrtUnexpanded, - "precomputed": DistanceType.Precomputed, - "cosine": DistanceType.CosineExpanded + "cosine": DistanceType.CosineExpanded, + "precomputed": DistanceType.Precomputed } if self.metric in metric_parsing: metric = metric_parsing[self.metric.lower()] diff --git a/python/cuml/tests/test_dbscan.py b/python/cuml/tests/test_dbscan.py index b23b2af67c..e2cb8f27b2 100644 --- a/python/cuml/tests/test_dbscan.py +++ b/python/cuml/tests/test_dbscan.py @@ -107,9 +107,6 @@ def test_dbscan_precomputed(datatype, nrows, max_mbytes_per_batch, out_dtype): algorithm="brute") sk_labels = sk_dbscan.fit_predict(X_dist) - print("cu_labels:", cu_labels) - print("sk_labels:", sk_labels) - # Check the core points are equal assert array_equal(cuml_dbscan.core_sample_indices_, sk_dbscan.core_sample_indices_) From aabe1170dce2c25136dfc87b7b31d0a4db4b08ec Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 14 Jun 2022 16:10:12 -0700 Subject: [PATCH 5/7] Updates after PR Review --- cpp/src/dbscan/dbscan.cuh | 15 ++--- cpp/src/dbscan/runner.cuh | 7 ++- cpp/src/dbscan/vertexdeg/algo.cuh | 68 +++++++++++++++++++--- cpp/src/dbscan/vertexdeg/cosine.cuh | 90 ----------------------------- cpp/src/dbscan/vertexdeg/runner.cuh | 9 +-- python/cuml/cluster/dbscan.pyx | 5 +- 6 files changed, 77 insertions(+), 117 deletions(-) delete mode 100644 cpp/src/dbscan/vertexdeg/cosine.cuh diff --git a/cpp/src/dbscan/dbscan.cuh b/cpp/src/dbscan/dbscan.cuh index 478bf1a534..24595ff931 100644 --- a/cpp/src/dbscan/dbscan.cuh +++ b/cpp/src/dbscan/dbscan.cuh @@ -110,14 +110,7 @@ void dbscanFitImpl(const raft::handle_t& handle, { raft::common::nvtx::range fun_scope("ML::Dbscan::Fit"); ML::Logger::get().setLevel(verbosity); - int algo_vd; - if (metric == raft::distance::Precomputed) { - algo_vd = 2; - } else if (metric == raft::distance::CosineExpanded) { - algo_vd = 3; - } else { - algo_vd = 1; - } + int algo_vd = (metric == raft::distance::Precomputed) ? 2 : 1; int algo_adj = 1; int algo_ccl = 2; @@ -187,7 +180,8 @@ void dbscanFitImpl(const raft::handle_t& handle, algo_ccl, NULL, batch_size, - stream); + stream, + metric); CUML_LOG_DEBUG("Workspace size: %lf MB", (double)workspaceSize * 1e-6); @@ -207,7 +201,8 @@ void dbscanFitImpl(const raft::handle_t& handle, algo_ccl, workspace.data(), batch_size, - stream); + stream, + metric); } } // namespace Dbscan diff --git a/cpp/src/dbscan/runner.cuh b/cpp/src/dbscan/runner.cuh index e3e8dcd8aa..acef6c9785 100644 --- a/cpp/src/dbscan/runner.cuh +++ b/cpp/src/dbscan/runner.cuh @@ -114,7 +114,8 @@ std::size_t run(const raft::handle_t& handle, int algo_ccl, void* workspace, std::size_t batch_size, - cudaStream_t stream) + cudaStream_t stream, + raft::distance::DistanceType metric) { const std::size_t align = 256; Index_ n_batches = raft::ceildiv((std::size_t)n_owned_rows, batch_size); @@ -191,7 +192,7 @@ std::size_t run(const raft::handle_t& handle, CUML_LOG_DEBUG("--> Computing vertex degrees"); raft::common::nvtx::push_range("Trace::Dbscan::VertexDeg"); VertexDeg::run( - handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream); + handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream, metric); raft::common::nvtx::pop_range(); CUML_LOG_DEBUG("--> Computing core point mask"); @@ -219,7 +220,7 @@ std::size_t run(const raft::handle_t& handle, CUML_LOG_DEBUG("--> Computing vertex degrees"); raft::common::nvtx::push_range("Trace::Dbscan::VertexDeg"); VertexDeg::run( - handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream); + handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream, metric); raft::common::nvtx::pop_range(); } raft::update_host(&curradjlen, vd + n_points, 1, stream); diff --git a/cpp/src/dbscan/vertexdeg/algo.cuh b/cpp/src/dbscan/vertexdeg/algo.cuh index b4458e9008..147203a240 100644 --- a/cpp/src/dbscan/vertexdeg/algo.cuh +++ b/cpp/src/dbscan/vertexdeg/algo.cuh @@ -18,7 +18,10 @@ #include #include +#include +#include #include +#include #include "pack.h" @@ -35,19 +38,70 @@ void launcher(const raft::handle_t& handle, Pack data, index_t start_vertex_id, index_t batch_size, - cudaStream_t stream) + cudaStream_t stream, + raft::distance::DistanceType metric) { data.resetArray(stream, batch_size + 1); ASSERT(sizeof(index_t) == 4 || sizeof(index_t) == 8, "index_t should be 4 or 8 bytes"); - index_t m = data.N; - index_t n = min(data.N - start_vertex_id, batch_size); - index_t k = data.D; - value_t eps2 = data.eps * data.eps; + index_t m = data.N; + index_t n = min(data.N - start_vertex_id, batch_size); + index_t k = data.D; + value_t eps2; - raft::spatial::knn::epsUnexpL2SqNeighborhood( - data.adj, data.vd, data.x, data.x + start_vertex_id * k, m, n, k, eps2, stream); + if (metric == raft::distance::DistanceType::Precomputed) { + rmm::device_uvector rowNorms(m, stream); + + raft::linalg::rowNorm(rowNorms.data(), + data.x, + k, + m, + raft::linalg::NormType::L2Norm, + true, + stream, + [] __device__(value_t in) { return sqrtf(in); }); + + /* Cast away constness because the output matrix for normalization cannot be of const type. + * Input matrix will be modified due to normalization. + */ + raft::linalg::matrixVectorOp( + const_cast(data.x), + data.x, + rowNorms.data(), + k, + m, + true, + true, + [] __device__(value_t mat_in, value_t vec_in) { return mat_in / vec_in; }, + stream); + + eps2 = 2 * data.eps; + + raft::spatial::knn::epsUnexpL2SqNeighborhood( + data.adj, data.vd, data.x, data.x + start_vertex_id * k, m, n, k, eps2, stream); + + /** + * Restoring the input matrix after normalization. + */ + raft::linalg::matrixVectorOp( + const_cast(data.x), + data.x, + rowNorms.data(), + k, + m, + true, + true, + [] __device__(value_t mat_in, value_t vec_in) { return mat_in * vec_in; }, + stream); + } + + else { + eps2 = data.eps * data.eps; + + raft::spatial::knn::epsUnexpL2SqNeighborhood( + data.adj, data.vd, data.x, data.x + start_vertex_id * k, m, n, k, eps2, stream); + } } } // namespace Algo diff --git a/cpp/src/dbscan/vertexdeg/cosine.cuh b/cpp/src/dbscan/vertexdeg/cosine.cuh deleted file mode 100644 index a21dceece6..0000000000 --- a/cpp/src/dbscan/vertexdeg/cosine.cuh +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "pack.h" - -namespace ML { -namespace Dbscan { -namespace VertexDeg { -namespace Cosine { - -/** - * Calculates the vertex degree array and the epsilon neighborhood adjacency matrix for the batch. - */ -template -void launcher(const raft::handle_t& handle, - Pack data, - index_t start_vertex_id, - index_t batch_size, - cudaStream_t stream) -{ - data.resetArray(stream, batch_size + 1); - - ASSERT(sizeof(index_t) == 4 || sizeof(index_t) == 8, "index_t should be 4 or 8 bytes"); - - index_t m = data.N; - index_t n = min(data.N - start_vertex_id, batch_size); - index_t k = data.D; - value_t eps2 = 2 * data.eps; - - rmm::device_uvector rowNorms(m, stream); - rmm::device_uvector l2Normalized(m * n, stream); - - raft::linalg::rowNorm(rowNorms.data(), - data.x, - k, - m, - raft::linalg::NormType::L2Norm, - true, - stream, - [] __device__(value_t in) { return sqrtf(in); }); - - raft::linalg::matrixVectorOp( - l2Normalized.data(), - data.x, - rowNorms.data(), - k, - m, - true, - true, - [] __device__(value_t mat_in, value_t vec_in) { return mat_in / vec_in; }, - stream); - - raft::spatial::knn::epsUnexpL2SqNeighborhood( - data.adj, - data.vd, - l2Normalized.data(), - l2Normalized.data() + start_vertex_id * k, - m, - n, - k, - eps2, - stream); -} - -} // namespace Cosine -} // end namespace VertexDeg -} // end namespace Dbscan -} // namespace ML diff --git a/cpp/src/dbscan/vertexdeg/runner.cuh b/cpp/src/dbscan/vertexdeg/runner.cuh index 6a9b5024e4..561c98ab12 100644 --- a/cpp/src/dbscan/vertexdeg/runner.cuh +++ b/cpp/src/dbscan/vertexdeg/runner.cuh @@ -17,7 +17,6 @@ #pragma once #include "algo.cuh" -#include "cosine.cuh" #include "naive.cuh" #include "pack.h" #include "precomputed.cuh" @@ -37,20 +36,18 @@ void run(const raft::handle_t& handle, int algo, Index_ start_vertex_id, Index_ batch_size, - cudaStream_t stream) + cudaStream_t stream, + raft::distance::DistanceType metric) { Pack data = {vd, adj, x, eps, N, D}; switch (algo) { case 0: Naive::launcher(data, start_vertex_id, batch_size, stream); break; case 1: - Algo::launcher(handle, data, start_vertex_id, batch_size, stream); + Algo::launcher(handle, data, start_vertex_id, batch_size, stream, metric); break; case 2: Precomputed::launcher(handle, data, start_vertex_id, batch_size, stream); break; - case 3: - Cosine::launcher(handle, data, start_vertex_id, batch_size, stream); - break; default: ASSERT(false, "Incorrect algo passed! '%d'", algo); } } diff --git a/python/cuml/cluster/dbscan.pyx b/python/cuml/cluster/dbscan.pyx index 212295bc64..7727b66574 100644 --- a/python/cuml/cluster/dbscan.pyx +++ b/python/cuml/cluster/dbscan.pyx @@ -147,10 +147,13 @@ class DBSCAN(Base, min_samples : int (default = 5) The number of samples in a neighborhood such that this group can be considered as an important core point (including the point itself). - metric: {'euclidean', 'precomputed', 'cosine'}, default = 'euclidean' + metric: {'euclidean', 'cosine', 'precomputed'}, default = 'euclidean' The metric to use when calculating distances between points. If metric is 'precomputed', X is assumed to be a distance matrix and must be square. + The input will be modified temporarily when cosine distance is used + and the restored input matrix might not match completely + due to numerical rounding. verbose : int or boolean, default=False Sets logging level. It must be one of `cuml.common.logger.level_*`. See :ref:`verbosity-levels` for more info. From 695cafbef81cb4f201b35faba8785b83ce7af5a1 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 14 Jun 2022 16:39:33 -0700 Subject: [PATCH 6/7] Minor correction in vertex deg computation --- cpp/src/dbscan/vertexdeg/algo.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/dbscan/vertexdeg/algo.cuh b/cpp/src/dbscan/vertexdeg/algo.cuh index 147203a240..8cb13e8c4d 100644 --- a/cpp/src/dbscan/vertexdeg/algo.cuh +++ b/cpp/src/dbscan/vertexdeg/algo.cuh @@ -50,7 +50,7 @@ void launcher(const raft::handle_t& handle, index_t k = data.D; value_t eps2; - if (metric == raft::distance::DistanceType::Precomputed) { + if (metric == raft::distance::DistanceType::CosineExpanded) { rmm::device_uvector rowNorms(m, stream); raft::linalg::rowNorm(rowNorms.data(), From 0481785b72dae290108e46a446dd0f8d36a18a62 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 16 Jun 2022 10:42:31 -0700 Subject: [PATCH 7/7] Style fix --- python/cuml/cluster/dbscan.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cuml/cluster/dbscan.pyx b/python/cuml/cluster/dbscan.pyx index 7727b66574..255ed01bf5 100644 --- a/python/cuml/cluster/dbscan.pyx +++ b/python/cuml/cluster/dbscan.pyx @@ -153,7 +153,7 @@ class DBSCAN(Base, and must be square. The input will be modified temporarily when cosine distance is used and the restored input matrix might not match completely - due to numerical rounding. + due to numerical rounding. verbose : int or boolean, default=False Sets logging level. It must be one of `cuml.common.logger.level_*`. See :ref:`verbosity-levels` for more info.