Skip to content

Commit

Permalink
Allow cosine distance metric in dbscan (#4776)
Browse files Browse the repository at this point in the history
closes #4210 
Added cosine distance metric for computing epsilon neighborhood in DBSCAN. The cosine distance computed as L2 norm of L2 normalized vectors and the epsilon value is adjusted accordingly.

Authors:
  - Tarang Jain (https://github.com/tarang-jain)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #4776
  • Loading branch information
tarang-jain authored Jul 7, 2022
1 parent 892558b commit c8aebc3
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 17 deletions.
6 changes: 4 additions & 2 deletions cpp/src/dbscan/dbscan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ void dbscanFitImpl(const raft::handle_t& handle,
algo_ccl,
NULL,
batch_size,
stream);
stream,
metric);

CUML_LOG_DEBUG("Workspace size: %lf MB", (double)workspaceSize * 1e-6);

Expand All @@ -200,7 +201,8 @@ void dbscanFitImpl(const raft::handle_t& handle,
algo_ccl,
workspace.data(),
batch_size,
stream);
stream,
metric);
}

} // namespace Dbscan
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/dbscan/runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ std::size_t run(const raft::handle_t& handle,
int algo_ccl,
void* workspace,
std::size_t batch_size,
cudaStream_t stream)
cudaStream_t stream,
raft::distance::DistanceType metric)
{
const std::size_t align = 256;
Index_ n_batches = raft::ceildiv((std::size_t)n_owned_rows, batch_size);
Expand Down Expand Up @@ -196,7 +197,7 @@ std::size_t run(const raft::handle_t& handle,
CUML_LOG_DEBUG("--> Computing vertex degrees");
raft::common::nvtx::push_range("Trace::Dbscan::VertexDeg");
VertexDeg::run<Type_f, Index_>(
handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream);
handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream, metric);
raft::common::nvtx::pop_range();

CUML_LOG_DEBUG("--> Computing core point mask");
Expand Down Expand Up @@ -224,7 +225,7 @@ std::size_t run(const raft::handle_t& handle,
CUML_LOG_DEBUG("--> Computing vertex degrees");
raft::common::nvtx::push_range("Trace::Dbscan::VertexDeg");
VertexDeg::run<Type_f, Index_>(
handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream);
handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream, metric);
raft::common::nvtx::pop_range();
}
raft::update_host(&curradjlen, vd + n_points, 1, stream);
Expand Down
68 changes: 61 additions & 7 deletions cpp/src/dbscan/vertexdeg/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

#include <cuda_runtime.h>
#include <math.h>
#include <raft/linalg/matrix_vector_op.hpp>
#include <raft/linalg/norm.cuh>
#include <raft/spatial/knn/epsilon_neighborhood.hpp>
#include <rmm/device_uvector.hpp>

#include "pack.h"

Expand All @@ -35,19 +38,70 @@ void launcher(const raft::handle_t& handle,
Pack<value_t, index_t> data,
index_t start_vertex_id,
index_t batch_size,
cudaStream_t stream)
cudaStream_t stream,
raft::distance::DistanceType metric)
{
data.resetArray(stream, batch_size + 1);

ASSERT(sizeof(index_t) == 4 || sizeof(index_t) == 8, "index_t should be 4 or 8 bytes");

index_t m = data.N;
index_t n = min(data.N - start_vertex_id, batch_size);
index_t k = data.D;
value_t eps2 = data.eps * data.eps;
index_t m = data.N;
index_t n = min(data.N - start_vertex_id, batch_size);
index_t k = data.D;
value_t eps2;

raft::spatial::knn::epsUnexpL2SqNeighborhood<value_t, index_t>(
data.adj, data.vd, data.x, data.x + start_vertex_id * k, m, n, k, eps2, stream);
if (metric == raft::distance::DistanceType::CosineExpanded) {
rmm::device_uvector<value_t> rowNorms(m, stream);

raft::linalg::rowNorm(rowNorms.data(),
data.x,
k,
m,
raft::linalg::NormType::L2Norm,
true,
stream,
[] __device__(value_t in) { return sqrtf(in); });

/* Cast away constness because the output matrix for normalization cannot be of const type.
* Input matrix will be modified due to normalization.
*/
raft::linalg::matrixVectorOp(
const_cast<value_t*>(data.x),
data.x,
rowNorms.data(),
k,
m,
true,
true,
[] __device__(value_t mat_in, value_t vec_in) { return mat_in / vec_in; },
stream);

eps2 = 2 * data.eps;

raft::spatial::knn::epsUnexpL2SqNeighborhood<value_t, index_t>(
data.adj, data.vd, data.x, data.x + start_vertex_id * k, m, n, k, eps2, stream);

/**
* Restoring the input matrix after normalization.
*/
raft::linalg::matrixVectorOp(
const_cast<value_t*>(data.x),
data.x,
rowNorms.data(),
k,
m,
true,
true,
[] __device__(value_t mat_in, value_t vec_in) { return mat_in * vec_in; },
stream);
}

else {
eps2 = data.eps * data.eps;

raft::spatial::knn::epsUnexpL2SqNeighborhood<value_t, index_t>(
data.adj, data.vd, data.x, data.x + start_vertex_id * k, m, n, k, eps2, stream);
}
}

} // namespace Algo
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/dbscan/vertexdeg/runner.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,13 +36,14 @@ void run(const raft::handle_t& handle,
int algo,
Index_ start_vertex_id,
Index_ batch_size,
cudaStream_t stream)
cudaStream_t stream,
raft::distance::DistanceType metric)
{
Pack<Type_f, Index_> data = {vd, adj, x, eps, N, D};
switch (algo) {
case 0: Naive::launcher<Type_f, Index_>(data, start_vertex_id, batch_size, stream); break;
case 1:
Algo::launcher<Type_f, Index_>(handle, data, start_vertex_id, batch_size, stream);
Algo::launcher<Type_f, Index_>(handle, data, start_vertex_id, batch_size, stream, metric);
break;
case 2:
Precomputed::launcher<Type_f, Index_>(handle, data, start_vertex_id, batch_size, stream);
Expand Down
8 changes: 6 additions & 2 deletions python/cuml/cluster/dbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,13 @@ class DBSCAN(Base,
min_samples : int (default = 5)
The number of samples in a neighborhood such that this group can be
considered as an important core point (including the point itself).
metric: {'euclidean', 'precomputed'}, default = 'euclidean'
metric: {'euclidean', 'cosine', 'precomputed'}, default = 'euclidean'
The metric to use when calculating distances between points.
If metric is 'precomputed', X is assumed to be a distance matrix
and must be square.
The input will be modified temporarily when cosine distance is used
and the restored input matrix might not match completely
due to numerical rounding.
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`.
See :ref:`verbosity-levels` for more info.
Expand Down Expand Up @@ -266,7 +269,8 @@ class DBSCAN(Base,
metric_parsing = {
"L2": DistanceType.L2SqrtUnexpanded,
"euclidean": DistanceType.L2SqrtUnexpanded,
"precomputed": DistanceType.Precomputed,
"cosine": DistanceType.CosineExpanded,
"precomputed": DistanceType.Precomputed
}
if self.metric in metric_parsing:
metric = metric_parsing[self.metric.lower()]
Expand Down
32 changes: 32 additions & 0 deletions python/cuml/tests/test_dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,38 @@ def test_dbscan_precomputed(datatype, nrows, max_mbytes_per_batch, out_dtype):
cuml_dbscan.core_sample_indices_, eps)


@pytest.mark.parametrize('max_mbytes_per_batch', [unit_param(1),
quality_param(1e2), stress_param(None)])
@pytest.mark.parametrize('nrows', [unit_param(500), quality_param(5000),
stress_param(10000)])
@pytest.mark.parametrize('out_dtype', ["int32", "int64"])
def test_dbscan_cosine(nrows, max_mbytes_per_batch, out_dtype):
# 2-dimensional dataset for easy distance matrix computation
X, y = make_blobs(n_samples=nrows, cluster_std=0.01,
n_features=2, random_state=0)

eps = 0.1

cuml_dbscan = cuDBSCAN(eps=eps, min_samples=5, metric='cosine',
max_mbytes_per_batch=max_mbytes_per_batch,
output_type='numpy')

cu_labels = cuml_dbscan.fit_predict(X, out_dtype=out_dtype)

sk_dbscan = skDBSCAN(eps=eps, min_samples=5, metric='cosine',
algorithm='brute')

sk_labels = sk_dbscan.fit_predict(X)

# Check the core points are equal
assert array_equal(cuml_dbscan.core_sample_indices_,
sk_dbscan.core_sample_indices_)

# Check the labels are correct
assert_dbscan_equal(sk_labels, cu_labels, X,
cuml_dbscan.core_sample_indices_, eps)


@pytest.mark.parametrize("name", [
'noisy_moons',
'blobs',
Expand Down

0 comments on commit c8aebc3

Please sign in to comment.