Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] add cosine distance for DBSCAN #4775

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cpp/src/dbscan/dbscan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,15 @@ void dbscanFitImpl(const raft::handle_t& handle,
{
raft::common::nvtx::range fun_scope("ML::Dbscan::Fit");
ML::Logger::get().setLevel(verbosity);
int algo_vd = (metric == raft::distance::Precomputed) ? 2 : 1;
// int algo_vd = (metric == raft::distance::Precomputed) ? 2 : 1;
int algo_vd;
if (metric == raft::distance::Precomputed) {
algo_vd = 2;
} else if (metric == raft::distance::CosineExpanded) {
algo_vd = 3;
} else {
algo_vd = 1;
}
int algo_adj = 1;
int algo_ccl = 2;

Expand Down
90 changes: 90 additions & 0 deletions cpp/src/dbscan/vertexdeg/cosine.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuda_runtime.h>
#include <math.h>
#include <raft/linalg/matrix_vector_op.hpp>
#include <raft/linalg/norm.cuh>
#include <raft/spatial/knn/epsilon_neighborhood.hpp>
#include <rmm/device_uvector.hpp>

#include "pack.h"

namespace ML {
namespace Dbscan {
namespace VertexDeg {
namespace Cosine {

/**
* Calculates the vertex degree array and the epsilon neighborhood adjacency matrix for the batch.
*/
template <typename value_t, typename index_t = int>
void launcher(const raft::handle_t& handle,
Pack<value_t, index_t> data,
index_t start_vertex_id,
index_t batch_size,
cudaStream_t stream)
{
data.resetArray(stream, batch_size + 1);

ASSERT(sizeof(index_t) == 4 || sizeof(index_t) == 8, "index_t should be 4 or 8 bytes");

index_t m = data.N;
index_t n = min(data.N - start_vertex_id, batch_size);
index_t k = data.D;
value_t eps2 = 2 * data.eps;

rmm::device_uvector<value_t> rowNorms(m, stream);
rmm::device_uvector<value_t> l2Normalized(m * n, stream);

raft::linalg::rowNorm(rowNorms.data(),
data.x,
k,
m,
raft::linalg::NormType::L2Norm,
true,
stream,
[] __device__(value_t in) { return sqrtf(in); });

raft::linalg::matrixVectorOp(
l2Normalized.data(),
data.x,
rowNorms.data(),
k,
m,
true,
true,
[] __device__(value_t mat_in, value_t vec_in) { return mat_in / vec_in; },
stream);

raft::spatial::knn::epsUnexpL2SqNeighborhood<value_t, index_t>(
data.adj,
data.vd,
l2Normalized.data(),
l2Normalized.data() + start_vertex_id * k,
m,
n,
k,
eps2,
stream);
}

} // namespace Cosine
} // end namespace VertexDeg
} // end namespace Dbscan
} // namespace ML
4 changes: 4 additions & 0 deletions cpp/src/dbscan/vertexdeg/runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "algo.cuh"
#include "cosine.cuh"
#include "naive.cuh"
#include "pack.h"
#include "precomputed.cuh"
Expand Down Expand Up @@ -47,6 +48,9 @@ void run(const raft::handle_t& handle,
case 2:
Precomputed::launcher<Type_f, Index_>(handle, data, start_vertex_id, batch_size, stream);
break;
case 3:
Cosine::launcher<Type_f, Index_>(handle, data, start_vertex_id, batch_size, stream);
break;
default: ASSERT(false, "Incorrect algo passed! '%d'", algo);
}
}
Expand Down
3 changes: 2 additions & 1 deletion python/cuml/cluster/dbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class DBSCAN(Base,
min_samples : int (default = 5)
The number of samples in a neighborhood such that this group can be
considered as an important core point (including the point itself).
metric: {'euclidean', 'precomputed'}, default = 'euclidean'
metric: {'euclidean', 'precomputed', 'cosine'}, default = 'euclidean'
The metric to use when calculating distances between points.
If metric is 'precomputed', X is assumed to be a distance matrix
and must be square.
Expand Down Expand Up @@ -267,6 +267,7 @@ class DBSCAN(Base,
"L2": DistanceType.L2SqrtUnexpanded,
"euclidean": DistanceType.L2SqrtUnexpanded,
"precomputed": DistanceType.Precomputed,
"cosine": DistanceType.CosineExpanded
}
if self.metric in metric_parsing:
metric = metric_parsing[self.metric.lower()]
Expand Down
26 changes: 23 additions & 3 deletions python/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#

# distutils: language = c++

import cupy as cp
import numpy as np
import nvtx
import rmm
Expand Down Expand Up @@ -52,7 +54,6 @@ cimport cuml.common.cuda

cimport cython


cdef extern from "cuml/ensemble/randomforest.hpp" namespace "ML":

cdef void fit(handle_t& handle,
Expand Down Expand Up @@ -208,7 +209,7 @@ class RandomForestClassifier(BaseRandomForestModel,
node to be spilt.
max_batch_size : int (default = 4096)
Maximum number of nodes that can be processed in a given batch.
random_state : int (default = None)
random_state : int, RandomState instance or None, optional (default=None)
Seed for the random number generator. Unseeded by default. Does not
currently fully guarantee the exact same results.
handle : cuml.Handle
Expand Down Expand Up @@ -449,7 +450,26 @@ class RandomForestClassifier(BaseRandomForestModel,
if self.random_state is None:
seed_val = <uintptr_t>NULL
else:
seed_val = <uintptr_t>self.random_state
if isinstance(self.random_state, np.uintp):
seed_val = <uintptr_t>self.random_state
else:
rs = self.random_state
if isinstance(rs, np.random.RandomState) or \
isinstance(rs, cp.random.RandomState):
seed_val = <uintptr_t>rs.randint(
low=0,
high=np.iinfo(np.uintp).max,
dtype=np.uintp)
elif isinstance(rs, np.random.Generator):
seed_val = <uintptr_t>rs.integers(
low=0,
high=np.iinfo(np.uintp).max,
dtype=np.uintp)
else:
seed_val = <uintptr_t>np.random.default_rng(rs).integers(
low=0,
high=np.iinfo(np.uintp).max,
dtype=np.uintp)

rf_params = set_rf_params(<int> self.max_depth,
<int> self.max_leaves,
Expand Down
29 changes: 26 additions & 3 deletions python/cuml/ensemble/randomforestregressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# distutils: language = c++

import cupy as cp
import numpy as np
import nvtx
import rmm
Expand Down Expand Up @@ -214,7 +215,7 @@ class RandomForestRegressor(BaseRandomForestModel,
* for mean square error' : ``'mse'``
max_batch_size : int (default = 4096)
Maximum number of nodes that can be processed in a given batch.
random_state : int (default = None)
random_state : int, RandomState instance or None, optional (default=None)
Seed for the random number generator. Unseeded by default. Does not
currently fully guarantee the exact same results.
handle : cuml.Handle
Expand Down Expand Up @@ -436,9 +437,31 @@ class RandomForestRegressor(BaseRandomForestModel,
new RandomForestMetaData[double, double]()
self.rf_forest64 = <uintptr_t> rf_forest64
if self.random_state is None:
seed_val = <uintptr_t>NULL
seed_val = <uint64_t>NULL
else:
seed_val = <uintptr_t>self.random_state
if isinstance(self.random_state, np.uint64):
seed_val = <uint64_t>self.random_state
# Otherwise create a RandomState instance to generate a new
# np.uintp
else:
rs = self.random_state
if isinstance(rs, np.random.RandomState) or \
isinstance(rs, cp.random.RandomState):
seed_val = <uint64_t>rs.randint(
low=0,
high=np.iinfo(np.uint64).max,
dtype=np.uint64)
elif isinstance(self.random_state, np.random.Generator):
seed_val = <uint64_t>rs.integers(
low=0,
high=np.iinfo(np.uint64).max,
dtype=np.uint64)
else:
seed_val = <uint64_t>np.random.default_rng(rs).integers(
low=0,
high=np.iinfo(np.uint64).max,
dtype=np.uint64)


rf_params = set_rf_params(<int> self.max_depth,
<int> self.max_leaves,
Expand Down
35 changes: 35 additions & 0 deletions python/cuml/tests/test_dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,41 @@ def test_dbscan_precomputed(datatype, nrows, max_mbytes_per_batch, out_dtype):
algorithm="brute")
sk_labels = sk_dbscan.fit_predict(X_dist)

print("cu_labels:", cu_labels)
print("sk_labels:", sk_labels)

# Check the core points are equal
assert array_equal(cuml_dbscan.core_sample_indices_,
sk_dbscan.core_sample_indices_)

# Check the labels are correct
assert_dbscan_equal(sk_labels, cu_labels, X,
cuml_dbscan.core_sample_indices_, eps)


@pytest.mark.parametrize('max_mbytes_per_batch', [unit_param(1),
quality_param(1e2), stress_param(None)])
@pytest.mark.parametrize('nrows', [unit_param(500), quality_param(5000),
stress_param(10000)])
@pytest.mark.parametrize('out_dtype', ["int32", "int64"])
def test_dbscan_cosine(nrows, max_mbytes_per_batch, out_dtype):
# 2-dimensional dataset for easy distance matrix computation
X, y = make_blobs(n_samples=nrows, cluster_std=0.01,
n_features=2, random_state=0)

eps = 0.1

cuml_dbscan = cuDBSCAN(eps=eps, min_samples=5, metric='cosine',
max_mbytes_per_batch=max_mbytes_per_batch,
output_type='numpy')

cu_labels = cuml_dbscan.fit_predict(X, out_dtype=out_dtype)

sk_dbscan = skDBSCAN(eps=eps, min_samples=5, metric='cosine',
algorithm='brute')

sk_labels = sk_dbscan.fit_predict(X)

# Check the core points are equal
assert array_equal(cuml_dbscan.core_sample_indices_,
sk_dbscan.core_sample_indices_)
Expand Down
61 changes: 54 additions & 7 deletions python/cuml/tests/test_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import warnings
import cudf
import cupy as cp
import numpy as np
import random
import json
Expand Down Expand Up @@ -380,7 +381,10 @@ def test_rf_regression(


@pytest.mark.parametrize("datatype", [np.float32, np.float64])
def test_rf_classification_seed(small_clf, datatype):
@pytest.mark.parametrize("rs_class",
[int, np.uintp, np.random.RandomState,
cp.random.RandomState, np.random.default_rng])
def test_rf_classification_seed(small_clf, datatype, rs_class):

X, y = small_clf
X = X.astype(datatype)
Expand All @@ -391,30 +395,28 @@ def test_rf_classification_seed(small_clf, datatype):

for i in range(8):
seed = random.randint(100, 1e5)
cu_class_seed = rs_class(seed)
cu_class2_seed = rs_class(seed)
# Initialize, fit and predict using cuML's
# random forest classification model
cu_class = curfc(random_state=seed, n_streams=1)
cu_class = curfc(random_state=cu_class_seed, n_streams=1)
cu_class.fit(X_train, y_train)

# predict using FIL
fil_preds_orig = cu_class.predict(X_test, predict_model="GPU")
cu_preds_orig = cu_class.predict(X_test, predict_model="CPU")
cu_acc_orig = accuracy_score(y_test, cu_preds_orig)
fil_preds_orig = np.reshape(fil_preds_orig, np.shape(cu_preds_orig))

fil_acc_orig = accuracy_score(y_test, fil_preds_orig)

# Initialize, fit and predict using cuML's
# random forest classification model
cu_class2 = curfc(random_state=seed, n_streams=1)
cu_class2 = curfc(random_state=cu_class2_seed, n_streams=1)
cu_class2.fit(X_train, y_train)

# predict using FIL
fil_preds_rerun = cu_class2.predict(X_test, predict_model="GPU")
cu_preds_rerun = cu_class2.predict(X_test, predict_model="CPU")
cu_acc_rerun = accuracy_score(y_test, cu_preds_rerun)
fil_preds_rerun = np.reshape(fil_preds_rerun, np.shape(cu_preds_rerun))

fil_acc_rerun = accuracy_score(y_test, fil_preds_rerun)

assert fil_acc_orig == fil_acc_rerun
Expand All @@ -423,6 +425,51 @@ def test_rf_classification_seed(small_clf, datatype):
assert (cu_preds_orig == cu_preds_rerun).all()


@pytest.mark.parametrize("datatype", [np.float32, np.float64])
@pytest.mark.parametrize("rs_class",
[int, np.uint64, np.random.RandomState,
cp.random.RandomState, np.random.default_rng])
def test_rf_regression_seed(special_reg, datatype, rs_class):

X, y = special_reg
X = X.astype(datatype)
y = y.astype(datatype)
X_train, X_test, y_train, y_test = train_test_split(
X, y, train_size=0.8, random_state=0
)

for i in range(8):
seed = random.randint(100, 1e5)
cu_reg_seed = rs_class(seed)
cu_reg2_seed = rs_class(seed)
# Initialize, fit and predict using cuML's
# random forest classification model
cu_reg = curfr(random_state=cu_reg_seed, n_streams=1)
cu_reg.fit(X_train, y_train)

# predict using FIL
fil_preds_orig = cu_reg.predict(X_test, predict_model="GPU")
cu_preds_orig = cu_reg.predict(X_test, predict_model="CPU")

cu_r2_orig = r2_score(y_test, cu_preds_orig, convert_dtype=datatype)
fil_r2_orig = r2_score(y_test, fil_preds_orig, convert_dtype=datatype)

cu_reg2 = curfr(random_state=cu_reg2_seed, n_streams=1)
cu_reg2.fit(X_train, y_train)

# predict using FIL
fil_preds_rerun = cu_reg2.predict(X_test, predict_model="GPU")
cu_preds_rerun = cu_reg2.predict(X_test, predict_model="CPU")

cu_r2_rerun = r2_score(y_test, cu_preds_rerun,
convert_dtype=datatype)
fil_r2_rerun = r2_score(y_test, fil_preds_rerun,
convert_dtype=datatype)

assert abs(fil_r2_orig - fil_r2_rerun) <= 0.02
assert abs(cu_r2_orig - cu_r2_rerun) <= 0.02


@pytest.mark.parametrize(
"datatype", [(np.float64, np.float32), (np.float32, np.float64)]
)
Expand Down