Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Allow cosine distance metric in dbscan #4776

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cpp/src/dbscan/dbscan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,14 @@ void dbscanFitImpl(const raft::handle_t& handle,
{
raft::common::nvtx::range fun_scope("ML::Dbscan::Fit");
ML::Logger::get().setLevel(verbosity);
int algo_vd = (metric == raft::distance::Precomputed) ? 2 : 1;
int algo_vd;
if (metric == raft::distance::Precomputed) {
algo_vd = 2;
} else if (metric == raft::distance::CosineExpanded) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than duplicating the call to the epsilon neighborhood primitive for the cosine case, I'd prefer to pass the metric through directly when metric != precomputed and normalize the input conditionally in the case where metric == cosine.

algo_vd = 3;
} else {
algo_vd = 1;
}
int algo_adj = 1;
int algo_ccl = 2;

Expand Down
90 changes: 90 additions & 0 deletions cpp/src/dbscan/vertexdeg/cosine.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only use the current year for new files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should this be:
Copyright (c) 2021-2022, NVIDIA CORPORATION ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, for new files it would just be 2022.

*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuda_runtime.h>
#include <math.h>
#include <raft/linalg/matrix_vector_op.hpp>
#include <raft/linalg/norm.cuh>
#include <raft/spatial/knn/epsilon_neighborhood.hpp>
#include <rmm/device_uvector.hpp>

#include "pack.h"

namespace ML {
namespace Dbscan {
namespace VertexDeg {
namespace Cosine {

/**
* Calculates the vertex degree array and the epsilon neighborhood adjacency matrix for the batch.
*/
template <typename value_t, typename index_t = int>
void launcher(const raft::handle_t& handle,
Pack<value_t, index_t> data,
index_t start_vertex_id,
index_t batch_size,
cudaStream_t stream)
{
data.resetArray(stream, batch_size + 1);

ASSERT(sizeof(index_t) == 4 || sizeof(index_t) == 8, "index_t should be 4 or 8 bytes");

index_t m = data.N;
index_t n = min(data.N - start_vertex_id, batch_size);
index_t k = data.D;
value_t eps2 = 2 * data.eps;

rmm::device_uvector<value_t> rowNorms(m, stream);
rmm::device_uvector<value_t> l2Normalized(m * n, stream);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have two options for supporting cosine- either we normalize the input or we perform the normalization in the computation. If we normalize the input, we should do so directly to the input and then revert the values back afterwords because this is very expensive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a problem with attempting to modify the input: the output array address in raft::linalg::matrixVectorOp cannot be a const float *. Note that data.x is of the type const float *.


raft::linalg::rowNorm(rowNorms.data(),
data.x,
k,
m,
raft::linalg::NormType::L2Norm,
true,
stream,
[] __device__(value_t in) { return sqrtf(in); });

raft::linalg::matrixVectorOp(
l2Normalized.data(),
data.x,
rowNorms.data(),
k,
m,
true,
true,
[] __device__(value_t mat_in, value_t vec_in) { return mat_in / vec_in; },
stream);

raft::spatial::knn::epsUnexpL2SqNeighborhood<value_t, index_t>(
data.adj,
data.vd,
l2Normalized.data(),
l2Normalized.data() + start_vertex_id * k,
m,
n,
k,
eps2,
stream);
}

} // namespace Cosine
} // end namespace VertexDeg
} // end namespace Dbscan
} // namespace ML
6 changes: 5 additions & 1 deletion cpp/src/dbscan/vertexdeg/runner.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
#pragma once

#include "algo.cuh"
#include "cosine.cuh"
#include "naive.cuh"
#include "pack.h"
#include "precomputed.cuh"
Expand Down Expand Up @@ -47,6 +48,9 @@ void run(const raft::handle_t& handle,
case 2:
Precomputed::launcher<Type_f, Index_>(handle, data, start_vertex_id, batch_size, stream);
break;
case 3:
Cosine::launcher<Type_f, Index_>(handle, data, start_vertex_id, batch_size, stream);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, I prefer to use the same "launcher" as the other metrics and pass the metric in directly. This will also make it much easier to support other metrics in the future, rather than having to duplicate the launcher each time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should we also pass algo_vd as an argument to the launcher?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Ideally Algo::launcher would just accept the distance type.

break;
default: ASSERT(false, "Incorrect algo passed! '%d'", algo);
}
}
Expand Down
3 changes: 2 additions & 1 deletion python/cuml/cluster/dbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class DBSCAN(Base,
min_samples : int (default = 5)
The number of samples in a neighborhood such that this group can be
considered as an important core point (including the point itself).
metric: {'euclidean', 'precomputed'}, default = 'euclidean'
metric: {'euclidean', 'precomputed', 'cosine'}, default = 'euclidean'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a little nitpick: I would prefer if we kept precomputed either before or after the actual distance metrics for clarity. We should also add a little note to the docs here that the input will be modified temporarily when cosine distance is used (and might not match completely afterwards due to numerical rounding).

The metric to use when calculating distances between points.
If metric is 'precomputed', X is assumed to be a distance matrix
and must be square.
Expand Down Expand Up @@ -267,6 +267,7 @@ class DBSCAN(Base,
"L2": DistanceType.L2SqrtUnexpanded,
"euclidean": DistanceType.L2SqrtUnexpanded,
"precomputed": DistanceType.Precomputed,
"cosine": DistanceType.CosineExpanded
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also move this, maybe after euclidean, for readability.

}
if self.metric in metric_parsing:
metric = metric_parsing[self.metric.lower()]
Expand Down
35 changes: 35 additions & 0 deletions python/cuml/tests/test_dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,41 @@ def test_dbscan_precomputed(datatype, nrows, max_mbytes_per_batch, out_dtype):
algorithm="brute")
sk_labels = sk_dbscan.fit_predict(X_dist)

print("cu_labels:", cu_labels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to remove debug prints from tests

print("sk_labels:", sk_labels)

# Check the core points are equal
assert array_equal(cuml_dbscan.core_sample_indices_,
sk_dbscan.core_sample_indices_)

# Check the labels are correct
assert_dbscan_equal(sk_labels, cu_labels, X,
cuml_dbscan.core_sample_indices_, eps)


@pytest.mark.parametrize('max_mbytes_per_batch', [unit_param(1),
quality_param(1e2), stress_param(None)])
@pytest.mark.parametrize('nrows', [unit_param(500), quality_param(5000),
stress_param(10000)])
@pytest.mark.parametrize('out_dtype', ["int32", "int64"])
def test_dbscan_cosine(nrows, max_mbytes_per_batch, out_dtype):
# 2-dimensional dataset for easy distance matrix computation
X, y = make_blobs(n_samples=nrows, cluster_std=0.01,
n_features=2, random_state=0)

eps = 0.1

cuml_dbscan = cuDBSCAN(eps=eps, min_samples=5, metric='cosine',
max_mbytes_per_batch=max_mbytes_per_batch,
output_type='numpy')

cu_labels = cuml_dbscan.fit_predict(X, out_dtype=out_dtype)

sk_dbscan = skDBSCAN(eps=eps, min_samples=5, metric='cosine',
algorithm='brute')

sk_labels = sk_dbscan.fit_predict(X)

# Check the core points are equal
assert array_equal(cuml_dbscan.core_sample_indices_,
sk_dbscan.core_sample_indices_)
Expand Down