Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sample_weight parameter to dbscan.fit #5574

Merged
merged 8 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/bench/sg/dbscan.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,6 +59,7 @@ class Dbscan : public BlobsFixture<D, int> {
raft::distance::L2SqrtUnexpanded,
this->data.y.data(),
this->core_sample_indices,
nullptr,
dParams.max_bytes_per_batch);
state.SetItemsProcessed(this->params.nrows * this->params.ncols);
});
Expand Down
3 changes: 2 additions & 1 deletion cpp/examples/dbscan/dbscan_example.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -205,6 +205,7 @@ int main(int argc, char* argv[])
raft::distance::L2SqrtUnexpanded,
d_labels,
nullptr,
nullptr,
max_bytes_per_batch,
false);
CUDA_RT_CALL(cudaMemcpyAsync(
Expand Down
10 changes: 9 additions & 1 deletion cpp/include/cuml/cluster/dbscan.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,6 +43,10 @@ namespace Dbscan {
* indices of each core point. If the number of core points is less
* than n_rows, the right will be padded with -1. Setting this to
* NULL will prevent calculating the core sample indices
* @param[in] sample_weight (size n_rows) input array containing the
* weight of each sample to be taken instead of a plain sum to
* fulfill the min_pts criteria for core points.
* NULL will default to weights of 1 for all samples
* @param[in] max_bytes_per_batch the maximum number of megabytes to be used for
* each batch of the pairwise distance calculation. This enables the
* trade off between memory usage and algorithm execution time.
Expand All @@ -60,6 +64,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int* labels,
int* core_sample_indices = nullptr,
float* sample_weight = nullptr,
size_t max_bytes_per_batch = 0,
int verbosity = CUML_LEVEL_INFO,
bool opg = false);
Expand All @@ -72,6 +77,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int* labels,
int* core_sample_indices = nullptr,
double* sample_weight = nullptr,
size_t max_bytes_per_batch = 0,
int verbosity = CUML_LEVEL_INFO,
bool opg = false);
Expand All @@ -85,6 +91,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int64_t* labels,
int64_t* core_sample_indices = nullptr,
float* sample_weight = nullptr,
size_t max_bytes_per_batch = 0,
int verbosity = CUML_LEVEL_INFO,
bool opg = false);
Expand All @@ -97,6 +104,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int64_t* labels,
int64_t* core_sample_indices = nullptr,
double* sample_weight = nullptr,
size_t max_bytes_per_batch = 0,
int verbosity = CUML_LEVEL_INFO,
bool opg = false);
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/dbscan/corepoints/compute.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,16 +28,16 @@ namespace CorePoints {
/**
* Compute the core points from the vertex degrees and min_pts criterion
* @param[in] handle cuML handle
* @param[in] vd Vertex degrees
* @param[in] vd Vertex degrees (optionally weighted)
* @param[out] mask Boolean core point mask
* @param[in] min_pts Core point criterion
* @param[in] start_vertex_id First point of the batch
* @param[in] batch_size Batch size
* @param[in] stream CUDA stream
*/
template <typename Index_ = int>
template <typename Values_ = int, typename Index_ = int>
void compute(const raft::handle_t& handle,
const Index_* vd,
const Values_* vd,
bool* mask,
Index_ min_pts,
Index_ start_vertex_id,
Expand All @@ -47,7 +47,7 @@ void compute(const raft::handle_t& handle,
auto counting = thrust::make_counting_iterator<Index_>(0);
thrust::for_each(
handle.get_thrust_policy(), counting, counting + batch_size, [=] __device__(Index_ idx) {
mask[idx + start_vertex_id] = vd[idx] >= min_pts;
mask[idx + start_vertex_id] = (Index_)vd[idx] >= min_pts;
});
}

Expand Down
14 changes: 13 additions & 1 deletion cpp/src/dbscan/dbscan.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int* labels,
int* core_sample_indices,
float* sample_weight,
size_t max_bytes_per_batch,
int verbosity,
bool opg)
Expand All @@ -45,6 +46,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -58,6 +60,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -72,6 +75,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int* labels,
int* core_sample_indices,
double* sample_weight,
size_t max_bytes_per_batch,
int verbosity,
bool opg)
Expand All @@ -86,6 +90,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -99,6 +104,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -113,6 +119,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int64_t* labels,
int64_t* core_sample_indices,
float* sample_weight,
size_t max_bytes_per_batch,
int verbosity,
bool opg)
Expand All @@ -127,6 +134,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -140,6 +148,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -154,6 +163,7 @@ void fit(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int64_t* labels,
int64_t* core_sample_indices,
double* sample_weight,
size_t max_bytes_per_batch,
int verbosity,
bool opg)
Expand All @@ -168,6 +178,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand All @@ -181,6 +192,7 @@ void fit(const raft::handle_t& handle,
metric,
labels,
core_sample_indices,
sample_weight,
max_bytes_per_batch,
handle.get_stream(),
verbosity);
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/dbscan/dbscan.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -104,6 +104,7 @@ void dbscanFitImpl(const raft::handle_t& handle,
raft::distance::DistanceType metric,
Index_* labels,
Index_* core_sample_indices,
T* sample_weight,
size_t max_mbytes_per_batch,
cudaStream_t stream,
int verbosity)
Expand Down Expand Up @@ -177,6 +178,7 @@ void dbscanFitImpl(const raft::handle_t& handle,
min_pts,
labels,
core_sample_indices,
sample_weight,
algo_vd,
algo_adj,
algo_ccl,
Expand All @@ -198,6 +200,7 @@ void dbscanFitImpl(const raft::handle_t& handle,
min_pts,
labels,
core_sample_indices,
sample_weight,
algo_vd,
algo_adj,
algo_ccl,
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/dbscan/dbscan_api.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,6 +47,7 @@ cumlError_t cumlSpDbscanFit(cumlHandle_t handle,
raft::distance::L2SqrtUnexpanded,
labels,
core_sample_indices,
NULL,
max_bytes_per_batch,
verbosity);
}
Expand Down Expand Up @@ -88,6 +89,7 @@ cumlError_t cumlDpDbscanFit(cumlHandle_t handle,
raft::distance::L2SqrtUnexpanded,
labels,
core_sample_indices,
NULL,
max_bytes_per_batch,
verbosity);
}
Expand Down
54 changes: 46 additions & 8 deletions cpp/src/dbscan/runner.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -110,6 +110,7 @@ std::size_t run(const raft::handle_t& handle,
Index_ min_pts,
Index_* labels,
Index_* core_indices,
const Type_f* sample_weight,
int algo_vd,
int algo_adj,
int algo_ccl,
Expand Down Expand Up @@ -146,6 +147,8 @@ std::size_t run(const raft::handle_t& handle,
std::size_t ex_scan_size = raft::alignTo<std::size_t>(sizeof(Index_) * batch_size, align);
std::size_t row_cnt_size = raft::alignTo<std::size_t>(sizeof(Index_) * batch_size, align);
std::size_t labels_size = raft::alignTo<std::size_t>(sizeof(Index_) * N, align);
std::size_t wght_sum_size =
sample_weight != nullptr ? raft::alignTo<std::size_t>(sizeof(Type_f) * batch_size, align) : 0;

Index_ MAX_LABEL = std::numeric_limits<Index_>::max();

Expand All @@ -157,8 +160,8 @@ std::size_t run(const raft::handle_t& handle,
(unsigned long)batch_size);

if (workspace == NULL) {
auto size =
adj_size + core_pts_size + m_size + vd_size + ex_scan_size + row_cnt_size + 2 * labels_size;
auto size = adj_size + core_pts_size + m_size + vd_size + ex_scan_size + row_cnt_size +
2 * labels_size + wght_sum_size;
return size;
}

Expand All @@ -183,6 +186,11 @@ std::size_t run(const raft::handle_t& handle,
temp += labels_size;
Index_* work_buffer = (Index_*)temp;
temp += labels_size;
Type_f* wght_sum = nullptr;
if (sample_weight != nullptr) {
wght_sum = (Type_f*)temp;
temp += wght_sum_size;
}

// Compute the mask
// 1. Compute the part owned by this worker (reversed order of batches to
Expand All @@ -196,13 +204,31 @@ std::size_t run(const raft::handle_t& handle,

CUML_LOG_DEBUG("--> Computing vertex degrees");
raft::common::nvtx::push_range("Trace::Dbscan::VertexDeg");
VertexDeg::run<Type_f, Index_>(
handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream, metric);
VertexDeg::run<Type_f, Index_>(handle,
adj,
vd,
wght_sum,
x,
sample_weight,
eps,
N,
D,
algo_vd,
start_vertex_id,
n_points,
stream,
metric);
raft::common::nvtx::pop_range();

CUML_LOG_DEBUG("--> Computing core point mask");
raft::common::nvtx::push_range("Trace::Dbscan::CorePoints");
CorePoints::compute<Index_>(handle, vd, core_pts, min_pts, start_vertex_id, n_points, stream);
if (wght_sum != nullptr) {
CorePoints::compute<Type_f, Index_>(
handle, wght_sum, core_pts, min_pts, start_vertex_id, n_points, stream);
} else {
CorePoints::compute<Index_, Index_>(
handle, vd, core_pts, min_pts, start_vertex_id, n_points, stream);
}
raft::common::nvtx::pop_range();
}
// 2. Exchange with the other workers
Expand All @@ -224,8 +250,20 @@ std::size_t run(const raft::handle_t& handle,
if (i > 0) {
CUML_LOG_DEBUG("--> Computing vertex degrees");
raft::common::nvtx::push_range("Trace::Dbscan::VertexDeg");
VertexDeg::run<Type_f, Index_>(
handle, adj, vd, x, eps, N, D, algo_vd, start_vertex_id, n_points, stream, metric);
VertexDeg::run<Type_f, Index_>(handle,
adj,
vd,
nullptr,
x,
nullptr,
eps,
N,
D,
algo_vd,
start_vertex_id,
n_points,
stream,
metric);
raft::common::nvtx::pop_range();
}
raft::update_host(&curradjlen, vd + n_points, 1, stream);
Expand Down
Loading