Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] reduce memory pressure in membership vector computation #5268

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
e49d06a
membership_vector initial commit
tarang-jain Feb 18, 2023
436b180
Further updates to membership_vector
tarang-jain Feb 22, 2023
48030b8
Merge branch 'branch-23.04' into fea-membership-vector
tarang-jain Feb 22, 2023
7912dba
Initial testing membership_vector
tarang-jain Feb 23, 2023
4b41edb
Debug statements
tarang-jain Feb 23, 2023
fe0fd34
Merge branch 'fea-membership-vector' of https://github.com/tarang-jai…
tarang-jain Feb 23, 2023
9d5badc
debugging membership_vector
tarang-jain Feb 24, 2023
19f9dd8
membership_vector first working impl
tarang-jain Feb 28, 2023
a4b565c
GoogleTest intermediate commit
tarang-jain Feb 28, 2023
1f4bf78
GTest working
tarang-jain Feb 28, 2023
fdf100b
working tests and styling changes
tarang-jain Feb 28, 2023
e18096a
replace with raft mdspan primitives and add FastIntDiv
tarang-jain Mar 1, 2023
c2aa77e
Merge branch 'branch-23.04' into fea-membership-vector
tarang-jain Mar 1, 2023
182ba31
cpu support
tarang-jain Mar 1, 2023
366ef26
Fix failing pytest
tarang-jain Mar 7, 2023
b60d869
Merge branch 'branch-23.04' into fea-membership-vector
tarang-jain Mar 7, 2023
6bfaae2
modification after merge
tarang-jain Mar 7, 2023
c4e0bf1
Update softmax with raft::linalg reduction
tarang-jain Mar 8, 2023
fb634e4
Remove sync stream
tarang-jain Mar 9, 2023
a49ba87
memory study commit (to be reversed)
tarang-jain Mar 11, 2023
4ed9fd7
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-memb…
tarang-jain Mar 11, 2023
d1712c0
first commit (working)
tarang-jain Mar 13, 2023
f41416a
set batch_size as an arg
tarang-jain Mar 14, 2023
333077a
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-new-…
tarang-jain Mar 14, 2023
71217e2
working build, styling changes
tarang-jain Mar 14, 2023
bdaefa5
batch_size added to membership_vector
tarang-jain Mar 17, 2023
04e76cb
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-new-…
tarang-jain Mar 17, 2023
fa7b44e
Style fix
tarang-jain Mar 17, 2023
45f8ca4
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-memb…
tarang-jain Mar 17, 2023
367de04
Remove print debug statements
tarang-jain Mar 17, 2023
eeb52c2
Resolved failing pytest
tarang-jain Mar 20, 2023
612afb1
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-new-…
tarang-jain Mar 20, 2023
980b1f7
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-memb…
tarang-jain Mar 20, 2023
0bf779b
copyright changes
tarang-jain Mar 20, 2023
98aa237
Merge branch 'branch-23.04' into fea-membership-vector
tarang-jain Mar 20, 2023
3a38769
Merge branch 'branch-23.04' into fea-new-reduce-memory-pressure-apmv
tarang-jain Mar 20, 2023
d387026
Merge branch 'branch-23.04' into fea-membership-vector
tarang-jain Mar 27, 2023
ed40e22
Updates after PR reviews
tarang-jain Mar 28, 2023
387cde8
Merge branch 'fea-membership-vector' of https://github.com/tarang-jai…
tarang-jain Mar 28, 2023
092b3f8
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-memb…
tarang-jain Mar 28, 2023
ef85fd3
Update height_argmax
tarang-jain Mar 28, 2023
52eda5c
Intermediate merge commit
tarang-jain Mar 29, 2023
d8da560
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-new-…
tarang-jain Mar 29, 2023
7a95bfe
Update after merge membership_vector
tarang-jain Mar 29, 2023
dc92f90
Updates after PR Reviews
tarang-jain Mar 30, 2023
615ad10
Merge branch 'branch-23.04' of github.com:rapidsai/cuml into fea-new-…
tarang-jain Mar 30, 2023
7b89484
Merge branch 'branch-23.04' into fea-new-reduce-memory-pressure-apmv
tarang-jain Mar 30, 2023
7f7f0a4
Resolve merge conflicts
tarang-jain Mar 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cpp/include/cuml/cluster/hdbscan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,8 @@ void compute_all_points_membership_vectors(
HDBSCAN::Common::PredictionData<int, float>& prediction_data,
const float* X,
raft::distance::DistanceType metric,
float* membership_vec);
float* membership_vec,
int batch_size);

void compute_membership_vector(const raft::handle_t& handle,
HDBSCAN::Common::CondensedHierarchy<int, float>& condensed_tree,
Expand All @@ -468,7 +469,8 @@ void compute_membership_vector(const raft::handle_t& handle,
size_t n_prediction_points,
int min_samples,
raft::distance::DistanceType metric,
float* membership_vec);
float* membership_vec,
int batch_size);

void out_of_sample_predict(const raft::handle_t& handle,
HDBSCAN::Common::CondensedHierarchy<int, float>& condensed_tree,
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/hdbscan/detail/predict.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
82 changes: 48 additions & 34 deletions cpp/src/hdbscan/detail/soft_clustering.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,23 +36,15 @@
#include <raft/distance/distance_types.hpp>
#include <raft/label/classlabels.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/matrix/argmax.cuh>
#include <raft/util/fast_int_div.cuh>

#include <algorithm>
#include <cmath>
#include <limits>

#include <thrust/copy.h>
#include <thrust/execution_policy.h>
#include <thrust/fill.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/transform.h>
#include <thrust/unique.h>

#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>
Expand All @@ -62,7 +54,7 @@ namespace HDBSCAN {
namespace detail {
namespace Predict {

// This function is used for both -- all_points_membership_vectors and membership_vector
// Computing distance based membership for points in the original clustering on which the clusterer was trained and new points outside of the training data.
template <typename value_idx, typename value_t>
void dist_membership_vector(const raft::handle_t& handle,
const value_t* X,
Expand All @@ -75,6 +67,7 @@ void dist_membership_vector(const raft::handle_t& handle,
value_idx* exemplar_label_offsets,
value_t* dist_membership_vec,
raft::distance::DistanceType metric,
int batch_size,
bool softmax = false)
{
auto stream = handle.get_stream();
Expand All @@ -86,34 +79,46 @@ void dist_membership_vector(const raft::handle_t& handle,
raft::matrix::copyRows<value_t, value_idx, size_t>(
X, n_exemplars, n, exemplars_dense.data(), exemplar_idx, n_exemplars, stream, true);

// compute the distances using raft API
rmm::device_uvector<value_t> dist(n_queries * n_exemplars, stream);
// compute the number of batches based on the batch size
value_idx n_batches;

switch (metric) {
case raft::distance::DistanceType::L2SqrtExpanded:
raft::distance::
distance<raft::distance::DistanceType::L2SqrtExpanded, value_t, value_t, value_t, int>(
handle, query, exemplars_dense.data(), dist.data(), n_queries, n_exemplars, n, true);
if (batch_size == 0) {
n_batches = 1;
batch_size = n_queries;
}
else {
n_batches = raft::ceildiv((int)n_queries, (int)batch_size);
}
for(value_idx bid = 0; bid < n_batches; bid++) {
value_idx batch_offset = bid * batch_size;
value_idx samples_per_batch = min(batch_size, (int)n_queries - batch_offset);
rmm::device_uvector<value_t> dist(samples_per_batch * n_exemplars, stream);

// compute the distances using raft API
switch (metric) {
case raft::distance::DistanceType::L2SqrtExpanded:
raft::distance::
distance<raft::distance::DistanceType::L2SqrtExpanded, value_t, value_t, value_t, int>(
handle, query + batch_offset * n, exemplars_dense.data(), dist.data(), samples_per_batch, n_exemplars, n, true);
break;
case raft::distance::DistanceType::L1:
raft::distance::distance<raft::distance::DistanceType::L1, value_t, value_t, value_t, int>(
handle, query, exemplars_dense.data(), dist.data(), n_queries, n_exemplars, n, true);
handle, query + batch_offset * n, exemplars_dense.data(), dist.data(), samples_per_batch, n_exemplars, n, true);
break;
case raft::distance::DistanceType::CosineExpanded:
raft::distance::
distance<raft::distance::DistanceType::CosineExpanded, value_t, value_t, value_t, int>(
handle, query, exemplars_dense.data(), dist.data(), n_queries, n_exemplars, n, true);
handle, query + batch_offset * n, exemplars_dense.data(), dist.data(), samples_per_batch, n_exemplars, n, true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! This is a good solution to provide some significant savings in memory.

break;
default: ASSERT(false, "Incorrect metric passed!");
}

// compute the minimum distances to exemplars of each cluster
value_idx n_elements = n_queries * n_selected_clusters;
value_idx n_elements = samples_per_batch * n_selected_clusters;
auto min_dist = raft::make_device_vector<value_t, value_idx>(handle, n_elements);

// thrust::fill(exec_policy, min_dist.data_handle(), min_dist.data_handle + n_elements, );

auto reduction_op = [dist = dist.data(),
batch_offset,
divisor = raft::util::FastIntDiv(n_selected_clusters),
n_selected_clusters,
n_exemplars,
Expand All @@ -138,8 +143,8 @@ void dist_membership_vector(const raft::handle_t& handle,
if (softmax) {
thrust::transform(exec_policy,
min_dist.data_handle(),
min_dist.data_handle() + n_queries * n_selected_clusters,
dist_membership_vec,
min_dist.data_handle() + samples_per_batch * n_selected_clusters,
dist_membership_vec + batch_offset * n_selected_clusters,
[=] __device__(value_t val) {
if (val != 0) { return value_t(exp(1.0 / val)); }
return std::numeric_limits<value_t>::max();
Expand All @@ -150,17 +155,17 @@ void dist_membership_vector(const raft::handle_t& handle,
else {
thrust::transform(exec_policy,
min_dist.data_handle(),
min_dist.data_handle() + n_queries * n_selected_clusters,
dist_membership_vec,
min_dist.data_handle() + samples_per_batch * n_selected_clusters,
dist_membership_vec + batch_offset * n_selected_clusters,
[=] __device__(value_t val) {
if (val > 0) { return value_t(1.0 / val); }
return std::numeric_limits<value_t>::max() / n_selected_clusters;
});
}

}
// Normalize the obtained result to sum to 1.0
Utils::normalize(dist_membership_vec, n_selected_clusters, n_queries, stream);
};
}

template <typename value_idx, typename value_t, int tpb = 256>
void all_points_outlier_membership_vector(
Expand Down Expand Up @@ -378,27 +383,31 @@ void prob_in_some_cluster(const raft::handle_t& handle,
* @param[in] X all points (size m * n)
* @param[in] metric distance metric
* @param[out] membership_vec output membership vectors (size m * n_selected_clusters)
* @param[in] batch_size batch size to be used while computing distance based memberships
*/
template <typename value_idx, typename value_t>
void all_points_membership_vectors(const raft::handle_t& handle,
Common::CondensedHierarchy<value_idx, value_t>& condensed_tree,
Common::PredictionData<value_idx, value_t>& prediction_data,
const value_t* X,
raft::distance::DistanceType metric,
value_t* membership_vec)
value_t* membership_vec,
value_idx batch_size)
{
auto stream = handle.get_stream();
auto exec_policy = handle.get_thrust_policy();

size_t m = prediction_data.n_rows;
size_t n = prediction_data.n_cols;
RAFT_EXPECTS(0 <= batch_size && batch_size <= m, "Invalid batch_size. batch_size should be >= 0 and <= the number of samples in the training data");

auto parents = condensed_tree.get_parents();
auto children = condensed_tree.get_children();
auto lambdas = condensed_tree.get_lambdas();
auto n_edges = condensed_tree.get_n_edges();
auto n_clusters = condensed_tree.get_n_clusters();
auto n_leaves = condensed_tree.get_n_leaves();

size_t m = prediction_data.n_rows;
size_t n = prediction_data.n_cols;
value_idx n_selected_clusters = prediction_data.get_n_selected_clusters();
value_t* deaths = prediction_data.get_deaths();
value_idx* selected_clusters = prediction_data.get_selected_clusters();
Expand All @@ -421,7 +430,8 @@ void all_points_membership_vectors(const raft::handle_t& handle,
prediction_data.get_exemplar_idx(),
prediction_data.get_exemplar_label_offsets(),
dist_membership_vec.data(),
metric);
metric,
batch_size);

rmm::device_uvector<value_t> merge_heights(m * n_selected_clusters, stream);

Expand Down Expand Up @@ -485,6 +495,7 @@ void all_points_membership_vectors(const raft::handle_t& handle,
* @param[in] metric distance metric
* @param[in] min_samples neighborhood size during training (includes self-loop)
* @param[out] membership_vec output membership vectors (size n_prediction_points * n_selected_clusters)
* @param[in] batch_size batch size to be used while computing distance based memberships
*/
template <typename value_idx, typename value_t, int tpb = 256>
void membership_vector(const raft::handle_t& handle,
Expand All @@ -495,10 +506,12 @@ void membership_vector(const raft::handle_t& handle,
size_t n_prediction_points,
raft::distance::DistanceType metric,
int min_samples,
value_t* membership_vec)
value_t* membership_vec,
value_idx batch_size)
{
RAFT_EXPECTS(metric == raft::distance::DistanceType::L2SqrtExpanded,
"Currently only L2 expanded distance is supported");
RAFT_EXPECTS(0 <= batch_size && batch_size <= n_prediction_points, "Invalid batch_size. batch_size should be >= 0 and <= the number of points to predict");

auto stream = handle.get_stream();
auto exec_policy = handle.get_thrust_policy();
Expand All @@ -525,7 +538,8 @@ void membership_vector(const raft::handle_t& handle,
prediction_data.get_exemplar_idx(),
prediction_data.get_exemplar_label_offsets(),
dist_membership_vec.data(),
raft::distance::DistanceType::L2SqrtExpanded);
raft::distance::DistanceType::L2SqrtExpanded,
batch_size);

auto prediction_lambdas = raft::make_device_vector<value_t, value_idx>(handle, n_prediction_points);
rmm::device_uvector<value_idx> min_mr_inds(n_prediction_points, stream);
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/hdbscan/detail/utils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -213,7 +213,6 @@ void normalize(value_t* data, value_idx n, size_t m, cudaStream_t stream)
false,
[] __device__(value_t mat_in, value_t vec_in) { return mat_in / vec_in; },
stream);

}

/**
Expand Down
13 changes: 8 additions & 5 deletions cpp/src/hdbscan/hdbscan.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -92,10 +92,11 @@ void compute_all_points_membership_vectors(
HDBSCAN::Common::PredictionData<int, float>& prediction_data,
const float* X,
raft::distance::DistanceType metric,
float* membership_vec)
float* membership_vec,
int batch_size)
{
HDBSCAN::detail::Predict::all_points_membership_vectors(
handle, condensed_tree, prediction_data, X, metric, membership_vec);
handle, condensed_tree, prediction_data, X, metric, membership_vec, batch_size);
}

void compute_membership_vector(const raft::handle_t& handle,
Expand All @@ -106,7 +107,8 @@ void compute_membership_vector(const raft::handle_t& handle,
size_t n_prediction_points,
int min_samples,
raft::distance::DistanceType metric,
float* membership_vec)
float* membership_vec,
int batch_size)
{
// Note that (min_samples+1) is parsed to the approximate_predict function. This was done for the
// core distance computation to consistent with Scikit learn Contrib.
Expand All @@ -118,7 +120,8 @@ void compute_membership_vector(const raft::handle_t& handle,
n_prediction_points,
metric,
min_samples + 1,
membership_vec);
membership_vec,
batch_size);
}

void out_of_sample_predict(const raft::handle_t& handle,
Expand Down
6 changes: 4 additions & 2 deletions cpp/test/sg/hdbscan_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ class AllPointsMembershipVectorsTest
prediction_data_,
data.data(),
raft::distance::DistanceType::L2SqrtExpanded,
membership_vec.data());
membership_vec.data(),
0);

ASSERT_TRUE(MLCommon::devArrMatch(membership_vec.data(),
params.expected_probabilities.data(),
Expand Down Expand Up @@ -754,7 +755,8 @@ class MembershipVectorTest : public ::testing::TestWithParam<MembershipVectorInp
params.n_points_to_predict,
params.min_samples,
raft::distance::DistanceType::L2SqrtExpanded,
membership_vec.data());
membership_vec.data(),
0);

ASSERT_TRUE(MLCommon::devArrMatch(membership_vec.data(),
params.expected_probabilities.data(),
Expand Down
Loading