diff --git a/CHANGELOG.md b/CHANGELOG.md index f80313f0b2..68090085f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # cuML 0.17.0 (Date TBD) ## New Features +- PR #2659: Add initial max inner product sparse knn +- PR #2836: Refactor UMAP to accept sparse inputs ## Improvements - PR #3077: Improve runtime for test_kmeans diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 8a9a991639..b2462888ab 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -401,6 +401,7 @@ if(BUILD_CUML_CPP_LIBRARY) src/holtwinters/holtwinters.cu src/kmeans/kmeans.cu src/knn/knn.cu + src/knn/knn_sparse.cu src/metrics/accuracy_score.cu src/metrics/adjusted_rand_index.cu src/metrics/completeness_score.cu diff --git a/cpp/bench/sg/umap.cu b/cpp/bench/sg/umap.cu index d7ddb31552..a28064cd5b 100644 --- a/cpp/bench/sg/umap.cu +++ b/cpp/bench/sg/umap.cu @@ -111,8 +111,8 @@ class UmapSupervised : public UmapBase { protected: void coreBenchmarkMethod() { - fit(*this->handle, this->data.X, yFloat, this->params.nrows, - this->params.ncols, nullptr, nullptr, &uParams, embeddings); + UMAP::fit(*this->handle, this->data.X, yFloat, this->params.nrows, + this->params.ncols, nullptr, nullptr, &uParams, embeddings); } }; ML_BENCH_REGISTER(Params, UmapSupervised, "blobs", getInputs()); @@ -124,8 +124,8 @@ class UmapUnsupervised : public UmapBase { protected: void coreBenchmarkMethod() { - fit(*this->handle, this->data.X, this->params.nrows, this->params.ncols, - nullptr, nullptr, &uParams, embeddings); + UMAP::fit(*this->handle, this->data.X, nullptr, this->params.nrows, + this->params.ncols, nullptr, nullptr, &uParams, embeddings); } }; ML_BENCH_REGISTER(Params, UmapUnsupervised, "blobs", getInputs()); @@ -136,17 +136,17 @@ class UmapTransform : public UmapBase { protected: void coreBenchmarkMethod() { - transform(*this->handle, this->data.X, this->params.nrows, - this->params.ncols, nullptr, nullptr, this->data.X, - this->params.nrows, embeddings, this->params.nrows, &uParams, - transformed); + UMAP::transform(*this->handle, this->data.X, this->params.nrows, + this->params.ncols, nullptr, nullptr, this->data.X, + this->params.nrows, embeddings, this->params.nrows, + &uParams, transformed); } void allocateBuffers(const ::benchmark::State& state) { UmapBase::allocateBuffers(state); auto& handle = *this->handle; alloc(transformed, this->params.nrows * uParams.n_components); - fit(handle, this->data.X, yFloat, this->params.nrows, this->params.ncols, - nullptr, nullptr, &uParams, embeddings); + UMAP::fit(handle, this->data.X, yFloat, this->params.nrows, + this->params.ncols, nullptr, nullptr, &uParams, embeddings); } void deallocateBuffers(const ::benchmark::State& state) { dealloc(transformed, this->params.nrows * uParams.n_components); diff --git a/cpp/cmake/Dependencies.cmake b/cpp/cmake/Dependencies.cmake index ffa014641b..c5175b659e 100644 --- a/cpp/cmake/Dependencies.cmake +++ b/cpp/cmake/Dependencies.cmake @@ -39,7 +39,7 @@ else(DEFINED ENV{RAFT_PATH}) ExternalProject_Add(raft GIT_REPOSITORY https://github.com/rapidsai/raft.git - GIT_TAG 9b3afe67895fbea397fb2c72375157aadfc132d8 + GIT_TAG eebd0e306624b419168b2cd5cd7aa44ebaec51f1 PREFIX ${RAFT_DIR} CONFIGURE_COMMAND "" BUILD_COMMAND "" diff --git a/cpp/include/cuml/distance/distance_type.h b/cpp/include/cuml/distance/distance_type.h new file mode 100644 index 0000000000..ebb40f10f9 --- /dev/null +++ b/cpp/include/cuml/distance/distance_type.h @@ -0,0 +1,25 @@ +#pragma once + +namespace ML { +namespace Distance { + +/** enum to tell how to compute euclidean distance */ +enum DistanceType : unsigned short { + /** evaluate as dist_ij = sum(x_ik^2) + sum(y_ij)^2 - 2*sum(x_ik * y_jk) */ + EucExpandedL2 = 0, + /** same as above, but inside the epilogue, perform square root operation */ + EucExpandedL2Sqrt = 1, + /** cosine distance */ + EucExpandedCosine = 2, + /** L1 distance */ + EucUnexpandedL1 = 3, + /** evaluate as dist_ij += (x_ik - y-jk)^2 */ + EucUnexpandedL2 = 4, + /** same as above, but inside the epilogue, perform square root operation */ + EucUnexpandedL2Sqrt = 5, + /** simple inner product */ + InnerProduct = 6 +}; + +}; // end namespace Distance +}; // end namespace ML diff --git a/cpp/include/cuml/manifold/common.hpp b/cpp/include/cuml/manifold/common.hpp new file mode 100644 index 0000000000..79caf0d58d --- /dev/null +++ b/cpp/include/cuml/manifold/common.hpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace ML { + +// Dense input uses int64_t until FAISS is updated +typedef int64_t knn_indices_dense_t; + +typedef int knn_indices_sparse_t; + +/** + * Simple container for KNN graph properties + * @tparam value_idx + * @tparam value_t + */ +template +struct knn_graph { + knn_graph(value_idx n_rows_, int n_neighbors_) + : n_rows(n_rows_), n_neighbors(n_neighbors_) {} + + knn_graph(value_idx n_rows_, int n_neighbors_, value_idx *knn_indices_, + value_t *knn_dists_) + : n_rows(n_rows_), + n_neighbors(n_neighbors_), + knn_indices(knn_indices_), + knn_dists(knn_dists_) {} + + value_idx *knn_indices; + value_t *knn_dists; + + value_idx n_rows; + int n_neighbors; +}; + +/** + * Base struct for representing inputs to manifold learning + * algorithms. + * @tparam T + */ +template +struct manifold_inputs_t { + T *y; + int n; + int d; + + manifold_inputs_t(T *y_, int n_, int d_) : y(y_), n(n_), d(d_) {} + + virtual bool alloc_knn_graph() const; +}; + +/** + * Dense input to manifold learning algorithms + * @tparam T + */ +template +struct manifold_dense_inputs_t : public manifold_inputs_t { + T *X; + + manifold_dense_inputs_t(T *x_, T *y_, int n_, int d_) + : manifold_inputs_t(y_, n_, d_), X(x_) {} + + bool alloc_knn_graph() const { return true; } +}; + +/** + * Sparse CSR input to manifold learning algorithms + * @tparam value_idx + * @tparam T + */ +template +struct manifold_sparse_inputs_t : public manifold_inputs_t { + value_idx *indptr; + value_idx *indices; + T *data; + + size_t nnz; + + manifold_sparse_inputs_t(value_idx *indptr_, value_idx *indices_, T *data_, + T *y_, size_t nnz_, int n_, int d_) + : manifold_inputs_t(y_, n_, d_), + indptr(indptr_), + indices(indices_), + data(data_), + nnz(nnz_) {} + + bool alloc_knn_graph() const { return true; } +}; + +/** + * Precomputed KNN graph input to manifold learning algorithms + * @tparam value_idx + * @tparam value_t + */ +template +struct manifold_precomputed_knn_inputs_t + : public manifold_dense_inputs_t { + manifold_precomputed_knn_inputs_t( + value_idx *knn_indices_, value_t *knn_dists_, value_t *X_, value_t *y_, + int n_, int d_, int n_neighbors_) + : manifold_dense_inputs_t(X_, y_, n_, d_), + knn_graph(n_, n_neighbors_, knn_indices_, knn_dists_) {} + + knn_graph knn_graph; + + bool alloc_knn_graph() const { return false; } +}; + +}; // end namespace ML \ No newline at end of file diff --git a/cpp/include/cuml/manifold/umap.hpp b/cpp/include/cuml/manifold/umap.hpp index a794385805..d9cd589c27 100644 --- a/cpp/include/cuml/manifold/umap.hpp +++ b/cpp/include/cuml/manifold/umap.hpp @@ -20,12 +20,19 @@ #include "umapparams.h" namespace ML { +namespace UMAP { void transform(const raft::handle_t &handle, float *X, int n, int d, int64_t *knn_indices, float *knn_dists, float *orig_X, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed); +void transform_sparse(const raft::handle_t &handle, int *indptr, int *indices, + float *data, size_t nnz, int n, int d, int *orig_x_indptr, + int *orig_x_indices, float *orig_x_data, size_t orig_nnz, + int orig_n, float *embedding, int embedding_n, + UMAPParams *params, float *transformed); + void find_ab(const raft::handle_t &handle, UMAPParams *params); void fit(const raft::handle_t &handle, @@ -34,86 +41,11 @@ void fit(const raft::handle_t &handle, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params, float *embeddings); -void fit(const raft::handle_t &handle, - float *X, // input matrix - int n, // rows - int d, // cols - int64_t *knn_indices, float *knn_dists, UMAPParams *params, - float *embeddings); - -class UMAP_API { - float *orig_X; - int orig_n; - raft::handle_t *handle; - UMAPParams *params; - - public: - UMAP_API(const raft::handle_t &handle, UMAPParams *params); - ~UMAP_API(); - - /** - * Fits an unsupervised UMAP model - * @param X - * pointer to an array in row-major format (note: this will be col-major soon) - * @param n - * n_samples in X - * @param d - * d_features in X - * @param knn_indices - * an array containing the n_neighbors nearest neighors indices for each sample - * @param knn_dists - * an array containing the n_neighbors nearest neighors distances for each sample - * @param embeddings - * an array to return the output embeddings of size (n_samples, n_components) - */ - void fit(float *X, int n, int d, int64_t *knn_indices, float *knn_dists, - float *embeddings); - - /** - * Fits a supervised UMAP model - * @param X - * pointer to an array in row-major format (note: this will be col-major soon) - * @param y - * pointer to an array of labels, shape=n_samples - * @param n - * n_samples in X - * @param d - * d_features in X - * @param knn_indices - * an array containing the n_neighbors nearest neighors indices for each sample - * @param knn_dists - * an array containing the n_neighbors nearest neighors distances for each sample - * @param embeddings - * an array to return the output embeddings of size (n_samples, n_components) - */ - void fit(float *X, float *y, int n, int d, int64_t *knn_indices, - float *knn_dists, float *embeddings); - - /** - * Project a set of X vectors into the embedding space. - * @param X - * pointer to an array in row-major format (note: this will be col-major soon) - * @param n - * n_samples in X - * @param d - * d_features in X - * @param knn_indices - * an array containing the n_neighbors nearest neighors indices for each sample - * @param knn_dists - * an array containing the n_neighbors nearest neighors distances for each sample - * @param embedding - * pointer to embedding array of size (embedding_n, n_components) that has been created with fit() - * @param embedding_n - * n_samples in embedding array - * @param out - * pointer to array for storing output embeddings (n, n_components) - */ - void transform(float *X, int n, int d, int64_t *knn_indices, float *knn_dists, - float *embedding, int embedding_n, float *out); - - /** - * Get the UMAPParams instance - */ - UMAPParams *get_params(); -}; +void fit_sparse(const raft::handle_t &handle, + int *indptr, // input matrix + int *indices, float *data, size_t nnz, float *y, + int n, // rows + int d, // cols + UMAPParams *params, float *embeddings); +} // namespace UMAP } // namespace ML diff --git a/cpp/include/cuml/neighbors/knn_sparse.hpp b/cpp/include/cuml/neighbors/knn_sparse.hpp new file mode 100644 index 0000000000..ccca8011b8 --- /dev/null +++ b/cpp/include/cuml/neighbors/knn_sparse.hpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +#include + +namespace ML { +namespace Sparse { + +constexpr int DEFAULT_BATCH_SIZE = 2 << 16; + +void brute_force_knn(raft::handle_t &handle, const int *idx_indptr, + const int *idx_indices, const float *idx_data, + size_t idx_nnz, int n_idx_rows, int n_idx_cols, + const int *query_indptr, const int *query_indices, + const float *query_data, size_t query_nnz, + int n_query_rows, int n_query_cols, int *output_indices, + float *output_dists, int k, + size_t batch_size_index = DEFAULT_BATCH_SIZE, + size_t batch_size_query = DEFAULT_BATCH_SIZE, + ML::MetricType metric = ML::MetricType::METRIC_L2, + float metricArg = 0, bool expanded_form = false); +}; // end namespace Sparse +}; // end namespace ML diff --git a/cpp/src/knn/knn_sparse.cu b/cpp/src/knn/knn_sparse.cu new file mode 100644 index 0000000000..0a73f54211 --- /dev/null +++ b/cpp/src/knn/knn_sparse.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include + +#include + +namespace ML { +namespace Sparse { + +void brute_force_knn(raft::handle_t &handle, const int *idx_indptr, + const int *idx_indices, const float *idx_data, + size_t idx_nnz, int n_idx_rows, int n_idx_cols, + const int *query_indptr, const int *query_indices, + const float *query_data, size_t query_nnz, + int n_query_rows, int n_query_cols, int *output_indices, + float *output_dists, int k, + size_t batch_size_index, // approx 1M + size_t batch_size_query, ML::MetricType metric, + float metricArg, bool expanded_form) { + auto d_alloc = handle.get_device_allocator(); + cusparseHandle_t cusparse_handle = handle.get_cusparse_handle(); + cudaStream_t stream = handle.get_stream(); + + MLCommon::Sparse::Selection::brute_force_knn( + idx_indptr, idx_indices, idx_data, idx_nnz, n_idx_rows, n_idx_cols, + query_indptr, query_indices, query_data, query_nnz, n_query_rows, + n_query_cols, output_indices, output_dists, k, cusparse_handle, d_alloc, + stream, batch_size_index, batch_size_query, metric, metricArg, + expanded_form); +} +}; // namespace Sparse +}; // namespace ML diff --git a/cpp/src/umap/fuzzy_simpl_set/naive.cuh b/cpp/src/umap/fuzzy_simpl_set/naive.cuh index 8028dc80b6..d40f127524 100644 --- a/cpp/src/umap/fuzzy_simpl_set/naive.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/naive.cuh @@ -76,10 +76,10 @@ static const float MIN_K_DIST_SCALE = 1e-3; * Descriptions adapted from: https://github.com/lmcinnes/umap/blob/master/umap/umap_.py * */ -template +template __global__ void smooth_knn_dist_kernel( - const T *knn_dists, int n, float mean_dist, T *sigmas, - T *rhos, // Size of n, iniitalized to zeros + const value_t *knn_dists, int n, float mean_dist, value_t *sigmas, + value_t *rhos, // Size of n, iniitalized to zeros int n_neighbors, float local_connectivity = 1.0, int n_iter = 64, float bandwidth = 1.0) { // row-based matrix 1 thread per row @@ -179,21 +179,22 @@ __global__ void smooth_knn_dist_kernel( * @param knn_dists: the knn distance matrix of size (n, k) * @param sigmas: array of size n representing distance to kth nearest neighbor * @param rhos: array of size n representing distance to the first nearest neighbor - * @param vals: T array of size n*k - * @param rows: int64_t array of size n - * @param cols: int64_t array of size k + * @param vals: value_t array of size n*k + * @param rows: value_idx array of size n + * @param cols: value_idx array of size k * @param n Number of samples (rows in knn indices/distances) * @param n_neighbors number of columns in knn indices/distances * * Descriptions adapted from: https://github.com/lmcinnes/umap/blob/master/umap/umap_.py */ -template +template __global__ void compute_membership_strength_kernel( - const int64_t *knn_indices, - const float *knn_dists, // nn outputs - const T *sigmas, const T *rhos, // continuous dists to nearest neighbors - T *vals, int *rows, int *cols, // result coo - int n, int n_neighbors) { // model params + const value_idx *knn_indices, + const float *knn_dists, // nn outputs + const value_t *sigmas, + const value_t *rhos, // continuous dists to nearest neighbors + value_t *vals, int *rows, int *cols, // result coo + int n, int n_neighbors) { // model params // row-based matrix is best int idx = (blockIdx.x * TPB_X) + threadIdx.x; @@ -204,7 +205,7 @@ __global__ void compute_membership_strength_kernel( double cur_rho = rhos[row]; double cur_sigma = sigmas[row]; - int64_t cur_knn_ind = knn_indices[idx]; + value_idx cur_knn_ind = knn_indices[idx]; double cur_knn_dist = knn_dists[idx]; if (cur_knn_ind != -1) { @@ -229,22 +230,23 @@ __global__ void compute_membership_strength_kernel( /* * Sets up and runs the knn dist smoothing */ -template -void smooth_knn_dist(int n, const int64_t *knn_indices, const float *knn_dists, - T *rhos, T *sigmas, UMAPParams *params, int n_neighbors, +template +void smooth_knn_dist(int n, const value_idx *knn_indices, + const float *knn_dists, value_t *rhos, value_t *sigmas, + UMAPParams *params, int n_neighbors, float local_connectivity, std::shared_ptr d_alloc, cudaStream_t stream) { dim3 grid(raft::ceildiv(n, TPB_X), 1, 1); dim3 blk(TPB_X, 1, 1); - MLCommon::device_buffer dist_means_dev(d_alloc, stream, n_neighbors); + MLCommon::device_buffer dist_means_dev(d_alloc, stream, n_neighbors); raft::stats::mean(dist_means_dev.data(), knn_dists, 1, n_neighbors * n, false, false, stream); CUDA_CHECK(cudaPeekAtLastError()); - T mean_dist = 0.0; + value_t mean_dist = 0.0; raft::update_host(&mean_dist, dist_means_dev.data(), 1, stream); CUDA_CHECK(cudaStreamSynchronize(stream)); @@ -273,24 +275,24 @@ void smooth_knn_dist(int n, const int64_t *knn_indices, const float *knn_dists, * @param d_alloc the device allocator to use for temp memory * @param stream cuda stream to use for device operations */ -template -void launcher(int n, const int64_t *knn_indices, const float *knn_dists, - int n_neighbors, MLCommon::Sparse::COO *out, +template +void launcher(int n, const value_idx *knn_indices, const float *knn_dists, + int n_neighbors, MLCommon::Sparse::COO *out, UMAPParams *params, std::shared_ptr d_alloc, cudaStream_t stream) { /** * Calculate mean distance through a parallel reduction */ - MLCommon::device_buffer sigmas(d_alloc, stream, n); - MLCommon::device_buffer rhos(d_alloc, stream, n); - CUDA_CHECK(cudaMemsetAsync(sigmas.data(), 0, n * sizeof(T), stream)); - CUDA_CHECK(cudaMemsetAsync(rhos.data(), 0, n * sizeof(T), stream)); + MLCommon::device_buffer sigmas(d_alloc, stream, n); + MLCommon::device_buffer rhos(d_alloc, stream, n); + CUDA_CHECK(cudaMemsetAsync(sigmas.data(), 0, n * sizeof(value_t), stream)); + CUDA_CHECK(cudaMemsetAsync(rhos.data(), 0, n * sizeof(value_t), stream)); - smooth_knn_dist(n, knn_indices, knn_dists, rhos.data(), - sigmas.data(), params, n_neighbors, - params->local_connectivity, d_alloc, stream); + smooth_knn_dist( + n, knn_indices, knn_dists, rhos.data(), sigmas.data(), params, n_neighbors, + params->local_connectivity, d_alloc, stream); - MLCommon::Sparse::COO in(d_alloc, stream, n * n_neighbors, n, n); + MLCommon::Sparse::COO in(d_alloc, stream, n * n_neighbors, n, n); // check for logging in order to avoid the potentially costly `arr2Str` call! if (ML::Logger::get().shouldLogFor(CUML_LEVEL_DEBUG)) { @@ -327,17 +329,18 @@ void launcher(int n, const int64_t *knn_indices, const float *knn_dists, * one via a fuzzy union. (Symmetrize knn graph). */ float set_op_mix_ratio = params->set_op_mix_ratio; - MLCommon::Sparse::coo_symmetrize( + MLCommon::Sparse::coo_symmetrize( &in, out, - [set_op_mix_ratio] __device__(int row, int col, T result, T transpose) { - T prod_matrix = result * transpose; - T res = set_op_mix_ratio * (result + transpose - prod_matrix) + - (1.0 - set_op_mix_ratio) * prod_matrix; + [set_op_mix_ratio] __device__(int row, int col, value_t result, + value_t transpose) { + value_t prod_matrix = result * transpose; + value_t res = set_op_mix_ratio * (result + transpose - prod_matrix) + + (1.0 - set_op_mix_ratio) * prod_matrix; return res; }, d_alloc, stream); - MLCommon::Sparse::coo_sort(out, d_alloc, stream); + MLCommon::Sparse::coo_sort(out, d_alloc, stream); } } // namespace Naive } // namespace FuzzySimplSet diff --git a/cpp/src/umap/fuzzy_simpl_set/runner.cuh b/cpp/src/umap/fuzzy_simpl_set/runner.cuh index 005630abaf..aba1bbf883 100644 --- a/cpp/src/umap/fuzzy_simpl_set/runner.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/runner.cuh @@ -39,15 +39,15 @@ using namespace ML; * @param stream cuda stream * @param algorithm algo type to choose */ -template -void run(int n, const int64_t *knn_indices, const T *knn_dists, int n_neighbors, - MLCommon::Sparse::COO *coo, UMAPParams *params, +template +void run(int n, const value_idx *knn_indices, const T *knn_dists, + int n_neighbors, MLCommon::Sparse::COO *coo, UMAPParams *params, std::shared_ptr alloc, cudaStream_t stream, int algorithm = 0) { switch (algorithm) { case 0: - Naive::launcher(n, knn_indices, knn_dists, n_neighbors, coo, - params, alloc, stream); + Naive::launcher( + n, knn_indices, knn_dists, n_neighbors, coo, params, alloc, stream); break; } } diff --git a/cpp/src/umap/init_embed/random_algo.cuh b/cpp/src/umap/init_embed/random_algo.cuh index b555dda35b..81fdc0dc58 100644 --- a/cpp/src/umap/init_embed/random_algo.cuh +++ b/cpp/src/umap/init_embed/random_algo.cuh @@ -27,10 +27,9 @@ namespace RandomInit { using namespace ML; -template -void launcher(const T *X, int n, int d, const int64_t *knn_indices, - const T *knn_dists, UMAPParams *params, T *embedding, - cudaStream_t stream) { +template +void launcher(int n, int d, const value_idx *knn_indices, const T *knn_dists, + UMAPParams *params, T *embedding, cudaStream_t stream) { uint64_t seed = params->random_state; raft::random::Rng r(seed); diff --git a/cpp/src/umap/init_embed/runner.cuh b/cpp/src/umap/init_embed/runner.cuh index 6d18592e78..5045c8c8af 100644 --- a/cpp/src/umap/init_embed/runner.cuh +++ b/cpp/src/umap/init_embed/runner.cuh @@ -29,9 +29,9 @@ namespace InitEmbed { using namespace ML; -template -void run(const raft::handle_t &handle, const T *X, int n, int d, - const int64_t *knn_indices, const T *knn_dists, +template +void run(const raft::handle_t &handle, int n, int d, + const value_idx *knn_indices, const T *knn_dists, MLCommon::Sparse::COO *coo, UMAPParams *params, T *embedding, cudaStream_t stream, int algo = 0) { switch (algo) { @@ -39,13 +39,13 @@ void run(const raft::handle_t &handle, const T *X, int n, int d, * Initial algo uses FAISS indices */ case 0: - RandomInit::launcher(X, n, d, knn_indices, knn_dists, params, embedding, + RandomInit::launcher(n, d, knn_indices, knn_dists, params, embedding, handle.get_stream()); break; case 1: - SpectralInit::launcher(handle, X, n, d, knn_indices, knn_dists, coo, - params, embedding); + SpectralInit::launcher(handle, n, d, knn_indices, knn_dists, coo, params, + embedding); break; } } diff --git a/cpp/src/umap/init_embed/spectral_algo.cuh b/cpp/src/umap/init_embed/spectral_algo.cuh index 86adcc1e3e..5f6175ee95 100644 --- a/cpp/src/umap/init_embed/spectral_algo.cuh +++ b/cpp/src/umap/init_embed/spectral_algo.cuh @@ -40,9 +40,9 @@ using namespace ML; /** * Performs a spectral layout initialization */ -template -void launcher(const raft::handle_t &handle, const T *X, int n, int d, - const int64_t *knn_indices, const T *knn_dists, +template +void launcher(const raft::handle_t &handle, int n, int d, + const value_idx *knn_indices, const T *knn_dists, MLCommon::Sparse::COO *coo, UMAPParams *params, T *embedding) { cudaStream_t stream = handle.get_stream(); diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 53b74eae8d..1fd7c15356 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -15,44 +15,121 @@ */ #include +#include +#include #include #include #include +#include + +#include + +#include +#include #pragma once namespace UMAPAlgo { - namespace kNNGraph { - namespace Algo { -using namespace ML; - /** * Initial implementation calls out to FAISS to do its work. - * TODO: cuML kNN implementation should support FAISS' approx NN variants (e.g. IVFPQ GPU). */ -/** - * void brute_force_knn(float **input, int *sizes, int n_params, IntType D, - float *search_items, IntType n, int64_t *res_I, float *res_D, - IntType k, cudaStream_t s) - */ -template -void launcher(float *X, int x_n, float *X_query, int x_q_n, int d, - int64_t *knn_indices, T *knn_dists, int n_neighbors, - UMAPParams *params, std::shared_ptr d_alloc, +template +void launcher(const raft::handle_t &handle, const umap_inputs &inputsA, + const umap_inputs &inputsB, + ML::knn_graph &out, int n_neighbors, + const ML::UMAPParams *params, + std::shared_ptr d_alloc, + cudaStream_t stream); + +// Instantiation for dense inputs, int64_t indices +template <> +void launcher(const raft::handle_t &handle, + const ML::manifold_dense_inputs_t &inputsA, + const ML::manifold_dense_inputs_t &inputsB, + ML::knn_graph &out, int n_neighbors, + const ML::UMAPParams *params, + std::shared_ptr d_alloc, cudaStream_t stream) { std::vector ptrs(1); std::vector sizes(1); - ptrs[0] = X; - sizes[0] = x_n; + ptrs[0] = inputsA.X; + sizes[0] = inputsA.n; + + MLCommon::Selection::brute_force_knn( + ptrs, sizes, inputsA.d, inputsB.X, inputsB.n, out.knn_indices, + out.knn_dists, n_neighbors, d_alloc, stream); +} - MLCommon::Selection::brute_force_knn(ptrs, sizes, d, X_query, x_q_n, - knn_indices, knn_dists, n_neighbors, - d_alloc, stream); +// Instantiation for dense inputs, int indices +template <> +void launcher(const raft::handle_t &handle, + const ML::manifold_dense_inputs_t &inputsA, + const ML::manifold_dense_inputs_t &inputsB, + ML::knn_graph &out, int n_neighbors, + const ML::UMAPParams *params, + std::shared_ptr d_alloc, + cudaStream_t stream) { + throw raft::exception("Dense KNN doesn't yet support 32-bit integer indices"); } + +template <> +void launcher(const raft::handle_t &handle, + const ML::manifold_sparse_inputs_t &inputsA, + const ML::manifold_sparse_inputs_t &inputsB, + ML::knn_graph &out, int n_neighbors, + const ML::UMAPParams *params, + std::shared_ptr d_alloc, + cudaStream_t stream) { + MLCommon::Sparse::Selection::brute_force_knn( + inputsA.indptr, inputsA.indices, inputsA.data, inputsA.nnz, inputsA.n, + inputsA.d, inputsB.indptr, inputsB.indices, inputsB.data, inputsB.nnz, + inputsB.n, inputsB.d, out.knn_indices, out.knn_dists, n_neighbors, + handle.get_cusparse_handle(), d_alloc, stream, + ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, + ML::MetricType::METRIC_L2); +} + +template <> +void launcher(const raft::handle_t &handle, + const ML::manifold_sparse_inputs_t &inputsA, + const ML::manifold_sparse_inputs_t &inputsB, + ML::knn_graph &out, int n_neighbors, + const ML::UMAPParams *params, + std::shared_ptr d_alloc, + cudaStream_t stream) { + throw raft::exception("Sparse KNN doesn't support 64-bit integer indices"); +} + +template <> +void launcher( + const raft::handle_t &handle, + const ML::manifold_precomputed_knn_inputs_t &inputsA, + const ML::manifold_precomputed_knn_inputs_t &inputsB, + ML::knn_graph &out, int n_neighbors, + const ML::UMAPParams *params, std::shared_ptr d_alloc, + cudaStream_t stream) { + out.knn_indices = inputsA.knn_graph.knn_indices; + out.knn_dists = inputsA.knn_graph.knn_dists; +} + +// Instantiation for precomputed inputs, int indices +template <> +void launcher(const raft::handle_t &handle, + const ML::manifold_precomputed_knn_inputs_t &inputsA, + const ML::manifold_precomputed_knn_inputs_t &inputsB, + ML::knn_graph &out, int n_neighbors, + const ML::UMAPParams *params, + std::shared_ptr d_alloc, + cudaStream_t stream) { + out.knn_indices = inputsA.knn_graph.knn_indices; + out.knn_dists = inputsA.knn_graph.knn_dists; +} + } // namespace Algo } // namespace kNNGraph }; // namespace UMAPAlgo diff --git a/cpp/src/umap/knn_graph/runner.cuh b/cpp/src/umap/knn_graph/runner.cuh index 5c529d3a61..4c0005c68e 100644 --- a/cpp/src/umap/knn_graph/runner.cuh +++ b/cpp/src/umap/knn_graph/runner.cuh @@ -14,7 +14,8 @@ * limitations under the License. */ -#include "algo.cuh" +#include +#include #pragma once @@ -28,24 +29,26 @@ using namespace ML; * @brief This function performs a k-nearest neighbors against * the input algorithm using the specified knn algorithm. * Only algorithm supported at the moment is brute force - * knn primitive. - * @tparam T: Type of input, query, and dist matrices. Usually float - * @param X: Matrix to query (size n x d) in row-major format - * @param n: Number of rows in X - * @param query: Search matrix in row-major format - * @param q_n: Number of rows in query matrix - * @param d: Number of columns in X and query matrices - * @param knn_indices: Return indices matrix (size n*k) - * @param knn_dists: Return dists matrix (size n*k) - * @param n_neighbors: Number of closest neighbors, k, to query - * @param params: Instance of UMAPParam settings - * @param d_alloc: device allocator - * @param stream: cuda stream to use - * @param algo: Algorithm to use. Currently only brute force is supported + * knn primitive. + * @tparam value_idx: Type of knn indices matrix. Usually an integral type. + * @tparam value_t: Type of input, query, and dist matrices. Usually float + * @param[in] X: Matrix to query (size n x d) in row-major format + * @param[in] n: Number of rows in X + * @param[in] query: Search matrix in row-major format + * @param[in] q_n: Number of rows in query matrix + * @param[in] d: Number of columns in X and query matrices + * @param[out] knn_graph : output knn_indices and knn_dists (size n*k) + * @param[in] n_neighbors: Number of closest neighbors, k, to query + * @param[in] params: Instance of UMAPParam settings + * @param[in] d_alloc: device allocator + * @param[in] stream: cuda stream to use + * @param[in] algo: Algorithm to use. Currently only brute force is supported */ -template -void run(T *X, int n, T *query, int q_n, int d, int64_t *knn_indices, - T *knn_dists, int n_neighbors, UMAPParams *params, +template +void run(const raft::handle_t &handle, const umap_inputs &inputsA, + const umap_inputs &inputsB, knn_graph &out, + int n_neighbors, const UMAPParams *params, std::shared_ptr d_alloc, cudaStream_t stream, int algo = 0) { switch (algo) { @@ -53,10 +56,11 @@ void run(T *X, int n, T *query, int q_n, int d, int64_t *knn_indices, * Initial algo uses FAISS indices */ case 0: - Algo::launcher(X, n, query, q_n, d, knn_indices, knn_dists, n_neighbors, - params, d_alloc, stream); + Algo::launcher( + handle, inputsA, inputsB, out, n_neighbors, params, d_alloc, stream); break; } } + } // namespace kNNGraph }; // namespace UMAPAlgo diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index d9c8f6733a..6c48521e9d 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -18,6 +18,7 @@ #include #include +#include #include "optimize.cuh" #include "supervised.cuh" @@ -26,6 +27,8 @@ #include "knn_graph/runner.cuh" #include "simpl_set_embed/runner.cuh" +#include + #include #include #include @@ -79,13 +82,9 @@ void find_ab(UMAPParams *params, std::shared_ptr d_alloc, Optimize::find_params_ab(params, d_alloc, stream); } -template -void _fit(const raft::handle_t &handle, - T *X, // input matrix - int n, // rows - int d, // cols - int64_t *knn_indices, T *knn_dists, UMAPParams *params, - T *embeddings) { +template +void _fit(const raft::handle_t &handle, const umap_inputs &inputs, + UMAPParams *params, value_t *embeddings) { ML::PUSH_RANGE("umap::unsupervised::fit"); cudaStream_t stream = handle.get_stream(); auto d_alloc = handle.get_device_allocator(); @@ -97,74 +96,77 @@ void _fit(const raft::handle_t &handle, CUML_LOG_DEBUG("n_neighbors=%d", params->n_neighbors); ML::PUSH_RANGE("umap::knnGraph"); - MLCommon::device_buffer *knn_indices_b = nullptr; - MLCommon::device_buffer *knn_dists_b = nullptr; + std::unique_ptr> knn_indices_b = nullptr; + std::unique_ptr> knn_dists_b = nullptr; - if (!knn_indices || !knn_dists) { - ASSERT(!knn_indices && !knn_dists, - "Either both or none of the KNN parameters should be provided"); + knn_graph knn_graph(inputs.n, k); + /** + * If not given precomputed knn graph, compute it + */ + if (inputs.alloc_knn_graph()) { /** * Allocate workspace for kNN graph */ - knn_indices_b = - new MLCommon::device_buffer(d_alloc, stream, n * k); - knn_dists_b = new MLCommon::device_buffer(d_alloc, stream, n * k); - - knn_indices = knn_indices_b->data(); - knn_dists = knn_dists_b->data(); + knn_indices_b = std::make_unique>( + d_alloc, stream, inputs.n * k); + knn_dists_b = std::make_unique>( + d_alloc, stream, inputs.n * k); - kNNGraph::run(X, n, X, n, d, knn_indices, knn_dists, k, params, d_alloc, - stream); - CUDA_CHECK(cudaPeekAtLastError()); + knn_graph.knn_indices = knn_indices_b->data(); + knn_graph.knn_dists = knn_dists_b->data(); } + + CUML_LOG_DEBUG("Calling knn graph run"); + + kNNGraph::run( + handle, inputs, inputs, knn_graph, k, params, d_alloc, stream); ML::POP_RANGE(); + CUML_LOG_DEBUG("Done. Calling fuzzy simplicial set"); + ML::PUSH_RANGE("umap::simplicial_set"); - COO rgraph_coo(d_alloc, stream); - FuzzySimplSet::run(n, knn_indices, knn_dists, k, &rgraph_coo, - params, d_alloc, stream); + COO rgraph_coo(d_alloc, stream); + FuzzySimplSet::run( + inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &rgraph_coo, + params, d_alloc, stream); + CUML_LOG_DEBUG("Done. Calling remove zeros"); /** * Remove zeros from simplicial set */ - COO cgraph_coo(d_alloc, stream); - MLCommon::Sparse::coo_remove_zeros(&rgraph_coo, &cgraph_coo, - d_alloc, stream); + COO cgraph_coo(d_alloc, stream); + MLCommon::Sparse::coo_remove_zeros(&rgraph_coo, &cgraph_coo, + d_alloc, stream); ML::POP_RANGE(); /** * Run initialization method */ ML::PUSH_RANGE("umap::embedding"); - InitEmbed::run(handle, X, n, d, knn_indices, knn_dists, &cgraph_coo, params, - embeddings, stream, params->init); - - if (knn_indices_b) delete knn_indices_b; - if (knn_dists_b) delete knn_dists_b; + InitEmbed::run(handle, inputs.n, inputs.d, knn_graph.knn_indices, + knn_graph.knn_dists, &cgraph_coo, params, embeddings, stream, + params->init); if (params->callback) { - params->callback->setup(n, params->n_components); + params->callback->setup(inputs.n, params->n_components); params->callback->on_preprocess_end(embeddings); } /** * Run simplicial set embedding to approximate low-dimensional representation */ - SimplSetEmbed::run(X, n, d, &cgraph_coo, params, embeddings, - d_alloc, stream); + SimplSetEmbed::run(inputs.n, inputs.d, &cgraph_coo, params, + embeddings, d_alloc, stream); ML::POP_RANGE(); if (params->callback) params->callback->on_train_end(embeddings); ML::POP_RANGE(); } -template -void _fit(const raft::handle_t &handle, - T *X, // input matrix - T *y, // labels - int n, int d, int64_t *knn_indices, T *knn_dists, UMAPParams *params, - T *embeddings) { +template +void _fit_supervised(const raft::handle_t &handle, const umap_inputs &inputs, + UMAPParams *params, value_t *embeddings) { ML::PUSH_RANGE("umap::supervised::fit"); auto d_alloc = handle.get_device_allocator(); cudaStream_t stream = handle.get_stream(); @@ -177,48 +179,52 @@ void _fit(const raft::handle_t &handle, params->target_n_neighbors = params->n_neighbors; ML::PUSH_RANGE("umap::knnGraph"); - MLCommon::device_buffer *knn_indices_b = nullptr; - MLCommon::device_buffer *knn_dists_b = nullptr; + std::unique_ptr> knn_indices_b = nullptr; + std::unique_ptr> knn_dists_b = nullptr; - if (!knn_indices || !knn_dists) { - ASSERT(!knn_indices && !knn_dists, - "Either both or none of the KNN parameters should be provided"); + knn_graph knn_graph(inputs.n, k); + /** + * If not given precomputed knn graph, compute it + */ + if (inputs.alloc_knn_graph()) { /** * Allocate workspace for kNN graph */ - knn_indices_b = - new MLCommon::device_buffer(d_alloc, stream, n * k); - knn_dists_b = new MLCommon::device_buffer(d_alloc, stream, n * k); + knn_indices_b = std::make_unique>( + d_alloc, stream, inputs.n * k); + knn_dists_b = std::make_unique>( + d_alloc, stream, inputs.n * k); - knn_indices = knn_indices_b->data(); - knn_dists = knn_dists_b->data(); - - kNNGraph::run(X, n, X, n, d, knn_indices, knn_dists, k, params, d_alloc, - stream); - CUDA_CHECK(cudaPeekAtLastError()); + knn_graph.knn_indices = knn_indices_b->data(); + knn_graph.knn_dists = knn_dists_b->data(); } + + kNNGraph::run( + handle, inputs, inputs, knn_graph, k, params, d_alloc, stream); + ML::POP_RANGE(); /** * Allocate workspace for fuzzy simplicial set. */ ML::PUSH_RANGE("umap::simplicial_set"); - COO rgraph_coo(d_alloc, stream); - COO tmp_coo(d_alloc, stream); + COO rgraph_coo(d_alloc, stream); + COO tmp_coo(d_alloc, stream); /** * Run Fuzzy simplicial set */ //int nnz = n*k*2; - FuzzySimplSet::run(n, knn_indices, knn_dists, params->n_neighbors, - &tmp_coo, params, d_alloc, stream); + FuzzySimplSet::run( + inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, params->n_neighbors, + &tmp_coo, params, d_alloc, stream); CUDA_CHECK(cudaPeekAtLastError()); - MLCommon::Sparse::coo_remove_zeros(&tmp_coo, &rgraph_coo, d_alloc, - stream); + MLCommon::Sparse::coo_remove_zeros(&tmp_coo, &rgraph_coo, + d_alloc, stream); - COO final_coo(d_alloc, stream); + COO final_coo(d_alloc, stream); /** * If target metric is 'categorical', perform @@ -226,48 +232,46 @@ void _fit(const raft::handle_t &handle, */ if (params->target_metric == ML::UMAPParams::MetricType::CATEGORICAL) { CUML_LOG_DEBUG("Performing categorical intersection"); - Supervised::perform_categorical_intersection( - y, &rgraph_coo, &final_coo, params, d_alloc, stream); + Supervised::perform_categorical_intersection( + inputs.y, &rgraph_coo, &final_coo, params, d_alloc, stream); /** * Otherwise, perform general simplicial set intersection */ } else { CUML_LOG_DEBUG("Performing general intersection"); - Supervised::perform_general_intersection( - handle, y, &rgraph_coo, &final_coo, params, stream); + Supervised::perform_general_intersection( + handle, inputs.y, &rgraph_coo, &final_coo, params, stream); } /** * Remove zeros */ - MLCommon::Sparse::coo_sort(&final_coo, d_alloc, stream); + MLCommon::Sparse::coo_sort(&final_coo, d_alloc, stream); - COO ocoo(d_alloc, stream); - MLCommon::Sparse::coo_remove_zeros(&final_coo, &ocoo, d_alloc, - stream); + COO ocoo(d_alloc, stream); + MLCommon::Sparse::coo_remove_zeros(&final_coo, &ocoo, d_alloc, + stream); ML::POP_RANGE(); /** * Initialize embeddings */ ML::PUSH_RANGE("umap::supervised::fit"); - InitEmbed::run(handle, X, n, d, knn_indices, knn_dists, &ocoo, params, - embeddings, stream, params->init); - - if (knn_indices_b) delete knn_indices_b; - if (knn_dists_b) delete knn_dists_b; + InitEmbed::run(handle, inputs.n, inputs.d, knn_graph.knn_indices, + knn_graph.knn_dists, &ocoo, params, embeddings, stream, + params->init); if (params->callback) { - params->callback->setup(n, params->n_components); + params->callback->setup(inputs.n, params->n_components); params->callback->on_preprocess_end(embeddings); } /** * Run simplicial set embedding to approximate low-dimensional representation */ - SimplSetEmbed::run(X, n, d, &ocoo, params, embeddings, d_alloc, - stream); + SimplSetEmbed::run(inputs.n, inputs.d, &ocoo, params, + embeddings, d_alloc, stream); ML::POP_RANGE(); if (params->callback) params->callback->on_train_end(embeddings); @@ -279,11 +283,10 @@ void _fit(const raft::handle_t &handle, /** * */ -template -void _transform(const raft::handle_t &handle, T *X, int n, int d, - int64_t *knn_indices, float *knn_dists, T *orig_X, int orig_n, - T *embedding, int embedding_n, UMAPParams *params, - T *transformed) { +template +void _transform(const raft::handle_t &handle, const umap_inputs &inputs, + umap_inputs &orig_x_inputs, value_t *embedding, int embedding_n, + UMAPParams *params, value_t *transformed) { ML::PUSH_RANGE("umap::transform"); auto d_alloc = handle.get_device_allocator(); cudaStream_t stream = handle.get_stream(); @@ -295,30 +298,33 @@ void _transform(const raft::handle_t &handle, T *X, int n, int d, CUML_LOG_DEBUG("Building KNN Graph"); ML::PUSH_RANGE("umap::knnGraph"); - MLCommon::device_buffer *knn_indices_b = nullptr; - MLCommon::device_buffer *knn_dists_b = nullptr; + std::unique_ptr> knn_indices_b = nullptr; + std::unique_ptr> knn_dists_b = nullptr; + + int k = params->n_neighbors; - if (!knn_indices || !knn_dists) { - ASSERT(!knn_indices && !knn_dists, - "Either both or none of the KNN parameters should be provided"); + knn_graph knn_graph(inputs.n, k); + /** + * If not given precomputed knn graph, compute it + */ + + if (inputs.alloc_knn_graph()) { /** * Allocate workspace for kNN graph */ + knn_indices_b = std::make_unique>( + d_alloc, stream, inputs.n * k); + knn_dists_b = std::make_unique>( + d_alloc, stream, inputs.n * k); - int k = params->n_neighbors; - - knn_indices_b = - new MLCommon::device_buffer(d_alloc, stream, n * k); - knn_dists_b = new MLCommon::device_buffer(d_alloc, stream, n * k); + knn_graph.knn_indices = knn_indices_b->data(); + knn_graph.knn_dists = knn_dists_b->data(); + } - knn_indices = knn_indices_b->data(); - knn_dists = knn_dists_b->data(); + kNNGraph::run( + handle, orig_x_inputs, inputs, knn_graph, k, params, d_alloc, stream); - kNNGraph::run(orig_X, orig_n, X, n, d, knn_indices, knn_dists, - params->n_neighbors, params, d_alloc, stream); - CUDA_CHECK(cudaPeekAtLastError()); - } ML::POP_RANGE(); ML::PUSH_RANGE("umap::smooth_knn"); @@ -330,24 +336,27 @@ void _transform(const raft::handle_t &handle, T *X, int n, int d, /** * Perform smooth_knn_dist */ - MLCommon::device_buffer sigmas(d_alloc, stream, n); - MLCommon::device_buffer rhos(d_alloc, stream, n); - CUDA_CHECK(cudaMemsetAsync(sigmas.data(), 0, n * sizeof(T), stream)); - CUDA_CHECK(cudaMemsetAsync(rhos.data(), 0, n * sizeof(T), stream)); + MLCommon::device_buffer sigmas(d_alloc, stream, inputs.n); + MLCommon::device_buffer rhos(d_alloc, stream, inputs.n); + CUDA_CHECK( + cudaMemsetAsync(sigmas.data(), 0, inputs.n * sizeof(value_t), stream)); + CUDA_CHECK( + cudaMemsetAsync(rhos.data(), 0, inputs.n * sizeof(value_t), stream)); - dim3 grid_n(raft::ceildiv(n, TPB_X), 1, 1); + dim3 grid_n(raft::ceildiv(inputs.n, TPB_X), 1, 1); dim3 blk(TPB_X, 1, 1); - FuzzySimplSetImpl::smooth_knn_dist( - n, knn_indices, knn_dists, rhos.data(), sigmas.data(), params, - params->n_neighbors, adjusted_local_connectivity, d_alloc, stream); + FuzzySimplSetImpl::smooth_knn_dist( + inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, rhos.data(), + sigmas.data(), params, params->n_neighbors, adjusted_local_connectivity, + d_alloc, stream); ML::POP_RANGE(); /** * Compute graph of membership strengths */ - int nnz = n * params->n_neighbors; + int nnz = inputs.n * params->n_neighbors; dim3 grid_nnz(raft::ceildiv(nnz, TPB_X), 1, 1); @@ -357,36 +366,33 @@ void _transform(const raft::handle_t &handle, T *X, int n, int d, * Allocate workspace for fuzzy simplicial set. */ - COO graph_coo(d_alloc, stream, nnz, n, n); + COO graph_coo(d_alloc, stream, nnz, inputs.n, inputs.n); FuzzySimplSetImpl::compute_membership_strength_kernel - <<>>(knn_indices, knn_dists, sigmas.data(), - rhos.data(), graph_coo.vals(), + <<>>(knn_graph.knn_indices, knn_graph.knn_dists, + sigmas.data(), rhos.data(), graph_coo.vals(), graph_coo.rows(), graph_coo.cols(), graph_coo.n_rows, params->n_neighbors); CUDA_CHECK(cudaPeekAtLastError()); - if (knn_indices_b) delete knn_indices_b; - if (knn_dists_b) delete knn_dists_b; - - MLCommon::device_buffer row_ind(d_alloc, stream, n); - MLCommon::device_buffer ia(d_alloc, stream, n); + MLCommon::device_buffer row_ind(d_alloc, stream, inputs.n); + MLCommon::device_buffer ia(d_alloc, stream, inputs.n); MLCommon::Sparse::sorted_coo_to_csr(&graph_coo, row_ind.data(), d_alloc, stream); MLCommon::Sparse::coo_row_count(&graph_coo, ia.data(), stream); - MLCommon::device_buffer vals_normed(d_alloc, stream, graph_coo.nnz); - CUDA_CHECK( - cudaMemsetAsync(vals_normed.data(), 0, graph_coo.nnz * sizeof(T), stream)); + MLCommon::device_buffer vals_normed(d_alloc, stream, graph_coo.nnz); + CUDA_CHECK(cudaMemsetAsync(vals_normed.data(), 0, + graph_coo.nnz * sizeof(value_t), stream)); CUML_LOG_DEBUG("Performing L1 normalization"); - MLCommon::Sparse::csr_row_normalize_l1( + MLCommon::Sparse::csr_row_normalize_l1( row_ind.data(), graph_coo.vals(), graph_coo.nnz, graph_coo.n_rows, vals_normed.data(), stream); - init_transform<<>>( + init_transform<<>>( graph_coo.cols(), vals_normed.data(), graph_coo.n_rows, embedding, embedding_n, params->n_components, transformed, params->n_neighbors); CUDA_CHECK(cudaPeekAtLastError()); @@ -399,13 +405,14 @@ void _transform(const raft::handle_t &handle, T *X, int n, int d, * Go through COO values and set everything that's less than * vals.max() / params->n_epochs to 0.0 */ - thrust::device_ptr d_ptr = thrust::device_pointer_cast(graph_coo.vals()); - T max = + thrust::device_ptr d_ptr = + thrust::device_pointer_cast(graph_coo.vals()); + value_t max = *(thrust::max_element(thrust::cuda::par.on(stream), d_ptr, d_ptr + nnz)); int n_epochs = params->n_epochs; if (n_epochs <= 0) { - if (n <= 10000) + if (inputs.n <= 10000) n_epochs = 100; else n_epochs = 30; @@ -415,9 +422,9 @@ void _transform(const raft::handle_t &handle, T *X, int n, int d, CUML_LOG_DEBUG("n_epochs=%d", n_epochs); - raft::linalg::unaryOp( + raft::linalg::unaryOp( graph_coo.vals(), graph_coo.vals(), graph_coo.nnz, - [=] __device__(T input) { + [=] __device__(value_t input) { if (input < (max / float(n_epochs))) return 0.0f; else @@ -430,14 +437,14 @@ void _transform(const raft::handle_t &handle, T *X, int n, int d, /** * Remove zeros */ - MLCommon::Sparse::COO comp_coo(d_alloc, stream); - MLCommon::Sparse::coo_remove_zeros(&graph_coo, &comp_coo, d_alloc, - stream); + MLCommon::Sparse::COO comp_coo(d_alloc, stream); + MLCommon::Sparse::coo_remove_zeros(&graph_coo, &comp_coo, + d_alloc, stream); ML::PUSH_RANGE("umap::optimization"); CUML_LOG_DEBUG("Computing # of epochs for training each sample"); - MLCommon::device_buffer epochs_per_sample(d_alloc, stream, nnz); + MLCommon::device_buffer epochs_per_sample(d_alloc, stream, nnz); SimplSetEmbedImpl::make_epochs_per_sample( comp_coo.vals(), comp_coo.nnz, n_epochs, epochs_per_sample.data(), stream); @@ -445,17 +452,17 @@ void _transform(const raft::handle_t &handle, T *X, int n, int d, CUML_LOG_DEBUG("Performing optimization"); if (params->callback) { - params->callback->setup(n, params->n_components); + params->callback->setup(inputs.n, params->n_components); params->callback->on_preprocess_end(transformed); } params->initial_alpha /= 4.0; // TODO: This value should be passed into "optimize layout" directly to avoid side-effects. - SimplSetEmbedImpl::optimize_layout( - transformed, n, embedding, embedding_n, comp_coo.rows(), comp_coo.cols(), - comp_coo.nnz, epochs_per_sample.data(), n, params->repulsion_strength, - params, n_epochs, d_alloc, stream); + SimplSetEmbedImpl::optimize_layout( + transformed, inputs.n, embedding, embedding_n, comp_coo.rows(), + comp_coo.cols(), comp_coo.nnz, epochs_per_sample.data(), inputs.n, + params->repulsion_strength, params, n_epochs, d_alloc, stream); ML::POP_RANGE(); if (params->callback) params->callback->on_train_end(transformed); diff --git a/cpp/src/umap/simpl_set_embed/runner.cuh b/cpp/src/umap/simpl_set_embed/runner.cuh index 4fad6955ce..c8b95b0842 100644 --- a/cpp/src/umap/simpl_set_embed/runner.cuh +++ b/cpp/src/umap/simpl_set_embed/runner.cuh @@ -28,10 +28,9 @@ namespace SimplSetEmbed { using namespace ML; template -void run(const T *X, int m, int n, MLCommon::Sparse::COO *coo, - UMAPParams *params, T *embedding, - std::shared_ptr alloc, cudaStream_t stream, - int algorithm = 0) { +void run(int m, int n, MLCommon::Sparse::COO *coo, UMAPParams *params, + T *embedding, std::shared_ptr alloc, + cudaStream_t stream, int algorithm = 0) { switch (algorithm) { case 0: SimplSetEmbed::Algo::launcher(m, n, coo, params, embedding, diff --git a/cpp/src/umap/supervised.cuh b/cpp/src/umap/supervised.cuh index 2e455ba528..5fe50c8cf5 100644 --- a/cpp/src/umap/supervised.cuh +++ b/cpp/src/umap/supervised.cuh @@ -97,25 +97,25 @@ void reset_local_connectivity(COO *in_coo, COO *out_coo, * and this will update the fuzzy simplicial set to respect that label * data. */ -template -void categorical_simplicial_set_intersection(COO *graph_coo, T *target, +template +void categorical_simplicial_set_intersection(COO *graph_coo, + value_t *target, cudaStream_t stream, float far_dist = 5.0, float unknown_dist = 1.0) { dim3 grid(raft::ceildiv(graph_coo->nnz, TPB_X), 1, 1); dim3 blk(TPB_X, 1, 1); - fast_intersection_kernel<<>>( + fast_intersection_kernel<<>>( graph_coo->rows(), graph_coo->cols(), graph_coo->vals(), graph_coo->nnz, target, unknown_dist, far_dist); } -template -__global__ void sset_intersection_kernel(int *row_ind1, int *cols1, T *vals1, - int nnz1, int *row_ind2, int *cols2, - T *vals2, int nnz2, int *result_ind, - int *result_cols, T *result_vals, - int nnz, T left_min, T right_min, - int m, float mix_weight = 0.5) { +template +__global__ void sset_intersection_kernel( + int *row_ind1, int *cols1, value_t *vals1, int nnz1, int *row_ind2, + int *cols2, value_t *vals2, int nnz2, int *result_ind, int *result_cols, + value_t *result_vals, int nnz, value_t left_min, value_t right_min, int m, + float mix_weight = 0.5) { int row = (blockIdx.x * TPB_X) + threadIdx.x; if (row < m) { @@ -131,14 +131,14 @@ __global__ void sset_intersection_kernel(int *row_ind1, int *cols1, T *vals1, for (int j = start_idx_res; j < stop_idx_res; j++) { int col = result_cols[j]; - T left_val = left_min; + value_t left_val = left_min; for (int k = start_idx1; k < stop_idx1; k++) { if (cols1[k] == col) { left_val = vals1[k]; } } - T right_val = right_min; + value_t right_val = right_min; for (int k = start_idx2; k < stop_idx2; k++) { if (cols2[k] == col) { right_val = vals2[k]; @@ -185,8 +185,8 @@ void general_simplicial_set_intersection( result->vals(), stream); //@todo: Write a wrapper function for this - MLCommon::Sparse::csr_to_coo(result_ind.data(), result->n_rows, - result->rows(), result->nnz, stream); + MLCommon::Sparse::csr_to_coo(result_ind.data(), result->n_rows, + result->rows(), result->nnz, stream); thrust::device_ptr d_ptr1 = thrust::device_pointer_cast(in1->vals()); T min1 = *(thrust::min_element(thrust::cuda::par.on(stream), d_ptr1, @@ -231,22 +231,29 @@ void perform_categorical_intersection(T *y, COO *rgraph_coo, CUDA_CHECK(cudaPeekAtLastError()); } -template -void perform_general_intersection(const raft::handle_t &handle, T *y, - COO *rgraph_coo, COO *final_coo, - UMAPParams *params, cudaStream_t stream) { +template +void perform_general_intersection(const raft::handle_t &handle, value_t *y, + COO *rgraph_coo, + COO *final_coo, UMAPParams *params, + cudaStream_t stream) { auto d_alloc = handle.get_device_allocator(); /** * Calculate kNN for Y */ int knn_dims = rgraph_coo->n_rows * params->target_n_neighbors; - MLCommon::device_buffer y_knn_indices(d_alloc, stream, knn_dims); - MLCommon::device_buffer y_knn_dists(d_alloc, stream, knn_dims); + MLCommon::device_buffer y_knn_indices(d_alloc, stream, knn_dims); + MLCommon::device_buffer y_knn_dists(d_alloc, stream, knn_dims); - kNNGraph::run(y, rgraph_coo->n_rows, y, rgraph_coo->n_rows, 1, - y_knn_indices.data(), y_knn_dists.data(), - params->target_n_neighbors, params, d_alloc, stream); + knn_graph knn_graph(rgraph_coo->n_rows, + params->target_n_neighbors); + knn_graph.knn_indices = y_knn_indices.data(); + knn_graph.knn_dists = y_knn_dists.data(); + + manifold_dense_inputs_t y_inputs(y, nullptr, rgraph_coo->n_rows, 1); + kNNGraph::run>( + handle, y_inputs, y_inputs, knn_graph, params->target_n_neighbors, params, + d_alloc, stream); CUDA_CHECK(cudaPeekAtLastError()); if (ML::Logger::get().shouldLogFor(CUML_LEVEL_DEBUG)) { @@ -265,11 +272,11 @@ void perform_general_intersection(const raft::handle_t &handle, T *y, /** * Compute fuzzy simplicial set */ - COO ygraph_coo(d_alloc, stream); + COO ygraph_coo(d_alloc, stream); - FuzzySimplSet::run(rgraph_coo->n_rows, y_knn_indices.data(), - y_knn_dists.data(), params->target_n_neighbors, - &ygraph_coo, params, d_alloc, stream); + FuzzySimplSet::run( + rgraph_coo->n_rows, y_knn_indices.data(), y_knn_dists.data(), + params->target_n_neighbors, &ygraph_coo, params, d_alloc, stream); CUDA_CHECK(cudaPeekAtLastError()); if (ML::Logger::get().shouldLogFor(CUML_LEVEL_DEBUG)) { @@ -290,26 +297,26 @@ void perform_general_intersection(const raft::handle_t &handle, T *y, CUDA_CHECK(cudaMemsetAsync(yrow_ind.data(), 0, ygraph_coo.n_rows * sizeof(int), stream)); - COO cygraph_coo(d_alloc, stream); - coo_remove_zeros(&ygraph_coo, &cygraph_coo, d_alloc, stream); + COO cygraph_coo(d_alloc, stream); + coo_remove_zeros(&ygraph_coo, &cygraph_coo, d_alloc, stream); MLCommon::Sparse::sorted_coo_to_csr(&cygraph_coo, yrow_ind.data(), d_alloc, stream); MLCommon::Sparse::sorted_coo_to_csr(rgraph_coo, xrow_ind.data(), d_alloc, stream); - COO result_coo(d_alloc, stream); - general_simplicial_set_intersection( + COO result_coo(d_alloc, stream); + general_simplicial_set_intersection( xrow_ind.data(), rgraph_coo, yrow_ind.data(), &cygraph_coo, &result_coo, params->target_weights, d_alloc, stream); /** * Remove zeros */ - COO out(d_alloc, stream); - coo_remove_zeros(&result_coo, &out, d_alloc, stream); + COO out(d_alloc, stream); + coo_remove_zeros(&result_coo, &out, d_alloc, stream); - reset_local_connectivity(&out, final_coo, d_alloc, stream); + reset_local_connectivity(&out, final_coo, d_alloc, stream); CUDA_CHECK(cudaPeekAtLastError()); } diff --git a/cpp/src/umap/umap.cu b/cpp/src/umap/umap.cu index 425ca93585..06ab72e8bb 100644 --- a/cpp/src/umap/umap.cu +++ b/cpp/src/umap/umap.cu @@ -15,40 +15,110 @@ */ #include +#include #include #include "runner.cuh" +#include + #include namespace ML { +namespace UMAP { -static const int TPB_X = 256; +static const int TPB_X = raft::WarpSize; +// Dense transform void transform(const raft::handle_t &handle, float *X, int n, int d, - int64_t *knn_indices, float *knn_dists, float *orig_X, - int orig_n, float *embedding, int embedding_n, + knn_indices_dense_t *knn_indices, float *knn_dists, + float *orig_X, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed) { - UMAPAlgo::_transform(handle, X, n, d, knn_indices, knn_dists, - orig_X, orig_n, embedding, embedding_n, - params, transformed); + if (knn_indices != nullptr && knn_dists != nullptr) { + manifold_precomputed_knn_inputs_t inputs( + knn_indices, knn_dists, X, nullptr, n, d, params->n_neighbors); + UMAPAlgo::_transform< + knn_indices_dense_t, float, + manifold_precomputed_knn_inputs_t, TPB_X>( + handle, inputs, inputs, embedding, embedding_n, params, transformed); + } else { + manifold_dense_inputs_t inputs(X, nullptr, n, d); + manifold_dense_inputs_t orig_inputs(orig_X, nullptr, orig_n, d); + UMAPAlgo::_transform, TPB_X>( + handle, inputs, orig_inputs, embedding, embedding_n, params, transformed); + } } + +// Sparse transform +void transform_sparse(const raft::handle_t &handle, int *indptr, int *indices, + float *data, size_t nnz, int n, int d, int *orig_x_indptr, + int *orig_x_indices, float *orig_x_data, size_t orig_nnz, + int orig_n, float *embedding, int embedding_n, + UMAPParams *params, float *transformed) { + manifold_sparse_inputs_t inputs( + indptr, indices, data, nullptr, nnz, n, d); + manifold_sparse_inputs_t orig_x_inputs( + orig_x_indptr, orig_x_indices, orig_x_data, nullptr, orig_nnz, orig_n, d); + + UMAPAlgo::_transform, TPB_X>( + handle, inputs, orig_x_inputs, embedding, embedding_n, params, transformed); +} + +// Dense fit void fit(const raft::handle_t &handle, float *X, // input matrix float *y, // labels - int n, int d, int64_t *knn_indices, float *knn_dists, + int n, int d, knn_indices_dense_t *knn_indices, float *knn_dists, UMAPParams *params, float *embeddings) { - UMAPAlgo::_fit(handle, X, y, n, d, knn_indices, knn_dists, - params, embeddings); + if (knn_indices != nullptr && knn_dists != nullptr) { + CUML_LOG_DEBUG("Calling UMAP::fit() with precomputed KNN"); + + manifold_precomputed_knn_inputs_t inputs( + knn_indices, knn_dists, X, y, n, d, params->n_neighbors); + if (y != nullptr) { + UMAPAlgo::_fit_supervised< + knn_indices_dense_t, float, + manifold_precomputed_knn_inputs_t, TPB_X>( + handle, inputs, params, embeddings); + } else { + UMAPAlgo::_fit< + knn_indices_dense_t, float, + manifold_precomputed_knn_inputs_t, TPB_X>( + handle, inputs, params, embeddings); + } + + } else { + manifold_dense_inputs_t inputs(X, y, n, d); + if (y != nullptr) { + UMAPAlgo::_fit_supervised, TPB_X>( + handle, inputs, params, embeddings); + } else { + UMAPAlgo::_fit, + TPB_X>(handle, inputs, params, embeddings); + } + } } -void fit(const raft::handle_t &handle, - float *X, // input matrix - int n, // rows - int d, // cols - int64_t *knn_indices, float *knn_dists, UMAPParams *params, - float *embeddings) { - UMAPAlgo::_fit(handle, X, n, d, knn_indices, knn_dists, params, - embeddings); +// Sparse fit +void fit_sparse(const raft::handle_t &handle, + int *indptr, // input matrix + int *indices, float *data, size_t nnz, float *y, + int n, // rows + int d, // cols + UMAPParams *params, float *embeddings) { + manifold_sparse_inputs_t inputs(indptr, indices, data, y, nnz, n, + d); + if (y != nullptr) { + UMAPAlgo::_fit_supervised, TPB_X>( + handle, inputs, params, embeddings); + } else { + UMAPAlgo::_fit, TPB_X>( + handle, inputs, params, embeddings); + } } void find_ab(const raft::handle_t &handle, UMAPParams *params) { @@ -56,76 +126,6 @@ void find_ab(const raft::handle_t &handle, UMAPParams *params) { auto d_alloc = handle.get_device_allocator(); UMAPAlgo::find_ab(params, d_alloc, stream); } -UMAP_API::UMAP_API(const raft::handle_t &handle, UMAPParams *params) - : params(params) { - this->handle = const_cast(&handle); - orig_X = nullptr; - orig_n = 0; -}; - -UMAP_API::~UMAP_API() {} - -/** - * Fits a UMAP model - * @param X - * pointer to an array in row-major format (note: this will be col-major soon) - * @param n - * n_samples in X - * @param d - * d_features in X - * @param knn_indices - * an array containing the n_neighbors nearest neighors indices for each sample - * @param knn_dists - * an array containing the n_neighbors nearest neighors distances for each sample - * @param embeddings - * an array to return the output embeddings of size (n_samples, n_components) - */ -void UMAP_API::fit(float *X, int n, int d, int64_t *knn_indices, - float *knn_dists, float *embeddings) { - this->orig_X = X; - this->orig_n = n; - UMAPAlgo::_fit(*this->handle, X, n, d, knn_indices, knn_dists, - get_params(), embeddings); -} - -void UMAP_API::fit(float *X, float *y, int n, int d, int64_t *knn_indices, - float *knn_dists, float *embeddings) { - this->orig_X = X; - this->orig_n = n; - - UMAPAlgo::_fit(*this->handle, X, y, n, d, knn_indices, - knn_dists, get_params(), embeddings); -} -/** - * Project a set of X vectors into the embedding space. - * @param X - * pointer to an array in row-major format (note: this will be col-major soon) - * @param n - * n_samples in X - * @param d - * d_features in X -* @param knn_indices - * an array containing the n_neighbors nearest neighors indices for each sample - * @param knn_dists - * an array containing the n_neighbors nearest neighors distances for each sample - * @param embedding - * pointer to embedding array of size (embedding_n, n_components) that has been created with fit() - * @param embedding_n - * n_samples in embedding array - * @param out - * pointer to array for storing output embeddings (n, n_components) - */ -void UMAP_API::transform(float *X, int n, int d, int64_t *knn_indices, - float *knn_dists, float *embedding, int embedding_n, - float *out) { - UMAPAlgo::_transform(*this->handle, X, n, d, knn_indices, - knn_dists, this->orig_X, this->orig_n, - embedding, embedding_n, get_params(), out); -} - -/** - * Get the UMAPParams instance - */ -UMAPParams *UMAP_API::get_params() { return this->params; } +} // namespace UMAP } // namespace ML diff --git a/cpp/src_prims/distance/distance.cuh b/cpp/src_prims/distance/distance.cuh index bc254ca1b6..0e9504ee73 100644 --- a/cpp/src_prims/distance/distance.cuh +++ b/cpp/src_prims/distance/distance.cuh @@ -302,7 +302,7 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, x, y, dist, m, n, k, workspace, stream, isRowMajor); break; default: - THROW("Unknown distance metric '%d'!", (int)metric); + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; } /** @} */ diff --git a/cpp/src_prims/selection/columnWiseSort.cuh b/cpp/src_prims/selection/columnWiseSort.cuh index b6f62820aa..9497f6c864 100644 --- a/cpp/src_prims/selection/columnWiseSort.cuh +++ b/cpp/src_prims/selection/columnWiseSort.cuh @@ -175,7 +175,7 @@ template void sortColumnsPerRow(const InType *in, OutType *out, int n_rows, int n_columns, bool &bAllocWorkspace, void *workspacePtr, size_t &workspaceSize, cudaStream_t stream, - InType *sortedKeys = nullptr) { + InType *sortedKeys = nullptr, bool ascending = true) { // assume non-square row-major matrices // current use-case: KNN, trustworthiness scores // output : either sorted indices or sorted indices and input values @@ -231,10 +231,17 @@ void sortColumnsPerRow(const InType *in, OutType *out, int n_rows, OutType *tmpValIn = nullptr; int *tmpOffsetBuffer = nullptr; - // first call is to get size of workspace - CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortPairs( - workspacePtr, workspaceSize, in, sortedKeys, tmpValIn, out, - totalElements, numSegments, tmpOffsetBuffer, tmpOffsetBuffer + 1)); + if (ascending) { + // first call is to get size of workspace + CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortPairs( + workspacePtr, workspaceSize, in, sortedKeys, tmpValIn, out, + totalElements, numSegments, tmpOffsetBuffer, tmpOffsetBuffer + 1)); + } else { + // first call is to get size of workspace + CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortPairsDescending( + workspacePtr, workspaceSize, in, sortedKeys, tmpValIn, out, + totalElements, numSegments, tmpOffsetBuffer, tmpOffsetBuffer + 1)); + } bAllocWorkspace = true; // more staging space for temp output of keys if (!sortedKeys) @@ -275,10 +282,17 @@ void sortColumnsPerRow(const InType *in, OutType *out, int n_rows, CUDA_CHECK( layoutSortOffset(dSegmentOffsets, n_columns, numSegments, stream)); - CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortPairs( - workspacePtr, workspaceSize, in, sortedKeys, dValuesIn, out, - totalElements, numSegments, dSegmentOffsets, dSegmentOffsets + 1, 0, - sizeof(InType) * 8, stream)); + if (ascending) { + CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortPairs( + workspacePtr, workspaceSize, in, sortedKeys, dValuesIn, out, + totalElements, numSegments, dSegmentOffsets, dSegmentOffsets + 1, 0, + sizeof(InType) * 8, stream)); + } else { + CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortPairsDescending( + workspacePtr, workspaceSize, in, sortedKeys, dValuesIn, out, + totalElements, numSegments, dSegmentOffsets, dSegmentOffsets + 1, 0, + sizeof(InType) * 8, stream)); + } } } else { // batched per row device wide sort @@ -322,9 +336,15 @@ void sortColumnsPerRow(const InType *in, OutType *out, int n_rows, OutType *rowOut = reinterpret_cast( (size_t)out + (i * sizeof(OutType) * (size_t)n_columns)); - CUDA_CHECK(cub::DeviceRadixSort::SortPairs(workspacePtr, workspaceSize, - rowIn, sortedKeys, dValuesIn, - rowOut, n_columns)); + if (ascending) { + CUDA_CHECK(cub::DeviceRadixSort::SortPairs( + workspacePtr, workspaceSize, rowIn, sortedKeys, dValuesIn, rowOut, + n_columns)); + } else { + CUDA_CHECK(cub::DeviceRadixSort::SortPairsDescending( + workspacePtr, workspaceSize, rowIn, sortedKeys, dValuesIn, rowOut, + n_columns)); + } if (userKeyOutputBuffer) sortedKeys = reinterpret_cast( diff --git a/cpp/src_prims/selection/knn.cuh b/cpp/src_prims/selection/knn.cuh index 1cad85c045..c7ea416911 100644 --- a/cpp/src_prims/selection/knn.cuh +++ b/cpp/src_prims/selection/knn.cuh @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -55,21 +56,24 @@ inline __device__ T get_lbls(const T *labels, const int64_t *knn_indices, } } -template -__global__ void knn_merge_parts_kernel(float *inK, int64_t *inV, float *outK, - int64_t *outV, size_t n_samples, - int n_parts, float initK, int64_t initV, - int k, int64_t *translations) { +template +__global__ void knn_merge_parts_kernel(value_t *inK, value_idx *inV, + value_t *outK, value_idx *outV, + size_t n_samples, int n_parts, + value_t initK, value_idx initV, int k, + value_idx *translations) { constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; - __shared__ float smemK[kNumWarps * warp_q]; - __shared__ int64_t smemV[kNumWarps * warp_q]; + __shared__ value_t smemK[kNumWarps * warp_q]; + __shared__ value_idx smemV[kNumWarps * warp_q]; /** * Uses shared memory */ - faiss::gpu::BlockSelect, - warp_q, thread_q, tpb> + faiss::gpu::BlockSelect, warp_q, thread_q, + tpb> heap(initK, initV, smemK, smemV, k); // Grid is exactly sized to rows available @@ -84,11 +88,11 @@ __global__ void knn_merge_parts_kernel(float *inK, int64_t *inV, float *outK, int col = i % k; - float *inKStart = inK + (row_idx + col); - int64_t *inVStart = inV + (row_idx + col); + value_t *inKStart = inK + (row_idx + col); + value_idx *inVStart = inV + (row_idx + col); int limit = faiss::gpu::utils::roundDown(total_k, faiss::gpu::kWarpSize); - int64_t translation = 0; + value_idx translation = 0; for (; i < limit; i += tpb) { translation = translations[part]; @@ -117,19 +121,20 @@ __global__ void knn_merge_parts_kernel(float *inK, int64_t *inV, float *outK, } } -template -inline void knn_merge_parts_impl(float *inK, int64_t *inV, float *outK, - int64_t *outV, size_t n_samples, int n_parts, +template +inline void knn_merge_parts_impl(value_t *inK, value_idx *inV, value_t *outK, + value_idx *outV, size_t n_samples, int n_parts, int k, cudaStream_t stream, - int64_t *translations) { + value_idx *translations) { auto grid = dim3(n_samples); constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; auto block = dim3(n_threads); - auto kInit = faiss::gpu::Limits::getMax(); + auto kInit = faiss::gpu::Limits::getMax(); auto vInit = -1; - knn_merge_parts_kernel + knn_merge_parts_kernel <<>>(inK, inV, outK, outV, n_samples, n_parts, kInit, vInit, k, translations); CUDA_CHECK(cudaPeekAtLastError()); @@ -149,30 +154,32 @@ inline void knn_merge_parts_impl(float *inK, int64_t *inV, float *outK, * @param stream CUDA stream to use * @param translations mapping of index offsets for each partition */ -inline void knn_merge_parts(float *inK, int64_t *inV, float *outK, - int64_t *outV, size_t n_samples, int n_parts, int k, - cudaStream_t stream, int64_t *translations) { +template +inline void knn_merge_parts(value_t *inK, value_idx *inV, value_t *outK, + value_idx *outV, size_t n_samples, int n_parts, + int k, cudaStream_t stream, + value_idx *translations) { if (k == 1) - knn_merge_parts_impl<1, 1>(inK, inV, outK, outV, n_samples, n_parts, k, - stream, translations); + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); else if (k <= 32) - knn_merge_parts_impl<32, 2>(inK, inV, outK, outV, n_samples, n_parts, k, - stream, translations); + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); else if (k <= 64) - knn_merge_parts_impl<64, 3>(inK, inV, outK, outV, n_samples, n_parts, k, - stream, translations); + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); else if (k <= 128) - knn_merge_parts_impl<128, 3>(inK, inV, outK, outV, n_samples, n_parts, k, - stream, translations); + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); else if (k <= 256) - knn_merge_parts_impl<256, 4>(inK, inV, outK, outV, n_samples, n_parts, k, - stream, translations); + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); else if (k <= 512) - knn_merge_parts_impl<512, 8>(inK, inV, outK, outV, n_samples, n_parts, k, - stream, translations); + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); else if (k <= 1024) - knn_merge_parts_impl<1024, 8>(inK, inV, outK, outV, n_samples, n_parts, k, - stream, translations); + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); } inline faiss::MetricType build_faiss_metric(ML::MetricType metric) { diff --git a/cpp/src_prims/sparse/csr.cuh b/cpp/src_prims/sparse/csr.cuh index 5751afdf9c..e43bbd850d 100644 --- a/cpp/src_prims/sparse/csr.cuh +++ b/cpp/src_prims/sparse/csr.cuh @@ -16,7 +16,11 @@ #pragma once +#include + +#include #include +#include #include #include