Skip to content

Commit

Permalink
DBSCAN utilize rbc eps_neighbors (rapidsai#5728)
Browse files Browse the repository at this point in the history
This PR enables rbc eps-neighbor computation via raft. The resulting adjacency matrix is sparse and allows to skip the implicit conversion.

Notes:
* the 'algorithm'-parameter was added to the DBSCAN init signature to allow the user to choose (default is 'brute', 'rbc' is optional)
* the memory management is still very conservative, assuming a dense adjacency matrix and therefore selecting comparably small batches
* in case maximum row length of a batch is sufficiently small the CSR structure can be computed in a single pass

CC @tfeher

Authors:
  - Malte Förster (https://github.com/mfoerste4)
  - Tamas Bela Feher (https://github.com/tfeher)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#5728
  • Loading branch information
mfoerste4 authored Mar 11, 2024
1 parent a3bed22 commit a6c0478
Show file tree
Hide file tree
Showing 15 changed files with 357 additions and 112 deletions.
1 change: 1 addition & 0 deletions cpp/examples/dbscan/dbscan_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ int main(int argc, char* argv[])
nullptr,
nullptr,
max_bytes_per_batch,
ML::Dbscan::EpsNnMethod::BRUTE_FORCE,
false);
CUDA_RT_CALL(cudaMemcpyAsync(
h_labels.data(), d_labels, nRows * sizeof(int), cudaMemcpyDeviceToHost, stream));
Expand Down
7 changes: 7 additions & 0 deletions cpp/include/cuml/cluster/dbscan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class handle_t;
namespace ML {
namespace Dbscan {

enum EpsNnMethod { BRUTE_FORCE, RBC };

/**
* @defgroup DbscanCpp C++ implementation of Dbscan algo
* @brief Fits a DBSCAN model on an input feature matrix and outputs the labels
Expand All @@ -53,6 +55,7 @@ namespace Dbscan {
* @param[in] max_bytes_per_batch the maximum number of megabytes to be used for
* each batch of the pairwise distance calculation. This enables the
* trade off between memory usage and algorithm execution time.
* @param[in] eps_nn_method method for computing epsilon neighborhood
* @param[in] verbosity verbosity level for logging messages during execution
* @param[in] opg whether we are running in a multi-node multi-GPU context
* @{
Expand All @@ -69,6 +72,7 @@ void fit(const raft::handle_t& handle,
int* core_sample_indices = nullptr,
float* sample_weight = nullptr,
size_t max_bytes_per_batch = 0,
EpsNnMethod eps_nn_method = BRUTE_FORCE,
int verbosity = CUML_LEVEL_INFO,
bool opg = false);
void fit(const raft::handle_t& handle,
Expand All @@ -82,6 +86,7 @@ void fit(const raft::handle_t& handle,
int* core_sample_indices = nullptr,
double* sample_weight = nullptr,
size_t max_bytes_per_batch = 0,
EpsNnMethod eps_nn_method = BRUTE_FORCE,
int verbosity = CUML_LEVEL_INFO,
bool opg = false);

Expand All @@ -96,6 +101,7 @@ void fit(const raft::handle_t& handle,
int64_t* core_sample_indices = nullptr,
float* sample_weight = nullptr,
size_t max_bytes_per_batch = 0,
EpsNnMethod eps_nn_method = BRUTE_FORCE,
int verbosity = CUML_LEVEL_INFO,
bool opg = false);
void fit(const raft::handle_t& handle,
Expand All @@ -109,6 +115,7 @@ void fit(const raft::handle_t& handle,
int64_t* core_sample_indices = nullptr,
double* sample_weight = nullptr,
size_t max_bytes_per_batch = 0,
EpsNnMethod eps_nn_method = BRUTE_FORCE,
int verbosity = CUML_LEVEL_INFO,
bool opg = false);

Expand Down
12 changes: 12 additions & 0 deletions cpp/src/dbscan/dbscan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void fit(const raft::handle_t& handle,
int* core_sample_indices,
float* sample_weight,
size_t max_bytes_per_batch,
EpsNnMethod eps_nn_method,
int verbosity,
bool opg)
{
Expand All @@ -49,6 +50,7 @@ void fit(const raft::handle_t& handle,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
eps_nn_method,
handle.get_stream(),
verbosity);
else
Expand All @@ -63,6 +65,7 @@ void fit(const raft::handle_t& handle,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
eps_nn_method,
handle.get_stream(),
verbosity);
}
Expand All @@ -78,6 +81,7 @@ void fit(const raft::handle_t& handle,
int* core_sample_indices,
double* sample_weight,
size_t max_bytes_per_batch,
EpsNnMethod eps_nn_method,
int verbosity,
bool opg)
{
Expand All @@ -93,6 +97,7 @@ void fit(const raft::handle_t& handle,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
eps_nn_method,
handle.get_stream(),
verbosity);
else
Expand All @@ -107,6 +112,7 @@ void fit(const raft::handle_t& handle,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
eps_nn_method,
handle.get_stream(),
verbosity);
}
Expand All @@ -122,6 +128,7 @@ void fit(const raft::handle_t& handle,
int64_t* core_sample_indices,
float* sample_weight,
size_t max_bytes_per_batch,
EpsNnMethod eps_nn_method,
int verbosity,
bool opg)
{
Expand All @@ -137,6 +144,7 @@ void fit(const raft::handle_t& handle,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
eps_nn_method,
handle.get_stream(),
verbosity);
else
Expand All @@ -151,6 +159,7 @@ void fit(const raft::handle_t& handle,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
eps_nn_method,
handle.get_stream(),
verbosity);
}
Expand All @@ -166,6 +175,7 @@ void fit(const raft::handle_t& handle,
int64_t* core_sample_indices,
double* sample_weight,
size_t max_bytes_per_batch,
EpsNnMethod eps_nn_method,
int verbosity,
bool opg)
{
Expand All @@ -181,6 +191,7 @@ void fit(const raft::handle_t& handle,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
eps_nn_method,
handle.get_stream(),
verbosity);
else
Expand All @@ -195,6 +206,7 @@ void fit(const raft::handle_t& handle,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
eps_nn_method,
handle.get_stream(),
verbosity);
}
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/dbscan/dbscan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ void dbscanFitImpl(const raft::handle_t& handle,
Index_* core_sample_indices,
T* sample_weight,
size_t max_mbytes_per_batch,
EpsNnMethod eps_nn_method,
cudaStream_t stream,
int verbosity)
{
Expand Down Expand Up @@ -184,6 +185,7 @@ void dbscanFitImpl(const raft::handle_t& handle,
algo_ccl,
NULL,
batch_size,
eps_nn_method,
stream,
metric);

Expand All @@ -206,6 +208,7 @@ void dbscanFitImpl(const raft::handle_t& handle,
algo_ccl,
workspace.data(),
batch_size,
eps_nn_method,
stream,
metric);
}
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/dbscan/dbscan_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ cumlError_t cumlSpDbscanFit(cumlHandle_t handle,
core_sample_indices,
NULL,
max_bytes_per_batch,
ML::Dbscan::EpsNnMethod::BRUTE_FORCE,
verbosity);
}
// TODO: Implement this
Expand Down Expand Up @@ -91,6 +92,7 @@ cumlError_t cumlDpDbscanFit(cumlHandle_t handle,
core_sample_indices,
NULL,
max_bytes_per_batch,
ML::Dbscan::EpsNnMethod::BRUTE_FORCE,
verbosity);
}
// TODO: Implement this
Expand Down
Loading

0 comments on commit a6c0478

Please sign in to comment.