Skip to content

Commit

Permalink
Removing cub symbol from libraft-distance instantiation. (rapidsai#887)
Browse files Browse the repository at this point in the history
The recent template instantiations for `fusedL2NN` included `cub::KeyValuePair` in the symbols. Unfortunately, the cub symbols end up including all of the cuda architectures upon which it is built in the symbol name. For example, if libraft-distance is built for 5 architectures then someone building locally using that libraft-distance package for only their architecture will get an undefined symbol error. The best solution here is to not allow symbols from cub to leak through any publicly exposed APIs (even if those APIs are template instantiations of implementaiton details). 

I'm also adding a new `raft::KeyValuePair` object which has a conversion constructor so we can still easily make use of `cub::KeyValuePair` inside RAFT functions which use our new KVP object.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Seunghwa Kang (https://github.com/seunghwak)
  - Brad Rees (https://github.com/BradReesWork)

URL: rapidsai#887
  • Loading branch information
cjnolet authored Oct 5, 2022
1 parent 844a919 commit 55953f3
Show file tree
Hide file tree
Showing 16 changed files with 203 additions and 165 deletions.
34 changes: 17 additions & 17 deletions cpp/bench/spatial/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,28 @@ struct fused_l2_nn : public fixture {
uniform(handle, r, y.data(), p.n * p.k, T(-1.0), T(1.0));
raft::linalg::rowNorm(xn.data(), x.data(), p.k, p.m, raft::linalg::L2Norm, true, stream);
raft::linalg::rowNorm(yn.data(), y.data(), p.k, p.n, raft::linalg::L2Norm, true, stream);
raft::distance::initialize<T, cub::KeyValuePair<int, T>, int>(
raft::distance::initialize<T, raft::KeyValuePair<int, T>, int>(
handle, out.data(), p.m, std::numeric_limits<T>::max(), op);
}

void run_benchmark(::benchmark::State& state) override
{
loop_on_state(state, [this]() {
// it is enough to only benchmark the L2-squared metric
raft::distance::fusedL2NN<T, cub::KeyValuePair<int, T>, int>(out.data(),
x.data(),
y.data(),
xn.data(),
yn.data(),
params.m,
params.n,
params.k,
(void*)workspace.data(),
op,
pairRedOp,
false,
false,
stream);
raft::distance::fusedL2NN<T, raft::KeyValuePair<int, T>, int>(out.data(),
x.data(),
y.data(),
xn.data(),
yn.data(),
params.m,
params.n,
params.k,
(void*)workspace.data(),
op,
pairRedOp,
false,
false,
stream);
});

// Num distance calculations
Expand All @@ -92,7 +92,7 @@ struct fused_l2_nn : public fixture {
state.counters["FLOP/s"] = benchmark::Counter(
num_flops, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000);

state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(cub::KeyValuePair<int, T>),
state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(raft::KeyValuePair<int, T>),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);
state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(float),
Expand All @@ -105,7 +105,7 @@ struct fused_l2_nn : public fixture {
private:
fused_l2_nn_inputs params;
rmm::device_uvector<T> x, y, xn, yn;
rmm::device_uvector<cub::KeyValuePair<int, T>> out;
rmm::device_uvector<raft::KeyValuePair<int, T>> out;
rmm::device_uvector<int> workspace;
raft::distance::KVPMinReduce<int, T> pairRedOp;
raft::distance::MinAndDistanceReduceOp<int, T> op;
Expand Down
45 changes: 23 additions & 22 deletions cpp/include/raft/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <raft/core/device_mdarray.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/kvp.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/distance/distance_types.hpp>
Expand Down Expand Up @@ -278,7 +279,7 @@ void kmeans_fit_main(const raft::handle_t& handle,
// - key is the index of nearest cluster
// - value is the distance to the nearest cluster
auto minClusterAndDistance =
raft::make_device_vector<cub::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);
raft::make_device_vector<raft::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);

// temporary buffer to store L2 norm of centroids or distance matrix,
// destructor releases the resource
Expand All @@ -292,7 +293,7 @@ void kmeans_fit_main(const raft::handle_t& handle,
// resource
auto wtInCluster = raft::make_device_vector<DataT, IndexT>(handle, n_clusters);

rmm::device_scalar<cub::KeyValuePair<IndexT, DataT>> clusterCostD(stream);
rmm::device_scalar<raft::KeyValuePair<IndexT, DataT>> clusterCostD(stream);

// L2 norm of X: ||x||^2
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
Expand Down Expand Up @@ -337,12 +338,12 @@ void kmeans_fit_main(const raft::handle_t& handle,
workspace);

// Using TransformInputIteratorT to dereference an array of
// cub::KeyValuePair and converting them to just return the Key to be used
// raft::KeyValuePair and converting them to just return the Key to be used
// in reduce_rows_by_key prims
detail::KeyValueIndexOp<IndexT, DataT> conversion_op;
cub::TransformInputIterator<IndexT,
detail::KeyValueIndexOp<IndexT, DataT>,
cub::KeyValuePair<IndexT, DataT>*>
raft::KeyValuePair<IndexT, DataT>*>
itr(minClusterAndDistance.data_handle(), conversion_op);

workspace.resize(n_samples, stream);
Expand Down Expand Up @@ -400,14 +401,14 @@ void kmeans_fit_main(const raft::handle_t& handle,
itr_wt,
wtInCluster.size(),
newCentroids.data_handle(),
[=] __device__(cub::KeyValuePair<ptrdiff_t, DataT> map) { // predicate
[=] __device__(raft::KeyValuePair<ptrdiff_t, DataT> map) { // predicate
// copy when the # of samples in the cluster is 0
if (map.value == 0)
return true;
else
return false;
},
[=] __device__(cub::KeyValuePair<ptrdiff_t, DataT> map) { // map
[=] __device__(raft::KeyValuePair<ptrdiff_t, DataT> map) { // map
return map.key;
},
stream);
Expand Down Expand Up @@ -439,9 +440,9 @@ void kmeans_fit_main(const raft::handle_t& handle,
minClusterAndDistance.view(),
workspace,
raft::make_device_scalar_view(clusterCostD.data()),
[] __device__(const cub::KeyValuePair<IndexT, DataT>& a,
const cub::KeyValuePair<IndexT, DataT>& b) {
cub::KeyValuePair<IndexT, DataT> res;
[] __device__(const raft::KeyValuePair<IndexT, DataT>& a,
const raft::KeyValuePair<IndexT, DataT>& b) {
raft::KeyValuePair<IndexT, DataT> res;
res.key = 0;
res.value = a.value + b.value;
return res;
Expand Down Expand Up @@ -489,8 +490,8 @@ void kmeans_fit_main(const raft::handle_t& handle,
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
weight.data_handle(),
minClusterAndDistance.data_handle(),
[=] __device__(const cub::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
cub::KeyValuePair<IndexT, DataT> res;
[=] __device__(const raft::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
raft::KeyValuePair<IndexT, DataT> res;
res.value = kvp.value * wt;
res.key = kvp.key;
return res;
Expand All @@ -501,9 +502,9 @@ void kmeans_fit_main(const raft::handle_t& handle,
minClusterAndDistance.view(),
workspace,
raft::make_device_scalar_view(clusterCostD.data()),
[] __device__(const cub::KeyValuePair<IndexT, DataT>& a,
const cub::KeyValuePair<IndexT, DataT>& b) {
cub::KeyValuePair<IndexT, DataT> res;
[] __device__(const raft::KeyValuePair<IndexT, DataT>& a,
const raft::KeyValuePair<IndexT, DataT>& b) {
raft::KeyValuePair<IndexT, DataT> res;
res.key = 0;
res.value = a.value + b.value;
return res;
Expand Down Expand Up @@ -970,7 +971,7 @@ void kmeans_predict(handle_t const& handle,
if (normalize_weight) checkWeight(handle, weight.view(), workspace);

auto minClusterAndDistance =
raft::make_device_vector<cub::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);
raft::make_device_vector<raft::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);
rmm::device_uvector<DataT> L2NormBuf_OR_DistBuf(0, stream);

// L2 norm of X: ||x||^2
Expand Down Expand Up @@ -1001,15 +1002,15 @@ void kmeans_predict(handle_t const& handle,
workspace);

// calculate cluster cost phi_x(C)
rmm::device_scalar<cub::KeyValuePair<IndexT, DataT>> clusterCostD(stream);
rmm::device_scalar<raft::KeyValuePair<IndexT, DataT>> clusterCostD(stream);
// TODO: add different templates for InType of binaryOp to avoid thrust transform
thrust::transform(handle.get_thrust_policy(),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
weight.data_handle(),
minClusterAndDistance.data_handle(),
[=] __device__(const cub::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
cub::KeyValuePair<IndexT, DataT> res;
[=] __device__(const raft::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
raft::KeyValuePair<IndexT, DataT> res;
res.value = kvp.value * wt;
res.key = kvp.key;
return res;
Expand All @@ -1019,9 +1020,9 @@ void kmeans_predict(handle_t const& handle,
minClusterAndDistance.view(),
workspace,
raft::make_device_scalar_view(clusterCostD.data()),
[] __device__(const cub::KeyValuePair<IndexT, DataT>& a,
const cub::KeyValuePair<IndexT, DataT>& b) {
cub::KeyValuePair<IndexT, DataT> res;
[] __device__(const raft::KeyValuePair<IndexT, DataT>& a,
const raft::KeyValuePair<IndexT, DataT>& b) {
raft::KeyValuePair<IndexT, DataT> res;
res.key = 0;
res.value = a.value + b.value;
return res;
Expand All @@ -1033,7 +1034,7 @@ void kmeans_predict(handle_t const& handle,
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
labels.data_handle(),
[=] __device__(cub::KeyValuePair<IndexT, DataT> pair) { return pair.key; });
[=] __device__(raft::KeyValuePair<IndexT, DataT> pair) { return pair.key; });
}

template <typename DataT, typename IndexT = int>
Expand Down
31 changes: 16 additions & 15 deletions cpp/include/raft/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <raft/core/cudart_utils.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/kvp.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/distance/distance.cuh>
Expand Down Expand Up @@ -66,7 +67,7 @@ struct SamplingOp {
}

__host__ __device__ __forceinline__ bool operator()(
const cub::KeyValuePair<ptrdiff_t, DataT>& a) const
const raft::KeyValuePair<ptrdiff_t, DataT>& a) const
{
DataT prob_threshold = (DataT)rnd[a.key];

Expand All @@ -79,7 +80,7 @@ struct SamplingOp {
template <typename IndexT, typename DataT>
struct KeyValueIndexOp {
__host__ __device__ __forceinline__ IndexT
operator()(const cub::KeyValuePair<IndexT, DataT>& a) const
operator()(const raft::KeyValuePair<IndexT, DataT>& a) const
{
return a.key;
}
Expand Down Expand Up @@ -224,7 +225,7 @@ void sampleCentroids(const raft::handle_t& handle,
auto nSelected = raft::make_device_scalar<IndexT>(handle, 0);
cub::ArgIndexInputIterator<DataT*> ip_itr(minClusterDistance.data_handle());
auto sampledMinClusterDistance =
raft::make_device_vector<cub::KeyValuePair<ptrdiff_t, DataT>, IndexT>(handle, n_local_samples);
raft::make_device_vector<raft::KeyValuePair<ptrdiff_t, DataT>, IndexT>(handle, n_local_samples);
size_t temp_storage_bytes = 0;
RAFT_CUDA_TRY(cub::DeviceSelect::If(nullptr,
temp_storage_bytes,
Expand Down Expand Up @@ -254,7 +255,7 @@ void sampleCentroids(const raft::handle_t& handle,
thrust::for_each_n(handle.get_thrust_policy(),
sampledMinClusterDistance.data_handle(),
nPtsSampledInRank,
[=] __device__(cub::KeyValuePair<ptrdiff_t, DataT> val) {
[=] __device__(raft::KeyValuePair<ptrdiff_t, DataT> val) {
rawPtr_isSampleCentroid[val.key] = 1;
});

Expand All @@ -266,7 +267,7 @@ void sampleCentroids(const raft::handle_t& handle,
sampledMinClusterDistance.data_handle(),
nPtsSampledInRank,
inRankCp.data(),
[=] __device__(cub::KeyValuePair<ptrdiff_t, DataT> val) { // MapTransformOp
[=] __device__(raft::KeyValuePair<ptrdiff_t, DataT> val) { // MapTransformOp
return val.key;
},
stream);
Expand Down Expand Up @@ -355,7 +356,7 @@ void minClusterAndDistanceCompute(
const KMeansParams& params,
const raft::device_matrix_view<const DataT, IndexT> X,
const raft::device_matrix_view<const DataT, IndexT> centroids,
const raft::device_vector_view<cub::KeyValuePair<IndexT, DataT>, IndexT> minClusterAndDistance,
const raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT> minClusterAndDistance,
const raft::device_vector_view<DataT, IndexT> L2NormX,
rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
rmm::device_uvector<char>& workspace)
Expand Down Expand Up @@ -390,7 +391,7 @@ void minClusterAndDistanceCompute(
auto pairwiseDistance = raft::make_device_matrix_view<DataT, IndexT>(
L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize);

cub::KeyValuePair<IndexT, DataT> initial_value(0, std::numeric_limits<DataT>::max());
raft::KeyValuePair<IndexT, DataT> initial_value(0, std::numeric_limits<DataT>::max());

thrust::fill(handle.get_thrust_policy(),
minClusterAndDistance.data_handle(),
Expand All @@ -409,7 +410,7 @@ void minClusterAndDistanceCompute(

// minClusterAndDistanceView [ns x n_clusters]
auto minClusterAndDistanceView =
raft::make_device_vector_view<cub::KeyValuePair<IndexT, DataT>, IndexT>(
raft::make_device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT>(
minClusterAndDistance.data_handle() + dIdx, ns);

auto L2NormXView =
Expand All @@ -420,7 +421,7 @@ void minClusterAndDistanceCompute(
workspace.resize((sizeof(int)) * ns, stream);

// todo(lsugy): remove cIdx
raft::distance::fusedL2NNMinReduce<DataT, cub::KeyValuePair<IndexT, DataT>, IndexT>(
raft::distance::fusedL2NNMinReduce<DataT, raft::KeyValuePair<IndexT, DataT>, IndexT>(
minClusterAndDistanceView.data_handle(),
datasetView.data_handle(),
centroids.data_handle(),
Expand Down Expand Up @@ -466,15 +467,15 @@ void minClusterAndDistanceCompute(
stream,
true,
[=] __device__(const DataT val, const IndexT i) {
cub::KeyValuePair<IndexT, DataT> pair;
raft::KeyValuePair<IndexT, DataT> pair;
pair.key = cIdx + i;
pair.value = val;
return pair;
},
[=] __device__(cub::KeyValuePair<IndexT, DataT> a, cub::KeyValuePair<IndexT, DataT> b) {
[=] __device__(raft::KeyValuePair<IndexT, DataT> a, raft::KeyValuePair<IndexT, DataT> b) {
return (b.value < a.value) ? b : a;
},
[=] __device__(cub::KeyValuePair<IndexT, DataT> pair) { return pair; });
[=] __device__(raft::KeyValuePair<IndexT, DataT> pair) { return pair; });
}
}
}
Expand Down Expand Up @@ -623,7 +624,7 @@ void countSamplesInCluster(const raft::handle_t& handle,
// - key is the index of nearest cluster
// - value is the distance to the nearest cluster
auto minClusterAndDistance =
raft::make_device_vector<cub::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);
raft::make_device_vector<raft::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);

// temporary buffer to store distance matrix, destructor releases the resource
rmm::device_uvector<DataT> L2NormBuf_OR_DistBuf(0, stream);
Expand All @@ -642,13 +643,13 @@ void countSamplesInCluster(const raft::handle_t& handle,
L2NormBuf_OR_DistBuf,
workspace);

// Using TransformInputIteratorT to dereference an array of cub::KeyValuePair
// Using TransformInputIteratorT to dereference an array of raft::KeyValuePair
// and converting them to just return the Key to be used in reduce_rows_by_key
// prims
detail::KeyValueIndexOp<IndexT, DataT> conversion_op;
cub::TransformInputIterator<IndexT,
detail::KeyValueIndexOp<IndexT, DataT>,
cub::KeyValuePair<IndexT, DataT>*>
raft::KeyValuePair<IndexT, DataT>*>
itr(minClusterAndDistance.data_handle(), conversion_op);

// count # of samples in each cluster
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/cluster/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <optional>
#include <raft/cluster/detail/kmeans.cuh>
#include <raft/cluster/kmeans_types.hpp>
#include <raft/core/kvp.hpp>
#include <raft/core/mdarray.hpp>

namespace raft::cluster {
Expand Down Expand Up @@ -353,7 +354,7 @@ void minClusterAndDistanceCompute(
const KMeansParams& params,
const raft::device_matrix_view<const DataT, IndexT> X,
const raft::device_matrix_view<const DataT, IndexT> centroids,
const raft::device_vector_view<cub::KeyValuePair<IndexT, DataT>, IndexT>& minClusterAndDistance,
const raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT>& minClusterAndDistance,
const raft::device_vector_view<DataT, IndexT>& L2NormX,
rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
rmm::device_uvector<char>& workspace)
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/core/detail/macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#define _RAFT_HOST_DEVICE _RAFT_HOST _RAFT_DEVICE

#ifndef RAFT_INLINE_FUNCTION
#define RAFT_INLINE_FUNCTION _RAFT_FORCEINLINE _RAFT_HOST_DEVICE
#define RAFT_INLINE_FUNCTION _RAFT_HOST_DEVICE _RAFT_FORCEINLINE
#endif

/**
Expand Down
Loading

0 comments on commit 55953f3

Please sign in to comment.