diff --git a/cpp/src/dbscan/dbscan.cuh b/cpp/src/dbscan/dbscan.cuh index 8664fb0a65..804651d37f 100644 --- a/cpp/src/dbscan/dbscan.cuh +++ b/cpp/src/dbscan/dbscan.cuh @@ -110,7 +110,15 @@ void dbscanFitImpl(const raft::handle_t& handle, { raft::common::nvtx::range fun_scope("ML::Dbscan::Fit"); ML::Logger::get().setLevel(verbosity); - int algo_vd = (metric == raft::distance::Precomputed) ? 2 : 1; + // int algo_vd = (metric == raft::distance::Precomputed) ? 2 : 1; + int algo_vd; + if (metric == raft::distance::Precomputed) { + algo_vd = 2; + } else if (metric == raft::distance::CosineExpanded) { + algo_vd = 3; + } else { + algo_vd = 1; + } int algo_adj = 1; int algo_ccl = 2; diff --git a/cpp/src/dbscan/vertexdeg/cosine.cuh b/cpp/src/dbscan/vertexdeg/cosine.cuh new file mode 100644 index 0000000000..a21dceece6 --- /dev/null +++ b/cpp/src/dbscan/vertexdeg/cosine.cuh @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "pack.h" + +namespace ML { +namespace Dbscan { +namespace VertexDeg { +namespace Cosine { + +/** + * Calculates the vertex degree array and the epsilon neighborhood adjacency matrix for the batch. + */ +template +void launcher(const raft::handle_t& handle, + Pack data, + index_t start_vertex_id, + index_t batch_size, + cudaStream_t stream) +{ + data.resetArray(stream, batch_size + 1); + + ASSERT(sizeof(index_t) == 4 || sizeof(index_t) == 8, "index_t should be 4 or 8 bytes"); + + index_t m = data.N; + index_t n = min(data.N - start_vertex_id, batch_size); + index_t k = data.D; + value_t eps2 = 2 * data.eps; + + rmm::device_uvector rowNorms(m, stream); + rmm::device_uvector l2Normalized(m * n, stream); + + raft::linalg::rowNorm(rowNorms.data(), + data.x, + k, + m, + raft::linalg::NormType::L2Norm, + true, + stream, + [] __device__(value_t in) { return sqrtf(in); }); + + raft::linalg::matrixVectorOp( + l2Normalized.data(), + data.x, + rowNorms.data(), + k, + m, + true, + true, + [] __device__(value_t mat_in, value_t vec_in) { return mat_in / vec_in; }, + stream); + + raft::spatial::knn::epsUnexpL2SqNeighborhood( + data.adj, + data.vd, + l2Normalized.data(), + l2Normalized.data() + start_vertex_id * k, + m, + n, + k, + eps2, + stream); +} + +} // namespace Cosine +} // end namespace VertexDeg +} // end namespace Dbscan +} // namespace ML diff --git a/cpp/src/dbscan/vertexdeg/runner.cuh b/cpp/src/dbscan/vertexdeg/runner.cuh index 082a2ac46f..3a60da69ea 100644 --- a/cpp/src/dbscan/vertexdeg/runner.cuh +++ b/cpp/src/dbscan/vertexdeg/runner.cuh @@ -17,6 +17,7 @@ #pragma once #include "algo.cuh" +#include "cosine.cuh" #include "naive.cuh" #include "pack.h" #include "precomputed.cuh" @@ -47,6 +48,9 @@ void run(const raft::handle_t& handle, case 2: Precomputed::launcher(handle, data, start_vertex_id, batch_size, stream); break; + case 3: + Cosine::launcher(handle, data, start_vertex_id, batch_size, stream); + break; default: ASSERT(false, "Incorrect algo passed! '%d'", algo); } } diff --git a/python/cuml/cluster/dbscan.pyx b/python/cuml/cluster/dbscan.pyx index d00df0a822..26f8dd8db8 100644 --- a/python/cuml/cluster/dbscan.pyx +++ b/python/cuml/cluster/dbscan.pyx @@ -147,7 +147,7 @@ class DBSCAN(Base, min_samples : int (default = 5) The number of samples in a neighborhood such that this group can be considered as an important core point (including the point itself). - metric: {'euclidean', 'precomputed'}, default = 'euclidean' + metric: {'euclidean', 'precomputed', 'cosine'}, default = 'euclidean' The metric to use when calculating distances between points. If metric is 'precomputed', X is assumed to be a distance matrix and must be square. @@ -267,6 +267,7 @@ class DBSCAN(Base, "L2": DistanceType.L2SqrtUnexpanded, "euclidean": DistanceType.L2SqrtUnexpanded, "precomputed": DistanceType.Precomputed, + "cosine": DistanceType.CosineExpanded } if self.metric in metric_parsing: metric = metric_parsing[self.metric.lower()] diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index 2416a04613..5cfb39656c 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -16,6 +16,8 @@ # # distutils: language = c++ + +import cupy as cp import numpy as np import nvtx import rmm @@ -52,7 +54,6 @@ cimport cuml.common.cuda cimport cython - cdef extern from "cuml/ensemble/randomforest.hpp" namespace "ML": cdef void fit(handle_t& handle, @@ -208,7 +209,7 @@ class RandomForestClassifier(BaseRandomForestModel, node to be spilt. max_batch_size : int (default = 4096) Maximum number of nodes that can be processed in a given batch. - random_state : int (default = None) + random_state : int, RandomState instance or None, optional (default=None) Seed for the random number generator. Unseeded by default. Does not currently fully guarantee the exact same results. handle : cuml.Handle @@ -449,7 +450,26 @@ class RandomForestClassifier(BaseRandomForestModel, if self.random_state is None: seed_val = NULL else: - seed_val = self.random_state + if isinstance(self.random_state, np.uintp): + seed_val = self.random_state + else: + rs = self.random_state + if isinstance(rs, np.random.RandomState) or \ + isinstance(rs, cp.random.RandomState): + seed_val = rs.randint( + low=0, + high=np.iinfo(np.uintp).max, + dtype=np.uintp) + elif isinstance(rs, np.random.Generator): + seed_val = rs.integers( + low=0, + high=np.iinfo(np.uintp).max, + dtype=np.uintp) + else: + seed_val = np.random.default_rng(rs).integers( + low=0, + high=np.iinfo(np.uintp).max, + dtype=np.uintp) rf_params = set_rf_params( self.max_depth, self.max_leaves, diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index ef32c8f917..9bd1922fbd 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -16,6 +16,7 @@ # distutils: language = c++ +import cupy as cp import numpy as np import nvtx import rmm @@ -214,7 +215,7 @@ class RandomForestRegressor(BaseRandomForestModel, * for mean square error' : ``'mse'`` max_batch_size : int (default = 4096) Maximum number of nodes that can be processed in a given batch. - random_state : int (default = None) + random_state : int, RandomState instance or None, optional (default=None) Seed for the random number generator. Unseeded by default. Does not currently fully guarantee the exact same results. handle : cuml.Handle @@ -436,9 +437,31 @@ class RandomForestRegressor(BaseRandomForestModel, new RandomForestMetaData[double, double]() self.rf_forest64 = rf_forest64 if self.random_state is None: - seed_val = NULL + seed_val = NULL else: - seed_val = self.random_state + if isinstance(self.random_state, np.uint64): + seed_val = self.random_state + # Otherwise create a RandomState instance to generate a new + # np.uintp + else: + rs = self.random_state + if isinstance(rs, np.random.RandomState) or \ + isinstance(rs, cp.random.RandomState): + seed_val = rs.randint( + low=0, + high=np.iinfo(np.uint64).max, + dtype=np.uint64) + elif isinstance(self.random_state, np.random.Generator): + seed_val = rs.integers( + low=0, + high=np.iinfo(np.uint64).max, + dtype=np.uint64) + else: + seed_val = np.random.default_rng(rs).integers( + low=0, + high=np.iinfo(np.uint64).max, + dtype=np.uint64) + rf_params = set_rf_params( self.max_depth, self.max_leaves, diff --git a/python/cuml/tests/test_dbscan.py b/python/cuml/tests/test_dbscan.py index 8c8027d7ec..b23b2af67c 100644 --- a/python/cuml/tests/test_dbscan.py +++ b/python/cuml/tests/test_dbscan.py @@ -107,6 +107,41 @@ def test_dbscan_precomputed(datatype, nrows, max_mbytes_per_batch, out_dtype): algorithm="brute") sk_labels = sk_dbscan.fit_predict(X_dist) + print("cu_labels:", cu_labels) + print("sk_labels:", sk_labels) + + # Check the core points are equal + assert array_equal(cuml_dbscan.core_sample_indices_, + sk_dbscan.core_sample_indices_) + + # Check the labels are correct + assert_dbscan_equal(sk_labels, cu_labels, X, + cuml_dbscan.core_sample_indices_, eps) + + +@pytest.mark.parametrize('max_mbytes_per_batch', [unit_param(1), + quality_param(1e2), stress_param(None)]) +@pytest.mark.parametrize('nrows', [unit_param(500), quality_param(5000), + stress_param(10000)]) +@pytest.mark.parametrize('out_dtype', ["int32", "int64"]) +def test_dbscan_cosine(nrows, max_mbytes_per_batch, out_dtype): + # 2-dimensional dataset for easy distance matrix computation + X, y = make_blobs(n_samples=nrows, cluster_std=0.01, + n_features=2, random_state=0) + + eps = 0.1 + + cuml_dbscan = cuDBSCAN(eps=eps, min_samples=5, metric='cosine', + max_mbytes_per_batch=max_mbytes_per_batch, + output_type='numpy') + + cu_labels = cuml_dbscan.fit_predict(X, out_dtype=out_dtype) + + sk_dbscan = skDBSCAN(eps=eps, min_samples=5, metric='cosine', + algorithm='brute') + + sk_labels = sk_dbscan.fit_predict(X) + # Check the core points are equal assert array_equal(cuml_dbscan.core_sample_indices_, sk_dbscan.core_sample_indices_) diff --git a/python/cuml/tests/test_random_forest.py b/python/cuml/tests/test_random_forest.py index 85551aefb7..50a8759190 100644 --- a/python/cuml/tests/test_random_forest.py +++ b/python/cuml/tests/test_random_forest.py @@ -17,6 +17,7 @@ import warnings import cudf +import cupy as cp import numpy as np import random import json @@ -380,7 +381,10 @@ def test_rf_regression( @pytest.mark.parametrize("datatype", [np.float32, np.float64]) -def test_rf_classification_seed(small_clf, datatype): +@pytest.mark.parametrize("rs_class", + [int, np.uintp, np.random.RandomState, + cp.random.RandomState, np.random.default_rng]) +def test_rf_classification_seed(small_clf, datatype, rs_class): X, y = small_clf X = X.astype(datatype) @@ -391,30 +395,28 @@ def test_rf_classification_seed(small_clf, datatype): for i in range(8): seed = random.randint(100, 1e5) + cu_class_seed = rs_class(seed) + cu_class2_seed = rs_class(seed) # Initialize, fit and predict using cuML's # random forest classification model - cu_class = curfc(random_state=seed, n_streams=1) + cu_class = curfc(random_state=cu_class_seed, n_streams=1) cu_class.fit(X_train, y_train) # predict using FIL fil_preds_orig = cu_class.predict(X_test, predict_model="GPU") cu_preds_orig = cu_class.predict(X_test, predict_model="CPU") cu_acc_orig = accuracy_score(y_test, cu_preds_orig) - fil_preds_orig = np.reshape(fil_preds_orig, np.shape(cu_preds_orig)) - fil_acc_orig = accuracy_score(y_test, fil_preds_orig) # Initialize, fit and predict using cuML's # random forest classification model - cu_class2 = curfc(random_state=seed, n_streams=1) + cu_class2 = curfc(random_state=cu_class2_seed, n_streams=1) cu_class2.fit(X_train, y_train) # predict using FIL fil_preds_rerun = cu_class2.predict(X_test, predict_model="GPU") cu_preds_rerun = cu_class2.predict(X_test, predict_model="CPU") cu_acc_rerun = accuracy_score(y_test, cu_preds_rerun) - fil_preds_rerun = np.reshape(fil_preds_rerun, np.shape(cu_preds_rerun)) - fil_acc_rerun = accuracy_score(y_test, fil_preds_rerun) assert fil_acc_orig == fil_acc_rerun @@ -423,6 +425,51 @@ def test_rf_classification_seed(small_clf, datatype): assert (cu_preds_orig == cu_preds_rerun).all() +@pytest.mark.parametrize("datatype", [np.float32, np.float64]) +@pytest.mark.parametrize("rs_class", + [int, np.uint64, np.random.RandomState, + cp.random.RandomState, np.random.default_rng]) +def test_rf_regression_seed(special_reg, datatype, rs_class): + + X, y = special_reg + X = X.astype(datatype) + y = y.astype(datatype) + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=0 + ) + + for i in range(8): + seed = random.randint(100, 1e5) + cu_reg_seed = rs_class(seed) + cu_reg2_seed = rs_class(seed) + # Initialize, fit and predict using cuML's + # random forest classification model + cu_reg = curfr(random_state=cu_reg_seed, n_streams=1) + cu_reg.fit(X_train, y_train) + + # predict using FIL + fil_preds_orig = cu_reg.predict(X_test, predict_model="GPU") + cu_preds_orig = cu_reg.predict(X_test, predict_model="CPU") + + cu_r2_orig = r2_score(y_test, cu_preds_orig, convert_dtype=datatype) + fil_r2_orig = r2_score(y_test, fil_preds_orig, convert_dtype=datatype) + + cu_reg2 = curfr(random_state=cu_reg2_seed, n_streams=1) + cu_reg2.fit(X_train, y_train) + + # predict using FIL + fil_preds_rerun = cu_reg2.predict(X_test, predict_model="GPU") + cu_preds_rerun = cu_reg2.predict(X_test, predict_model="CPU") + + cu_r2_rerun = r2_score(y_test, cu_preds_rerun, + convert_dtype=datatype) + fil_r2_rerun = r2_score(y_test, fil_preds_rerun, + convert_dtype=datatype) + + assert abs(fil_r2_orig - fil_r2_rerun) <= 0.02 + assert abs(cu_r2_orig - cu_r2_rerun) <= 0.02 + + @pytest.mark.parametrize( "datatype", [(np.float64, np.float32), (np.float32, np.float64)] )