diff --git a/cpp/include/raft/linalg/normalize.cuh b/cpp/include/raft/linalg/normalize.cuh index 1f60860c8c..de5f4e62ce 100644 --- a/cpp/include/raft/linalg/normalize.cuh +++ b/cpp/include/raft/linalg/normalize.cuh @@ -18,9 +18,11 @@ #include "detail/normalize.cuh" +#include #include #include #include +#include namespace raft { namespace linalg { diff --git a/cpp/include/raft/spectral/detail/modularity_maximization.hpp b/cpp/include/raft/spectral/detail/modularity_maximization.hpp index 2a3b5cf36c..a4e504883a 100644 --- a/cpp/include/raft/spectral/detail/modularity_maximization.hpp +++ b/cpp/include/raft/spectral/detail/modularity_maximization.hpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -101,8 +102,9 @@ std::tuple modularity_maximization( // notice that at this point the matrix has already been transposed, so we are scaling // columns - scale_obs(nEigVecs, n, eigVecs); - RAFT_CHECK_CUDA(stream); + auto dataset_view = raft::make_device_matrix_view(eigVecs, nEigVecs, n); + raft::linalg::row_normalize( + handle, raft::make_const_mdspan(dataset_view), dataset_view, raft::linalg::L2Norm); // Find partition clustering auto pair_cluster = cluster_solver.solve(handle, n, nEigVecs, eigVecs, clusters); diff --git a/cpp/include/raft/spectral/detail/spectral_util.cuh b/cpp/include/raft/spectral/detail/spectral_util.cuh index 736936a1f1..002fad9680 100644 --- a/cpp/include/raft/spectral/detail/spectral_util.cuh +++ b/cpp/include/raft/spectral/detail/spectral_util.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,85 +39,6 @@ namespace raft { namespace spectral { -template -RAFT_KERNEL scale_obs_kernel(index_type_t m, index_type_t n, value_type_t* obs) -{ - index_type_t i, j, k, index, mm; - value_type_t alpha, v, last; - bool valid; - // ASSUMPTION: kernel is launched with either 2, 4, 8, 16 or 32 threads in x-dimension - - // compute alpha - mm = (((m + blockDim.x - 1) / blockDim.x) * blockDim.x); // m in multiple of blockDim.x - alpha = 0.0; - - for (j = threadIdx.y + blockIdx.y * blockDim.y; j < n; j += blockDim.y * gridDim.y) { - for (i = threadIdx.x; i < mm; i += blockDim.x) { - // check if the thread is valid - valid = i < m; - - // get the value of the last thread - last = __shfl_sync(warp_full_mask(), alpha, blockDim.x - 1, blockDim.x); - - // if you are valid read the value from memory, otherwise set your value to 0 - alpha = (valid) ? obs[i + j * m] : 0.0; - alpha = alpha * alpha; - - // do prefix sum (of size warpSize=blockDim.x =< 32) - for (k = 1; k < blockDim.x; k *= 2) { - v = __shfl_up_sync(warp_full_mask(), alpha, k, blockDim.x); - if (threadIdx.x >= k) alpha += v; - } - // shift by last - alpha += last; - } - } - - // scale by alpha - alpha = __shfl_sync(warp_full_mask(), alpha, blockDim.x - 1, blockDim.x); - alpha = raft::sqrt(alpha); - for (j = threadIdx.y + blockIdx.y * blockDim.y; j < n; j += blockDim.y * gridDim.y) { - for (i = threadIdx.x; i < m; i += blockDim.x) { // blockDim.x=32 - index = i + j * m; - obs[index] = obs[index] / alpha; - } - } -} - -template -index_type_t next_pow2(index_type_t n) -{ - index_type_t v; - // Reference: - // http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2Float - v = n - 1; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - return v + 1; -} - -template -cudaError_t scale_obs(index_type_t m, index_type_t n, value_type_t* obs) -{ - index_type_t p2m; - - // find next power of 2 - p2m = next_pow2(m); - // setup launch configuration - unsigned int xsize = std::max(2, std::min(p2m, 32)); - dim3 nthreads{xsize, 256 / xsize, 1}; - - dim3 nblocks{1, (n + nthreads.y - 1) / nthreads.y, 1}; - - // launch scaling kernel (scale each column of obs by its norm) - scale_obs_kernel<<>>(m, n, obs); - - return cudaSuccess; -} - template void transform_eigen_matrix(raft::resources const& handle, edge_t n,