Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify use of common functors #1049

Merged
merged 24 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f87c1e0
Unify use of common functors
Nyrio Nov 25, 2022
5ca9307
Undo unintentional changes to k-means test
Nyrio Nov 25, 2022
20d807a
Remove include of deleted file
Nyrio Nov 25, 2022
a744213
Clang format
Nyrio Nov 25, 2022
1b91e53
Add missing namespace specifiers
Nyrio Nov 28, 2022
630a64c
Merge remote-tracking branch 'origin/branch-23.02' into enh-common-fu…
Nyrio Nov 30, 2022
f74c3d4
New functors have automatic type inferrence, old ones are deprecrated
Nyrio Nov 30, 2022
406f12e
Update copyright year
Nyrio Nov 30, 2022
2701ba4
Lift operator ambiguity using compose ops
Nyrio Nov 30, 2022
6051c7f
Fix conflicting template type
Nyrio Dec 1, 2022
062d875
Move new operators to their own header
Nyrio Dec 1, 2022
c4bfe37
Revert unwanted change on cmake file
Nyrio Dec 1, 2022
e1f0a0b
Support arbitrary number of arguments for inner_op in unary_compose_op
Nyrio Dec 1, 2022
fe0bff9
Operators improvements: perfect forwarding, references, better suppor…
Nyrio Dec 2, 2022
5c07ef3
Remove unused includes and bring back custom DivOp in silhouette_scor…
Nyrio Dec 2, 2022
66f52ed
More powerful compose ops
Nyrio Dec 5, 2022
d5f94dc
Add tuple header
Nyrio Dec 5, 2022
fee9a50
Change return types to auto
Nyrio Dec 6, 2022
f2f0f79
Suppress clang-tidy modernize-use-default-member-init warning
Nyrio Dec 6, 2022
05476d7
Merge remote-tracking branch 'origin/branch-23.02' into enh-common-fu…
Nyrio Dec 7, 2022
7113adc
Rename scalar_xxx_op to xxx_const_op
Nyrio Dec 7, 2022
384606f
Operators tests
Nyrio Dec 9, 2022
c6ee689
Typo explictly -> explicitly
Nyrio Dec 9, 2022
b734f7e
Replace raft::abs in tests with std::abs
Nyrio Dec 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 44 additions & 72 deletions cpp/include/raft/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,15 @@ void kmeansPlusPlus(const raft::handle_t& handle,
// Outputs minDistanceBuf[n_trials x n_samples] where minDistance[i, :] contains updated
// minClusterDistance that includes candidate-i
auto minDistBuf = distBuffer.view();
raft::linalg::matrixVectorOp(
minDistBuf.data_handle(),
pwd.data_handle(),
minClusterDistance.data_handle(),
pwd.extent(1),
pwd.extent(0),
true,
true,
[=] __device__(DataT mat, DataT vec) { return vec <= mat ? vec : mat; },
stream);
raft::linalg::matrixVectorOp(minDistBuf.data_handle(),
pwd.data_handle(),
minClusterDistance.data_handle(),
pwd.extent(1),
pwd.extent(0),
true,
true,
raft::Min<DataT>{},
stream);

// Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using
// centroid candidate-i
Expand Down Expand Up @@ -321,21 +320,15 @@ void update_centroids(const raft::handle_t& handle,
// weight_per_cluster[n_clusters] - 1D array, weight_per_cluster[i] contains sum of weights in
// cluster-i.
// Note - when weight_per_cluster[i] is 0, new_centroids[i] is reset to 0
raft::linalg::matrixVectorOp(
new_centroids.data_handle(),
new_centroids.data_handle(),
weight_per_cluster.data_handle(),
new_centroids.extent(1),
new_centroids.extent(0),
true,
false,
[=] __device__(DataT mat, DataT vec) {
if (vec == 0)
return DataT(0);
else
return mat / vec;
},
handle.get_stream());
raft::linalg::matrixVectorOp(new_centroids.data_handle(),
new_centroids.data_handle(),
weight_per_cluster.data_handle(),
new_centroids.extent(1),
new_centroids.extent(0),
true,
false,
raft::DivideCheckZero<DataT>{},
handle.get_stream());

// copy centroids[i] to new_centroids[i] when weight_per_cluster[i] is 0
cub::ArgIndexInputIterator<DataT*> itr_wt(weight_per_cluster.data_handle());
Expand All @@ -351,9 +344,7 @@ void update_centroids(const raft::handle_t& handle,
// copy when the sum of weights in the cluster is 0
return map.value == 0;
},
[=] __device__(raft::KeyValuePair<ptrdiff_t, DataT> map) { // map
return map.key;
},
raft::KeyOp{},
handle.get_stream());
}

Expand Down Expand Up @@ -394,7 +385,7 @@ void kmeans_fit_main(const raft::handle_t& handle,
// resource
auto wtInCluster = raft::make_device_vector<DataT, IndexT>(handle, n_clusters);

rmm::device_scalar<raft::KeyValuePair<IndexT, DataT>> clusterCostD(stream);
rmm::device_scalar<DataT> clusterCostD(stream);

// L2 norm of X: ||x||^2
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
Expand Down Expand Up @@ -465,16 +456,12 @@ void kmeans_fit_main(const raft::handle_t& handle,
// compute the squared norm between the newCentroids and the original
// centroids, destructor releases the resource
auto sqrdNorm = raft::make_device_scalar(handle, DataT(0));
raft::linalg::mapThenSumReduce(
sqrdNorm.data_handle(),
newCentroids.size(),
[=] __device__(const DataT a, const DataT b) {
DataT diff = a - b;
return diff * diff;
},
stream,
centroids.data_handle(),
newCentroids.data_handle());
raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(),
newCentroids.size(),
raft::SqDiff<DataT>{},
stream,
centroids.data_handle(),
newCentroids.data_handle());

DataT sqrdNormError = 0;
raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream);
Expand All @@ -489,18 +476,11 @@ void kmeans_fit_main(const raft::handle_t& handle,
minClusterAndDistance.view(),
workspace,
raft::make_device_scalar_view(clusterCostD.data()),
[] __device__(const raft::KeyValuePair<IndexT, DataT>& a,
const raft::KeyValuePair<IndexT, DataT>& b) {
raft::KeyValuePair<IndexT, DataT> res;
res.key = 0;
res.value = a.value + b.value;
return res;
});

DataT curClusteringCost = 0;
raft::copy(&curClusteringCost, &(clusterCostD.data()->value), 1, stream);

handle.sync_stream(stream);
raft::ValueOp{},
raft::Sum<DataT>{});

DataT curClusteringCost = clusterCostD.value(stream);

ASSERT(curClusteringCost != (DataT)0.0,
"Too few points and centroids being found is getting 0 cost from "
"centers");
Expand Down Expand Up @@ -553,15 +533,10 @@ void kmeans_fit_main(const raft::handle_t& handle,
minClusterAndDistance.view(),
workspace,
raft::make_device_scalar_view(clusterCostD.data()),
[] __device__(const raft::KeyValuePair<IndexT, DataT>& a,
const raft::KeyValuePair<IndexT, DataT>& b) {
raft::KeyValuePair<IndexT, DataT> res;
res.key = 0;
res.value = a.value + b.value;
return res;
});
raft::ValueOp{},
raft::Sum<DataT>{});

raft::copy(inertia.data_handle(), &(clusterCostD.data()->value), 1, stream);
inertia[0] = clusterCostD.value(stream);

RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ",
n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0],
Expand Down Expand Up @@ -673,7 +648,8 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle,
minClusterDistanceVec.view(),
workspace,
raft::make_device_scalar_view(clusterCost.data()),
[] __device__(const DataT& a, const DataT& b) { return a + b; });
raft::Nop<DataT>{},
raft::Sum<DataT>{});

auto psi = clusterCost.value(stream);

Expand Down Expand Up @@ -705,7 +681,8 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle,
minClusterDistanceVec.view(),
workspace,
raft::make_device_scalar_view<DataT>(clusterCost.data()),
[] __device__(const DataT& a, const DataT& b) { return a + b; });
raft::Nop<DataT>{},
raft::Sum<DataT>{});

psi = clusterCost.value(stream);

Expand Down Expand Up @@ -1074,7 +1051,7 @@ void kmeans_predict(handle_t const& handle,
workspace);

// calculate cluster cost phi_x(C)
rmm::device_scalar<raft::KeyValuePair<IndexT, DataT>> clusterCostD(stream);
rmm::device_scalar<DataT> clusterCostD(stream);
// TODO: add different templates for InType of binaryOp to avoid thrust transform
thrust::transform(handle.get_thrust_policy(),
minClusterAndDistance.data_handle(),
Expand All @@ -1092,21 +1069,16 @@ void kmeans_predict(handle_t const& handle,
minClusterAndDistance.view(),
workspace,
raft::make_device_scalar_view(clusterCostD.data()),
[] __device__(const raft::KeyValuePair<IndexT, DataT>& a,
const raft::KeyValuePair<IndexT, DataT>& b) {
raft::KeyValuePair<IndexT, DataT> res;
res.key = 0;
res.value = a.value + b.value;
return res;
});

raft::copy(inertia.data_handle(), &(clusterCostD.data()->value), 1, stream);
raft::ValueOp{},
raft::Sum<DataT>{});

thrust::transform(handle.get_thrust_policy(),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
labels.data_handle(),
[=] __device__(raft::KeyValuePair<IndexT, DataT> pair) { return pair.key; });
raft::KeyOp{});

inertia[0] = clusterCostD.value(stream);
}

template <typename DataT, typename IndexT = int>
Expand Down
69 changes: 31 additions & 38 deletions cpp/include/raft/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,7 @@ void checkWeight(const raft::handle_t& handle,

auto scale = static_cast<DataT>(n_samples) / wt_sum;
raft::linalg::unaryOp(
weight.data_handle(),
weight.data_handle(),
n_samples,
[=] __device__(const DataT& wt) { return wt * scale; },
stream);
weight.data_handle(), weight.data_handle(), n_samples, raft::ScalarMul<DataT>{scale}, stream);
}
}

Expand All @@ -179,33 +175,42 @@ IndexT getCentroidsBatchSize(int batch_centroids, IndexT n_local_clusters)
return (minVal == 0) ? n_local_clusters : minVal;
}

template <typename DataT, typename ReductionOpT, typename IndexT = int>
template <typename InputT,
typename OutputT,
typename MainOpT,
typename ReductionOpT,
typename IndexT = int>
void computeClusterCost(const raft::handle_t& handle,
raft::device_vector_view<DataT, IndexT> minClusterDistance,
raft::device_vector_view<InputT, IndexT> minClusterDistance,
rmm::device_uvector<char>& workspace,
raft::device_scalar_view<DataT> clusterCost,
raft::device_scalar_view<OutputT> clusterCost,
MainOpT main_op,
ReductionOpT reduction_op)
{
cudaStream_t stream = handle.get_stream();
cudaStream_t stream = handle.get_stream();

cub::TransformInputIterator<OutputT, MainOpT, InputT*> itr(minClusterDistance.data_handle(),
main_op);

size_t temp_storage_bytes = 0;
RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(nullptr,
temp_storage_bytes,
minClusterDistance.data_handle(),
itr,
clusterCost.data_handle(),
minClusterDistance.size(),
reduction_op,
DataT(),
OutputT(),
stream));

workspace.resize(temp_storage_bytes, stream);

RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(workspace.data(),
temp_storage_bytes,
minClusterDistance.data_handle(),
itr,
clusterCost.data_handle(),
minClusterDistance.size(),
reduction_op,
DataT(),
OutputT(),
stream));
}

Expand Down Expand Up @@ -267,9 +272,7 @@ void sampleCentroids(const raft::handle_t& handle,
sampledMinClusterDistance.data_handle(),
nPtsSampledInRank,
inRankCp.data(),
[=] __device__(raft::KeyValuePair<ptrdiff_t, DataT> val) { // MapTransformOp
return val.key;
},
raft::KeyOp{},
stream);
}

Expand Down Expand Up @@ -464,10 +467,8 @@ void minClusterAndDistanceCompute(
pair.value = val;
return pair;
},
[=] __device__(raft::KeyValuePair<IndexT, DataT> a, raft::KeyValuePair<IndexT, DataT> b) {
return (b.value < a.value) ? b : a;
},
[=] __device__(raft::KeyValuePair<IndexT, DataT> pair) { return pair; });
raft::ArgMin{},
raft::Nop<raft::KeyValuePair<IndexT, DataT>, IndexT>{});
}
}
}
Expand Down Expand Up @@ -542,7 +543,6 @@ void minClusterDistanceCompute(const raft::handle_t& handle,
if (is_fused) {
workspace.resize((sizeof(IndexT)) * ns, stream);

// todo(lsugy): remove cIdx
raft::distance::fusedL2NNMinReduce<DataT, DataT, IndexT>(
minClusterDistanceView.data_handle(),
datasetView.data_handle(),
Expand Down Expand Up @@ -577,23 +577,16 @@ void minClusterDistanceCompute(const raft::handle_t& handle,
pairwise_distance_kmeans<DataT, IndexT>(
handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric);

raft::linalg::coalescedReduction(
minClusterDistanceView.data_handle(),
pairwiseDistanceView.data_handle(),
pairwiseDistanceView.extent(1),
pairwiseDistanceView.extent(0),
std::numeric_limits<DataT>::max(),
stream,
true,
[=] __device__(DataT val, IndexT i) { // MainLambda
return val;
},
[=] __device__(DataT a, DataT b) { // ReduceLambda
return (b < a) ? b : a;
},
[=] __device__(DataT val) { // FinalLambda
return val;
});
raft::linalg::coalescedReduction(minClusterDistanceView.data_handle(),
pairwiseDistanceView.data_handle(),
pairwiseDistanceView.extent(1),
pairwiseDistanceView.extent(0),
std::numeric_limits<DataT>::max(),
stream,
true,
raft::Nop<DataT, IndexT>{},
raft::Min<DataT>{},
raft::Nop<DataT, IndexT>{});
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/cluster/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ void cluster_cost(const raft::handle_t& handle,
raft::device_scalar_view<DataT> clusterCost,
ReductionOpT reduction_op)
{
detail::computeClusterCost(handle, minClusterDistance, workspace, clusterCost, reduction_op);
detail::computeClusterCost(
handle, minClusterDistance, workspace, clusterCost, raft::Nop<DataT>{}, reduction_op);
}

/**
Expand Down
11 changes: 6 additions & 5 deletions cpp/include/raft/distance/detail/cosine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,6 @@ void cosineAlgo1(Index_ m,
cudaStream_t stream,
bool isRowMajor)
{
auto norm_op = [] __device__(AccType in) { return raft::mySqrt(in); };

// raft distance support inputs as float/double and output as uint8_t/float/double.
static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))),
"OutType can be uint8_t, float, double,"
Expand All @@ -248,10 +246,13 @@ void cosineAlgo1(Index_ m,
InType* row_vec = workspace;
if (pA != pB) {
row_vec += m;
raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op);
raft::linalg::rowNorm(row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op);
raft::linalg::rowNorm(
Nyrio marked this conversation as resolved.
Show resolved Hide resolved
col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::SqrtOp<AccType>{});
raft::linalg::rowNorm(
row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::SqrtOp<AccType>{});
} else {
raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op);
raft::linalg::rowNorm(
col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::SqrtOp<AccType>{});
}

if (isRowMajor) {
Expand Down
11 changes: 6 additions & 5 deletions cpp/include/raft/distance/detail/euclidean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,6 @@ void euclideanAlgo1(Index_ m,
cudaStream_t stream,
bool isRowMajor)
{
auto norm_op = [] __device__(InType in) { return in; };

// raft distance support inputs as float/double and output as uint8_t/float/double.
static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))),
"OutType can be uint8_t, float, double,"
Expand All @@ -266,10 +264,13 @@ void euclideanAlgo1(Index_ m,
InType* row_vec = workspace;
if (pA != pB) {
row_vec += m;
raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op);
raft::linalg::rowNorm(row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op);
raft::linalg::rowNorm(
col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop<InType>{});
raft::linalg::rowNorm(
row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop<InType>{});
} else {
raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op);
raft::linalg::rowNorm(
col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::Nop<InType>{});
}

if (isRowMajor) {
Expand Down
Loading