Skip to content

Commit

Permalink
Revert "Revert recent fused l2 nn instantiations (rapidsai#899)"
Browse files Browse the repository at this point in the history
This reverts commit 11e00f7.
  • Loading branch information
cjnolet committed Oct 6, 2022
1 parent 1317774 commit 43a7bc8
Show file tree
Hide file tree
Showing 17 changed files with 204 additions and 166 deletions.
34 changes: 17 additions & 17 deletions cpp/bench/spatial/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,28 @@ struct fused_l2_nn : public fixture {
uniform(handle, r, y.data(), p.n * p.k, T(-1.0), T(1.0));
raft::linalg::rowNorm(xn.data(), x.data(), p.k, p.m, raft::linalg::L2Norm, true, stream);
raft::linalg::rowNorm(yn.data(), y.data(), p.k, p.n, raft::linalg::L2Norm, true, stream);
raft::distance::initialize<T, cub::KeyValuePair<int, T>, int>(
raft::distance::initialize<T, raft::KeyValuePair<int, T>, int>(
handle, out.data(), p.m, std::numeric_limits<T>::max(), op);
}

void run_benchmark(::benchmark::State& state) override
{
loop_on_state(state, [this]() {
// it is enough to only benchmark the L2-squared metric
raft::distance::fusedL2NN<T, cub::KeyValuePair<int, T>, int>(out.data(),
x.data(),
y.data(),
xn.data(),
yn.data(),
params.m,
params.n,
params.k,
(void*)workspace.data(),
op,
pairRedOp,
false,
false,
stream);
raft::distance::fusedL2NN<T, raft::KeyValuePair<int, T>, int>(out.data(),
x.data(),
y.data(),
xn.data(),
yn.data(),
params.m,
params.n,
params.k,
(void*)workspace.data(),
op,
pairRedOp,
false,
false,
stream);
});

// Num distance calculations
Expand All @@ -92,7 +92,7 @@ struct fused_l2_nn : public fixture {
state.counters["FLOP/s"] = benchmark::Counter(
num_flops, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000);

state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(cub::KeyValuePair<int, T>),
state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(raft::KeyValuePair<int, T>),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);
state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(float),
Expand All @@ -105,7 +105,7 @@ struct fused_l2_nn : public fixture {
private:
fused_l2_nn_inputs params;
rmm::device_uvector<T> x, y, xn, yn;
rmm::device_uvector<cub::KeyValuePair<int, T>> out;
rmm::device_uvector<raft::KeyValuePair<int, T>> out;
rmm::device_uvector<int> workspace;
raft::distance::KVPMinReduce<int, T> pairRedOp;
raft::distance::MinAndDistanceReduceOp<int, T> op;
Expand Down
45 changes: 23 additions & 22 deletions cpp/include/raft/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <raft/core/device_mdarray.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/kvp.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/distance/distance_types.hpp>
Expand Down Expand Up @@ -278,7 +279,7 @@ void kmeans_fit_main(const raft::handle_t& handle,
// - key is the index of nearest cluster
// - value is the distance to the nearest cluster
auto minClusterAndDistance =
raft::make_device_vector<cub::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);
raft::make_device_vector<raft::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);

// temporary buffer to store L2 norm of centroids or distance matrix,
// destructor releases the resource
Expand All @@ -292,7 +293,7 @@ void kmeans_fit_main(const raft::handle_t& handle,
// resource
auto wtInCluster = raft::make_device_vector<DataT, IndexT>(handle, n_clusters);

rmm::device_scalar<cub::KeyValuePair<IndexT, DataT>> clusterCostD(stream);
rmm::device_scalar<raft::KeyValuePair<IndexT, DataT>> clusterCostD(stream);

// L2 norm of X: ||x||^2
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
Expand Down Expand Up @@ -337,12 +338,12 @@ void kmeans_fit_main(const raft::handle_t& handle,
workspace);

// Using TransformInputIteratorT to dereference an array of
// cub::KeyValuePair and converting them to just return the Key to be used
// raft::KeyValuePair and converting them to just return the Key to be used
// in reduce_rows_by_key prims
detail::KeyValueIndexOp<IndexT, DataT> conversion_op;
cub::TransformInputIterator<IndexT,
detail::KeyValueIndexOp<IndexT, DataT>,
cub::KeyValuePair<IndexT, DataT>*>
raft::KeyValuePair<IndexT, DataT>*>
itr(minClusterAndDistance.data_handle(), conversion_op);

workspace.resize(n_samples, stream);
Expand Down Expand Up @@ -400,14 +401,14 @@ void kmeans_fit_main(const raft::handle_t& handle,
itr_wt,
wtInCluster.size(),
newCentroids.data_handle(),
[=] __device__(cub::KeyValuePair<ptrdiff_t, DataT> map) { // predicate
[=] __device__(raft::KeyValuePair<ptrdiff_t, DataT> map) { // predicate
// copy when the # of samples in the cluster is 0
if (map.value == 0)
return true;
else
return false;
},
[=] __device__(cub::KeyValuePair<ptrdiff_t, DataT> map) { // map
[=] __device__(raft::KeyValuePair<ptrdiff_t, DataT> map) { // map
return map.key;
},
stream);
Expand Down Expand Up @@ -439,9 +440,9 @@ void kmeans_fit_main(const raft::handle_t& handle,
minClusterAndDistance.view(),
workspace,
raft::make_device_scalar_view(clusterCostD.data()),
[] __device__(const cub::KeyValuePair<IndexT, DataT>& a,
const cub::KeyValuePair<IndexT, DataT>& b) {
cub::KeyValuePair<IndexT, DataT> res;
[] __device__(const raft::KeyValuePair<IndexT, DataT>& a,
const raft::KeyValuePair<IndexT, DataT>& b) {
raft::KeyValuePair<IndexT, DataT> res;
res.key = 0;
res.value = a.value + b.value;
return res;
Expand Down Expand Up @@ -489,8 +490,8 @@ void kmeans_fit_main(const raft::handle_t& handle,
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
weight.data_handle(),
minClusterAndDistance.data_handle(),
[=] __device__(const cub::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
cub::KeyValuePair<IndexT, DataT> res;
[=] __device__(const raft::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
raft::KeyValuePair<IndexT, DataT> res;
res.value = kvp.value * wt;
res.key = kvp.key;
return res;
Expand All @@ -501,9 +502,9 @@ void kmeans_fit_main(const raft::handle_t& handle,
minClusterAndDistance.view(),
workspace,
raft::make_device_scalar_view(clusterCostD.data()),
[] __device__(const cub::KeyValuePair<IndexT, DataT>& a,
const cub::KeyValuePair<IndexT, DataT>& b) {
cub::KeyValuePair<IndexT, DataT> res;
[] __device__(const raft::KeyValuePair<IndexT, DataT>& a,
const raft::KeyValuePair<IndexT, DataT>& b) {
raft::KeyValuePair<IndexT, DataT> res;
res.key = 0;
res.value = a.value + b.value;
return res;
Expand Down Expand Up @@ -970,7 +971,7 @@ void kmeans_predict(handle_t const& handle,
if (normalize_weight) checkWeight(handle, weight.view(), workspace);

auto minClusterAndDistance =
raft::make_device_vector<cub::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);
raft::make_device_vector<raft::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);
rmm::device_uvector<DataT> L2NormBuf_OR_DistBuf(0, stream);

// L2 norm of X: ||x||^2
Expand Down Expand Up @@ -1001,15 +1002,15 @@ void kmeans_predict(handle_t const& handle,
workspace);

// calculate cluster cost phi_x(C)
rmm::device_scalar<cub::KeyValuePair<IndexT, DataT>> clusterCostD(stream);
rmm::device_scalar<raft::KeyValuePair<IndexT, DataT>> clusterCostD(stream);
// TODO: add different templates for InType of binaryOp to avoid thrust transform
thrust::transform(handle.get_thrust_policy(),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
weight.data_handle(),
minClusterAndDistance.data_handle(),
[=] __device__(const cub::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
cub::KeyValuePair<IndexT, DataT> res;
[=] __device__(const raft::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
raft::KeyValuePair<IndexT, DataT> res;
res.value = kvp.value * wt;
res.key = kvp.key;
return res;
Expand All @@ -1019,9 +1020,9 @@ void kmeans_predict(handle_t const& handle,
minClusterAndDistance.view(),
workspace,
raft::make_device_scalar_view(clusterCostD.data()),
[] __device__(const cub::KeyValuePair<IndexT, DataT>& a,
const cub::KeyValuePair<IndexT, DataT>& b) {
cub::KeyValuePair<IndexT, DataT> res;
[] __device__(const raft::KeyValuePair<IndexT, DataT>& a,
const raft::KeyValuePair<IndexT, DataT>& b) {
raft::KeyValuePair<IndexT, DataT> res;
res.key = 0;
res.value = a.value + b.value;
return res;
Expand All @@ -1033,7 +1034,7 @@ void kmeans_predict(handle_t const& handle,
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
labels.data_handle(),
[=] __device__(cub::KeyValuePair<IndexT, DataT> pair) { return pair.key; });
[=] __device__(raft::KeyValuePair<IndexT, DataT> pair) { return pair.key; });
}

template <typename DataT, typename IndexT = int>
Expand Down
31 changes: 16 additions & 15 deletions cpp/include/raft/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <raft/core/cudart_utils.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/kvp.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/distance/distance.cuh>
Expand Down Expand Up @@ -66,7 +67,7 @@ struct SamplingOp {
}

__host__ __device__ __forceinline__ bool operator()(
const cub::KeyValuePair<ptrdiff_t, DataT>& a) const
const raft::KeyValuePair<ptrdiff_t, DataT>& a) const
{
DataT prob_threshold = (DataT)rnd[a.key];

Expand All @@ -79,7 +80,7 @@ struct SamplingOp {
template <typename IndexT, typename DataT>
struct KeyValueIndexOp {
__host__ __device__ __forceinline__ IndexT
operator()(const cub::KeyValuePair<IndexT, DataT>& a) const
operator()(const raft::KeyValuePair<IndexT, DataT>& a) const
{
return a.key;
}
Expand Down Expand Up @@ -224,7 +225,7 @@ void sampleCentroids(const raft::handle_t& handle,
auto nSelected = raft::make_device_scalar<IndexT>(handle, 0);
cub::ArgIndexInputIterator<DataT*> ip_itr(minClusterDistance.data_handle());
auto sampledMinClusterDistance =
raft::make_device_vector<cub::KeyValuePair<ptrdiff_t, DataT>, IndexT>(handle, n_local_samples);
raft::make_device_vector<raft::KeyValuePair<ptrdiff_t, DataT>, IndexT>(handle, n_local_samples);
size_t temp_storage_bytes = 0;
RAFT_CUDA_TRY(cub::DeviceSelect::If(nullptr,
temp_storage_bytes,
Expand Down Expand Up @@ -254,7 +255,7 @@ void sampleCentroids(const raft::handle_t& handle,
thrust::for_each_n(handle.get_thrust_policy(),
sampledMinClusterDistance.data_handle(),
nPtsSampledInRank,
[=] __device__(cub::KeyValuePair<ptrdiff_t, DataT> val) {
[=] __device__(raft::KeyValuePair<ptrdiff_t, DataT> val) {
rawPtr_isSampleCentroid[val.key] = 1;
});

Expand All @@ -266,7 +267,7 @@ void sampleCentroids(const raft::handle_t& handle,
sampledMinClusterDistance.data_handle(),
nPtsSampledInRank,
inRankCp.data(),
[=] __device__(cub::KeyValuePair<ptrdiff_t, DataT> val) { // MapTransformOp
[=] __device__(raft::KeyValuePair<ptrdiff_t, DataT> val) { // MapTransformOp
return val.key;
},
stream);
Expand Down Expand Up @@ -355,7 +356,7 @@ void minClusterAndDistanceCompute(
const KMeansParams& params,
const raft::device_matrix_view<const DataT, IndexT> X,
const raft::device_matrix_view<const DataT, IndexT> centroids,
const raft::device_vector_view<cub::KeyValuePair<IndexT, DataT>, IndexT> minClusterAndDistance,
const raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT> minClusterAndDistance,
const raft::device_vector_view<DataT, IndexT> L2NormX,
rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
rmm::device_uvector<char>& workspace)
Expand Down Expand Up @@ -390,7 +391,7 @@ void minClusterAndDistanceCompute(
auto pairwiseDistance = raft::make_device_matrix_view<DataT, IndexT>(
L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize);

cub::KeyValuePair<IndexT, DataT> initial_value(0, std::numeric_limits<DataT>::max());
raft::KeyValuePair<IndexT, DataT> initial_value(0, std::numeric_limits<DataT>::max());

thrust::fill(handle.get_thrust_policy(),
minClusterAndDistance.data_handle(),
Expand All @@ -409,7 +410,7 @@ void minClusterAndDistanceCompute(

// minClusterAndDistanceView [ns x n_clusters]
auto minClusterAndDistanceView =
raft::make_device_vector_view<cub::KeyValuePair<IndexT, DataT>, IndexT>(
raft::make_device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT>(
minClusterAndDistance.data_handle() + dIdx, ns);

auto L2NormXView =
Expand All @@ -420,7 +421,7 @@ void minClusterAndDistanceCompute(
workspace.resize((sizeof(int)) * ns, stream);

// todo(lsugy): remove cIdx
raft::distance::fusedL2NNMinReduce<DataT, cub::KeyValuePair<IndexT, DataT>, IndexT>(
raft::distance::fusedL2NNMinReduce<DataT, raft::KeyValuePair<IndexT, DataT>, IndexT>(
minClusterAndDistanceView.data_handle(),
datasetView.data_handle(),
centroids.data_handle(),
Expand Down Expand Up @@ -466,15 +467,15 @@ void minClusterAndDistanceCompute(
stream,
true,
[=] __device__(const DataT val, const IndexT i) {
cub::KeyValuePair<IndexT, DataT> pair;
raft::KeyValuePair<IndexT, DataT> pair;
pair.key = cIdx + i;
pair.value = val;
return pair;
},
[=] __device__(cub::KeyValuePair<IndexT, DataT> a, cub::KeyValuePair<IndexT, DataT> b) {
[=] __device__(raft::KeyValuePair<IndexT, DataT> a, raft::KeyValuePair<IndexT, DataT> b) {
return (b.value < a.value) ? b : a;
},
[=] __device__(cub::KeyValuePair<IndexT, DataT> pair) { return pair; });
[=] __device__(raft::KeyValuePair<IndexT, DataT> pair) { return pair; });
}
}
}
Expand Down Expand Up @@ -623,7 +624,7 @@ void countSamplesInCluster(const raft::handle_t& handle,
// - key is the index of nearest cluster
// - value is the distance to the nearest cluster
auto minClusterAndDistance =
raft::make_device_vector<cub::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);
raft::make_device_vector<raft::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);

// temporary buffer to store distance matrix, destructor releases the resource
rmm::device_uvector<DataT> L2NormBuf_OR_DistBuf(0, stream);
Expand All @@ -642,13 +643,13 @@ void countSamplesInCluster(const raft::handle_t& handle,
L2NormBuf_OR_DistBuf,
workspace);

// Using TransformInputIteratorT to dereference an array of cub::KeyValuePair
// Using TransformInputIteratorT to dereference an array of raft::KeyValuePair
// and converting them to just return the Key to be used in reduce_rows_by_key
// prims
detail::KeyValueIndexOp<IndexT, DataT> conversion_op;
cub::TransformInputIterator<IndexT,
detail::KeyValueIndexOp<IndexT, DataT>,
cub::KeyValuePair<IndexT, DataT>*>
raft::KeyValuePair<IndexT, DataT>*>
itr(minClusterAndDistance.data_handle(), conversion_op);

// count # of samples in each cluster
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/cluster/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <optional>
#include <raft/cluster/detail/kmeans.cuh>
#include <raft/cluster/kmeans_types.hpp>
#include <raft/core/kvp.hpp>
#include <raft/core/mdarray.hpp>

namespace raft::cluster {
Expand Down Expand Up @@ -353,7 +354,7 @@ void minClusterAndDistanceCompute(
const KMeansParams& params,
const raft::device_matrix_view<const DataT, IndexT> X,
const raft::device_matrix_view<const DataT, IndexT> centroids,
const raft::device_vector_view<cub::KeyValuePair<IndexT, DataT>, IndexT>& minClusterAndDistance,
const raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT>& minClusterAndDistance,
const raft::device_vector_view<DataT, IndexT>& L2NormX,
rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
rmm::device_uvector<char>& workspace)
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/core/detail/macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#define _RAFT_HOST_DEVICE _RAFT_HOST _RAFT_DEVICE

#ifndef RAFT_INLINE_FUNCTION
#define RAFT_INLINE_FUNCTION _RAFT_FORCEINLINE _RAFT_HOST_DEVICE
#define RAFT_INLINE_FUNCTION _RAFT_HOST_DEVICE _RAFT_FORCEINLINE
#endif

/**
Expand Down
Loading

0 comments on commit 43a7bc8

Please sign in to comment.