From 10f4327760a62f9823237ba69e977a113c6ddc6d Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 10 Dec 2020 21:48:28 -0600 Subject: [PATCH 01/22] cpp build and tests working --- cpp/include/cuml/manifold/tsne.h | 4 +- cpp/src/tsne/bh_kernels.cuh | 8 +- cpp/src/tsne/distances.cuh | 45 ++++++++--- cpp/src/tsne/tsne.cu | 128 ++++++++++++++++++++++--------- cpp/src_prims/sparse/coo.cuh | 12 +-- 5 files changed, 137 insertions(+), 60 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index e94d1dd4d7..4f734ceda5 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -71,8 +71,8 @@ namespace ML { * approach is available in their article t-SNE-CUDA: GPU-Accelerated t-SNE and * its Applications to Modern Data (https://arxiv.org/abs/1807.11824). */ -void TSNE_fit(const raft::handle_t &handle, const float *X, float *Y, - const int n, const int p, const int dim = 2, +void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, + int n, int p, const int dim = 2, int n_neighbors = 1023, const float theta = 0.5f, const float epssq = 0.0025, float perplexity = 50.0f, const int perplexity_max_iter = 100, diff --git a/cpp/src/tsne/bh_kernels.cuh b/cpp/src/tsne/bh_kernels.cuh index 3125b8735c..d69d7e912c 100644 --- a/cpp/src/tsne/bh_kernels.cuh +++ b/cpp/src/tsne/bh_kernels.cuh @@ -75,7 +75,7 @@ __global__ void Find_Normalization(float *restrict Z_norm, const float N) { /** * Figures the bounding boxes for every point in the embedding. */ -__global__ __launch_bounds__(THREADS1, FACTOR1) void BoundingBoxKernel( +__global__ __launch_bounds__(THREADS1, 2) void BoundingBoxKernel( int *restrict startd, int *restrict childd, float *restrict massd, float *restrict posxd, float *restrict posyd, float *restrict maxxd, float *restrict maxyd, float *restrict minxd, float *restrict minyd, @@ -174,7 +174,7 @@ __global__ __launch_bounds__(1024, 1) void ClearKernel1(int *restrict childd, * See: https://iss.oden.utexas.edu/Publications/Papers/burtscher11.pdf */ __global__ __launch_bounds__( - THREADS2, FACTOR2) void TreeBuildingKernel(/* int *restrict errd, */ + THREADS2, 2) void TreeBuildingKernel(/* int *restrict errd, */ int *restrict childd, const float *restrict posxd, const float *restrict posyd, @@ -508,7 +508,7 @@ __global__ __launch_bounds__(THREADS4, FACTOR4) void SortKernel( */ __global__ __launch_bounds__( THREADS5, - FACTOR5) void RepulsionKernel(/* int *restrict errd, */ + 1) void RepulsionKernel(/* int *restrict errd, */ const float theta, const float epssqd, // correction for zero distance @@ -664,7 +664,7 @@ __global__ void attractive_kernel_bh( /** * Apply gradient updates. */ -__global__ __launch_bounds__(THREADS6, FACTOR6) void IntegrationKernel( +__global__ __launch_bounds__(THREADS6, 1) void IntegrationKernel( const float eta, const float momentum, const float exaggeration, float *restrict Y1, float *restrict Y2, const float *restrict attract1, const float *restrict attract2, const float *restrict repel1, diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 922b361de9..c22413de14 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -19,31 +19,37 @@ #include #include #include +#include #include +#include namespace ML { namespace TSNE { /** * @brief Uses FAISS's KNN to find the top n_neighbors. This speeds up the attractive forces. - * @param[in] X: The GPU handle. - * @param[in] n: The number of rows in the data X. - * @param[in] p: The number of columns in the data X. + * @param[in] input: dense/sparse manifold input * @param[out] indices: The output indices from KNN. * @param[out] distances: The output sorted distances from KNN. * @param[in] n_neighbors: The number of nearest neighbors you want. * @param[in] d_alloc: device allocator * @param[in] stream: The GPU stream. */ -void get_distances(const float *X, const int n, const int p, long *indices, +template +void get_distances(const raft::handle_t &handle, tsne_input &input, knn_value_idx *indices, + knn_value_t *distances, const int n_neighbors, + cudaStream_t stream); + +// dense +template <> +void get_distances(const raft::handle_t &handle, manifold_dense_inputs_t &input, long *indices, float *distances, const int n_neighbors, - std::shared_ptr d_alloc, cudaStream_t stream) { // TODO: for TSNE transform first fit some points then transform with 1/(1+d^2) // #861 - std::vector input_vec = {const_cast(X)}; - std::vector sizes_vec = {n}; + std::vector input_vec = {input.X}; + std::vector sizes_vec = {input.n}; /** * std::vector &input, std::vector &sizes, @@ -53,9 +59,23 @@ void get_distances(const float *X, const int n, const int p, long *indices, cudaStream_t userStream, */ - MLCommon::Selection::brute_force_knn(input_vec, sizes_vec, p, - const_cast(X), n, indices, - distances, n_neighbors, d_alloc, stream); + MLCommon::Selection::brute_force_knn(input_vec, sizes_vec, input.d, + input.X, input.n, indices, + distances, n_neighbors, handle.get_device_allocator(), stream); +} + +// sparse +template <> +void get_distances(const raft::handle_t &handle, manifold_sparse_inputs_t &input, int *indices, + float *distances, int n_neighbors, + cudaStream_t stream) { + MLCommon::Sparse::Selection::brute_force_knn( + input.indptr, input.indices, input.data, input.nnz, input.n, + input.d, input.indptr, input.indices, input.data, input.nnz, + input.n, input.d, indices, distances, n_neighbors, + handle.get_cusparse_handle(), handle.get_device_allocator(), stream, + ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, + ML::MetricType::METRIC_L2); } /** @@ -91,7 +111,8 @@ void normalize_distances(const int n, float *distances, const int n_neighbors, * @param[in] stream: The GPU stream. * @param[in] handle: The GPU handle. */ -void symmetrize_perplexity(float *P, long *indices, const int n, const int k, +template +void symmetrize_perplexity(float *P, knn_value_idx *indices, const int n, const int k, const float exaggeration, MLCommon::Sparse::COO *COO_Matrix, cudaStream_t stream, const raft::handle_t &handle) { @@ -100,7 +121,7 @@ void symmetrize_perplexity(float *P, long *indices, const int n, const int k, raft::linalg::scalarMultiply(P, P, div, n * k, stream); // Symmetrize to form P + P.T - MLCommon::Sparse::from_knn_symmetrize_matrix( + MLCommon::Sparse::from_knn_symmetrize_matrix( indices, P, n, k, COO_Matrix, stream, handle.get_device_allocator()); } diff --git a/cpp/src/tsne/tsne.cu b/cpp/src/tsne/tsne.cu index d1e02269c6..072c0575f5 100644 --- a/cpp/src/tsne/tsne.cu +++ b/cpp/src/tsne/tsne.cu @@ -15,9 +15,11 @@ */ #include +#include #include #include #include +#include #include "distances.cuh" #include "exact_kernels.cuh" #include "utils.cuh" @@ -27,8 +29,8 @@ namespace ML { -void TSNE_fit(const raft::handle_t &handle, const float *X, float *Y, - const int n, const int p, const int dim, int n_neighbors, +template +void _fit(const raft::handle_t &handle, tsne_input &input, const int dim, int n_neighbors, const float theta, const float epssq, float perplexity, const int perplexity_max_iter, const float perplexity_tol, const float early_exaggeration, const int exaggeration_iter, @@ -38,8 +40,11 @@ void TSNE_fit(const raft::handle_t &handle, const float *X, float *Y, const float post_momentum, const long long random_state, int verbosity, const bool initialize_embeddings, bool barnes_hut) { - ASSERT(n > 0 && p > 0 && dim > 0 && n_neighbors > 0 && X != NULL && Y != NULL, - "Wrong input args"); + + auto n = input.n; + auto p = input.d; + auto *Y = input.y; + ML::Logger::get().setLevel(verbosity); if (dim > 2 and barnes_hut) { barnes_hut = false; @@ -73,41 +78,52 @@ void TSNE_fit(const raft::handle_t &handle, const float *X, float *Y, //--------------------------------------------------- // Get distances CUML_LOG_DEBUG("Getting distances."); - MLCommon::device_buffer distances(d_alloc, stream, n * n_neighbors); - MLCommon::device_buffer indices(d_alloc, stream, n * n_neighbors); - TSNE::get_distances(X, n, p, indices.data(), distances.data(), n_neighbors, - d_alloc, stream); - //--------------------------------------------------- - END_TIMER(DistancesTime); - START_TIMER; - //--------------------------------------------------- - // Normalize distances - CUML_LOG_DEBUG("Now normalizing distances so exp(D) doesn't explode."); - TSNE::normalize_distances(n, distances.data(), n_neighbors, stream); - //--------------------------------------------------- - END_TIMER(NormalizeTime); + MLCommon::Sparse::COO COO_Matrix(d_alloc, stream); - START_TIMER; - //--------------------------------------------------- - // Optimal perplexity - CUML_LOG_DEBUG("Searching for optimal perplexity via bisection search."); - MLCommon::device_buffer P(d_alloc, stream, n * n_neighbors); - TSNE::perplexity_search(distances.data(), P.data(), perplexity, - perplexity_max_iter, perplexity_tol, n, n_neighbors, - handle); - distances.release(stream); - //--------------------------------------------------- - END_TIMER(PerplexityTime); + // artificial scope for safe destruction of indices/distances buffers + { + rmm::device_uvector indices(0, stream); + rmm::device_uvector distances(0, stream); + + if (input.alloc_knn_graph()) { + indices = rmm::device_uvector(n * n_neighbors, stream); + distances = rmm::device_uvector(n * n_neighbors, stream); + } + + TSNE::get_distances(handle, input, indices.data(), distances.data(), n_neighbors, + stream); + //--------------------------------------------------- + END_TIMER(DistancesTime); + + START_TIMER; + //--------------------------------------------------- + // Normalize distances + CUML_LOG_DEBUG("Now normalizing distances so exp(D) doesn't explode."); + TSNE::normalize_distances(n, distances.data(), n_neighbors, stream); + //--------------------------------------------------- + END_TIMER(NormalizeTime); + + START_TIMER; + //--------------------------------------------------- + // Optimal perplexity + CUML_LOG_DEBUG("Searching for optimal perplexity via bisection search."); + MLCommon::device_buffer P(d_alloc, stream, n * n_neighbors); + TSNE::perplexity_search(distances.data(), P.data(), perplexity, + perplexity_max_iter, perplexity_tol, n, n_neighbors, + handle); + + //--------------------------------------------------- + END_TIMER(PerplexityTime); + + START_TIMER; + //--------------------------------------------------- + // Convert data to COO layout + TSNE::symmetrize_perplexity(P.data(), indices.data(), n, n_neighbors, + early_exaggeration, &COO_Matrix, stream, handle); + P.release(stream); + } - START_TIMER; - //--------------------------------------------------- - // Convert data to COO layout - MLCommon::Sparse::COO COO_Matrix(d_alloc, stream); - TSNE::symmetrize_perplexity(P.data(), indices.data(), n, n_neighbors, - early_exaggeration, &COO_Matrix, stream, handle); - P.release(stream); - indices.release(stream); const int NNZ = COO_Matrix.nnz; float *VAL = COO_Matrix.vals(); const int *COL = COO_Matrix.cols(); @@ -129,4 +145,44 @@ void TSNE_fit(const raft::handle_t &handle, const float *X, float *Y, } } +void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, + int n, int p, const int dim, int n_neighbors, + const float theta, const float epssq, float perplexity, + const int perplexity_max_iter, const float perplexity_tol, + const float early_exaggeration, const int exaggeration_iter, + const float min_gain, const float pre_learning_rate, + const float post_learning_rate, const int max_iter, + const float min_grad_norm, const float pre_momentum, + const float post_momentum, const long long random_state, + int verbosity, const bool initialize_embeddings, + bool barnes_hut) { + ASSERT(n > 0 && p > 0 && dim > 0 && n_neighbors > 0 && X != NULL && Y != NULL, + "Wrong input args"); + + manifold_dense_inputs_t input(X, Y, n, p); + _fit< manifold_dense_inputs_t, knn_indices_dense_t, float >(handle, input, dim, n_neighbors, theta, epssq, perplexity, perplexity_max_iter, perplexity_tol, + early_exaggeration, exaggeration_iter, min_gain, pre_learning_rate, post_learning_rate, + max_iter, min_grad_norm, pre_momentum, post_momentum, random_state, verbosity, initialize_embeddings, barnes_hut); +} + +void TSNE_fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, float *Y, + int nnz, int n, int p, const int dim, int n_neighbors, + const float theta, const float epssq, float perplexity, + const int perplexity_max_iter, const float perplexity_tol, + const float early_exaggeration, const int exaggeration_iter, + const float min_gain, const float pre_learning_rate, + const float post_learning_rate, const int max_iter, + const float min_grad_norm, const float pre_momentum, + const float post_momentum, const long long random_state, + int verbosity, const bool initialize_embeddings, + bool barnes_hut) { + ASSERT(n > 0 && p > 0 && dim > 0 && n_neighbors > 0 && indptr != NULL && indices != NULL && data != NULL && Y != NULL, + "Wrong input args"); + + manifold_sparse_inputs_t input(indptr, indices, data, Y, nnz, n, p); + _fit< manifold_sparse_inputs_t, knn_indices_sparse_t, float >(handle, input, dim, n_neighbors, theta, epssq, perplexity, perplexity_max_iter, perplexity_tol, + early_exaggeration, exaggeration_iter, min_gain, pre_learning_rate, post_learning_rate, + max_iter, min_grad_norm, pre_momentum, post_momentum, random_state, verbosity, initialize_embeddings, barnes_hut); +} + } // namespace ML diff --git a/cpp/src_prims/sparse/coo.cuh b/cpp/src_prims/sparse/coo.cuh index abe3664e9a..e7b98c5491 100644 --- a/cpp/src_prims/sparse/coo.cuh +++ b/cpp/src_prims/sparse/coo.cuh @@ -831,9 +831,9 @@ void coo_symmetrize(COO *in, COO *out, * @param row_sizes: Input empty row sum 1 array(n) * @param row_sizes2: Input empty row sum 2 array(n) for faster reduction */ -template +template __global__ static void symmetric_find_size(const math_t *restrict data, - const long *restrict indices, + const value_idx *restrict indices, const int n, const int k, int *restrict row_sizes, int *restrict row_sizes2) { @@ -879,10 +879,10 @@ __global__ static void reduce_find_size(const int n, const int k, * @param n: Number of rows * @param k: Number of n_neighbors */ -template +template __global__ static void symmetric_sum(int *restrict edges, const math_t *restrict data, - const long *restrict indices, + const value_idx *restrict indices, math_t *restrict VAL, int *restrict COL, int *restrict ROW, const int n, const int k) { @@ -921,8 +921,8 @@ __global__ static void symmetric_sum(int *restrict edges, * @param stream: Input cuda stream * @param d_alloc device allocator for temporary buffers */ -template -void from_knn_symmetrize_matrix(const long *restrict knn_indices, +template +void from_knn_symmetrize_matrix(const value_idx *restrict knn_indices, const math_t *restrict knn_dists, const int n, const int k, COO *out, cudaStream_t stream, From 47fd3032309048fcac80bf6802bad21803beda80 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 10 Dec 2020 22:55:30 -0600 Subject: [PATCH 02/22] cython bind --- cpp/include/cuml/manifold/tsne.h | 68 +++++++++++++++ python/cuml/manifold/t_sne.pyx | 139 +++++++++++++++++++++++-------- 2 files changed, 172 insertions(+), 35 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index 4f734ceda5..30c4caa38f 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -87,4 +87,72 @@ void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, int verbosity = CUML_LEVEL_INFO, const bool initialize_embeddings = true, bool barnes_hut = true); +/** + * @brief Dimensionality reduction via TSNE using either Barnes Hut O(NlogN) + * or brute force O(N^2). + * + * @param[in] handle The GPU handle. + * @param[in] indptr indptr of CSR dataset + * @param[in] indices indices of CSR dataset + * @param[in] data data of CSR dataset + * @param[out] Y The final embedding. + * @param[in] n Number of rows in data X. + * @param[in] p Number of columns in data X. + * @param[in] dim Number of output dimensions for embeddings Y. + * @param[in] n_neighbors Number of nearest neighbors used. + * @param[in] theta Float between 0 and 1. Tradeoff for speed (0) + * vs accuracy (1) for Barnes Hut only. + * @param[in] epssq A tiny jitter to promote numerical stability. + * @param[in] perplexity How many nearest neighbors are used during + * construction of Pij. + * @param[in] perplexity_max_iter Number of iterations used to construct Pij. + * @param[in] perplexity_tol The small tolerance used for Pij to ensure + * numerical stability. + * @param[in] early_exaggeration How much early pressure you want the clusters + * in TSNE to spread out more. + * @param[in] exaggeration_iter How many iterations you want the early + * pressure to run for. + * @param[in] min_gain Rounds up small gradient updates. + * @param[in] pre_learning_rate The learning rate during exaggeration phase. + * @param[in] post_learning_rate The learning rate after exaggeration phase. + * @param[in] max_iter The maximum number of iterations TSNE should + * run for. + * @param[in] min_grad_norm The smallest gradient norm TSNE should + * terminate on. + * @param[in] pre_momentum The momentum used during the exaggeration + * phase. + * @param[in] post_momentum The momentum used after the exaggeration + * phase. + * @param[in] random_state Set this to -1 for pure random intializations + * or >= 0 for reproducible outputs. + * @param[in] verbosity verbosity level for logging messages during + * execution + * @param[in] initialize_embeddings Whether to overwrite the current Y vector + * with random noise. + * @param[in] barnes_hut Whether to use the fast Barnes Hut or use the + * slower exact version. + * + * The CUDA implementation is derived from the excellent CannyLabs open source + * implementation here: https://github.com/CannyLab/tsne-cuda/. The CannyLabs + * code is licensed according to the conditions in + * cuml/cpp/src/tsne/cannylabs_tsne_license.txt. A full description of their + * approach is available in their article t-SNE-CUDA: GPU-Accelerated t-SNE and + * its Applications to Modern Data (https://arxiv.org/abs/1807.11824). + */ +void TSNE_fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, float *Y, + int n, int p, const int dim = 2, + int n_neighbors = 1023, const float theta = 0.5f, + const float epssq = 0.0025, float perplexity = 50.0f, + const int perplexity_max_iter = 100, + const float perplexity_tol = 1e-5, + const float early_exaggeration = 12.0f, + const int exaggeration_iter = 250, const float min_gain = 0.01f, + const float pre_learning_rate = 200.0f, + const float post_learning_rate = 500.0f, + const int max_iter = 1000, const float min_grad_norm = 1e-7, + const float pre_momentum = 0.5, const float post_momentum = 0.8, + const long long random_state = -1, + int verbosity = CUML_LEVEL_INFO, + const bool initialize_embeddings = true, bool barnes_hut = true); + } // namespace ML diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index ee21c829f9..b6f8efb398 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -25,6 +25,7 @@ import numpy as np import inspect import pandas as pd import warnings +import cupy import cuml.internals from cuml.common.array_descriptor import CumlArrayDescriptor @@ -33,6 +34,8 @@ from cuml.raft.common.handle cimport handle_t import cuml.common.logger as logger from cuml.common.array import CumlArray +from cuml.common.array_sparse import SparseCumlArray +from cuml.common.sparse_utils import is_sparse from cuml.common.doc_utils import generate_docstring from cuml.common import input_to_cuml_array import rmm @@ -45,11 +48,40 @@ cimport cuml.common.cuda cdef extern from "cuml/manifold/tsne.h" namespace "ML" nogil: cdef void TSNE_fit( + handle_t &handle, + float *X, + float *Y, + int n, + int p, + const int dim, + int n_neighbors, + const float theta, + const float epssq, + float perplexity, + const int perplexity_max_iter, + const float perplexity_tol, + const float early_exaggeration, + const int exaggeration_iter, + const float min_gain, + const float pre_learning_rate, + const float post_learning_rate, + const int max_iter, + const float min_grad_norm, + const float pre_momentum, + const float post_momentum, + const long long random_state, + int verbosity, + const bool initialize_embeddings, + bool barnes_hut) except + + + cdef void TSNE_fit_sparse( const handle_t &handle, - const float *X, + int *indptr, + int *indices, + float *data, float *Y, - const int n, - const int p, + int n, + int p, const int dim, int n_neighbors, const float theta, @@ -339,12 +371,20 @@ class TSNE(Base): if len(X.shape) != 2: raise ValueError("data should be two dimensional") - cdef uintptr_t X_ptr - X_m, n, p, dtype = \ - input_to_cuml_array(X, order='C', check_dtype=np.float32, - convert_to_dtype=(np.float32 if convert_dtype - else None)) - X_ptr = X_m.ptr + if is_sparse(X): + + self.X_m = SparseCumlArray(X, convert_to_dtype=cupy.float32, + convert_format=False) + n, p = self.X_m.shape + self.sparse_fit = True + + # Handle dense inputs + else: + self.X_m, n, p, _ = \ + input_to_cuml_array(X, order='C', check_dtype=np.float32, + convert_to_dtype=(np.float32 + if convert_dtype + else None)) if n <= 1: raise ValueError("There needs to be more than 1 sample to build " @@ -390,32 +430,61 @@ class TSNE(Base): cdef long long seed = -1 if self.random_state is not None: seed = self.random_state - - TSNE_fit(handle_[0], - X_ptr, - embed_ptr, - n, - p, - self.n_components, - self.n_neighbors, - self.angle, - self.epssq, - self.perplexity, - self.perplexity_max_iter, - self.perplexity_tol, - self.early_exaggeration, - self.exaggeration_iter, - self.min_gain, - self.pre_learning_rate, - self.post_learning_rate, - self.n_iter, - self.min_grad_norm, - self.pre_momentum, - self.post_momentum, - seed, - self.verbose, - True, - (self.method == 'barnes_hut')) + + if self.sparse_fit: + TSNE_fit_sparse(handle_[0], + self.X_m.indptr.ptr, + self.X_m.indices.ptr, + self.X_m.data.ptr, + embed_ptr, + n, + p, + self.n_components, + self.n_neighbors, + self.angle, + self.epssq, + self.perplexity, + self.perplexity_max_iter, + self.perplexity_tol, + self.early_exaggeration, + self.exaggeration_iter, + self.min_gain, + self.pre_learning_rate, + self.post_learning_rate, + self.n_iter, + self.min_grad_norm, + self.pre_momentum, + self.post_momentum, + seed, + self.verbose, + True, + (self.method == 'barnes_hut')) + else: + TSNE_fit(handle_[0], + self.X_m.ptr, + embed_ptr, + n, + p, + self.n_components, + self.n_neighbors, + self.angle, + self.epssq, + self.perplexity, + self.perplexity_max_iter, + self.perplexity_tol, + self.early_exaggeration, + self.exaggeration_iter, + self.min_gain, + self.pre_learning_rate, + self.post_learning_rate, + self.n_iter, + self.min_grad_norm, + self.pre_momentum, + self.post_momentum, + seed, + self.verbose, + True, + (self.method == 'barnes_hut')) # Clean up memory self.embedding_ = Y From 194c4890d04d10be18de54abbb4fb2568822afb8 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 10 Dec 2020 23:06:17 -0600 Subject: [PATCH 03/22] cython working --- python/cuml/manifold/t_sne.pyx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index b6f8efb398..96d9263290 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -357,6 +357,8 @@ class TSNE(Base): self.pre_learning_rate = learning_rate self.post_learning_rate = learning_rate * 2 + self.sparse_fit = False + @generate_docstring(convert_dtype_cast='np.float32') def fit(self, X, convert_dtype=True) -> "TSNE": """ From d21fa49783886b4b6cd809e576d493b665d583e8 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 10 Dec 2020 23:07:02 -0600 Subject: [PATCH 04/22] correcting libcuml++ API --- cpp/include/cuml/manifold/tsne.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index 30c4caa38f..ac4723c411 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -139,7 +139,7 @@ void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, * approach is available in their article t-SNE-CUDA: GPU-Accelerated t-SNE and * its Applications to Modern Data (https://arxiv.org/abs/1807.11824). */ -void TSNE_fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, float *Y, +void TSNE_fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, float *Y, int nnz, int n, int p, const int dim = 2, int n_neighbors = 1023, const float theta = 0.5f, const float epssq = 0.0025, float perplexity = 50.0f, From 2d44291c2637749eddc2d09356ef94cd1c0a7311 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 10 Dec 2020 23:07:47 -0600 Subject: [PATCH 05/22] style check --- cpp/include/cuml/manifold/tsne.h | 37 +++++----- cpp/src/tsne/bh_kernels.cuh | 40 +++++------ cpp/src/tsne/distances.cuh | 40 +++++------ cpp/src/tsne/tsne.cu | 112 +++++++++++++++++-------------- cpp/src_prims/sparse/coo.cuh | 3 +- 5 files changed, 118 insertions(+), 114 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index ac4723c411..a708948485 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -71,11 +71,10 @@ namespace ML { * approach is available in their article t-SNE-CUDA: GPU-Accelerated t-SNE and * its Applications to Modern Data (https://arxiv.org/abs/1807.11824). */ -void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, - int n, int p, const int dim = 2, - int n_neighbors = 1023, const float theta = 0.5f, - const float epssq = 0.0025, float perplexity = 50.0f, - const int perplexity_max_iter = 100, +void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, int n, int p, + const int dim = 2, int n_neighbors = 1023, + const float theta = 0.5f, const float epssq = 0.0025, + float perplexity = 50.0f, const int perplexity_max_iter = 100, const float perplexity_tol = 1e-5, const float early_exaggeration = 12.0f, const int exaggeration_iter = 250, const float min_gain = 0.01f, @@ -139,20 +138,18 @@ void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, * approach is available in their article t-SNE-CUDA: GPU-Accelerated t-SNE and * its Applications to Modern Data (https://arxiv.org/abs/1807.11824). */ -void TSNE_fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, float *Y, int nnz, - int n, int p, const int dim = 2, - int n_neighbors = 1023, const float theta = 0.5f, - const float epssq = 0.0025, float perplexity = 50.0f, - const int perplexity_max_iter = 100, - const float perplexity_tol = 1e-5, - const float early_exaggeration = 12.0f, - const int exaggeration_iter = 250, const float min_gain = 0.01f, - const float pre_learning_rate = 200.0f, - const float post_learning_rate = 500.0f, - const int max_iter = 1000, const float min_grad_norm = 1e-7, - const float pre_momentum = 0.5, const float post_momentum = 0.8, - const long long random_state = -1, - int verbosity = CUML_LEVEL_INFO, - const bool initialize_embeddings = true, bool barnes_hut = true); +void TSNE_fit_sparse( + const raft::handle_t &handle, int *indptr, int *indices, float *data, + float *Y, int nnz, int n, int p, const int dim = 2, int n_neighbors = 1023, + const float theta = 0.5f, const float epssq = 0.0025, + float perplexity = 50.0f, const int perplexity_max_iter = 100, + const float perplexity_tol = 1e-5, const float early_exaggeration = 12.0f, + const int exaggeration_iter = 250, const float min_gain = 0.01f, + const float pre_learning_rate = 200.0f, + const float post_learning_rate = 500.0f, const int max_iter = 1000, + const float min_grad_norm = 1e-7, const float pre_momentum = 0.5, + const float post_momentum = 0.8, const long long random_state = -1, + int verbosity = CUML_LEVEL_INFO, const bool initialize_embeddings = true, + bool barnes_hut = true); } // namespace ML diff --git a/cpp/src/tsne/bh_kernels.cuh b/cpp/src/tsne/bh_kernels.cuh index d69d7e912c..50271edc92 100644 --- a/cpp/src/tsne/bh_kernels.cuh +++ b/cpp/src/tsne/bh_kernels.cuh @@ -174,14 +174,13 @@ __global__ __launch_bounds__(1024, 1) void ClearKernel1(int *restrict childd, * See: https://iss.oden.utexas.edu/Publications/Papers/burtscher11.pdf */ __global__ __launch_bounds__( - THREADS2, 2) void TreeBuildingKernel(/* int *restrict errd, */ - int *restrict childd, - const float *restrict posxd, - const float *restrict posyd, - const int NNODES, const int N, - int *restrict maxdepthd, - int *restrict bottomd, - const float *restrict radiusd) { + THREADS2, + 2) void TreeBuildingKernel(/* int *restrict errd, */ + int *restrict childd, const float *restrict posxd, + const float *restrict posyd, const int NNODES, + const int N, int *restrict maxdepthd, + int *restrict bottomd, + const float *restrict radiusd) { int j, depth; float x, y, r; float px, py; @@ -509,20 +508,17 @@ __global__ __launch_bounds__(THREADS4, FACTOR4) void SortKernel( __global__ __launch_bounds__( THREADS5, 1) void RepulsionKernel(/* int *restrict errd, */ - const float theta, - const float - epssqd, // correction for zero distance - const int *restrict sortd, - const int *restrict childd, - const float *restrict massd, - const float *restrict posxd, - const float *restrict posyd, - float *restrict velxd, float *restrict velyd, - float *restrict Z_norm, - const float theta_squared, const int NNODES, - const int FOUR_NNODES, const int N, - const float *restrict radiusd_squared, - const int *restrict maxdepthd) { + const float theta, + const float epssqd, // correction for zero distance + const int *restrict sortd, const int *restrict childd, + const float *restrict massd, + const float *restrict posxd, + const float *restrict posyd, float *restrict velxd, + float *restrict velyd, float *restrict Z_norm, + const float theta_squared, const int NNODES, + const int FOUR_NNODES, const int N, + const float *restrict radiusd_squared, + const int *restrict maxdepthd) { // Return if max depth is too deep // Not possible since I limited it to 32 // if (maxdepthd[0] > 32) diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index c22413de14..b7d7ca39eb 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -17,11 +17,11 @@ #pragma once #include +#include #include #include -#include #include -#include +#include namespace ML { namespace TSNE { @@ -36,13 +36,14 @@ namespace TSNE { * @param[in] stream: The GPU stream. */ template -void get_distances(const raft::handle_t &handle, tsne_input &input, knn_value_idx *indices, - knn_value_t *distances, const int n_neighbors, - cudaStream_t stream); +void get_distances(const raft::handle_t &handle, tsne_input &input, + knn_value_idx *indices, knn_value_t *distances, + const int n_neighbors, cudaStream_t stream); // dense template <> -void get_distances(const raft::handle_t &handle, manifold_dense_inputs_t &input, long *indices, +void get_distances(const raft::handle_t &handle, + manifold_dense_inputs_t &input, long *indices, float *distances, const int n_neighbors, cudaStream_t stream) { // TODO: for TSNE transform first fit some points then transform with 1/(1+d^2) @@ -59,23 +60,22 @@ void get_distances(const raft::handle_t &handle, manifold_dense_inputs_t cudaStream_t userStream, */ - MLCommon::Selection::brute_force_knn(input_vec, sizes_vec, input.d, - input.X, input.n, indices, - distances, n_neighbors, handle.get_device_allocator(), stream); + MLCommon::Selection::brute_force_knn(input_vec, sizes_vec, input.d, input.X, + input.n, indices, distances, n_neighbors, + handle.get_device_allocator(), stream); } // sparse template <> -void get_distances(const raft::handle_t &handle, manifold_sparse_inputs_t &input, int *indices, - float *distances, int n_neighbors, - cudaStream_t stream) { +void get_distances(const raft::handle_t &handle, + manifold_sparse_inputs_t &input, int *indices, + float *distances, int n_neighbors, cudaStream_t stream) { MLCommon::Sparse::Selection::brute_force_knn( - input.indptr, input.indices, input.data, input.nnz, input.n, - input.d, input.indptr, input.indices, input.data, input.nnz, - input.n, input.d, indices, distances, n_neighbors, - handle.get_cusparse_handle(), handle.get_device_allocator(), stream, - ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, - ML::MetricType::METRIC_L2); + input.indptr, input.indices, input.data, input.nnz, input.n, input.d, + input.indptr, input.indices, input.data, input.nnz, input.n, input.d, + indices, distances, n_neighbors, handle.get_cusparse_handle(), + handle.get_device_allocator(), stream, ML::Sparse::DEFAULT_BATCH_SIZE, + ML::Sparse::DEFAULT_BATCH_SIZE, ML::MetricType::METRIC_L2); } /** @@ -112,8 +112,8 @@ void normalize_distances(const int n, float *distances, const int n_neighbors, * @param[in] handle: The GPU handle. */ template -void symmetrize_perplexity(float *P, knn_value_idx *indices, const int n, const int k, - const float exaggeration, +void symmetrize_perplexity(float *P, knn_value_idx *indices, const int n, + const int k, const float exaggeration, MLCommon::Sparse::COO *COO_Matrix, cudaStream_t stream, const raft::handle_t &handle) { // Perform (P + P.T) / P_sum * early_exaggeration diff --git a/cpp/src/tsne/tsne.cu b/cpp/src/tsne/tsne.cu index 072c0575f5..6cb3881150 100644 --- a/cpp/src/tsne/tsne.cu +++ b/cpp/src/tsne/tsne.cu @@ -15,10 +15,10 @@ */ #include -#include #include #include #include +#include #include #include "distances.cuh" #include "exact_kernels.cuh" @@ -30,17 +30,16 @@ namespace ML { template -void _fit(const raft::handle_t &handle, tsne_input &input, const int dim, int n_neighbors, - const float theta, const float epssq, float perplexity, - const int perplexity_max_iter, const float perplexity_tol, - const float early_exaggeration, const int exaggeration_iter, - const float min_gain, const float pre_learning_rate, - const float post_learning_rate, const int max_iter, - const float min_grad_norm, const float pre_momentum, - const float post_momentum, const long long random_state, - int verbosity, const bool initialize_embeddings, - bool barnes_hut) { - +void _fit(const raft::handle_t &handle, tsne_input &input, const int dim, + int n_neighbors, const float theta, const float epssq, + float perplexity, const int perplexity_max_iter, + const float perplexity_tol, const float early_exaggeration, + const int exaggeration_iter, const float min_gain, + const float pre_learning_rate, const float post_learning_rate, + const int max_iter, const float min_grad_norm, + const float pre_momentum, const float post_momentum, + const long long random_state, int verbosity, + const bool initialize_embeddings, bool barnes_hut) { auto n = input.n; auto p = input.d; auto *Y = input.y; @@ -91,8 +90,8 @@ void _fit(const raft::handle_t &handle, tsne_input &input, const int dim, int n_ distances = rmm::device_uvector(n * n_neighbors, stream); } - TSNE::get_distances(handle, input, indices.data(), distances.data(), n_neighbors, - stream); + TSNE::get_distances(handle, input, indices.data(), distances.data(), + n_neighbors, stream); //--------------------------------------------------- END_TIMER(DistancesTime); @@ -120,7 +119,8 @@ void _fit(const raft::handle_t &handle, tsne_input &input, const int dim, int n_ //--------------------------------------------------- // Convert data to COO layout TSNE::symmetrize_perplexity(P.data(), indices.data(), n, n_neighbors, - early_exaggeration, &COO_Matrix, stream, handle); + early_exaggeration, &COO_Matrix, stream, + handle); P.release(stream); } @@ -145,44 +145,54 @@ void _fit(const raft::handle_t &handle, tsne_input &input, const int dim, int n_ } } -void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, - int n, int p, const int dim, int n_neighbors, - const float theta, const float epssq, float perplexity, - const int perplexity_max_iter, const float perplexity_tol, - const float early_exaggeration, const int exaggeration_iter, - const float min_gain, const float pre_learning_rate, - const float post_learning_rate, const int max_iter, - const float min_grad_norm, const float pre_momentum, - const float post_momentum, const long long random_state, - int verbosity, const bool initialize_embeddings, - bool barnes_hut) { - ASSERT(n > 0 && p > 0 && dim > 0 && n_neighbors > 0 && X != NULL && Y != NULL, - "Wrong input args"); - - manifold_dense_inputs_t input(X, Y, n, p); - _fit< manifold_dense_inputs_t, knn_indices_dense_t, float >(handle, input, dim, n_neighbors, theta, epssq, perplexity, perplexity_max_iter, perplexity_tol, - early_exaggeration, exaggeration_iter, min_gain, pre_learning_rate, post_learning_rate, - max_iter, min_grad_norm, pre_momentum, post_momentum, random_state, verbosity, initialize_embeddings, barnes_hut); +void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, int n, int p, + const int dim, int n_neighbors, const float theta, + const float epssq, float perplexity, + const int perplexity_max_iter, const float perplexity_tol, + const float early_exaggeration, const int exaggeration_iter, + const float min_gain, const float pre_learning_rate, + const float post_learning_rate, const int max_iter, + const float min_grad_norm, const float pre_momentum, + const float post_momentum, const long long random_state, + int verbosity, const bool initialize_embeddings, + bool barnes_hut) { + ASSERT(n > 0 && p > 0 && dim > 0 && n_neighbors > 0 && X != NULL && Y != NULL, + "Wrong input args"); + + manifold_dense_inputs_t input(X, Y, n, p); + _fit, knn_indices_dense_t, float>( + handle, input, dim, n_neighbors, theta, epssq, perplexity, + perplexity_max_iter, perplexity_tol, early_exaggeration, exaggeration_iter, + min_gain, pre_learning_rate, post_learning_rate, max_iter, min_grad_norm, + pre_momentum, post_momentum, random_state, verbosity, initialize_embeddings, + barnes_hut); } -void TSNE_fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, float *Y, - int nnz, int n, int p, const int dim, int n_neighbors, - const float theta, const float epssq, float perplexity, - const int perplexity_max_iter, const float perplexity_tol, - const float early_exaggeration, const int exaggeration_iter, - const float min_gain, const float pre_learning_rate, - const float post_learning_rate, const int max_iter, - const float min_grad_norm, const float pre_momentum, - const float post_momentum, const long long random_state, - int verbosity, const bool initialize_embeddings, - bool barnes_hut) { - ASSERT(n > 0 && p > 0 && dim > 0 && n_neighbors > 0 && indptr != NULL && indices != NULL && data != NULL && Y != NULL, - "Wrong input args"); - - manifold_sparse_inputs_t input(indptr, indices, data, Y, nnz, n, p); - _fit< manifold_sparse_inputs_t, knn_indices_sparse_t, float >(handle, input, dim, n_neighbors, theta, epssq, perplexity, perplexity_max_iter, perplexity_tol, - early_exaggeration, exaggeration_iter, min_gain, pre_learning_rate, post_learning_rate, - max_iter, min_grad_norm, pre_momentum, post_momentum, random_state, verbosity, initialize_embeddings, barnes_hut); +void TSNE_fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, + float *data, float *Y, int nnz, int n, int p, + const int dim, int n_neighbors, const float theta, + const float epssq, float perplexity, + const int perplexity_max_iter, const float perplexity_tol, + const float early_exaggeration, + const int exaggeration_iter, const float min_gain, + const float pre_learning_rate, + const float post_learning_rate, const int max_iter, + const float min_grad_norm, const float pre_momentum, + const float post_momentum, const long long random_state, + int verbosity, const bool initialize_embeddings, + bool barnes_hut) { + ASSERT(n > 0 && p > 0 && dim > 0 && n_neighbors > 0 && indptr != NULL && + indices != NULL && data != NULL && Y != NULL, + "Wrong input args"); + + manifold_sparse_inputs_t input(indptr, indices, data, Y, nnz, n, + p); + _fit, knn_indices_sparse_t, float>( + handle, input, dim, n_neighbors, theta, epssq, perplexity, + perplexity_max_iter, perplexity_tol, early_exaggeration, exaggeration_iter, + min_gain, pre_learning_rate, post_learning_rate, max_iter, min_grad_norm, + pre_momentum, post_momentum, random_state, verbosity, initialize_embeddings, + barnes_hut); } } // namespace ML diff --git a/cpp/src_prims/sparse/coo.cuh b/cpp/src_prims/sparse/coo.cuh index e7b98c5491..d90bafeaaf 100644 --- a/cpp/src_prims/sparse/coo.cuh +++ b/cpp/src_prims/sparse/coo.cuh @@ -921,7 +921,8 @@ __global__ static void symmetric_sum(int *restrict edges, * @param stream: Input cuda stream * @param d_alloc device allocator for temporary buffers */ -template +template void from_knn_symmetrize_matrix(const value_idx *restrict knn_indices, const math_t *restrict knn_dists, const int n, const int k, COO *out, From 82273660eec90c948a2d553e33ecbade43e93f01 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 11 Dec 2020 00:37:21 -0600 Subject: [PATCH 06/22] sparse test --- python/cuml/manifold/t_sne.pyx | 2 ++ python/cuml/test/test_tsne.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 96d9263290..7448f22bb5 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -80,6 +80,7 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML" nogil: int *indices, float *data, float *Y, + int nnz, int n, int p, const int dim, @@ -439,6 +440,7 @@ class TSNE(Base): self.X_m.indices.ptr, self.X_m.data.ptr, embed_ptr, + self.X_m.nnz, n, p, self.n_components, diff --git a/python/cuml/test/test_tsne.py b/python/cuml/test/test_tsne.py index 76192259b2..03e44298bc 100644 --- a/python/cuml/test/test_tsne.py +++ b/python/cuml/test/test_tsne.py @@ -15,6 +15,8 @@ import numpy as np import pytest +import scipy +import cupyx from cuml.manifold import TSNE from cuml.test.utils import stress_param @@ -112,3 +114,35 @@ def test_tsne_large(nrows, ncols): def test_components_exception(): with pytest.raises(ValueError): TSNE(n_components=3) + +@pytest.mark.parametrize('input_type', ['cupy', 'scipy']) +def test_umap_transform_on_digits_sparse(input_type): + + digits = datasets.load_digits() + + digits_selection = np.random.RandomState(42).choice( + [True, False], 1797, replace=True, p=[0.60, 0.40]) + + if input_type == 'cupy': + sp_prefix = cupyx.scipy.sparse + else: + sp_prefix = scipy.sparse + + data = sp_prefix.csr_matrix( + scipy.sparse.csr_matrix(digits.data[digits_selection])) + + fitter = TSNE(2, n_neighbors=15, + random_state=1, + learning_rate=500, + angle=0.8) + + new_data = sp_prefix.csr_matrix( + scipy.sparse.csr_matrix(digits.data[~digits_selection])) + + embedding = fitter.fit_transform(new_data, convert_dtype=True) + + if input_type == 'cupy': + embedding = embedding.get() + + trust = trustworthiness(digits.data[~digits_selection], embedding, 15) + assert trust >= 0.85 From 758cc13e0f0d6f07e063d29167ba6855858db47a Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 11 Dec 2020 00:40:36 -0600 Subject: [PATCH 07/22] python style check --- python/cuml/test/test_tsne.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/cuml/test/test_tsne.py b/python/cuml/test/test_tsne.py index 03e44298bc..fc529b77d6 100644 --- a/python/cuml/test/test_tsne.py +++ b/python/cuml/test/test_tsne.py @@ -115,6 +115,7 @@ def test_components_exception(): with pytest.raises(ValueError): TSNE(n_components=3) + @pytest.mark.parametrize('input_type', ['cupy', 'scipy']) def test_umap_transform_on_digits_sparse(input_type): @@ -128,9 +129,6 @@ def test_umap_transform_on_digits_sparse(input_type): else: sp_prefix = scipy.sparse - data = sp_prefix.csr_matrix( - scipy.sparse.csr_matrix(digits.data[digits_selection])) - fitter = TSNE(2, n_neighbors=15, random_state=1, learning_rate=500, From a666f3fc693925f64374acac8a23096464c0fae9 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 11 Dec 2020 00:47:53 -0600 Subject: [PATCH 08/22] more python style check --- python/cuml/manifold/t_sne.pyx | 106 ++++++++++++++++----------------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 7448f22bb5..ce3966fddf 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -73,7 +73,7 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML" nogil: int verbosity, const bool initialize_embeddings, bool barnes_hut) except + - + cdef void TSNE_fit_sparse( const handle_t &handle, int *indptr, @@ -433,62 +433,62 @@ class TSNE(Base): cdef long long seed = -1 if self.random_state is not None: seed = self.random_state - + if self.sparse_fit: TSNE_fit_sparse(handle_[0], - self.X_m.indptr.ptr, - self.X_m.indices.ptr, - self.X_m.data.ptr, - embed_ptr, - self.X_m.nnz, - n, - p, - self.n_components, - self.n_neighbors, - self.angle, - self.epssq, - self.perplexity, - self.perplexity_max_iter, - self.perplexity_tol, - self.early_exaggeration, - self.exaggeration_iter, - self.min_gain, - self.pre_learning_rate, - self.post_learning_rate, - self.n_iter, - self.min_grad_norm, - self.pre_momentum, - self.post_momentum, - seed, - self.verbose, - True, - (self.method == 'barnes_hut')) + self.X_m.indptr.ptr, + self.X_m.indices.ptr, + self.X_m.data.ptr, + embed_ptr, + self.X_m.nnz, + n, + p, + self.n_components, + self.n_neighbors, + self.angle, + self.epssq, + self.perplexity, + self.perplexity_max_iter, + self.perplexity_tol, + self.early_exaggeration, + self.exaggeration_iter, + self.min_gain, + self.pre_learning_rate, + self.post_learning_rate, + self.n_iter, + self.min_grad_norm, + self.pre_momentum, + self.post_momentum, + seed, + self.verbose, + True, + (self.method == 'barnes_hut')) else: TSNE_fit(handle_[0], - self.X_m.ptr, - embed_ptr, - n, - p, - self.n_components, - self.n_neighbors, - self.angle, - self.epssq, - self.perplexity, - self.perplexity_max_iter, - self.perplexity_tol, - self.early_exaggeration, - self.exaggeration_iter, - self.min_gain, - self.pre_learning_rate, - self.post_learning_rate, - self.n_iter, - self.min_grad_norm, - self.pre_momentum, - self.post_momentum, - seed, - self.verbose, - True, - (self.method == 'barnes_hut')) + self.X_m.ptr, + embed_ptr, + n, + p, + self.n_components, + self.n_neighbors, + self.angle, + self.epssq, + self.perplexity, + self.perplexity_max_iter, + self.perplexity_tol, + self.early_exaggeration, + self.exaggeration_iter, + self.min_gain, + self.pre_learning_rate, + self.post_learning_rate, + self.n_iter, + self.min_grad_norm, + self.pre_momentum, + self.post_momentum, + seed, + self.verbose, + True, + (self.method == 'barnes_hut')) # Clean up memory self.embedding_ = Y From 48f88f5e39517bf492cbc243a56b67a6c945da9a Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 11 Dec 2020 00:53:44 -0600 Subject: [PATCH 09/22] more style check... --- python/cuml/manifold/t_sne.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index ce3966fddf..6b11cbbb2e 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -440,7 +440,7 @@ class TSNE(Base): self.X_m.indices.ptr, self.X_m.data.ptr, embed_ptr, - self.X_m.nnz, + self.X_m.nnz, n, p, self.n_components, From af4aa64c03fa103099f3a5017eb6e60b42220c0e Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 11 Dec 2020 13:22:07 -0600 Subject: [PATCH 10/22] adding class runner --- cpp/src/tsne/tsne.cu | 121 ++------------------- cpp/src/tsne/tsne_runner.cuh | 202 +++++++++++++++++++++++++++++++++++ 2 files changed, 210 insertions(+), 113 deletions(-) create mode 100644 cpp/src/tsne/tsne_runner.cuh diff --git a/cpp/src/tsne/tsne.cu b/cpp/src/tsne/tsne.cu index 6cb3881150..d7cfce7c41 100644 --- a/cpp/src/tsne/tsne.cu +++ b/cpp/src/tsne/tsne.cu @@ -15,17 +15,7 @@ */ #include -#include -#include -#include -#include -#include -#include "distances.cuh" -#include "exact_kernels.cuh" -#include "utils.cuh" - -#include "barnes_hut.cuh" -#include "exact_tsne.cuh" +#include "tsne_runner.cuh" namespace ML { @@ -40,109 +30,14 @@ void _fit(const raft::handle_t &handle, tsne_input &input, const int dim, const float pre_momentum, const float post_momentum, const long long random_state, int verbosity, const bool initialize_embeddings, bool barnes_hut) { - auto n = input.n; - auto p = input.d; - auto *Y = input.y; - - ML::Logger::get().setLevel(verbosity); - if (dim > 2 and barnes_hut) { - barnes_hut = false; - CUML_LOG_WARN( - "Barnes Hut only works for dim == 2. Switching to exact solution."); - } - if (n_neighbors > n) n_neighbors = n; - if (n_neighbors > 1023) { - CUML_LOG_WARN("FAISS only supports maximum n_neighbors = 1023."); - n_neighbors = 1023; - } - // Perplexity must be less than number of datapoints - // "How to Use t-SNE Effectively" https://distill.pub/2016/misread-tsne/ - if (perplexity > n) perplexity = n; - - CUML_LOG_DEBUG("Data size = (%d, %d) with dim = %d perplexity = %f", n, p, - dim, perplexity); - if (perplexity < 5 or perplexity > 50) - CUML_LOG_WARN( - "Perplexity should be within ranges (5, 50). Your results might be a" - " bit strange..."); - if (n_neighbors < perplexity * 3.0f) - CUML_LOG_WARN( - "# of Nearest Neighbors should be at least 3 * perplexity. Your results" - " might be a bit strange..."); - - auto d_alloc = handle.get_device_allocator(); - cudaStream_t stream = handle.get_stream(); - - START_TIMER; - //--------------------------------------------------- - // Get distances - CUML_LOG_DEBUG("Getting distances."); - - MLCommon::Sparse::COO COO_Matrix(d_alloc, stream); - - // artificial scope for safe destruction of indices/distances buffers - { - rmm::device_uvector indices(0, stream); - rmm::device_uvector distances(0, stream); - - if (input.alloc_knn_graph()) { - indices = rmm::device_uvector(n * n_neighbors, stream); - distances = rmm::device_uvector(n * n_neighbors, stream); - } - - TSNE::get_distances(handle, input, indices.data(), distances.data(), - n_neighbors, stream); - //--------------------------------------------------- - END_TIMER(DistancesTime); - - START_TIMER; - //--------------------------------------------------- - // Normalize distances - CUML_LOG_DEBUG("Now normalizing distances so exp(D) doesn't explode."); - TSNE::normalize_distances(n, distances.data(), n_neighbors, stream); - //--------------------------------------------------- - END_TIMER(NormalizeTime); - - START_TIMER; - //--------------------------------------------------- - // Optimal perplexity - CUML_LOG_DEBUG("Searching for optimal perplexity via bisection search."); - MLCommon::device_buffer P(d_alloc, stream, n * n_neighbors); - TSNE::perplexity_search(distances.data(), P.data(), perplexity, - perplexity_max_iter, perplexity_tol, n, n_neighbors, - handle); - - //--------------------------------------------------- - END_TIMER(PerplexityTime); - - START_TIMER; - //--------------------------------------------------- - // Convert data to COO layout - TSNE::symmetrize_perplexity(P.data(), indices.data(), n, n_neighbors, - early_exaggeration, &COO_Matrix, stream, - handle); - P.release(stream); - } - - const int NNZ = COO_Matrix.nnz; - float *VAL = COO_Matrix.vals(); - const int *COL = COO_Matrix.cols(); - const int *ROW = COO_Matrix.rows(); - //--------------------------------------------------- - END_TIMER(SymmetrizeTime); + TSNE_runner runner( + handle, input, dim, n_neighbors, theta, epssq, perplexity, + perplexity_max_iter, perplexity_tol, early_exaggeration, exaggeration_iter, + min_gain, pre_learning_rate, post_learning_rate, max_iter, min_grad_norm, + pre_momentum, post_momentum, random_state, verbosity, initialize_embeddings, + barnes_hut); - if (barnes_hut) { - TSNE::Barnes_Hut(VAL, COL, ROW, NNZ, handle, Y, n, theta, epssq, - early_exaggeration, exaggeration_iter, min_gain, - pre_learning_rate, post_learning_rate, max_iter, - min_grad_norm, pre_momentum, post_momentum, random_state, - initialize_embeddings); - } else { - TSNE::Exact_TSNE(VAL, COL, ROW, NNZ, handle, Y, n, dim, early_exaggeration, - exaggeration_iter, min_gain, pre_learning_rate, - post_learning_rate, max_iter, min_grad_norm, pre_momentum, - post_momentum, random_state, initialize_embeddings); - } + runner.run(); } void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, int n, int p, diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh new file mode 100644 index 0000000000..c0bcb42a34 --- /dev/null +++ b/cpp/src/tsne/tsne_runner.cuh @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "distances.cuh" +#include "exact_kernels.cuh" +#include "utils.cuh" + +#include "barnes_hut.cuh" +#include "exact_tsne.cuh" + +namespace ML { + +template +class TSNE_runner { + public: + TSNE_runner(const raft::handle_t &handle_, tsne_input &input_, const int dim_, + int n_neighbors_, const float theta_, const float epssq_, + float perplexity_, const int perplexity_max_iter_, + const float perplexity_tol_, const float early_exaggeration_, + const int exaggeration_iter_, const float min_gain_, + const float pre_learning_rate_, const float post_learning_rate_, + const int max_iter_, const float min_grad_norm_, + const float pre_momentum_, const float post_momentum_, + const long long random_state_, int verbosity_, + const bool initialize_embeddings_, bool barnes_hut_) + : handle(handle_), + input(input_), + dim(dim_), + n_neighbors(n_neighbors_), + theta(theta_), + epssq(epssq_), + perplexity(perplexity_), + perplexity_max_iter(perplexity_max_iter_), + perplexity_tol(perplexity_tol_), + early_exaggeration(early_exaggeration_), + exaggeration_iter(exaggeration_iter_), + min_gain(min_gain_), + pre_learning_rate(pre_learning_rate_), + post_learning_rate(post_learning_rate_), + max_iter(max_iter_), + min_grad_norm(min_grad_norm_), + pre_momentum(pre_momentum_), + post_momentum(post_momentum_), + random_state(random_state_), + verbosity(verbosity_), + initialize_embeddings(initialize_embeddings_), + barnes_hut(barnes_hut_), + COO_Matrix(handle_.get_device_allocator(), handle_.get_stream()) { + this->n = input.n; + this->p = input.d; + this->Y = input.y; + + ML::Logger::get().setLevel(verbosity); + if (dim > 2 and barnes_hut) { + barnes_hut = false; + CUML_LOG_WARN( + "Barnes Hut only works for dim == 2. Switching to exact solution."); + } + if (n_neighbors > n) n_neighbors = n; + if (n_neighbors > 1023) { + CUML_LOG_WARN("FAISS only supports maximum n_neighbors = 1023."); + n_neighbors = 1023; + } + // Perplexity must be less than number of datapoints + // "How to Use t-SNE Effectively" https://distill.pub/2016/misread-tsne/ + if (perplexity > n) perplexity = n; + + CUML_LOG_DEBUG("Data size = (%d, %d) with dim = %d perplexity = %f", n, p, + dim, perplexity); + if (perplexity < 5 or perplexity > 50) + CUML_LOG_WARN( + "Perplexity should be within ranges (5, 50). Your results might be a" + " bit strange..."); + if (n_neighbors < perplexity * 3.0f) + CUML_LOG_WARN( + "# of Nearest Neighbors should be at least 3 * perplexity. Your results" + " might be a bit strange..."); + } + + void run() { + distance_and_perplexity(); + + const int NNZ = COO_Matrix.nnz; + float *VAL = COO_Matrix.vals(); + const int *COL = COO_Matrix.cols(); + const int *ROW = COO_Matrix.rows(); + //--------------------------------------------------- + + if (barnes_hut) { + TSNE::Barnes_Hut(VAL, COL, ROW, NNZ, handle, Y, n, theta, epssq, + early_exaggeration, exaggeration_iter, min_gain, + pre_learning_rate, post_learning_rate, max_iter, + min_grad_norm, pre_momentum, post_momentum, random_state, + initialize_embeddings); + } else { + TSNE::Exact_TSNE(VAL, COL, ROW, NNZ, handle, Y, n, dim, + early_exaggeration, exaggeration_iter, min_gain, + pre_learning_rate, post_learning_rate, max_iter, + min_grad_norm, pre_momentum, post_momentum, random_state, + initialize_embeddings); + } + } + + private: + void distance_and_perplexity() { + START_TIMER; + + //--------------------------------------------------- + // Get distances + CUML_LOG_DEBUG("Getting distances."); + + auto stream = handle.get_stream(); + + rmm::device_uvector indices(0, stream); + rmm::device_uvector distances(0, stream); + + if (input.alloc_knn_graph()) { + indices = rmm::device_uvector(n * n_neighbors, stream); + distances = rmm::device_uvector(n * n_neighbors, stream); + } + + TSNE::get_distances(handle, input, indices.data(), distances.data(), + n_neighbors, stream); + //--------------------------------------------------- + END_TIMER(DistancesTime); + + START_TIMER; + //--------------------------------------------------- + // Normalize distances + CUML_LOG_DEBUG("Now normalizing distances so exp(D) doesn't explode."); + TSNE::normalize_distances(n, distances.data(), n_neighbors, stream); + //--------------------------------------------------- + END_TIMER(NormalizeTime); + + START_TIMER; + //--------------------------------------------------- + // Optimal perplexity + CUML_LOG_DEBUG("Searching for optimal perplexity via bisection search."); + rmm::device_uvector P(n * n_neighbors, stream); + TSNE::perplexity_search(distances.data(), P.data(), perplexity, + perplexity_max_iter, perplexity_tol, n, n_neighbors, + handle); + + //--------------------------------------------------- + END_TIMER(PerplexityTime); + + START_TIMER; + //--------------------------------------------------- + // Convert data to COO layout + TSNE::symmetrize_perplexity(P.data(), indices.data(), n, n_neighbors, + early_exaggeration, &COO_Matrix, stream, + handle); + END_TIMER(SymmetrizeTime); + } + + const raft::handle_t &handle; + tsne_input &input; + const int dim; + int n_neighbors; + const float theta; + const float epssq; + float perplexity; + const int perplexity_max_iter; + const float perplexity_tol; + const float early_exaggeration; + const int exaggeration_iter; + const float min_gain; + const float pre_learning_rate; + const float post_learning_rate; + const int max_iter; + const float min_grad_norm; + const float pre_momentum; + const float post_momentum; + const long long random_state; + int verbosity; + const bool initialize_embeddings; + bool barnes_hut; + + MLCommon::Sparse::COO COO_Matrix; + int n, p; + float *Y; +}; + +} // namespace ML \ No newline at end of file From 24bb192e7825e587f3e662fbc6ca75d67e69eb43 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 11 Dec 2020 13:38:21 -0600 Subject: [PATCH 11/22] correcting doxygen --- cpp/include/cuml/manifold/tsne.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index a708948485..7c9b015829 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -91,10 +91,11 @@ void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, int n, int p, * or brute force O(N^2). * * @param[in] handle The GPU handle. - * @param[in] indptr indptr of CSR dataset - * @param[in] indices indices of CSR dataset - * @param[in] data data of CSR dataset + * @param[in] indptr indptr of CSR dataset. + * @param[in] indices indices of CSR dataset. + * @param[in] data data of CSR dataset. * @param[out] Y The final embedding. + * @param[in] nnz The number of non-zero entries in the CSR. * @param[in] n Number of rows in data X. * @param[in] p Number of columns in data X. * @param[in] dim Number of output dimensions for embeddings Y. From 5ddd4dbd14a55473f8615bbd35c672c3dd30f835 Mon Sep 17 00:00:00 2001 From: divyegala Date: Sat, 12 Dec 2020 14:31:59 -0600 Subject: [PATCH 12/22] addressing some review changes for math_t -> value_t --- cpp/src/tsne/distances.cuh | 2 +- cpp/src_prims/sparse/coo.cuh | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index b7d7ca39eb..c96896a8b8 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -121,7 +121,7 @@ void symmetrize_perplexity(float *P, knn_value_idx *indices, const int n, raft::linalg::scalarMultiply(P, P, div, n * k, stream); // Symmetrize to form P + P.T - MLCommon::Sparse::from_knn_symmetrize_matrix( + MLCommon::Sparse::from_knn_symmetrize_matrix( indices, P, n, k, COO_Matrix, stream, handle.get_device_allocator()); } diff --git a/cpp/src_prims/sparse/coo.cuh b/cpp/src_prims/sparse/coo.cuh index d90bafeaaf..3de874a37c 100644 --- a/cpp/src_prims/sparse/coo.cuh +++ b/cpp/src_prims/sparse/coo.cuh @@ -831,8 +831,8 @@ void coo_symmetrize(COO *in, COO *out, * @param row_sizes: Input empty row sum 1 array(n) * @param row_sizes2: Input empty row sum 2 array(n) for faster reduction */ -template -__global__ static void symmetric_find_size(const math_t *restrict data, +template +__global__ static void symmetric_find_size(const value_t *restrict data, const value_idx *restrict indices, const int n, const int k, int *restrict row_sizes, @@ -879,11 +879,11 @@ __global__ static void reduce_find_size(const int n, const int k, * @param n: Number of rows * @param k: Number of n_neighbors */ -template +template __global__ static void symmetric_sum(int *restrict edges, - const math_t *restrict data, + const value_t *restrict data, const value_idx *restrict indices, - math_t *restrict VAL, int *restrict COL, + value_t *restrict VAL, int *restrict COL, int *restrict ROW, const int n, const int k) { const int row = blockIdx.x * blockDim.x + threadIdx.x; // for every row @@ -921,11 +921,10 @@ __global__ static void symmetric_sum(int *restrict edges, * @param stream: Input cuda stream * @param d_alloc device allocator for temporary buffers */ -template +template void from_knn_symmetrize_matrix(const value_idx *restrict knn_indices, - const math_t *restrict knn_dists, const int n, - const int k, COO *out, + const value_t *restrict knn_dists, const int n, + const int k, COO *out, cudaStream_t stream, std::shared_ptr d_alloc) { // (1) Find how much space needed in each row From 7f6291213456445a678bf4debdd4c4d4114673f0 Mon Sep 17 00:00:00 2001 From: divyegala Date: Sun, 13 Dec 2020 17:33:50 -0600 Subject: [PATCH 13/22] style check --- cpp/include/cuml/manifold/tsne.h | 21 +++++----- cpp/src/tsne/distances.cuh | 20 ++++++---- cpp/src/tsne/tsne.cu | 67 +++++++++++++++----------------- cpp/src/tsne/tsne_runner.cuh | 20 +++++----- cpp/test/sg/tsne_test.cu | 7 ++-- python/cuml/test/test_tsne.py | 3 +- 6 files changed, 70 insertions(+), 68 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index c907ab2fb9..f2d085ceca 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -74,10 +74,10 @@ namespace ML { * its Applications to Modern Data (https://arxiv.org/abs/1807.11824). */ void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, int n, int p, - int64_t *knn_indices, float *knn_dists, - const int dim = 2, int n_neighbors = 1023, - const float theta = 0.5f, const float epssq = 0.0025, - float perplexity = 50.0f, const int perplexity_max_iter = 100, + int64_t *knn_indices, float *knn_dists, const int dim = 2, + int n_neighbors = 1023, const float theta = 0.5f, + const float epssq = 0.0025, float perplexity = 50.0f, + const int perplexity_max_iter = 100, const float perplexity_tol = 1e-5, const float early_exaggeration = 12.0f, const int exaggeration_iter = 250, const float min_gain = 0.01f, @@ -144,13 +144,12 @@ void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, int n, int p, */ void TSNE_fit_sparse( const raft::handle_t &handle, int *indptr, int *indices, float *data, - float *Y, int nnz, int n, int p, int *knn_indices, float *knn_dists, - const int dim = 2, int n_neighbors = 1023, - const float theta = 0.5f, const float epssq = 0.0025, - float perplexity = 50.0f, const int perplexity_max_iter = 100, - const float perplexity_tol = 1e-5, const float early_exaggeration = 12.0f, - const int exaggeration_iter = 250, const float min_gain = 0.01f, - const float pre_learning_rate = 200.0f, + float *Y, int nnz, int n, int p, int *knn_indices, float *knn_dists, + const int dim = 2, int n_neighbors = 1023, const float theta = 0.5f, + const float epssq = 0.0025, float perplexity = 50.0f, + const int perplexity_max_iter = 100, const float perplexity_tol = 1e-5, + const float early_exaggeration = 12.0f, const int exaggeration_iter = 250, + const float min_gain = 0.01f, const float pre_learning_rate = 200.0f, const float post_learning_rate = 500.0f, const int max_iter = 1000, const float min_grad_norm = 1e-7, const float pre_momentum = 0.5, const float post_momentum = 0.8, const long long random_state = -1, diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 6990d5433d..912eed9274 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -45,7 +45,8 @@ namespace TSNE { // dense template void get_distances(const raft::handle_t &handle, - manifold_dense_inputs_t &input, knn_graph &k_graph, + manifold_dense_inputs_t &input, + knn_graph &k_graph, cudaStream_t stream) { // TODO: for TSNE transform first fit some points then transform with 1/(1+d^2) // #861 @@ -62,21 +63,24 @@ void get_distances(const raft::handle_t &handle, */ MLCommon::Selection::brute_force_knn(input_vec, sizes_vec, input.d, input.X, - input.n, k_graph.knn_indices, k_graph.knn_dists, k_graph.n_neighbors, + input.n, k_graph.knn_indices, + k_graph.knn_dists, k_graph.n_neighbors, handle.get_device_allocator(), stream); } // sparse template void get_distances(const raft::handle_t &handle, - manifold_sparse_inputs_t &input, knn_graph &k_graph, + manifold_sparse_inputs_t &input, + knn_graph &k_graph, cudaStream_t stream) { MLCommon::Sparse::Selection::brute_force_knn( input.indptr, input.indices, input.data, input.nnz, input.n, input.d, input.indptr, input.indices, input.data, input.nnz, input.n, input.d, - k_graph.knn_indices, k_graph.knn_dists, k_graph.n_neighbors, handle.get_cusparse_handle(), - handle.get_device_allocator(), stream, ML::Sparse::DEFAULT_BATCH_SIZE, - ML::Sparse::DEFAULT_BATCH_SIZE, ML::MetricType::METRIC_L2); + k_graph.knn_indices, k_graph.knn_dists, k_graph.n_neighbors, + handle.get_cusparse_handle(), handle.get_device_allocator(), stream, + ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, + ML::MetricType::METRIC_L2); } /** @@ -113,8 +117,8 @@ void normalize_distances(const int n, float *distances, const int n_neighbors, * @param[in] handle: The GPU handle. */ template -void symmetrize_perplexity(float *P, value_idx *indices, const int n, const int k, - const float exaggeration, +void symmetrize_perplexity(float *P, value_idx *indices, const int n, + const int k, const float exaggeration, MLCommon::Sparse::COO *COO_Matrix, cudaStream_t stream, const raft::handle_t &handle) { // Perform (P + P.T) / P_sum * early_exaggeration diff --git a/cpp/src/tsne/tsne.cu b/cpp/src/tsne/tsne.cu index 8c769b6308..474fca507d 100644 --- a/cpp/src/tsne/tsne.cu +++ b/cpp/src/tsne/tsne.cu @@ -20,21 +20,20 @@ namespace ML { template -void _fit(const raft::handle_t &handle, tsne_input &input, +void _fit(const raft::handle_t &handle, tsne_input &input, knn_graph &k_graph, const int dim, - const float theta, const float epssq, - float perplexity, const int perplexity_max_iter, - const float perplexity_tol, const float early_exaggeration, - const int exaggeration_iter, const float min_gain, - const float pre_learning_rate, const float post_learning_rate, - const int max_iter, const float min_grad_norm, - const float pre_momentum, const float post_momentum, - const long long random_state, int verbosity, - const bool initialize_embeddings, bool barnes_hut) { + const float theta, const float epssq, float perplexity, + const int perplexity_max_iter, const float perplexity_tol, + const float early_exaggeration, const int exaggeration_iter, + const float min_gain, const float pre_learning_rate, + const float post_learning_rate, const int max_iter, + const float min_grad_norm, const float pre_momentum, + const float post_momentum, const long long random_state, + int verbosity, const bool initialize_embeddings, bool barnes_hut) { TSNE_runner runner( - handle, input, k_graph, dim, theta, epssq, perplexity, - perplexity_max_iter, perplexity_tol, early_exaggeration, exaggeration_iter, - min_gain, pre_learning_rate, post_learning_rate, max_iter, min_grad_norm, + handle, input, k_graph, dim, theta, epssq, perplexity, perplexity_max_iter, + perplexity_tol, early_exaggeration, exaggeration_iter, min_gain, + pre_learning_rate, post_learning_rate, max_iter, min_grad_norm, pre_momentum, post_momentum, random_state, verbosity, initialize_embeddings, barnes_hut); @@ -42,17 +41,16 @@ void _fit(const raft::handle_t &handle, tsne_input &input, } void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, int n, int p, - int64_t *knn_indices, float *knn_dists, - const int dim, int n_neighbors, const float theta, - const float epssq, float perplexity, - const int perplexity_max_iter, const float perplexity_tol, - const float early_exaggeration, const int exaggeration_iter, - const float min_gain, const float pre_learning_rate, - const float post_learning_rate, const int max_iter, - const float min_grad_norm, const float pre_momentum, - const float post_momentum, const long long random_state, - int verbosity, const bool initialize_embeddings, - bool barnes_hut) { + int64_t *knn_indices, float *knn_dists, const int dim, + int n_neighbors, const float theta, const float epssq, + float perplexity, const int perplexity_max_iter, + const float perplexity_tol, const float early_exaggeration, + const int exaggeration_iter, const float min_gain, + const float pre_learning_rate, const float post_learning_rate, + const int max_iter, const float min_grad_norm, + const float pre_momentum, const float post_momentum, + const long long random_state, int verbosity, + const bool initialize_embeddings, bool barnes_hut) { ASSERT(n > 0 && p > 0 && dim > 0 && n_neighbors > 0 && X != NULL && Y != NULL, "Wrong input args"); @@ -60,20 +58,19 @@ void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, int n, int p, knn_graph k_graph(n, n_neighbors, knn_indices, knn_dists); _fit, knn_indices_dense_t, float>( - handle, input, k_graph, dim, theta, epssq, perplexity, - perplexity_max_iter, perplexity_tol, early_exaggeration, exaggeration_iter, - min_gain, pre_learning_rate, post_learning_rate, max_iter, min_grad_norm, + handle, input, k_graph, dim, theta, epssq, perplexity, perplexity_max_iter, + perplexity_tol, early_exaggeration, exaggeration_iter, min_gain, + pre_learning_rate, post_learning_rate, max_iter, min_grad_norm, pre_momentum, post_momentum, random_state, verbosity, initialize_embeddings, barnes_hut); } void TSNE_fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, float *Y, int nnz, int n, int p, - int *knn_indices, float *knn_dists, - const int dim, int n_neighbors, const float theta, - const float epssq, float perplexity, - const int perplexity_max_iter, const float perplexity_tol, - const float early_exaggeration, + int *knn_indices, float *knn_dists, const int dim, + int n_neighbors, const float theta, const float epssq, + float perplexity, const int perplexity_max_iter, + const float perplexity_tol, const float early_exaggeration, const int exaggeration_iter, const float min_gain, const float pre_learning_rate, const float post_learning_rate, const int max_iter, @@ -90,9 +87,9 @@ void TSNE_fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, knn_graph k_graph(n, n_neighbors, knn_indices, knn_dists); _fit, knn_indices_sparse_t, float>( - handle, input, k_graph, dim, theta, epssq, perplexity, - perplexity_max_iter, perplexity_tol, early_exaggeration, exaggeration_iter, - min_gain, pre_learning_rate, post_learning_rate, max_iter, min_grad_norm, + handle, input, k_graph, dim, theta, epssq, perplexity, perplexity_max_iter, + perplexity_tol, early_exaggeration, exaggeration_iter, min_gain, + pre_learning_rate, post_learning_rate, max_iter, min_grad_norm, pre_momentum, post_momentum, random_state, verbosity, initialize_embeddings, barnes_hut); } diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index adbe754f24..79a5bd1589 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -33,15 +33,15 @@ class TSNE_runner { public: TSNE_runner(const raft::handle_t &handle_, tsne_input &input_, knn_graph &k_graph_, const int dim_, - const float theta_, const float epssq_, - float perplexity_, const int perplexity_max_iter_, - const float perplexity_tol_, const float early_exaggeration_, - const int exaggeration_iter_, const float min_gain_, - const float pre_learning_rate_, const float post_learning_rate_, - const int max_iter_, const float min_grad_norm_, - const float pre_momentum_, const float post_momentum_, - const long long random_state_, int verbosity_, - const bool initialize_embeddings_, bool barnes_hut_) + const float theta_, const float epssq_, float perplexity_, + const int perplexity_max_iter_, const float perplexity_tol_, + const float early_exaggeration_, const int exaggeration_iter_, + const float min_gain_, const float pre_learning_rate_, + const float post_learning_rate_, const int max_iter_, + const float min_grad_norm_, const float pre_momentum_, + const float post_momentum_, const long long random_state_, + int verbosity_, const bool initialize_embeddings_, + bool barnes_hut_) : handle(handle_), input(input_), k_graph(k_graph_), @@ -136,7 +136,7 @@ class TSNE_runner { if (!k_graph.knn_indices || !k_graph.knn_dists) { ASSERT(!k_graph.knn_indices && !k_graph.knn_dists, - "Either both or none of the KNN parameters should be provided"); + "Either both or none of the KNN parameters should be provided"); indices = rmm::device_uvector(n * n_neighbors, stream); distances = rmm::device_uvector(n * n_neighbors, stream); diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index 201ca483ac..fd9aac1331 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -135,13 +135,14 @@ class TSNETest : public ::testing::Test { CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); manifold_dense_inputs_t input(X_d.data(), Y_d.data(), n, p); - knn_graph k_graph(n, 90, knn_indices.data(), knn_dists.data()); + knn_graph k_graph(n, 90, knn_indices.data(), + knn_dists.data()); - TSNE::get_distances(handle, input, k_graph, handle.get_stream()); + TSNE::get_distances(handle, input, k_graph, + handle.get_stream()); CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - // Test Barnes Hut TSNE_fit(handle, X_d.data(), Y_d.data(), n, p, knn_indices.data(), knn_dists.data(), 2, 90, 0.5, 0.0025, 50, 100, 1e-5, 12, 250, 0.01, diff --git a/python/cuml/test/test_tsne.py b/python/cuml/test/test_tsne.py index c6b30d4faa..730a218fb8 100644 --- a/python/cuml/test/test_tsne.py +++ b/python/cuml/test/test_tsne.py @@ -211,6 +211,7 @@ def test_tsne_transform_on_digits_sparse(input_type): trust = trustworthiness(digits.data[~digits_selection], embedding, 15) assert trust >= 0.85 + @pytest.mark.parametrize('type_knn_graph', ['sklearn', 'cuml']) @pytest.mark.parametrize('input_type', ['cupy', 'scipy']) def test_tsne_knn_parameters_sparse(type_knn_graph, input_type): @@ -223,7 +224,7 @@ def test_tsne_knn_parameters_sparse(type_knn_graph, input_type): digits_selection = np.random.RandomState(42).choice( [True, False], 1797, replace=True, p=[0.60, 0.40]) - + selected_digits = digits.data[~digits_selection] neigh.fit(selected_digits) From d58fb88acda32664c3cfe11d927a8980f80694c3 Mon Sep 17 00:00:00 2001 From: divyegala Date: Sun, 13 Dec 2020 17:40:18 -0600 Subject: [PATCH 14/22] doxygen --- cpp/include/cuml/manifold/tsne.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index f2d085ceca..3722a87ccb 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -101,6 +101,8 @@ void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, int n, int p, * @param[in] nnz The number of non-zero entries in the CSR. * @param[in] n Number of rows in data X. * @param[in] p Number of columns in data X. + * @param[in] knn_indices Array containing nearest neighors indices. + * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] dim Number of output dimensions for embeddings Y. * @param[in] n_neighbors Number of nearest neighbors used. * @param[in] theta Float between 0 and 1. Tradeoff for speed (0) From 7755ce6be818ce48b7eb85c34093c018781e85ef Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 14 Dec 2020 15:58:15 -0600 Subject: [PATCH 15/22] explict templates for knn --- cpp/src/tsne/distances.cuh | 39 +++++++++++++++++++++++++++--------- cpp/src/tsne/tsne_runner.cuh | 2 +- cpp/test/sg/tsne_test.cu | 2 +- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 912eed9274..4d84721298 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -25,6 +25,8 @@ #include +#include + namespace ML { namespace TSNE { @@ -37,16 +39,15 @@ namespace TSNE { * @param[in] d_alloc: device allocator * @param[in] stream: The GPU stream. */ -// template -// void get_distances(const raft::handle_t &handle, tsne_input &input, -// knn_value_idx *indices, knn_value_t *distances, -// const int n_neighbors, cudaStream_t stream); +template +void get_distances(const raft::handle_t &handle, tsne_input &input, + knn_graph &k_graph, cudaStream_t stream); -// dense -template +// dense, int64 indices +template <> void get_distances(const raft::handle_t &handle, manifold_dense_inputs_t &input, - knn_graph &k_graph, + knn_graph &k_graph, cudaStream_t stream) { // TODO: for TSNE transform first fit some points then transform with 1/(1+d^2) // #861 @@ -68,11 +69,20 @@ void get_distances(const raft::handle_t &handle, handle.get_device_allocator(), stream); } -// sparse -template +// dense, int32 indices +template <> +void get_distances(const raft::handle_t &handle, + manifold_dense_inputs_t &input, + knn_graph &k_graph, + cudaStream_t stream) { + throw raft::exception("Dense TSNE does not support 32-bit integer indices yet."); +} + +// sparse, int32 +template <> void get_distances(const raft::handle_t &handle, manifold_sparse_inputs_t &input, - knn_graph &k_graph, + knn_graph &k_graph, cudaStream_t stream) { MLCommon::Sparse::Selection::brute_force_knn( input.indptr, input.indices, input.data, input.nnz, input.n, input.d, @@ -83,6 +93,15 @@ void get_distances(const raft::handle_t &handle, ML::MetricType::METRIC_L2); } +// sparse, int64 +template <> +void get_distances(const raft::handle_t &handle, + manifold_sparse_inputs_t &input, + knn_graph &k_graph, + cudaStream_t stream) { + throw raft::exception("Sparse TSNE does not support 32-bit integer indices yet."); +} + /** * @brief Find the maximum element in the distances matrix, then divide all entries by this. * This promotes exp(distances) to not explode. diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index 79a5bd1589..9d67bda155 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -144,7 +144,7 @@ class TSNE_runner { k_graph.knn_indices = indices.data(); k_graph.knn_dists = distances.data(); - TSNE::get_distances(handle, input, k_graph, stream); + TSNE::get_distances(handle, input, k_graph, stream); } //--------------------------------------------------- diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index fd9aac1331..57dd7f94d8 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -138,7 +138,7 @@ class TSNETest : public ::testing::Test { knn_graph k_graph(n, 90, knn_indices.data(), knn_dists.data()); - TSNE::get_distances(handle, input, k_graph, + TSNE::get_distances(handle, input, k_graph, handle.get_stream()); CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); From c576c60a4956608fb95df8cd904af35981afe259 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 14 Dec 2020 16:31:03 -0600 Subject: [PATCH 16/22] runner and distances in templates --- cpp/src/tsne/distances.cuh | 19 +++++------ cpp/src/tsne/tsne.cu | 14 ++++---- cpp/src/tsne/tsne_runner.cuh | 62 ++++++++++++++++++------------------ cpp/src_prims/sparse/coo.cuh | 24 +++++++------- 4 files changed, 60 insertions(+), 59 deletions(-) diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 4d84721298..5a33b6896f 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -110,16 +110,17 @@ void get_distances(const raft::handle_t &handle, * @param[in] n_neighbors: The number of nearest neighbors you want. * @param[in] stream: The GPU stream. */ -void normalize_distances(const int n, float *distances, const int n_neighbors, +template +void normalize_distances(const value_idx n, value_t *distances, const int n_neighbors, cudaStream_t stream) { // Now D / max(abs(D)) to allow exp(D) to not explode - thrust::device_ptr begin = thrust::device_pointer_cast(distances); - float maxNorm = *thrust::max_element(thrust::cuda::par.on(stream), begin, + thrust::device_ptr begin = thrust::device_pointer_cast(distances); + value_t maxNorm = *thrust::max_element(thrust::cuda::par.on(stream), begin, begin + n * n_neighbors); if (maxNorm == 0.0f) maxNorm = 1.0f; // Divide distances inplace by max - const float div = 1.0f / maxNorm; // Mult faster than div + const value_t div = 1.0f / maxNorm; // Mult faster than div raft::linalg::scalarMultiply(distances, distances, div, n * n_neighbors, stream); } @@ -135,13 +136,13 @@ void normalize_distances(const int n, float *distances, const int n_neighbors, * @param[in] stream: The GPU stream. * @param[in] handle: The GPU handle. */ -template -void symmetrize_perplexity(float *P, value_idx *indices, const int n, - const int k, const float exaggeration, - MLCommon::Sparse::COO *COO_Matrix, +template +void symmetrize_perplexity(float *P, value_idx *indices, const value_idx n, + const int k, const value_t exaggeration, + MLCommon::Sparse::COO *COO_Matrix, cudaStream_t stream, const raft::handle_t &handle) { // Perform (P + P.T) / P_sum * early_exaggeration - const float div = exaggeration / (2.0f * n); + const value_t div = exaggeration / (2.0f * n); raft::linalg::scalarMultiply(P, P, div, n * k, stream); // Symmetrize to form P + P.T diff --git a/cpp/src/tsne/tsne.cu b/cpp/src/tsne/tsne.cu index 474fca507d..05ce85b0dc 100644 --- a/cpp/src/tsne/tsne.cu +++ b/cpp/src/tsne/tsne.cu @@ -22,13 +22,13 @@ namespace ML { template void _fit(const raft::handle_t &handle, tsne_input &input, knn_graph &k_graph, const int dim, - const float theta, const float epssq, float perplexity, - const int perplexity_max_iter, const float perplexity_tol, - const float early_exaggeration, const int exaggeration_iter, - const float min_gain, const float pre_learning_rate, - const float post_learning_rate, const int max_iter, - const float min_grad_norm, const float pre_momentum, - const float post_momentum, const long long random_state, + const value_t theta, const value_t epssq, value_t perplexity, + const int perplexity_max_iter, const value_t perplexity_tol, + const value_t early_exaggeration, const int exaggeration_iter, + const value_t min_gain, const value_t pre_learning_rate, + const value_t post_learning_rate, const int max_iter, + const value_t min_grad_norm, const value_t pre_momentum, + const value_t post_momentum, const long long random_state, int verbosity, const bool initialize_embeddings, bool barnes_hut) { TSNE_runner runner( handle, input, k_graph, dim, theta, epssq, perplexity, perplexity_max_iter, diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index 9d67bda155..e0675e48e9 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -100,25 +100,25 @@ class TSNE_runner { void run() { distance_and_perplexity(); - const int NNZ = COO_Matrix.nnz; - float *VAL = COO_Matrix.vals(); - const int *COL = COO_Matrix.cols(); - const int *ROW = COO_Matrix.rows(); + // const auto NNZ = COO_Matrix.nnz; + // auto *VAL = COO_Matrix.vals(); + // const auto *COL = COO_Matrix.cols(); + // const auto *ROW = COO_Matrix.rows(); //--------------------------------------------------- - if (barnes_hut) { - TSNE::Barnes_Hut(VAL, COL, ROW, NNZ, handle, Y, n, theta, epssq, - early_exaggeration, exaggeration_iter, min_gain, - pre_learning_rate, post_learning_rate, max_iter, - min_grad_norm, pre_momentum, post_momentum, random_state, - initialize_embeddings); - } else { - TSNE::Exact_TSNE(VAL, COL, ROW, NNZ, handle, Y, n, dim, - early_exaggeration, exaggeration_iter, min_gain, - pre_learning_rate, post_learning_rate, max_iter, - min_grad_norm, pre_momentum, post_momentum, random_state, - initialize_embeddings); - } + // if (barnes_hut) { + // TSNE::Barnes_Hut(VAL, COL, ROW, NNZ, handle, Y, n, theta, epssq, + // early_exaggeration, exaggeration_iter, min_gain, + // pre_learning_rate, post_learning_rate, max_iter, + // min_grad_norm, pre_momentum, post_momentum, random_state, + // initialize_embeddings); + // } else { + // TSNE::Exact_TSNE(VAL, COL, ROW, NNZ, handle, Y, n, dim, + // early_exaggeration, exaggeration_iter, min_gain, + // pre_learning_rate, post_learning_rate, max_iter, + // min_grad_norm, pre_momentum, post_momentum, random_state, + // initialize_embeddings); + // } } private: @@ -184,28 +184,28 @@ class TSNE_runner { knn_graph &k_graph; const int dim; int n_neighbors; - const float theta; - const float epssq; - float perplexity; + const value_t theta; + const value_t epssq; + value_t perplexity; const int perplexity_max_iter; - const float perplexity_tol; - const float early_exaggeration; + const value_t perplexity_tol; + const value_t early_exaggeration; const int exaggeration_iter; - const float min_gain; - const float pre_learning_rate; - const float post_learning_rate; + const value_t min_gain; + const value_t pre_learning_rate; + const value_t post_learning_rate; const int max_iter; - const float min_grad_norm; - const float pre_momentum; - const float post_momentum; + const value_t min_grad_norm; + const value_t pre_momentum; + const value_t post_momentum; const long long random_state; int verbosity; const bool initialize_embeddings; bool barnes_hut; - MLCommon::Sparse::COO COO_Matrix; - int n, p; - float *Y; + MLCommon::Sparse::COO COO_Matrix; + value_idx n, p; + value_t *Y; }; } // namespace ML \ No newline at end of file diff --git a/cpp/src_prims/sparse/coo.cuh b/cpp/src_prims/sparse/coo.cuh index 3de874a37c..d21efe18f5 100644 --- a/cpp/src_prims/sparse/coo.cuh +++ b/cpp/src_prims/sparse/coo.cuh @@ -170,7 +170,7 @@ class COO { /** * @brief Send human-readable state information to output stream */ - friend std::ostream &operator<<(std::ostream &out, const COO &c) { + friend std::ostream &operator<<(std::ostream &out, const COO &c) { if (c.validate_size() && c.validate_mem()) { cudaStream_t stream; CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); @@ -837,11 +837,11 @@ __global__ static void symmetric_find_size(const value_t *restrict data, const int n, const int k, int *restrict row_sizes, int *restrict row_sizes2) { - const int row = blockIdx.x * blockDim.x + threadIdx.x; // for every row - const int j = blockIdx.y * blockDim.y + threadIdx.y; // for every item in row + const auto row = blockIdx.x * blockDim.x + threadIdx.x; // for every row + const auto j = blockIdx.y * blockDim.y + threadIdx.y; // for every item in row if (row >= n || j >= k) return; - const int col = indices[row * k + j]; + const auto col = indices[row * k + j]; if (j % 2) raft::myAtomicAdd(&row_sizes[col], 1); else @@ -883,16 +883,16 @@ template __global__ static void symmetric_sum(int *restrict edges, const value_t *restrict data, const value_idx *restrict indices, - value_t *restrict VAL, int *restrict COL, - int *restrict ROW, const int n, + value_t *restrict VAL, value_idx *restrict COL, + value_idx *restrict ROW, const int n, const int k) { - const int row = blockIdx.x * blockDim.x + threadIdx.x; // for every row - const int j = blockIdx.y * blockDim.y + threadIdx.y; // for every item in row + const auto row = blockIdx.x * blockDim.x + threadIdx.x; // for every row + const auto j = blockIdx.y * blockDim.y + threadIdx.y; // for every item in row if (row >= n || j >= k) return; - const int col = indices[row * k + j]; - const int original = atomicAdd(&edges[row], 1); - const int transpose = atomicAdd(&edges[col], 1); + const auto col = indices[row * k + j]; + const auto original = atomicAdd(&edges[row], 1); + const auto transpose = atomicAdd(&edges[col], 1); VAL[transpose] = VAL[original] = data[row * k + j]; // Notice swapped ROW, COL since transpose @@ -924,7 +924,7 @@ __global__ static void symmetric_sum(int *restrict edges, template void from_knn_symmetrize_matrix(const value_idx *restrict knn_indices, const value_t *restrict knn_dists, const int n, - const int k, COO *out, + const int k, COO *out, cudaStream_t stream, std::shared_ptr d_alloc) { // (1) Find how much space needed in each row From f878cfd99a269ec26637f8c44e6ba150f79cff21 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 14 Dec 2020 16:59:46 -0600 Subject: [PATCH 17/22] exact TSNE with template --- cpp/src/tsne/exact_kernels.cuh | 243 +++++++++++++++++---------------- cpp/src/tsne/exact_tsne.cuh | 51 +++---- cpp/src/tsne/tsne.cu | 2 +- cpp/src/tsne/tsne_runner.cuh | 22 +-- 4 files changed, 164 insertions(+), 154 deletions(-) diff --git a/cpp/src/tsne/exact_kernels.cuh b/cpp/src/tsne/exact_kernels.cuh index e88e4f5e7d..3c5a609c12 100644 --- a/cpp/src/tsne/exact_kernels.cuh +++ b/cpp/src/tsne/exact_kernels.cuh @@ -29,20 +29,21 @@ namespace TSNE { /****************************************/ /* Finds the best Gaussian bandwidth for each row in the dataset */ -__global__ void sigmas_kernel(const float *restrict distances, - float *restrict P, const float perplexity, - const float desired_entropy, const int epochs, - const float tol, const int n, const int k) { +template +__global__ void sigmas_kernel(const value_t *restrict distances, + value_t *restrict P, const value_t perplexity, + const value_t desired_entropy, const int epochs, + const value_t tol, const value_idx n, const int k) { // For every item in row - const int i = (blockIdx.x * blockDim.x) + threadIdx.x; + const auto i = (blockIdx.x * blockDim.x) + threadIdx.x; if (i >= n) return; - float beta_min = -INFINITY, beta_max = INFINITY; - float beta = 1; - register const int ik = i * k; + value_t beta_min = -INFINITY, beta_max = INFINITY; + value_t beta = 1; + register const auto ik = i * k; for (int step = 0; step < epochs; step++) { - float sum_Pi = FLT_EPSILON; + value_t sum_Pi = FLT_EPSILON; // Exponentiate to get Gaussian for (int j = 0; j < k; j++) { @@ -51,15 +52,15 @@ __global__ void sigmas_kernel(const float *restrict distances, } // Normalize - float sum_disti_Pi = 0; - const float div = __fdividef(1.0f, sum_Pi); + value_t sum_disti_Pi = 0; + const value_t div = __fdividef(1.0f, sum_Pi); for (int j = 0; j < k; j++) { P[ik + j] *= div; sum_disti_Pi += distances[ik + j] * P[ik + j]; } - const float entropy = __logf(sum_Pi) + beta * sum_disti_Pi; - const float entropy_diff = entropy - desired_entropy; + const value_t entropy = __logf(sum_Pi) + beta * sum_disti_Pi; + const value_t entropy_diff = entropy - desired_entropy; if (fabs(entropy_diff) <= tol) break; // Bisection search @@ -82,33 +83,34 @@ __global__ void sigmas_kernel(const float *restrict distances, /****************************************/ /* Finds the best Gaussian bandwith for each row in the dataset */ -__global__ void sigmas_kernel_2d(const float *restrict distances, - float *restrict P, const float perplexity, - const float desired_entropy, const int epochs, - const float tol, const int n) { +template +__global__ void sigmas_kernel_2d(const value_t *restrict distances, + value_t *restrict P, const value_t perplexity, + const value_t desired_entropy, const int epochs, + const value_t tol, const value_idx n) { // For every item in row - const int i = (blockIdx.x * blockDim.x) + threadIdx.x; + const auto i = (blockIdx.x * blockDim.x) + threadIdx.x; if (i >= n) return; - float beta_min = -INFINITY, beta_max = INFINITY; - float beta = 1; - register const int ik = i * 2; + value_t beta_min = -INFINITY, beta_max = INFINITY; + value_t beta = 1; + register const auto ik = i * 2; for (int step = 0; step < epochs; step++) { // Exponentiate to get Gaussian P[ik] = __expf(-distances[ik] * beta); P[ik + 1] = __expf(-distances[ik + 1] * beta); - const float sum_Pi = FLT_EPSILON + P[ik] + P[ik + 1]; + const value_t sum_Pi = FLT_EPSILON + P[ik] + P[ik + 1]; // Normalize - const float div = __fdividef(1.0f, sum_Pi); + const value_t div = __fdividef(1.0f, sum_Pi); P[ik] *= div; P[ik + 1] *= div; - const float sum_disti_Pi = + const value_t sum_disti_Pi = distances[ik] * P[ik] + distances[ik + 1] * P[ik + 1]; - const float entropy = __logf(sum_Pi) + beta * sum_disti_Pi; - const float entropy_diff = entropy - desired_entropy; + const value_t entropy = __logf(sum_Pi) + beta * sum_disti_Pi; + const value_t entropy_diff = entropy - desired_entropy; if (fabs(entropy_diff) <= tol) break; // Bisection search @@ -129,19 +131,20 @@ __global__ void sigmas_kernel_2d(const float *restrict distances, } /****************************************/ -void perplexity_search(const float *restrict distances, float *restrict P, - const float perplexity, const int epochs, - const float tol, const int n, const int dim, +template +void perplexity_search(const value_t *restrict distances, value_t *restrict P, + const value_t perplexity, const int epochs, + const value_t tol, const value_idx n, const int dim, const raft::handle_t &handle) { - const float desired_entropy = logf(perplexity); + const value_t desired_entropy = logf(perplexity); auto d_alloc = handle.get_device_allocator(); cudaStream_t stream = handle.get_stream(); if (dim == 2) - sigmas_kernel_2d<<>>( + sigmas_kernel_2d<<>>( distances, P, perplexity, desired_entropy, epochs, tol, n); else - sigmas_kernel<<>>( + sigmas_kernel<<>>( distances, P, perplexity, desired_entropy, epochs, tol, n, dim); CUDA_CHECK(cudaPeekAtLastError()); cudaStreamSynchronize(stream); @@ -150,27 +153,28 @@ void perplexity_search(const float *restrict distances, float *restrict P, /****************************************/ /* Compute attractive forces in O(uN) time. Uses only nearest neighbors */ +template __global__ void attractive_kernel( - const float *restrict VAL, const int *restrict COL, const int *restrict ROW, - const float *restrict Y, const float *restrict norm, float *restrict attract, - const int NNZ, const int n, const int dim, - const float df_power, // -(df + 1)/2) - const float recp_df) // 1 / df + const value_t *restrict VAL, const value_idx *restrict COL, const value_idx *restrict ROW, + const value_t *restrict Y, const value_t *restrict norm, value_t *restrict attract, + const value_idx NNZ, const value_idx n, const value_idx dim, + const value_t df_power, // -(df + 1)/2) + const value_t recp_df) // 1 / df { - const int index = (blockIdx.x * blockDim.x) + threadIdx.x; + const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= NNZ) return; - const int i = ROW[index], j = COL[index]; + const auto i = ROW[index], j = COL[index]; // Euclidean distances // TODO: can provide any distance ie cosine // #862 - float d = 0; + value_t d = 0; for (int k = 0; k < dim; k++) d += Y[k * n + i] * Y[k * n + j]; - const float euclidean_d = -2.0f * d + norm[i] + norm[j]; + const value_t euclidean_d = -2.0f * d + norm[i] + norm[j]; // TODO: Calculate Kullback-Leibler divergence // #863 - const float PQ = + const value_t PQ = VAL[index] * __powf((1.0f + euclidean_d * recp_df), df_power); // P*Q // Apply forces @@ -181,24 +185,25 @@ __global__ void attractive_kernel( /****************************************/ /* Special case when dim == 2. Can speed up many calculations up */ +template __global__ void attractive_kernel_2d( - const float *restrict VAL, const int *restrict COL, const int *restrict ROW, - const float *restrict Y1, const float *restrict Y2, - const float *restrict norm, float *restrict attract1, - float *restrict attract2, const int NNZ) { - const int index = (blockIdx.x * blockDim.x) + threadIdx.x; + const value_t *restrict VAL, const value_idx *restrict COL, const value_idx *restrict ROW, + const value_t *restrict Y1, const value_t *restrict Y2, + const value_t *restrict norm, value_t *restrict attract1, + value_t *restrict attract2, const value_idx NNZ) { + const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= NNZ) return; - const int i = ROW[index], j = COL[index]; + const auto i = ROW[index], j = COL[index]; // Euclidean distances // TODO: can provide any distance ie cosine // #862 - const float euclidean_d = + const value_t euclidean_d = norm[i] + norm[j] - 2.0f * (Y1[i] * Y1[j] + Y2[i] * Y2[j]); // TODO: Calculate Kullback-Leibler divergence // #863 - const float PQ = __fdividef(VAL[index], (1.0f + euclidean_d)); // P*Q + const value_t PQ = __fdividef(VAL[index], (1.0f + euclidean_d)); // P*Q // Apply forces raft::myAtomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j])); @@ -206,25 +211,26 @@ __global__ void attractive_kernel_2d( } /****************************************/ -void attractive_forces(const float *restrict VAL, const int *restrict COL, - const int *restrict ROW, const float *restrict Y, - const float *restrict norm, float *restrict attract, - const int NNZ, const int n, const int dim, - const float df_power, // -(df + 1)/2) - const float recp_df, // 1 / df +template +void attractive_forces(const value_t *restrict VAL, const value_idx *restrict COL, + const value_idx *restrict ROW, const value_t *restrict Y, + const value_t *restrict norm, value_t *restrict attract, + const value_idx NNZ, const value_idx n, const value_idx dim, + const value_t df_power, // -(df + 1)/2) + const value_t recp_df, // 1 / df cudaStream_t stream) { - CUDA_CHECK(cudaMemsetAsync(attract, 0, sizeof(float) * n * dim, stream)); + CUDA_CHECK(cudaMemsetAsync(attract, 0, sizeof(value_t) * n * dim, stream)); // TODO: Calculate Kullback-Leibler divergence // #863 // For general embedding dimensions if (dim != 2) { - attractive_kernel<<>>( + attractive_kernel<<>>( VAL, COL, ROW, Y, norm, attract, NNZ, n, dim, df_power, recp_df); } // For special case dim == 2 else { - attractive_kernel_2d<<>>( + attractive_kernel_2d<<>>( VAL, COL, ROW, Y, Y + n, norm, attract, attract + n, NNZ); } CUDA_CHECK(cudaPeekAtLastError()); @@ -234,31 +240,32 @@ void attractive_forces(const float *restrict VAL, const int *restrict COL, /* Computes repulsive forces in pseudo-O(N^2) time where many of the math ops are made considerably faster. */ -__global__ void repulsive_kernel(const float *restrict Y, float *restrict repel, - const float *restrict norm, - float *restrict Z_sum1, float *restrict Z_sum2, - const int n, const int dim, - const float df_power, // -(df + 1)/2) - const float recp_df) // 1 / df +template +__global__ void repulsive_kernel(const value_t *restrict Y, value_t *restrict repel, + const value_t *restrict norm, + value_t *restrict Z_sum1, value_t *restrict Z_sum2, + const value_idx n, const value_idx dim, + const value_t df_power, // -(df + 1)/2) + const value_t recp_df) // 1 / df { - const int j = + const auto j = (blockIdx.x * blockDim.x) + threadIdx.x; // for every item in row - const int i = (blockIdx.y * blockDim.y) + threadIdx.y; // for every row + const auto i = (blockIdx.y * blockDim.y) + threadIdx.y; // for every row if (j >= i || i >= n || j >= n) return; // Euclidean distances // TODO: can provide any distance ie cosine - float d = 0; + value_t d = 0; for (int k = 0; k < dim; k++) d += Y[k * n + i] * Y[k * n + j]; - const float euclidean_d = -2.0f * d + norm[i] + norm[j]; + const value_t euclidean_d = -2.0f * d + norm[i] + norm[j]; // Q and Q^2 - const float Q = __powf((1.0f + euclidean_d * recp_df), df_power); - const float Q2 = Q * Q; + const value_t Q = __powf((1.0f + euclidean_d * recp_df), df_power); + const value_t Q2 = Q * Q; // Apply forces for (int k = 0; k < dim; k++) { - const float force = Q2 * (Y[k * n + j] - Y[k * n + i]); + const value_t force = Q2 * (Y[k * n + j] - Y[k * n + i]); raft::myAtomicAdd(&repel[k * n + i], force); raft::myAtomicAdd(&repel[k * n + j], force); } @@ -273,25 +280,26 @@ __global__ void repulsive_kernel(const float *restrict Y, float *restrict repel, /****************************************/ /* Special case when dim == 2. Much faster since calculations are streamlined. */ +template __global__ void repulsive_kernel_2d( - const float *restrict Y1, const float *restrict Y2, float *restrict repel1, - float *restrict repel2, const float *restrict norm, float *restrict Z_sum1, - float *restrict Z_sum2, const int n) { - const int j = + const value_t *restrict Y1, const value_t *restrict Y2, value_t *restrict repel1, + value_t *restrict repel2, const value_t *restrict norm, value_t *restrict Z_sum1, + value_t *restrict Z_sum2, const value_idx n) { + const auto j = (blockIdx.x * blockDim.x) + threadIdx.x; // for every item in row - const int i = (blockIdx.y * blockDim.y) + threadIdx.y; // for every row + const auto i = (blockIdx.y * blockDim.y) + threadIdx.y; // for every row if (j >= i || i >= n || j >= n) return; // Euclidean distances // TODO: can provide any distance ie cosine // #862 - const float euclidean_d = + const value_t euclidean_d = norm[i] + norm[j] - 2.0f * (Y1[i] * Y1[j] + Y2[i] * Y2[j]); - const float Q = __fdividef(1.0f, (1.0f + euclidean_d)); - const float Q2 = Q * Q; + const value_t Q = __fdividef(1.0f, (1.0f + euclidean_d)); + const value_t Q2 = Q * Q; - const float force1 = Q2 * (Y1[j] - Y1[i]); - const float force2 = Q2 * (Y2[j] - Y2[i]); + const value_t force1 = Q2 * (Y1[j] - Y1[i]); + const value_t force2 = Q2 * (Y2[j] - Y2[i]); // Add forces raft::myAtomicAdd(&repel1[i], force1); @@ -308,17 +316,17 @@ __global__ void repulsive_kernel_2d( } /****************************************/ -template -float repulsive_forces(const float *restrict Y, float *restrict repel, - const float *restrict norm, float *restrict Z_sum, - const int n, const int dim, - const float df_power, // -(df + 1)/2) - const float recp_df, cudaStream_t stream) { - CUDA_CHECK(cudaMemsetAsync(Z_sum, 0, sizeof(float) * 2 * n, stream)); - CUDA_CHECK(cudaMemsetAsync(repel, 0, sizeof(float) * n * dim, stream)); +template +value_t repulsive_forces(const value_t *restrict Y, value_t *restrict repel, + const value_t *restrict norm, value_t *restrict Z_sum, + const value_idx n, const value_idx dim, + const value_t df_power, // -(df + 1)/2) + const value_t recp_df, cudaStream_t stream) { + CUDA_CHECK(cudaMemsetAsync(Z_sum, 0, sizeof(value_t) * 2 * n, stream)); + CUDA_CHECK(cudaMemsetAsync(repel, 0, sizeof(value_t) * n * dim, stream)); const dim3 threadsPerBlock(TPB_X, TPB_Y); - const dim3 numBlocks(raft::ceildiv(n, TPB_X), raft::ceildiv(n, TPB_Y)); + const dim3 numBlocks(raft::ceildiv(n, (value_idx) TPB_X), raft::ceildiv(n, (value_idx) TPB_Y)); // For general embedding dimensions if (dim != 2) { @@ -333,31 +341,32 @@ float repulsive_forces(const float *restrict Y, float *restrict repel, CUDA_CHECK(cudaPeekAtLastError()); // Find sum(Z_sum) - thrust::device_ptr begin = thrust::device_pointer_cast(Z_sum); - float Z = thrust::reduce(thrust::cuda::par.on(stream), begin, begin + 2 * n); + thrust::device_ptr begin = thrust::device_pointer_cast(Z_sum); + value_t Z = thrust::reduce(thrust::cuda::par.on(stream), begin, begin + 2 * n); return 1.0f / (2.0f * - (Z + (float)n)); // Notice + n since diagonal of repulsion sums to n + (Z + (value_t)n)); // Notice + n since diagonal of repulsion sums to n } /****************************************/ /* Applys or integrates all forces. Uses more gains and contrains the output for output stability */ +template __global__ void apply_kernel( - float *restrict Y, float *restrict velocity, const float *restrict attract, - const float *restrict repel, float *restrict means, float *restrict gains, - const float Z, // sum(Q) - const float learning_rate, - const float C, // constant from T-Dist Degrees of Freedom - const float momentum, - const int SIZE, // SIZE = n*dim - const int n, const float min_gain, float *restrict gradient, + value_t *restrict Y, value_t *restrict velocity, const value_t *restrict attract, + const value_t *restrict repel, value_t *restrict means, value_t *restrict gains, + const value_t Z, // sum(Q) + const value_t learning_rate, + const value_t C, // constant from T-Dist Degrees of Freedom + const value_t momentum, + const value_idx SIZE, // SIZE = n*dim + const value_idx n, const value_t min_gain, value_t *restrict gradient, const bool check_convergence) { - const int index = (blockIdx.x * blockDim.x) + threadIdx.x; + const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= SIZE) return; - const float dy = C * (attract[index] + Z * repel[index]); + const value_t dy = C * (attract[index] + Z * repel[index]); if (check_convergence) gradient[index] = dy * dy; // Find new gain @@ -378,29 +387,29 @@ __global__ void apply_kernel( } /****************************************/ -template -float apply_forces(float *restrict Y, float *restrict velocity, - const float *restrict attract, const float *restrict repel, - float *restrict means, float *restrict gains, - const float Z, // sum(Q) - const float learning_rate, - const float C, // constant from T-dist - const float momentum, const int dim, const int n, - const float min_gain, float *restrict gradient, +template +value_t apply_forces(value_t *restrict Y, value_t *restrict velocity, + const value_t *restrict attract, const value_t *restrict repel, + value_t *restrict means, value_t *restrict gains, + const value_t Z, // sum(Q) + const value_t learning_rate, + const value_t C, // constant from T-dist + const value_t momentum, const value_idx dim, const value_idx n, + const value_t min_gain, value_t *restrict gradient, const bool check_convergence, cudaStream_t stream) { //cudaMemset(means, 0, sizeof(float) * dim); if (check_convergence) - CUDA_CHECK(cudaMemsetAsync(gradient, 0, sizeof(float) * n * dim, stream)); + CUDA_CHECK(cudaMemsetAsync(gradient, 0, sizeof(value_t) * n * dim, stream)); - apply_kernel<<>>( + apply_kernel<<>>( Y, velocity, attract, repel, means, gains, Z, learning_rate, C, momentum, n * dim, n, min_gain, gradient, check_convergence); CUDA_CHECK(cudaPeekAtLastError()); // Find sum of gradient norms - float gradient_norm = INFINITY; + value_t gradient_norm = INFINITY; if (check_convergence) { - thrust::device_ptr begin = thrust::device_pointer_cast(gradient); + thrust::device_ptr begin = thrust::device_pointer_cast(gradient); gradient_norm = sqrtf( thrust::reduce(thrust::cuda::par.on(stream), begin, begin + n * dim)); } diff --git a/cpp/src/tsne/exact_tsne.cuh b/cpp/src/tsne/exact_tsne.cuh index b4b8b7f826..665cd12503 100644 --- a/cpp/src/tsne/exact_tsne.cuh +++ b/cpp/src/tsne/exact_tsne.cuh @@ -46,14 +46,15 @@ namespace TSNE { * @param[in] random_state: Set this to -1 for pure random intializations or >= 0 for reproducible outputs. * @param[in] initialize_embeddings: Whether to overwrite the current Y vector with random noise. */ -void Exact_TSNE(float *VAL, const int *COL, const int *ROW, const int NNZ, - const raft::handle_t &handle, float *Y, const int n, - const int dim, const float early_exaggeration = 12.0f, - const int exaggeration_iter = 250, const float min_gain = 0.01f, - const float pre_learning_rate = 200.0f, - const float post_learning_rate = 500.0f, - const int max_iter = 1000, const float min_grad_norm = 1e-7, - const float pre_momentum = 0.5, const float post_momentum = 0.8, +template +void Exact_TSNE(value_t *VAL, const value_idx *COL, const value_idx *ROW, const value_idx NNZ, + const raft::handle_t &handle, value_t *Y, const value_idx n, + const value_idx dim, const value_t early_exaggeration = 12.0f, + const int exaggeration_iter = 250, const value_t min_gain = 0.01f, + const value_t pre_learning_rate = 200.0f, + const value_t post_learning_rate = 500.0f, + const int max_iter = 1000, const value_t min_grad_norm = 1e-7, + const value_t pre_momentum = 0.5, const value_t post_momentum = 0.8, const long long random_state = -1, const bool initialize_embeddings = true) { auto d_alloc = handle.get_device_allocator(); @@ -65,34 +66,34 @@ void Exact_TSNE(float *VAL, const int *COL, const int *ROW, const int NNZ, // Allocate space //--------------------------------------------------- CUML_LOG_DEBUG("Now allocating memory for TSNE."); - MLCommon::device_buffer norm(d_alloc, stream, n); - MLCommon::device_buffer Z_sum(d_alloc, stream, 2 * n); - MLCommon::device_buffer means(d_alloc, stream, dim); + MLCommon::device_buffer norm(d_alloc, stream, n); + MLCommon::device_buffer Z_sum(d_alloc, stream, 2 * n); + MLCommon::device_buffer means(d_alloc, stream, dim); - MLCommon::device_buffer attract(d_alloc, stream, n * dim); - MLCommon::device_buffer repel(d_alloc, stream, n * dim); + MLCommon::device_buffer attract(d_alloc, stream, n * dim); + MLCommon::device_buffer repel(d_alloc, stream, n * dim); - MLCommon::device_buffer velocity(d_alloc, stream, n * dim); + MLCommon::device_buffer velocity(d_alloc, stream, n * dim); CUDA_CHECK(cudaMemsetAsync( velocity.data(), 0, velocity.size() * sizeof(*velocity.data()), stream)); - MLCommon::device_buffer gains(d_alloc, stream, n * dim); - thrust::device_ptr begin = thrust::device_pointer_cast(gains.data()); + MLCommon::device_buffer gains(d_alloc, stream, n * dim); + thrust::device_ptr begin = thrust::device_pointer_cast(gains.data()); thrust::fill(thrust::cuda::par.on(stream), begin, begin + n * dim, 1.0f); - MLCommon::device_buffer gradient(d_alloc, stream, n * dim); + MLCommon::device_buffer gradient(d_alloc, stream, n * dim); //--------------------------------------------------- // Calculate degrees of freedom //--------------------------------------------------- - const float degrees_of_freedom = fmaxf(dim - 1, 1); - const float df_power = -(degrees_of_freedom + 1.0f) / 2.0f; - const float recp_df = 1.0f / degrees_of_freedom; - const float C = 2.0f * (degrees_of_freedom + 1.0f) / degrees_of_freedom; + const value_t degrees_of_freedom = fmaxf(dim - 1, 1); + const value_t df_power = -(degrees_of_freedom + 1.0f) / 2.0f; + const value_t recp_df = 1.0f / degrees_of_freedom; + const value_t C = 2.0f * (degrees_of_freedom + 1.0f) / degrees_of_freedom; CUML_LOG_DEBUG("Start gradient updates!"); - float momentum = pre_momentum; - float learning_rate = pre_learning_rate; + value_t momentum = pre_momentum; + value_t learning_rate = pre_learning_rate; bool check_convergence = false; for (int iter = 0; iter < max_iter; iter++) { @@ -101,7 +102,7 @@ void Exact_TSNE(float *VAL, const int *COL, const int *ROW, const int NNZ, if (iter == exaggeration_iter) { momentum = post_momentum; // Divide perplexities - const float div = 1.0f / early_exaggeration; + const value_t div = 1.0f / early_exaggeration; raft::linalg::scalarMultiply(VAL, VAL, div, NNZ, stream); learning_rate = post_learning_rate; } @@ -119,7 +120,7 @@ void Exact_TSNE(float *VAL, const int *COL, const int *ROW, const int NNZ, df_power, recp_df, stream); // Apply / integrate forces - const float gradient_norm = TSNE::apply_forces( + const value_t gradient_norm = TSNE::apply_forces( Y, velocity.data(), attract.data(), repel.data(), means.data(), gains.data(), Z, learning_rate, C, momentum, dim, n, min_gain, gradient.data(), check_convergence, stream); diff --git a/cpp/src/tsne/tsne.cu b/cpp/src/tsne/tsne.cu index 05ce85b0dc..488ab5065b 100644 --- a/cpp/src/tsne/tsne.cu +++ b/cpp/src/tsne/tsne.cu @@ -21,7 +21,7 @@ namespace ML { template void _fit(const raft::handle_t &handle, tsne_input &input, - knn_graph &k_graph, const int dim, + knn_graph &k_graph, const value_idx dim, const value_t theta, const value_t epssq, value_t perplexity, const int perplexity_max_iter, const value_t perplexity_tol, const value_t early_exaggeration, const int exaggeration_iter, diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index e0675e48e9..baf04a26c9 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -32,7 +32,7 @@ template class TSNE_runner { public: TSNE_runner(const raft::handle_t &handle_, tsne_input &input_, - knn_graph &k_graph_, const int dim_, + knn_graph &k_graph_, const value_idx dim_, const float theta_, const float epssq_, float perplexity_, const int perplexity_max_iter_, const float perplexity_tol_, const float early_exaggeration_, const int exaggeration_iter_, @@ -100,10 +100,10 @@ class TSNE_runner { void run() { distance_and_perplexity(); - // const auto NNZ = COO_Matrix.nnz; - // auto *VAL = COO_Matrix.vals(); - // const auto *COL = COO_Matrix.cols(); - // const auto *ROW = COO_Matrix.rows(); + const auto NNZ = COO_Matrix.nnz; + auto *VAL = COO_Matrix.vals(); + const auto *COL = COO_Matrix.cols(); + const auto *ROW = COO_Matrix.rows(); //--------------------------------------------------- // if (barnes_hut) { @@ -113,11 +113,11 @@ class TSNE_runner { // min_grad_norm, pre_momentum, post_momentum, random_state, // initialize_embeddings); // } else { - // TSNE::Exact_TSNE(VAL, COL, ROW, NNZ, handle, Y, n, dim, - // early_exaggeration, exaggeration_iter, min_gain, - // pre_learning_rate, post_learning_rate, max_iter, - // min_grad_norm, pre_momentum, post_momentum, random_state, - // initialize_embeddings); + TSNE::Exact_TSNE(VAL, COL, ROW, NNZ, handle, Y, n, dim, + early_exaggeration, exaggeration_iter, min_gain, + pre_learning_rate, post_learning_rate, max_iter, + min_grad_norm, pre_momentum, post_momentum, random_state, + initialize_embeddings); // } } @@ -182,7 +182,7 @@ class TSNE_runner { const raft::handle_t &handle; tsne_input &input; knn_graph &k_graph; - const int dim; + const value_idx dim; int n_neighbors; const value_t theta; const value_t epssq; From f5fbeb4853bc1b2109489b5165d4a56c866ebd87 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 14 Dec 2020 17:14:09 -0600 Subject: [PATCH 18/22] templates in coo symmetrize --- cpp/src_prims/sparse/coo.cuh | 50 +++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/cpp/src_prims/sparse/coo.cuh b/cpp/src_prims/sparse/coo.cuh index d21efe18f5..5501cccd3c 100644 --- a/cpp/src_prims/sparse/coo.cuh +++ b/cpp/src_prims/sparse/coo.cuh @@ -30,6 +30,7 @@ #include #include #include +#include #include #define restrict __restrict__ @@ -834,18 +835,18 @@ void coo_symmetrize(COO *in, COO *out, template __global__ static void symmetric_find_size(const value_t *restrict data, const value_idx *restrict indices, - const int n, const int k, - int *restrict row_sizes, - int *restrict row_sizes2) { + const value_idx n, const int k, + value_idx *restrict row_sizes, + value_idx *restrict row_sizes2) { const auto row = blockIdx.x * blockDim.x + threadIdx.x; // for every row const auto j = blockIdx.y * blockDim.y + threadIdx.y; // for every item in row if (row >= n || j >= k) return; const auto col = indices[row * k + j]; if (j % 2) - raft::myAtomicAdd(&row_sizes[col], 1); + atomicAdd(&row_sizes[col], (value_idx) 1); else - raft::myAtomicAdd(&row_sizes2[col], 1); + atomicAdd(&row_sizes2[col], (value_idx) 1); } /** @@ -857,10 +858,11 @@ __global__ static void symmetric_find_size(const value_t *restrict data, * @param row_sizes: Input row sum 1 array(n) * @param row_sizes2: Input row sum 2 array(n) for faster reduction */ -__global__ static void reduce_find_size(const int n, const int k, - int *restrict row_sizes, - const int *restrict row_sizes2) { - const int i = (blockIdx.x * blockDim.x) + threadIdx.x; +template +__global__ static void reduce_find_size(const value_idx n, const int k, + value_idx *restrict row_sizes, + const value_idx *restrict row_sizes2) { + const auto i = (blockIdx.x * blockDim.x) + threadIdx.x; if (i >= n) return; row_sizes[i] += (row_sizes2[i] + k); } @@ -880,19 +882,19 @@ __global__ static void reduce_find_size(const int n, const int k, * @param k: Number of n_neighbors */ template -__global__ static void symmetric_sum(int *restrict edges, +__global__ static void symmetric_sum(value_idx *restrict edges, const value_t *restrict data, const value_idx *restrict indices, value_t *restrict VAL, value_idx *restrict COL, - value_idx *restrict ROW, const int n, + value_idx *restrict ROW, const value_idx n, const int k) { const auto row = blockIdx.x * blockDim.x + threadIdx.x; // for every row const auto j = blockIdx.y * blockDim.y + threadIdx.y; // for every item in row if (row >= n || j >= k) return; const auto col = indices[row * k + j]; - const auto original = atomicAdd(&edges[row], 1); - const auto transpose = atomicAdd(&edges[col], 1); + const auto original = atomicAdd(&edges[row], (value_idx) 1); + const auto transpose = atomicAdd(&edges[col], (value_idx) 1); VAL[transpose] = VAL[original] = data[row * k + j]; // Notice swapped ROW, COL since transpose @@ -923,33 +925,33 @@ __global__ static void symmetric_sum(int *restrict edges, */ template void from_knn_symmetrize_matrix(const value_idx *restrict knn_indices, - const value_t *restrict knn_dists, const int n, + const value_t *restrict knn_dists, const value_idx n, const int k, COO *out, cudaStream_t stream, std::shared_ptr d_alloc) { // (1) Find how much space needed in each row // We look through all datapoints and increment the count for each row. const dim3 threadsPerBlock(TPB_X, TPB_Y); - const dim3 numBlocks(raft::ceildiv(n, TPB_X), raft::ceildiv(k, TPB_Y)); + const dim3 numBlocks(raft::ceildiv(n, (value_idx) TPB_X), raft::ceildiv(k, TPB_Y)); // Notice n+1 since we can reuse these arrays for transpose_edges, original_edges in step (4) - device_buffer row_sizes(d_alloc, stream, n); - CUDA_CHECK(cudaMemsetAsync(row_sizes.data(), 0, sizeof(int) * n, stream)); + device_buffer row_sizes(d_alloc, stream, n); + CUDA_CHECK(cudaMemsetAsync(row_sizes.data(), 0, sizeof(value_idx) * n, stream)); - device_buffer row_sizes2(d_alloc, stream, n); - CUDA_CHECK(cudaMemsetAsync(row_sizes2.data(), 0, sizeof(int) * n, stream)); + device_buffer row_sizes2(d_alloc, stream, n); + CUDA_CHECK(cudaMemsetAsync(row_sizes2.data(), 0, sizeof(value_idx) * n, stream)); symmetric_find_size<<>>( knn_dists, knn_indices, n, k, row_sizes.data(), row_sizes2.data()); CUDA_CHECK(cudaPeekAtLastError()); - reduce_find_size<<>>( + reduce_find_size<<>>( n, k, row_sizes.data(), row_sizes2.data()); CUDA_CHECK(cudaPeekAtLastError()); // (2) Compute final space needed (n*k + sum(row_sizes)) == 2*n*k // Notice we don't do any merging and leave the result as 2*NNZ - const int NNZ = 2 * n * k; + const auto NNZ = 2 * n * k; // (3) Allocate new space out->allocate(NNZ, n, n, true, stream); @@ -958,9 +960,9 @@ void from_knn_symmetrize_matrix(const value_idx *restrict knn_indices, // This mirrors CSR matrix's row Pointer, were maximum bounds for each row // are calculated as the cumulative rolling sum of the previous rows. // Notice reusing old row_sizes2 memory - int *edges = row_sizes2.data(); - thrust::device_ptr __edges = thrust::device_pointer_cast(edges); - thrust::device_ptr __row_sizes = + value_idx *edges = row_sizes2.data(); + thrust::device_ptr __edges = thrust::device_pointer_cast(edges); + thrust::device_ptr __row_sizes = thrust::device_pointer_cast(row_sizes.data()); // Rolling cumulative sum From c7de8d239e36a9144da53a13d361ad3e5cd41bc6 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 14 Dec 2020 18:15:17 -0600 Subject: [PATCH 19/22] templates on barnes hut --- cpp/src/tsne/barnes_hut.cuh | 115 ++++++------ cpp/src/tsne/bh_kernels.cuh | 312 ++++++++++++++++++--------------- cpp/src/tsne/distances.cuh | 33 ++-- cpp/src/tsne/exact_kernels.cuh | 95 +++++----- cpp/src/tsne/exact_tsne.cuh | 20 +-- cpp/src/tsne/tsne_runner.cuh | 16 +- cpp/src_prims/sparse/coo.cuh | 36 ++-- cpp/test/sg/tsne_test.cu | 3 +- 8 files changed, 339 insertions(+), 291 deletions(-) diff --git a/cpp/src/tsne/barnes_hut.cuh b/cpp/src/tsne/barnes_hut.cuh index b0f6dfb2c9..edd44b35cb 100644 --- a/cpp/src/tsne/barnes_hut.cuh +++ b/cpp/src/tsne/barnes_hut.cuh @@ -48,17 +48,17 @@ namespace TSNE { * @param[in] random_state: Set this to -1 for pure random intializations or >= 0 for reproducible outputs. * @param[in] initialize_embeddings: Whether to overwrite the current Y vector with random noise. */ -void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, - const raft::handle_t &handle, float *Y, const int n, - const float theta = 0.5f, const float epssq = 0.0025, - const float early_exaggeration = 12.0f, - const int exaggeration_iter = 250, const float min_gain = 0.01f, - const float pre_learning_rate = 200.0f, - const float post_learning_rate = 500.0f, - const int max_iter = 1000, const float min_grad_norm = 1e-7, - const float pre_momentum = 0.5, const float post_momentum = 0.8, - const long long random_state = -1, - const bool initialize_embeddings = true) { +template +void Barnes_Hut( + value_t *VAL, const value_idx *COL, const value_idx *ROW, const value_idx NNZ, + const raft::handle_t &handle, value_t *Y, const value_idx n, + const value_t theta = 0.5f, const value_t epssq = 0.0025, + const value_t early_exaggeration = 12.0f, const int exaggeration_iter = 250, + const value_t min_gain = 0.01f, const value_t pre_learning_rate = 200.0f, + const value_t post_learning_rate = 500.0f, const int max_iter = 1000, + const value_t min_grad_norm = 1e-7, const value_t pre_momentum = 0.5, + const value_t post_momentum = 0.8, const long long random_state = -1, + const bool initialize_embeddings = true) { auto d_alloc = handle.get_device_allocator(); cudaStream_t stream = handle.get_stream(); @@ -66,7 +66,7 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, //--------------------------------------------------- const int blocks = raft::getMultiProcessorCount(); - int nnodes = n * 2; + auto nnodes = n * 2; if (nnodes < 1024 * blocks) nnodes = 1024 * blocks; while ((nnodes & (32 - 1)) != 0) nnodes++; nnodes--; @@ -75,9 +75,9 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, // Allocate more space // MLCommon::device_buffer errl(d_alloc, stream, 1); MLCommon::device_buffer limiter(d_alloc, stream, 1); - MLCommon::device_buffer maxdepthd(d_alloc, stream, 1); - MLCommon::device_buffer bottomd(d_alloc, stream, 1); - MLCommon::device_buffer radiusd(d_alloc, stream, 1); + MLCommon::device_buffer maxdepthd(d_alloc, stream, 1); + MLCommon::device_buffer bottomd(d_alloc, stream, 1); + MLCommon::device_buffer radiusd(d_alloc, stream, 1); TSNE::InitializationKernel<<<1, 1, 0, stream>>>(/*errl.data(),*/ limiter.data(), @@ -85,54 +85,55 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, radiusd.data()); CUDA_CHECK(cudaPeekAtLastError()); - const int FOUR_NNODES = 4 * nnodes; - const int FOUR_N = 4 * n; - const float theta_squared = theta * theta; - const int NNODES = nnodes; + const value_idx FOUR_NNODES = 4 * nnodes; + const value_idx FOUR_N = 4 * n; + const value_t theta_squared = theta * theta; + const value_idx NNODES = nnodes; // Actual allocations - MLCommon::device_buffer startl(d_alloc, stream, nnodes + 1); - MLCommon::device_buffer childl(d_alloc, stream, (nnodes + 1) * 4); - MLCommon::device_buffer massl(d_alloc, stream, nnodes + 1); + MLCommon::device_buffer startl(d_alloc, stream, nnodes + 1); + MLCommon::device_buffer childl(d_alloc, stream, (nnodes + 1) * 4); + MLCommon::device_buffer massl(d_alloc, stream, nnodes + 1); - thrust::device_ptr begin_massl = + thrust::device_ptr begin_massl = thrust::device_pointer_cast(massl.data()); thrust::fill(thrust::cuda::par.on(stream), begin_massl, begin_massl + (nnodes + 1), 1.0f); - MLCommon::device_buffer maxxl(d_alloc, stream, blocks * FACTOR1); - MLCommon::device_buffer maxyl(d_alloc, stream, blocks * FACTOR1); - MLCommon::device_buffer minxl(d_alloc, stream, blocks * FACTOR1); - MLCommon::device_buffer minyl(d_alloc, stream, blocks * FACTOR1); + MLCommon::device_buffer maxxl(d_alloc, stream, blocks * FACTOR1); + MLCommon::device_buffer maxyl(d_alloc, stream, blocks * FACTOR1); + MLCommon::device_buffer minxl(d_alloc, stream, blocks * FACTOR1); + MLCommon::device_buffer minyl(d_alloc, stream, blocks * FACTOR1); // SummarizationKernel - MLCommon::device_buffer countl(d_alloc, stream, nnodes + 1); + MLCommon::device_buffer countl(d_alloc, stream, nnodes + 1); // SortKernel - MLCommon::device_buffer sortl(d_alloc, stream, nnodes + 1); + MLCommon::device_buffer sortl(d_alloc, stream, nnodes + 1); // RepulsionKernel - MLCommon::device_buffer rep_forces(d_alloc, stream, (nnodes + 1) * 2); - MLCommon::device_buffer attr_forces( + MLCommon::device_buffer rep_forces(d_alloc, stream, + (nnodes + 1) * 2); + MLCommon::device_buffer attr_forces( d_alloc, stream, n * 2); // n*2 double for reduction sum - MLCommon::device_buffer Z_norm(d_alloc, stream, 1); + MLCommon::device_buffer Z_norm(d_alloc, stream, 1); - MLCommon::device_buffer radiusd_squared(d_alloc, stream, 1); + MLCommon::device_buffer radiusd_squared(d_alloc, stream, 1); // Apply - MLCommon::device_buffer gains_bh(d_alloc, stream, n * 2); + MLCommon::device_buffer gains_bh(d_alloc, stream, n * 2); - thrust::device_ptr begin_gains_bh = + thrust::device_ptr begin_gains_bh = thrust::device_pointer_cast(gains_bh.data()); thrust::fill(thrust::cuda::par.on(stream), begin_gains_bh, begin_gains_bh + (n * 2), 1.0f); - MLCommon::device_buffer old_forces(d_alloc, stream, n * 2); + MLCommon::device_buffer old_forces(d_alloc, stream, n * 2); CUDA_CHECK( - cudaMemsetAsync(old_forces.data(), 0, sizeof(float) * n * 2, stream)); + cudaMemsetAsync(old_forces.data(), 0, sizeof(value_t) * n * 2, stream)); - MLCommon::device_buffer YY(d_alloc, stream, (nnodes + 1) * 2); + MLCommon::device_buffer YY(d_alloc, stream, (nnodes + 1) * 2); if (initialize_embeddings) { random_vector(YY.data(), -0.0001f, 0.0001f, (nnodes + 1) * 2, stream, random_state); @@ -143,27 +144,30 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, // Set cache levels for faster algorithm execution //--------------------------------------------------- - CUDA_CHECK( - cudaFuncSetCacheConfig(TSNE::BoundingBoxKernel, cudaFuncCachePreferShared)); - CUDA_CHECK( - cudaFuncSetCacheConfig(TSNE::TreeBuildingKernel, cudaFuncCachePreferL1)); - CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::ClearKernel1, cudaFuncCachePreferL1)); - CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::ClearKernel2, cudaFuncCachePreferL1)); - CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::SummarizationKernel, + CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::BoundingBoxKernel, cudaFuncCachePreferShared)); - CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::SortKernel, cudaFuncCachePreferL1)); - CUDA_CHECK( - cudaFuncSetCacheConfig(TSNE::RepulsionKernel, cudaFuncCachePreferL1)); - CUDA_CHECK( - cudaFuncSetCacheConfig(TSNE::attractive_kernel_bh, cudaFuncCachePreferL1)); + CUDA_CHECK(cudaFuncSetCacheConfig( + TSNE::TreeBuildingKernel, cudaFuncCachePreferL1)); + CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::ClearKernel1, + cudaFuncCachePreferL1)); + CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::ClearKernel2, + cudaFuncCachePreferL1)); + CUDA_CHECK(cudaFuncSetCacheConfig( + TSNE::SummarizationKernel, cudaFuncCachePreferShared)); CUDA_CHECK( - cudaFuncSetCacheConfig(TSNE::IntegrationKernel, cudaFuncCachePreferL1)); + cudaFuncSetCacheConfig(TSNE::SortKernel, cudaFuncCachePreferL1)); + CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::RepulsionKernel, + cudaFuncCachePreferL1)); + CUDA_CHECK(cudaFuncSetCacheConfig( + TSNE::attractive_kernel_bh, cudaFuncCachePreferL1)); + CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::IntegrationKernel, + cudaFuncCachePreferL1)); // Do gradient updates //--------------------------------------------------- CUML_LOG_DEBUG("Start gradient updates!"); - float momentum = pre_momentum; - float learning_rate = pre_learning_rate; + value_t momentum = pre_momentum; + value_t learning_rate = pre_learning_rate; for (int iter = 0; iter < max_iter; iter++) { CUDA_CHECK(cudaMemsetAsync(static_cast(rep_forces.data()), 0, @@ -181,7 +185,7 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, if (iter == exaggeration_iter) { momentum = post_momentum; // Divide perplexities - const float div = 1.0f / early_exaggeration; + const value_t div = 1.0f / early_exaggeration; raft::linalg::scalarMultiply(VAL, VAL, div, NNZ, stream); learning_rate = post_learning_rate; } @@ -252,7 +256,8 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, START_TIMER; // TODO: Calculate Kullback-Leibler divergence // For general embedding dimensions - TSNE::attractive_kernel_bh<<>>( + TSNE::attractive_kernel_bh<<>>( VAL, COL, ROW, YY.data(), YY.data() + nnodes + 1, attr_forces.data(), attr_forces.data() + n, NNZ); CUDA_CHECK(cudaPeekAtLastError()); diff --git a/cpp/src/tsne/bh_kernels.cuh b/cpp/src/tsne/bh_kernels.cuh index 50271edc92..6cf47550bb 100644 --- a/cpp/src/tsne/bh_kernels.cuh +++ b/cpp/src/tsne/bh_kernels.cuh @@ -35,6 +35,7 @@ #include #include +#include namespace ML { namespace TSNE { @@ -42,10 +43,11 @@ namespace TSNE { /** * Intializes the states of objects. This speeds the overall kernel up. */ +template __global__ void InitializationKernel(/*int *restrict errd, */ unsigned *restrict limiter, - int *restrict maxdepthd, - float *restrict radiusd) { + value_idx *restrict maxdepthd, + value_t *restrict radiusd) { // errd[0] = 0; maxdepthd[0] = 1; limiter[0] = 0; @@ -55,10 +57,12 @@ __global__ void InitializationKernel(/*int *restrict errd, */ /** * Reset normalization back to 0. */ -__global__ void Reset_Normalization(float *restrict Z_norm, - float *restrict radiusd_squared, - int *restrict bottomd, const int NNODES, - const float *restrict radiusd) { +template +__global__ void Reset_Normalization(value_t *restrict Z_norm, + value_t *restrict radiusd_squared, + value_idx *restrict bottomd, + const value_idx NNODES, + const value_t *restrict radiusd) { Z_norm[0] = 0.0f; radiusd_squared[0] = radiusd[0] * radiusd[0]; // create root node @@ -68,21 +72,24 @@ __global__ void Reset_Normalization(float *restrict Z_norm, /** * Find 1/Z */ -__global__ void Find_Normalization(float *restrict Z_norm, const float N) { +template +__global__ void Find_Normalization(value_t *restrict Z_norm, + const value_idx N) { Z_norm[0] = 1.0f / (Z_norm[0] - N); } /** * Figures the bounding boxes for every point in the embedding. */ +template __global__ __launch_bounds__(THREADS1, 2) void BoundingBoxKernel( - int *restrict startd, int *restrict childd, float *restrict massd, - float *restrict posxd, float *restrict posyd, float *restrict maxxd, - float *restrict maxyd, float *restrict minxd, float *restrict minyd, - const int FOUR_NNODES, const int NNODES, const int N, - unsigned *restrict limiter, float *restrict radiusd) { - float val, minx, maxx, miny, maxy; - __shared__ float sminx[THREADS1], smaxx[THREADS1], sminy[THREADS1], + value_idx *restrict startd, value_idx *restrict childd, + value_t *restrict massd, value_t *restrict posxd, value_t *restrict posyd, + value_t *restrict maxxd, value_t *restrict maxyd, value_t *restrict minxd, + value_t *restrict minyd, const value_idx FOUR_NNODES, const value_idx NNODES, + const value_idx N, unsigned *restrict limiter, value_t *restrict radiusd) { + value_t val, minx, maxx, miny, maxy; + __shared__ value_t sminx[THREADS1], smaxx[THREADS1], sminy[THREADS1], smaxy[THREADS1]; // initialize with valid data (in case #bodies < #threads) @@ -90,9 +97,9 @@ __global__ __launch_bounds__(THREADS1, 2) void BoundingBoxKernel( miny = maxy = posyd[0]; // scan all bodies - const int i = threadIdx.x; - const int inc = THREADS1 * gridDim.x; - for (int j = i + blockIdx.x * THREADS1; j < N; j += inc) { + const auto i = threadIdx.x; + const auto inc = THREADS1 * gridDim.x; + for (auto j = i + blockIdx.x * THREADS1; j < N; j += inc) { val = posxd[j]; if (val < minx) minx = val; @@ -112,9 +119,9 @@ __global__ __launch_bounds__(THREADS1, 2) void BoundingBoxKernel( sminy[i] = miny; smaxy[i] = maxy; - for (int j = THREADS1 / 2; j > i; j /= 2) { + for (auto j = THREADS1 / 2; j > i; j /= 2) { __syncthreads(); - const int k = i + j; + const auto k = i + j; sminx[i] = minx = fminf(minx, sminx[k]); smaxx[i] = maxx = fmaxf(maxx, smaxx[k]); sminy[i] = miny = fminf(miny, sminy[k]); @@ -123,18 +130,18 @@ __global__ __launch_bounds__(THREADS1, 2) void BoundingBoxKernel( if (i == 0) { // write block result to global memory - const int k = blockIdx.x; + const auto k = blockIdx.x; minxd[k] = minx; maxxd[k] = maxx; minyd[k] = miny; maxyd[k] = maxy; __threadfence(); - const int inc = gridDim.x - 1; + const auto inc = gridDim.x - 1; if (inc != atomicInc(limiter, inc)) return; // I'm the last block, so combine all block results - for (int j = 0; j <= inc; j++) { + for (auto j = 0; j <= inc; j++) { minx = fminf(minx, minxd[j]); maxx = fmaxf(maxx, maxxd[j]); miny = fminf(miny, minyd[j]); @@ -150,18 +157,20 @@ __global__ __launch_bounds__(THREADS1, 2) void BoundingBoxKernel( posyd[NNODES] = (miny + maxy) * 0.5f; #pragma unroll - for (int a = 0; a < 4; a++) childd[FOUR_NNODES + a] = -1; + for (auto a = 0; a < 4; a++) childd[FOUR_NNODES + a] = -1; } } /** * Clear some of the state vectors up. */ -__global__ __launch_bounds__(1024, 1) void ClearKernel1(int *restrict childd, - const int FOUR_NNODES, - const int FOUR_N) { - const int inc = blockDim.x * gridDim.x; - int k = (FOUR_N & -32) + threadIdx.x + blockIdx.x * blockDim.x; +template +__global__ __launch_bounds__(1024, + 1) void ClearKernel1(value_idx *restrict childd, + const value_idx FOUR_NNODES, + const value_idx FOUR_N) { + const auto inc = blockDim.x * gridDim.x; + value_idx k = (FOUR_N & -32) + threadIdx.x + blockIdx.x * blockDim.x; if (k < FOUR_N) k += inc; // iterate over all cells assigned to thread @@ -173,29 +182,32 @@ __global__ __launch_bounds__(1024, 1) void ClearKernel1(int *restrict childd, * Build the actual QuadTree. * See: https://iss.oden.utexas.edu/Publications/Papers/burtscher11.pdf */ +template __global__ __launch_bounds__( - THREADS2, - 2) void TreeBuildingKernel(/* int *restrict errd, */ - int *restrict childd, const float *restrict posxd, - const float *restrict posyd, const int NNODES, - const int N, int *restrict maxdepthd, - int *restrict bottomd, - const float *restrict radiusd) { - int j, depth; - float x, y, r; - float px, py; - int ch, n, locked, patch; + THREADS2, 2) void TreeBuildingKernel(/* int *restrict errd, */ + value_idx *restrict childd, + const value_t *restrict posxd, + const value_t *restrict posyd, + const value_idx NNODES, + const value_idx N, + value_idx *restrict maxdepthd, + value_idx *restrict bottomd, + const value_t *restrict radiusd) { + value_idx j, depth; + value_t x, y, r; + value_t px, py; + value_idx ch, n, locked, patch; // cache root data - const float radius = radiusd[0]; - const float rootx = posxd[NNODES]; - const float rooty = posyd[NNODES]; + const value_t radius = radiusd[0]; + const value_t rootx = posxd[NNODES]; + const value_t rooty = posyd[NNODES]; - int localmaxdepth = 1; - int skip = 1; + value_idx localmaxdepth = 1; + value_idx skip = 1; - const int inc = blockDim.x * gridDim.x; - int i = threadIdx.x + blockIdx.x * blockDim.x; + const auto inc = blockDim.x * gridDim.x; + value_idx i = threadIdx.x + blockIdx.x * blockDim.x; // iterate over all bodies assigned to thread while (i < N) { @@ -236,7 +248,7 @@ __global__ __launch_bounds__( if (ch == -1) { // Child is a nullptr ('-1'), so we write our body index to the leaf, and move on to the next body. - if (atomicCAS(&childd[locked], -1, i) == -1) { + if (atomicCAS(&childd[locked], (value_idx)-1, i) == -1) { if (depth > localmaxdepth) localmaxdepth = depth; i += inc; // move on to next body @@ -244,15 +256,16 @@ __global__ __launch_bounds__( } } else { // Child node isn't empty, so we store the current value of the child, lock the leaf, and patch in a new cell - if (ch == atomicCAS(&childd[locked], ch, -2)) { + if (ch == atomicCAS(&childd[locked], ch, (value_idx)-2)) { patch = -1; while (ch >= 0) { depth++; - const int cell = atomicSub(bottomd, 1) - 1; + const value_idx cell = atomicAdd(bottomd, (value_idx)-1) - 1; if (cell == N) { - atomicExch(bottomd, NNODES); + atomicExch(reinterpret_cast(bottomd), + (unsigned long long int)NNODES); } else if (cell < N) { depth--; continue; @@ -306,14 +319,13 @@ __global__ __launch_bounds__( /** * Clean more state vectors. */ -__global__ __launch_bounds__(1024, - 1) void ClearKernel2(int *restrict startd, - float *restrict massd, - const int NNODES, - const int *restrict bottomd) { - const int bottom = bottomd[0]; - const int inc = blockDim.x * gridDim.x; - int k = (bottom & -32) + threadIdx.x + blockIdx.x * blockDim.x; +template +__global__ __launch_bounds__(1024, 1) void ClearKernel2( + value_idx *restrict startd, value_t *restrict massd, const value_idx NNODES, + const value_idx *restrict bottomd) { + const auto bottom = bottomd[0]; + const auto inc = blockDim.x * gridDim.x; + auto k = (bottom & -32) + threadIdx.x + blockIdx.x * blockDim.x; if (k < bottom) k += inc; // iterate over all cells assigned to thread @@ -327,21 +339,23 @@ __global__ __launch_bounds__(1024, /** * Summarize the KD Tree via cell gathering */ +template __global__ __launch_bounds__(THREADS3, FACTOR3) void SummarizationKernel( - int *restrict countd, const int *restrict childd, - volatile float *restrict massd, float *restrict posxd, float *restrict posyd, - const int NNODES, const int N, const int *restrict bottomd) { + value_idx *restrict countd, const value_idx *restrict childd, + volatile value_t *restrict massd, value_t *restrict posxd, + value_t *restrict posyd, const value_idx NNODES, const value_idx N, + const value_idx *restrict bottomd) { bool flag = 0; - float cm, px, py; - __shared__ int child[THREADS3 * 4]; - __shared__ float mass[THREADS3 * 4]; + value_t cm, px, py; + __shared__ value_idx child[THREADS3 * 4]; + __shared__ value_t mass[THREADS3 * 4]; - const int bottom = bottomd[0]; - const int inc = blockDim.x * gridDim.x; - int k = (bottom & -32) + threadIdx.x + blockIdx.x * blockDim.x; + const auto bottom = bottomd[0]; + const auto inc = blockDim.x * gridDim.x; + auto k = (bottom & -32) + threadIdx.x + blockIdx.x * blockDim.x; if (k < bottom) k += inc; - const int restart = k; + const auto restart = k; for (int j = 0; j < 5; j++) // wait-free pre-passes { @@ -349,7 +363,7 @@ __global__ __launch_bounds__(THREADS3, FACTOR3) void SummarizationKernel( while (k <= NNODES) { if (massd[k] < 0.0f) { for (int i = 0; i < 4; i++) { - const int ch = childd[k * 4 + i]; + const auto ch = childd[k * 4 + i]; child[i * THREADS3 + threadIdx.x] = ch; if ((ch >= N) and @@ -361,13 +375,13 @@ __global__ __launch_bounds__(THREADS3, FACTOR3) void SummarizationKernel( cm = 0.0f; px = 0.0f; py = 0.0f; - int cnt = 0; + auto cnt = 0; #pragma unroll for (int i = 0; i < 4; i++) { const int ch = child[i * THREADS3 + threadIdx.x]; if (ch >= 0) { - const float m = + const value_t m = (ch >= N) ? (cnt += countd[ch], mass[i * THREADS3 + threadIdx.x]) : (cnt++, massd[ch]); // add child's contribution @@ -378,7 +392,7 @@ __global__ __launch_bounds__(THREADS3, FACTOR3) void SummarizationKernel( } countd[k] = cnt; - const float m = 1.0f / cm; + const value_t m = 1.0f / cm; posxd[k] = px * m; posyd[k] = py * m; __threadfence(); // make sure data are visible before setting mass @@ -402,7 +416,7 @@ __global__ __launch_bounds__(THREADS3, FACTOR3) void SummarizationKernel( if (j == 0) { j = 4; for (int i = 0; i < 4; i++) { - const int ch = childd[k * 4 + i]; + const auto ch = childd[k * 4 + i]; child[i * THREADS3 + threadIdx.x] = ch; if ((ch < N) or ((mass[i * THREADS3 + threadIdx.x] = massd[ch]) >= 0)) @@ -411,7 +425,7 @@ __global__ __launch_bounds__(THREADS3, FACTOR3) void SummarizationKernel( } else { j = 4; for (int i = 0; i < 4; i++) { - const int ch = child[i * THREADS3 + threadIdx.x]; + const auto ch = child[i * THREADS3 + threadIdx.x]; if ((ch < N) or (mass[i * THREADS3 + threadIdx.x] >= 0) or ((mass[i * THREADS3 + threadIdx.x] = massd[ch]) >= 0)) @@ -424,13 +438,13 @@ __global__ __launch_bounds__(THREADS3, FACTOR3) void SummarizationKernel( cm = 0.0f; px = 0.0f; py = 0.0f; - int cnt = 0; + auto cnt = 0; #pragma unroll for (int i = 0; i < 4; i++) { - const int ch = child[i * THREADS3 + threadIdx.x]; + const auto ch = child[i * THREADS3 + threadIdx.x]; if (ch >= 0) { - const float m = + const auto m = (ch >= N) ? (cnt += countd[ch], mass[i * THREADS3 + threadIdx.x]) : (cnt++, massd[ch]); // add child's contribution @@ -441,7 +455,7 @@ __global__ __launch_bounds__(THREADS3, FACTOR3) void SummarizationKernel( } countd[k] = cnt; - const float m = 1.0f / cm; + const value_t m = 1.0f / cm; posxd[k] = px * m; posyd[k] = py * m; flag = 1; @@ -460,15 +474,17 @@ __global__ __launch_bounds__(THREADS3, FACTOR3) void SummarizationKernel( /** * Sort the cells */ +template __global__ __launch_bounds__(THREADS4, FACTOR4) void SortKernel( - int *restrict sortd, const int *restrict countd, - volatile int *restrict startd, int *restrict childd, const int NNODES, - const int N, const int *restrict bottomd) { - const int bottom = bottomd[0]; - const int dec = blockDim.x * gridDim.x; - int k = NNODES + 1 - dec + threadIdx.x + blockIdx.x * blockDim.x; - int start; - int limiter = 0; + value_idx *restrict sortd, const value_idx *restrict countd, + volatile value_idx *restrict startd, value_idx *restrict childd, + const value_idx NNODES, const value_idx N, + const value_idx *restrict bottomd) { + const value_idx bottom = bottomd[0]; + const value_idx dec = blockDim.x * gridDim.x; + value_idx k = NNODES + 1 - dec + threadIdx.x + blockIdx.x * blockDim.x; + value_idx start; + value_idx limiter = 0; // iterate over all cells assigned to thread while (k >= bottom) { @@ -480,7 +496,7 @@ __global__ __launch_bounds__(THREADS4, FACTOR4) void SortKernel( int j = 0; for (int i = 0; i < 4; i++) { - const int ch = childd[k * 4 + i]; + const auto ch = childd[k * 4 + i]; if (ch >= 0) { if (i != j) { // move children to front (needed later for speed) @@ -505,20 +521,23 @@ __global__ __launch_bounds__(THREADS4, FACTOR4) void SortKernel( /** * Calculate the repulsive forces using the KD Tree */ +template __global__ __launch_bounds__( THREADS5, 1) void RepulsionKernel(/* int *restrict errd, */ - const float theta, - const float epssqd, // correction for zero distance - const int *restrict sortd, const int *restrict childd, - const float *restrict massd, - const float *restrict posxd, - const float *restrict posyd, float *restrict velxd, - float *restrict velyd, float *restrict Z_norm, - const float theta_squared, const int NNODES, - const int FOUR_NNODES, const int N, - const float *restrict radiusd_squared, - const int *restrict maxdepthd) { + const value_t theta, + const value_t epssqd, // correction for zero distance + const value_idx *restrict sortd, + const value_idx *restrict childd, + const value_t *restrict massd, + const value_t *restrict posxd, + const value_t *restrict posyd, + value_t *restrict velxd, value_t *restrict velyd, + value_t *restrict Z_norm, const value_t theta_squared, + const value_idx NNODES, const value_idx FOUR_NNODES, + const value_idx N, + const value_t *restrict radiusd_squared, + const value_idx *restrict maxdepthd) { // Return if max depth is too deep // Not possible since I limited it to 32 // if (maxdepthd[0] > 32) @@ -526,23 +545,23 @@ __global__ __launch_bounds__( // atomicExch(errd, max_depth); // return; // } - const float EPS_PLUS_1 = epssqd + 1.0f; + const value_t EPS_PLUS_1 = epssqd + 1.0f; - __shared__ int pos[THREADS5], node[THREADS5]; - __shared__ float dq[THREADS5]; + __shared__ value_idx pos[THREADS5], node[THREADS5]; + __shared__ value_t dq[THREADS5]; if (threadIdx.x == 0) { - const int max_depth = maxdepthd[0]; + const auto max_depth = maxdepthd[0]; dq[0] = __fdividef(radiusd_squared[0], theta_squared); - for (int i = 1; i < max_depth; i++) { + for (auto i = 1; i < max_depth; i++) { dq[i] = dq[i - 1] * 0.25f; dq[i - 1] += epssqd; } dq[max_depth - 1] += epssqd; // Add one so EPS_PLUS_1 can be compared - for (int i = 0; i < max_depth; i++) dq[i] += 1.0f; + for (auto i = 0; i < max_depth; i++) dq[i] += 1.0f; } __syncthreads(); @@ -562,20 +581,20 @@ __global__ __launch_bounds__( __threadfence_block(); // iterate over all bodies assigned to thread - const int MAX_SIZE = FOUR_NNODES + 4; + const auto MAX_SIZE = FOUR_NNODES + 4; - for (int k = threadIdx.x + blockIdx.x * blockDim.x; k < N; + for (auto k = threadIdx.x + blockIdx.x * blockDim.x; k < N; k += blockDim.x * gridDim.x) { - const int i = sortd[k]; // get permuted/sorted index + const auto i = sortd[k]; // get permuted/sorted index // cache position info if (i < 0 or i >= MAX_SIZE) continue; - const float px = posxd[i]; - const float py = posyd[i]; + const value_t px = posxd[i]; + const value_t py = posyd[i]; - float vx = 0.0f; - float vy = 0.0f; - float normsum = 0.0f; + value_t vx = 0.0f; + value_t vy = 0.0f; + value_t normsum = 0.0f; // initialize iteration stack, i.e., push root node onto stack int depth = sbase; @@ -587,24 +606,24 @@ __global__ __launch_bounds__( do { // stack is not empty - int pd = pos[depth]; - int nd = node[depth]; + auto pd = pos[depth]; + auto nd = node[depth]; while (pd < 4) { - const int index = nd + pd++; + const auto index = nd + pd++; if (index < 0 or index >= MAX_SIZE) break; - const int n = childd[index]; // load child pointer + const auto n = childd[index]; // load child pointer // Non child if (n < 0 or n > NNODES) break; - const float dx = px - posxd[n]; - const float dy = py - posyd[n]; - const float dxy1 = dx * dx + dy * dy + EPS_PLUS_1; + const value_t dx = px - posxd[n]; + const value_t dy = py - posyd[n]; + const value_t dxy1 = dx * dx + dy * dy + EPS_PLUS_1; if ((n < N) or __all_sync(__activemask(), dxy1 >= dq[depth])) { - const float tdist_2 = __fdividef(massd[n], dxy1 * dxy1); + const value_t tdist_2 = __fdividef(massd[n], dxy1 * dxy1); normsum += tdist_2 * dxy1; vx += dx * tdist_2; vy += dy * tdist_2; @@ -625,57 +644,60 @@ __global__ __launch_bounds__( // update velocity velxd[i] += vx; velyd[i] += vy; - raft::myAtomicAdd(Z_norm, normsum); + atomicAdd(Z_norm, normsum); } } /** * Fast attractive kernel. Uses COO matrix. */ +template __global__ void attractive_kernel_bh( - const float *restrict VAL, const int *restrict COL, const int *restrict ROW, - const float *restrict Y1, const float *restrict Y2, float *restrict attract1, - float *restrict attract2, const int NNZ) { - const int index = (blockIdx.x * blockDim.x) + threadIdx.x; + const value_t *restrict VAL, const value_idx *restrict COL, + const value_idx *restrict ROW, const value_t *restrict Y1, + const value_t *restrict Y2, value_t *restrict attract1, + value_t *restrict attract2, const value_idx NNZ) { + const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= NNZ) return; - const int i = ROW[index]; - const int j = COL[index]; + const auto i = ROW[index]; + const auto j = COL[index]; - const float y1d = Y1[i] - Y1[j]; - const float y2d = Y2[i] - Y2[j]; - float squared_euclidean_dist = y1d * y1d + y2d * y2d; + const value_t y1d = Y1[i] - Y1[j]; + const value_t y2d = Y2[i] - Y2[j]; + value_t squared_euclidean_dist = y1d * y1d + y2d * y2d; // As a sum of squares, SED is mathematically >= 0. There might be a source of // NaNs upstream though, so until we find and fix them, enforce that trait. if (!(squared_euclidean_dist >= 0)) squared_euclidean_dist = 0.0f; - const float PQ = __fdividef(VAL[index], squared_euclidean_dist + 1.0f); + const value_t PQ = __fdividef(VAL[index], squared_euclidean_dist + 1.0f); // TODO: Calculate Kullback-Leibler divergence // TODO: Convert attractive forces to CSR format // Apply forces - raft::myAtomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j])); - raft::myAtomicAdd(&attract2[i], PQ * (Y2[i] - Y2[j])); + atomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j])); + atomicAdd(&attract2[i], PQ * (Y2[i] - Y2[j])); } /** * Apply gradient updates. */ +template __global__ __launch_bounds__(THREADS6, 1) void IntegrationKernel( - const float eta, const float momentum, const float exaggeration, - float *restrict Y1, float *restrict Y2, const float *restrict attract1, - const float *restrict attract2, const float *restrict repel1, - const float *restrict repel2, float *restrict gains1, float *restrict gains2, - float *restrict old_forces1, float *restrict old_forces2, - const float *restrict Z, const int N) { - float ux, uy, gx, gy; + const value_t eta, const value_t momentum, const value_t exaggeration, + value_t *restrict Y1, value_t *restrict Y2, const value_t *restrict attract1, + const value_t *restrict attract2, const value_t *restrict repel1, + const value_t *restrict repel2, value_t *restrict gains1, + value_t *restrict gains2, value_t *restrict old_forces1, + value_t *restrict old_forces2, const value_t *restrict Z, const value_idx N) { + value_t ux, uy, gx, gy; // iterate over all bodies assigned to thread - const int inc = blockDim.x * gridDim.x; - const float Z_norm = Z[0]; + const auto inc = blockDim.x * gridDim.x; + const value_t Z_norm = Z[0]; - for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < N; i += inc) { - const float dx = attract1[i] - Z_norm * repel1[i]; - const float dy = attract2[i] - Z_norm * repel2[i]; + for (auto i = threadIdx.x + blockIdx.x * blockDim.x; i < N; i += inc) { + const value_t dx = attract1[i] - Z_norm * repel1[i]; + const value_t dy = attract2[i] - Z_norm * repel2[i]; if (signbit(dx) != signbit(ux = old_forces1[i])) gx = gains1[i] + 0.2f; diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 5a33b6896f..05dcfe2b52 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -47,8 +47,7 @@ void get_distances(const raft::handle_t &handle, tsne_input &input, template <> void get_distances(const raft::handle_t &handle, manifold_dense_inputs_t &input, - knn_graph &k_graph, - cudaStream_t stream) { + knn_graph &k_graph, cudaStream_t stream) { // TODO: for TSNE transform first fit some points then transform with 1/(1+d^2) // #861 @@ -73,17 +72,16 @@ void get_distances(const raft::handle_t &handle, template <> void get_distances(const raft::handle_t &handle, manifold_dense_inputs_t &input, - knn_graph &k_graph, - cudaStream_t stream) { - throw raft::exception("Dense TSNE does not support 32-bit integer indices yet."); + knn_graph &k_graph, cudaStream_t stream) { + throw raft::exception( + "Dense TSNE does not support 32-bit integer indices yet."); } // sparse, int32 template <> void get_distances(const raft::handle_t &handle, manifold_sparse_inputs_t &input, - knn_graph &k_graph, - cudaStream_t stream) { + knn_graph &k_graph, cudaStream_t stream) { MLCommon::Sparse::Selection::brute_force_knn( input.indptr, input.indices, input.data, input.nnz, input.n, input.d, input.indptr, input.indices, input.data, input.nnz, input.n, input.d, @@ -97,9 +95,9 @@ void get_distances(const raft::handle_t &handle, template <> void get_distances(const raft::handle_t &handle, manifold_sparse_inputs_t &input, - knn_graph &k_graph, - cudaStream_t stream) { - throw raft::exception("Sparse TSNE does not support 32-bit integer indices yet."); + knn_graph &k_graph, cudaStream_t stream) { + throw raft::exception( + "Sparse TSNE does not support 32-bit integer indices yet."); } /** @@ -111,12 +109,12 @@ void get_distances(const raft::handle_t &handle, * @param[in] stream: The GPU stream. */ template -void normalize_distances(const value_idx n, value_t *distances, const int n_neighbors, - cudaStream_t stream) { +void normalize_distances(const value_idx n, value_t *distances, + const int n_neighbors, cudaStream_t stream) { // Now D / max(abs(D)) to allow exp(D) to not explode thrust::device_ptr begin = thrust::device_pointer_cast(distances); value_t maxNorm = *thrust::max_element(thrust::cuda::par.on(stream), begin, - begin + n * n_neighbors); + begin + n * n_neighbors); if (maxNorm == 0.0f) maxNorm = 1.0f; // Divide distances inplace by max @@ -137,10 +135,11 @@ void normalize_distances(const value_idx n, value_t *distances, const int n_neig * @param[in] handle: The GPU handle. */ template -void symmetrize_perplexity(float *P, value_idx *indices, const value_idx n, - const int k, const value_t exaggeration, - MLCommon::Sparse::COO *COO_Matrix, - cudaStream_t stream, const raft::handle_t &handle) { +void symmetrize_perplexity( + float *P, value_idx *indices, const value_idx n, const int k, + const value_t exaggeration, + MLCommon::Sparse::COO *COO_Matrix, cudaStream_t stream, + const raft::handle_t &handle) { // Perform (P + P.T) / P_sum * early_exaggeration const value_t div = exaggeration / (2.0f * n); raft::linalg::scalarMultiply(P, P, div, n * k, stream); diff --git a/cpp/src/tsne/exact_kernels.cuh b/cpp/src/tsne/exact_kernels.cuh index 3c5a609c12..2ffb97eca0 100644 --- a/cpp/src/tsne/exact_kernels.cuh +++ b/cpp/src/tsne/exact_kernels.cuh @@ -33,7 +33,8 @@ template __global__ void sigmas_kernel(const value_t *restrict distances, value_t *restrict P, const value_t perplexity, const value_t desired_entropy, const int epochs, - const value_t tol, const value_idx n, const int k) { + const value_t tol, const value_idx n, + const int k) { // For every item in row const auto i = (blockIdx.x * blockDim.x) + threadIdx.x; if (i >= n) return; @@ -86,8 +87,9 @@ __global__ void sigmas_kernel(const value_t *restrict distances, template __global__ void sigmas_kernel_2d(const value_t *restrict distances, value_t *restrict P, const value_t perplexity, - const value_t desired_entropy, const int epochs, - const value_t tol, const value_idx n) { + const value_t desired_entropy, + const int epochs, const value_t tol, + const value_idx n) { // For every item in row const auto i = (blockIdx.x * blockDim.x) + threadIdx.x; if (i >= n) return; @@ -141,10 +143,10 @@ void perplexity_search(const value_t *restrict distances, value_t *restrict P, cudaStream_t stream = handle.get_stream(); if (dim == 2) - sigmas_kernel_2d<<>>( + sigmas_kernel_2d<<>>( distances, P, perplexity, desired_entropy, epochs, tol, n); else - sigmas_kernel<<>>( + sigmas_kernel<<>>( distances, P, perplexity, desired_entropy, epochs, tol, n, dim); CUDA_CHECK(cudaPeekAtLastError()); cudaStreamSynchronize(stream); @@ -155,9 +157,10 @@ void perplexity_search(const value_t *restrict distances, value_t *restrict P, Uses only nearest neighbors */ template __global__ void attractive_kernel( - const value_t *restrict VAL, const value_idx *restrict COL, const value_idx *restrict ROW, - const value_t *restrict Y, const value_t *restrict norm, value_t *restrict attract, - const value_idx NNZ, const value_idx n, const value_idx dim, + const value_t *restrict VAL, const value_idx *restrict COL, + const value_idx *restrict ROW, const value_t *restrict Y, + const value_t *restrict norm, value_t *restrict attract, const value_idx NNZ, + const value_idx n, const value_idx dim, const value_t df_power, // -(df + 1)/2) const value_t recp_df) // 1 / df { @@ -187,10 +190,10 @@ __global__ void attractive_kernel( up many calculations up */ template __global__ void attractive_kernel_2d( - const value_t *restrict VAL, const value_idx *restrict COL, const value_idx *restrict ROW, - const value_t *restrict Y1, const value_t *restrict Y2, - const value_t *restrict norm, value_t *restrict attract1, - value_t *restrict attract2, const value_idx NNZ) { + const value_t *restrict VAL, const value_idx *restrict COL, + const value_idx *restrict ROW, const value_t *restrict Y1, + const value_t *restrict Y2, const value_t *restrict norm, + value_t *restrict attract1, value_t *restrict attract2, const value_idx NNZ) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= NNZ) return; const auto i = ROW[index], j = COL[index]; @@ -212,10 +215,12 @@ __global__ void attractive_kernel_2d( /****************************************/ template -void attractive_forces(const value_t *restrict VAL, const value_idx *restrict COL, +void attractive_forces(const value_t *restrict VAL, + const value_idx *restrict COL, const value_idx *restrict ROW, const value_t *restrict Y, const value_t *restrict norm, value_t *restrict attract, - const value_idx NNZ, const value_idx n, const value_idx dim, + const value_idx NNZ, const value_idx n, + const value_idx dim, const value_t df_power, // -(df + 1)/2) const value_t recp_df, // 1 / df cudaStream_t stream) { @@ -225,13 +230,14 @@ void attractive_forces(const value_t *restrict VAL, const value_idx *restrict CO // #863 // For general embedding dimensions if (dim != 2) { - attractive_kernel<<>>( + attractive_kernel<<>>( VAL, COL, ROW, Y, norm, attract, NNZ, n, dim, df_power, recp_df); } // For special case dim == 2 else { - attractive_kernel_2d<<>>( - VAL, COL, ROW, Y, Y + n, norm, attract, attract + n, NNZ); + attractive_kernel_2d<<>>(VAL, COL, ROW, Y, Y + n, norm, attract, + attract + n, NNZ); } CUDA_CHECK(cudaPeekAtLastError()); } @@ -241,10 +247,12 @@ void attractive_forces(const value_t *restrict VAL, const value_idx *restrict CO time where many of the math ops are made considerably faster. */ template -__global__ void repulsive_kernel(const value_t *restrict Y, value_t *restrict repel, +__global__ void repulsive_kernel(const value_t *restrict Y, + value_t *restrict repel, const value_t *restrict norm, - value_t *restrict Z_sum1, value_t *restrict Z_sum2, - const value_idx n, const value_idx dim, + value_t *restrict Z_sum1, + value_t *restrict Z_sum2, const value_idx n, + const value_idx dim, const value_t df_power, // -(df + 1)/2) const value_t recp_df) // 1 / df { @@ -282,8 +290,9 @@ __global__ void repulsive_kernel(const value_t *restrict Y, value_t *restrict re since calculations are streamlined. */ template __global__ void repulsive_kernel_2d( - const value_t *restrict Y1, const value_t *restrict Y2, value_t *restrict repel1, - value_t *restrict repel2, const value_t *restrict norm, value_t *restrict Z_sum1, + const value_t *restrict Y1, const value_t *restrict Y2, + value_t *restrict repel1, value_t *restrict repel2, + const value_t *restrict norm, value_t *restrict Z_sum1, value_t *restrict Z_sum2, const value_idx n) { const auto j = (blockIdx.x * blockDim.x) + threadIdx.x; // for every item in row @@ -318,15 +327,16 @@ __global__ void repulsive_kernel_2d( /****************************************/ template value_t repulsive_forces(const value_t *restrict Y, value_t *restrict repel, - const value_t *restrict norm, value_t *restrict Z_sum, - const value_idx n, const value_idx dim, - const value_t df_power, // -(df + 1)/2) - const value_t recp_df, cudaStream_t stream) { + const value_t *restrict norm, value_t *restrict Z_sum, + const value_idx n, const value_idx dim, + const value_t df_power, // -(df + 1)/2) + const value_t recp_df, cudaStream_t stream) { CUDA_CHECK(cudaMemsetAsync(Z_sum, 0, sizeof(value_t) * 2 * n, stream)); CUDA_CHECK(cudaMemsetAsync(repel, 0, sizeof(value_t) * n * dim, stream)); const dim3 threadsPerBlock(TPB_X, TPB_Y); - const dim3 numBlocks(raft::ceildiv(n, (value_idx) TPB_X), raft::ceildiv(n, (value_idx) TPB_Y)); + const dim3 numBlocks(raft::ceildiv(n, (value_idx)TPB_X), + raft::ceildiv(n, (value_idx)TPB_Y)); // For general embedding dimensions if (dim != 2) { @@ -342,10 +352,12 @@ value_t repulsive_forces(const value_t *restrict Y, value_t *restrict repel, // Find sum(Z_sum) thrust::device_ptr begin = thrust::device_pointer_cast(Z_sum); - value_t Z = thrust::reduce(thrust::cuda::par.on(stream), begin, begin + 2 * n); + value_t Z = + thrust::reduce(thrust::cuda::par.on(stream), begin, begin + 2 * n); return 1.0f / (2.0f * - (Z + (value_t)n)); // Notice + n since diagonal of repulsion sums to n + (Z + + (value_t)n)); // Notice + n since diagonal of repulsion sums to n } /****************************************/ @@ -354,8 +366,9 @@ value_t repulsive_forces(const value_t *restrict Y, value_t *restrict repel, for output stability */ template __global__ void apply_kernel( - value_t *restrict Y, value_t *restrict velocity, const value_t *restrict attract, - const value_t *restrict repel, value_t *restrict means, value_t *restrict gains, + value_t *restrict Y, value_t *restrict velocity, + const value_t *restrict attract, const value_t *restrict repel, + value_t *restrict means, value_t *restrict gains, const value_t Z, // sum(Q) const value_t learning_rate, const value_t C, // constant from T-Dist Degrees of Freedom @@ -389,19 +402,21 @@ __global__ void apply_kernel( /****************************************/ template value_t apply_forces(value_t *restrict Y, value_t *restrict velocity, - const value_t *restrict attract, const value_t *restrict repel, - value_t *restrict means, value_t *restrict gains, - const value_t Z, // sum(Q) - const value_t learning_rate, - const value_t C, // constant from T-dist - const value_t momentum, const value_idx dim, const value_idx n, - const value_t min_gain, value_t *restrict gradient, - const bool check_convergence, cudaStream_t stream) { + const value_t *restrict attract, + const value_t *restrict repel, value_t *restrict means, + value_t *restrict gains, + const value_t Z, // sum(Q) + const value_t learning_rate, + const value_t C, // constant from T-dist + const value_t momentum, const value_idx dim, + const value_idx n, const value_t min_gain, + value_t *restrict gradient, const bool check_convergence, + cudaStream_t stream) { //cudaMemset(means, 0, sizeof(float) * dim); if (check_convergence) CUDA_CHECK(cudaMemsetAsync(gradient, 0, sizeof(value_t) * n * dim, stream)); - apply_kernel<<>>( + apply_kernel<<>>( Y, velocity, attract, repel, means, gains, Z, learning_rate, C, momentum, n * dim, n, min_gain, gradient, check_convergence); CUDA_CHECK(cudaPeekAtLastError()); diff --git a/cpp/src/tsne/exact_tsne.cuh b/cpp/src/tsne/exact_tsne.cuh index 665cd12503..81fd4b207e 100644 --- a/cpp/src/tsne/exact_tsne.cuh +++ b/cpp/src/tsne/exact_tsne.cuh @@ -47,16 +47,16 @@ namespace TSNE { * @param[in] initialize_embeddings: Whether to overwrite the current Y vector with random noise. */ template -void Exact_TSNE(value_t *VAL, const value_idx *COL, const value_idx *ROW, const value_idx NNZ, - const raft::handle_t &handle, value_t *Y, const value_idx n, - const value_idx dim, const value_t early_exaggeration = 12.0f, - const int exaggeration_iter = 250, const value_t min_gain = 0.01f, - const value_t pre_learning_rate = 200.0f, - const value_t post_learning_rate = 500.0f, - const int max_iter = 1000, const value_t min_grad_norm = 1e-7, - const value_t pre_momentum = 0.5, const value_t post_momentum = 0.8, - const long long random_state = -1, - const bool initialize_embeddings = true) { +void Exact_TSNE( + value_t *VAL, const value_idx *COL, const value_idx *ROW, const value_idx NNZ, + const raft::handle_t &handle, value_t *Y, const value_idx n, + const value_idx dim, const value_t early_exaggeration = 12.0f, + const int exaggeration_iter = 250, const value_t min_gain = 0.01f, + const value_t pre_learning_rate = 200.0f, + const value_t post_learning_rate = 500.0f, const int max_iter = 1000, + const value_t min_grad_norm = 1e-7, const value_t pre_momentum = 0.5, + const value_t post_momentum = 0.8, const long long random_state = -1, + const bool initialize_embeddings = true) { auto d_alloc = handle.get_device_allocator(); cudaStream_t stream = handle.get_stream(); diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index baf04a26c9..63eecd8db9 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -106,19 +106,19 @@ class TSNE_runner { const auto *ROW = COO_Matrix.rows(); //--------------------------------------------------- - // if (barnes_hut) { - // TSNE::Barnes_Hut(VAL, COL, ROW, NNZ, handle, Y, n, theta, epssq, - // early_exaggeration, exaggeration_iter, min_gain, - // pre_learning_rate, post_learning_rate, max_iter, - // min_grad_norm, pre_momentum, post_momentum, random_state, - // initialize_embeddings); - // } else { + if (barnes_hut) { + TSNE::Barnes_Hut(VAL, COL, ROW, NNZ, handle, Y, n, theta, epssq, + early_exaggeration, exaggeration_iter, min_gain, + pre_learning_rate, post_learning_rate, max_iter, + min_grad_norm, pre_momentum, post_momentum, random_state, + initialize_embeddings); + } else { TSNE::Exact_TSNE(VAL, COL, ROW, NNZ, handle, Y, n, dim, early_exaggeration, exaggeration_iter, min_gain, pre_learning_rate, post_learning_rate, max_iter, min_grad_norm, pre_momentum, post_momentum, random_state, initialize_embeddings); - // } + } } private: diff --git a/cpp/src_prims/sparse/coo.cuh b/cpp/src_prims/sparse/coo.cuh index 5501cccd3c..79da8dc18d 100644 --- a/cpp/src_prims/sparse/coo.cuh +++ b/cpp/src_prims/sparse/coo.cuh @@ -171,7 +171,8 @@ class COO { /** * @brief Send human-readable state information to output stream */ - friend std::ostream &operator<<(std::ostream &out, const COO &c) { + friend std::ostream &operator<<(std::ostream &out, + const COO &c) { if (c.validate_size() && c.validate_mem()) { cudaStream_t stream; CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); @@ -839,14 +840,15 @@ __global__ static void symmetric_find_size(const value_t *restrict data, value_idx *restrict row_sizes, value_idx *restrict row_sizes2) { const auto row = blockIdx.x * blockDim.x + threadIdx.x; // for every row - const auto j = blockIdx.y * blockDim.y + threadIdx.y; // for every item in row + const auto j = + blockIdx.y * blockDim.y + threadIdx.y; // for every item in row if (row >= n || j >= k) return; const auto col = indices[row * k + j]; if (j % 2) - atomicAdd(&row_sizes[col], (value_idx) 1); + atomicAdd(&row_sizes[col], (value_idx)1); else - atomicAdd(&row_sizes2[col], (value_idx) 1); + atomicAdd(&row_sizes2[col], (value_idx)1); } /** @@ -885,16 +887,18 @@ template __global__ static void symmetric_sum(value_idx *restrict edges, const value_t *restrict data, const value_idx *restrict indices, - value_t *restrict VAL, value_idx *restrict COL, + value_t *restrict VAL, + value_idx *restrict COL, value_idx *restrict ROW, const value_idx n, const int k) { const auto row = blockIdx.x * blockDim.x + threadIdx.x; // for every row - const auto j = blockIdx.y * blockDim.y + threadIdx.y; // for every item in row + const auto j = + blockIdx.y * blockDim.y + threadIdx.y; // for every item in row if (row >= n || j >= k) return; const auto col = indices[row * k + j]; - const auto original = atomicAdd(&edges[row], (value_idx) 1); - const auto transpose = atomicAdd(&edges[col], (value_idx) 1); + const auto original = atomicAdd(&edges[row], (value_idx)1); + const auto transpose = atomicAdd(&edges[col], (value_idx)1); VAL[transpose] = VAL[original] = data[row * k + j]; // Notice swapped ROW, COL since transpose @@ -925,27 +929,31 @@ __global__ static void symmetric_sum(value_idx *restrict edges, */ template void from_knn_symmetrize_matrix(const value_idx *restrict knn_indices, - const value_t *restrict knn_dists, const value_idx n, - const int k, COO *out, + const value_t *restrict knn_dists, + const value_idx n, const int k, + COO *out, cudaStream_t stream, std::shared_ptr d_alloc) { // (1) Find how much space needed in each row // We look through all datapoints and increment the count for each row. const dim3 threadsPerBlock(TPB_X, TPB_Y); - const dim3 numBlocks(raft::ceildiv(n, (value_idx) TPB_X), raft::ceildiv(k, TPB_Y)); + const dim3 numBlocks(raft::ceildiv(n, (value_idx)TPB_X), + raft::ceildiv(k, TPB_Y)); // Notice n+1 since we can reuse these arrays for transpose_edges, original_edges in step (4) device_buffer row_sizes(d_alloc, stream, n); - CUDA_CHECK(cudaMemsetAsync(row_sizes.data(), 0, sizeof(value_idx) * n, stream)); + CUDA_CHECK( + cudaMemsetAsync(row_sizes.data(), 0, sizeof(value_idx) * n, stream)); device_buffer row_sizes2(d_alloc, stream, n); - CUDA_CHECK(cudaMemsetAsync(row_sizes2.data(), 0, sizeof(value_idx) * n, stream)); + CUDA_CHECK( + cudaMemsetAsync(row_sizes2.data(), 0, sizeof(value_idx) * n, stream)); symmetric_find_size<<>>( knn_dists, knn_indices, n, k, row_sizes.data(), row_sizes2.data()); CUDA_CHECK(cudaPeekAtLastError()); - reduce_find_size<<>>( + reduce_find_size<<>>( n, k, row_sizes.data(), row_sizes2.data()); CUDA_CHECK(cudaPeekAtLastError()); diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index 57dd7f94d8..13f0a249fc 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -138,8 +138,7 @@ class TSNETest : public ::testing::Test { knn_graph k_graph(n, 90, knn_indices.data(), knn_dists.data()); - TSNE::get_distances(handle, input, k_graph, - handle.get_stream()); + TSNE::get_distances(handle, input, k_graph, handle.get_stream()); CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); From 09dad7ac48bc759b90f7961d8c093847cfea8f0e Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 21 Dec 2020 19:35:41 -0600 Subject: [PATCH 20/22] hyperparams to float32 --- cpp/src/tsne/barnes_hut.cuh | 14 +++++----- cpp/src/tsne/bh_kernels.cuh | 6 ++--- cpp/src/tsne/distances.cuh | 2 +- cpp/src/tsne/exact_kernels.cuh | 48 +++++++++++++++++----------------- cpp/src/tsne/exact_tsne.cuh | 28 ++++++++++---------- cpp/src/tsne/tsne.cu | 14 +++++----- cpp/src/tsne/tsne_runner.cuh | 22 ++++++++-------- 7 files changed, 67 insertions(+), 67 deletions(-) diff --git a/cpp/src/tsne/barnes_hut.cuh b/cpp/src/tsne/barnes_hut.cuh index edd44b35cb..27eadfb52d 100644 --- a/cpp/src/tsne/barnes_hut.cuh +++ b/cpp/src/tsne/barnes_hut.cuh @@ -52,12 +52,12 @@ template void Barnes_Hut( value_t *VAL, const value_idx *COL, const value_idx *ROW, const value_idx NNZ, const raft::handle_t &handle, value_t *Y, const value_idx n, - const value_t theta = 0.5f, const value_t epssq = 0.0025, - const value_t early_exaggeration = 12.0f, const int exaggeration_iter = 250, - const value_t min_gain = 0.01f, const value_t pre_learning_rate = 200.0f, - const value_t post_learning_rate = 500.0f, const int max_iter = 1000, - const value_t min_grad_norm = 1e-7, const value_t pre_momentum = 0.5, - const value_t post_momentum = 0.8, const long long random_state = -1, + const float theta = 0.5f, const float epssq = 0.0025, + const float early_exaggeration = 12.0f, const int exaggeration_iter = 250, + const float min_gain = 0.01f, const float pre_learning_rate = 200.0f, + const float post_learning_rate = 500.0f, const int max_iter = 1000, + const float min_grad_norm = 1e-7, const float pre_momentum = 0.5, + const float post_momentum = 0.8, const long long random_state = -1, const bool initialize_embeddings = true) { auto d_alloc = handle.get_device_allocator(); cudaStream_t stream = handle.get_stream(); @@ -87,7 +87,7 @@ void Barnes_Hut( const value_idx FOUR_NNODES = 4 * nnodes; const value_idx FOUR_N = 4 * n; - const value_t theta_squared = theta * theta; + const float theta_squared = theta * theta; const value_idx NNODES = nnodes; // Actual allocations diff --git a/cpp/src/tsne/bh_kernels.cuh b/cpp/src/tsne/bh_kernels.cuh index 6cf47550bb..241586ed0b 100644 --- a/cpp/src/tsne/bh_kernels.cuh +++ b/cpp/src/tsne/bh_kernels.cuh @@ -525,8 +525,8 @@ template __global__ __launch_bounds__( THREADS5, 1) void RepulsionKernel(/* int *restrict errd, */ - const value_t theta, - const value_t epssqd, // correction for zero distance + const float theta, + const float epssqd, // correction for zero distance const value_idx *restrict sortd, const value_idx *restrict childd, const value_t *restrict massd, @@ -683,7 +683,7 @@ __global__ void attractive_kernel_bh( */ template __global__ __launch_bounds__(THREADS6, 1) void IntegrationKernel( - const value_t eta, const value_t momentum, const value_t exaggeration, + const float eta, const float momentum, const float exaggeration, value_t *restrict Y1, value_t *restrict Y2, const value_t *restrict attract1, const value_t *restrict attract2, const value_t *restrict repel1, const value_t *restrict repel2, value_t *restrict gains1, diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 05dcfe2b52..7dc31e20dc 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -97,7 +97,7 @@ void get_distances(const raft::handle_t &handle, manifold_sparse_inputs_t &input, knn_graph &k_graph, cudaStream_t stream) { throw raft::exception( - "Sparse TSNE does not support 32-bit integer indices yet."); + "Sparse TSNE does not support 64-bit integer indices yet."); } /** diff --git a/cpp/src/tsne/exact_kernels.cuh b/cpp/src/tsne/exact_kernels.cuh index 2ffb97eca0..c5abd39556 100644 --- a/cpp/src/tsne/exact_kernels.cuh +++ b/cpp/src/tsne/exact_kernels.cuh @@ -31,9 +31,9 @@ namespace TSNE { each row in the dataset */ template __global__ void sigmas_kernel(const value_t *restrict distances, - value_t *restrict P, const value_t perplexity, - const value_t desired_entropy, const int epochs, - const value_t tol, const value_idx n, + value_t *restrict P, const float perplexity, + const float desired_entropy, const int epochs, + const float tol, const value_idx n, const int k) { // For every item in row const auto i = (blockIdx.x * blockDim.x) + threadIdx.x; @@ -86,9 +86,9 @@ __global__ void sigmas_kernel(const value_t *restrict distances, each row in the dataset */ template __global__ void sigmas_kernel_2d(const value_t *restrict distances, - value_t *restrict P, const value_t perplexity, - const value_t desired_entropy, - const int epochs, const value_t tol, + value_t *restrict P, const float perplexity, + const float desired_entropy, + const int epochs, const float tol, const value_idx n) { // For every item in row const auto i = (blockIdx.x * blockDim.x) + threadIdx.x; @@ -135,10 +135,10 @@ __global__ void sigmas_kernel_2d(const value_t *restrict distances, /****************************************/ template void perplexity_search(const value_t *restrict distances, value_t *restrict P, - const value_t perplexity, const int epochs, - const value_t tol, const value_idx n, const int dim, + const float perplexity, const int epochs, + const float tol, const value_idx n, const int dim, const raft::handle_t &handle) { - const value_t desired_entropy = logf(perplexity); + const float desired_entropy = logf(perplexity); auto d_alloc = handle.get_device_allocator(); cudaStream_t stream = handle.get_stream(); @@ -161,8 +161,8 @@ __global__ void attractive_kernel( const value_idx *restrict ROW, const value_t *restrict Y, const value_t *restrict norm, value_t *restrict attract, const value_idx NNZ, const value_idx n, const value_idx dim, - const value_t df_power, // -(df + 1)/2) - const value_t recp_df) // 1 / df + const float df_power, // -(df + 1)/2) + const float recp_df) // 1 / df { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= NNZ) return; @@ -221,8 +221,8 @@ void attractive_forces(const value_t *restrict VAL, const value_t *restrict norm, value_t *restrict attract, const value_idx NNZ, const value_idx n, const value_idx dim, - const value_t df_power, // -(df + 1)/2) - const value_t recp_df, // 1 / df + const float df_power, // -(df + 1)/2) + const float recp_df, // 1 / df cudaStream_t stream) { CUDA_CHECK(cudaMemsetAsync(attract, 0, sizeof(value_t) * n * dim, stream)); @@ -369,12 +369,12 @@ __global__ void apply_kernel( value_t *restrict Y, value_t *restrict velocity, const value_t *restrict attract, const value_t *restrict repel, value_t *restrict means, value_t *restrict gains, - const value_t Z, // sum(Q) - const value_t learning_rate, - const value_t C, // constant from T-Dist Degrees of Freedom - const value_t momentum, + const float Z, // sum(Q) + const float learning_rate, + const float C, // constant from T-Dist Degrees of Freedom + const float momentum, const value_idx SIZE, // SIZE = n*dim - const value_idx n, const value_t min_gain, value_t *restrict gradient, + const value_idx n, const float min_gain, value_t *restrict gradient, const bool check_convergence) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= SIZE) return; @@ -405,11 +405,11 @@ value_t apply_forces(value_t *restrict Y, value_t *restrict velocity, const value_t *restrict attract, const value_t *restrict repel, value_t *restrict means, value_t *restrict gains, - const value_t Z, // sum(Q) - const value_t learning_rate, - const value_t C, // constant from T-dist - const value_t momentum, const value_idx dim, - const value_idx n, const value_t min_gain, + const float Z, // sum(Q) + const float learning_rate, + const float C, // constant from T-dist + const float momentum, const value_idx dim, + const value_idx n, const float min_gain, value_t *restrict gradient, const bool check_convergence, cudaStream_t stream) { //cudaMemset(means, 0, sizeof(float) * dim); @@ -422,7 +422,7 @@ value_t apply_forces(value_t *restrict Y, value_t *restrict velocity, CUDA_CHECK(cudaPeekAtLastError()); // Find sum of gradient norms - value_t gradient_norm = INFINITY; + float gradient_norm = INFINITY; if (check_convergence) { thrust::device_ptr begin = thrust::device_pointer_cast(gradient); gradient_norm = sqrtf( diff --git a/cpp/src/tsne/exact_tsne.cuh b/cpp/src/tsne/exact_tsne.cuh index 81fd4b207e..c095900ba1 100644 --- a/cpp/src/tsne/exact_tsne.cuh +++ b/cpp/src/tsne/exact_tsne.cuh @@ -50,12 +50,12 @@ template void Exact_TSNE( value_t *VAL, const value_idx *COL, const value_idx *ROW, const value_idx NNZ, const raft::handle_t &handle, value_t *Y, const value_idx n, - const value_idx dim, const value_t early_exaggeration = 12.0f, - const int exaggeration_iter = 250, const value_t min_gain = 0.01f, - const value_t pre_learning_rate = 200.0f, - const value_t post_learning_rate = 500.0f, const int max_iter = 1000, - const value_t min_grad_norm = 1e-7, const value_t pre_momentum = 0.5, - const value_t post_momentum = 0.8, const long long random_state = -1, + const value_idx dim, const float early_exaggeration = 12.0f, + const int exaggeration_iter = 250, const float min_gain = 0.01f, + const float pre_learning_rate = 200.0f, + const float post_learning_rate = 500.0f, const int max_iter = 1000, + const float min_grad_norm = 1e-7, const float pre_momentum = 0.5, + const float post_momentum = 0.8, const long long random_state = -1, const bool initialize_embeddings = true) { auto d_alloc = handle.get_device_allocator(); cudaStream_t stream = handle.get_stream(); @@ -86,14 +86,14 @@ void Exact_TSNE( // Calculate degrees of freedom //--------------------------------------------------- - const value_t degrees_of_freedom = fmaxf(dim - 1, 1); - const value_t df_power = -(degrees_of_freedom + 1.0f) / 2.0f; - const value_t recp_df = 1.0f / degrees_of_freedom; - const value_t C = 2.0f * (degrees_of_freedom + 1.0f) / degrees_of_freedom; + const float degrees_of_freedom = fmaxf(dim - 1, 1); + const float df_power = -(degrees_of_freedom + 1.0f) / 2.0f; + const float recp_df = 1.0f / degrees_of_freedom; + const float C = 2.0f * (degrees_of_freedom + 1.0f) / degrees_of_freedom; CUML_LOG_DEBUG("Start gradient updates!"); - value_t momentum = pre_momentum; - value_t learning_rate = pre_learning_rate; + float momentum = pre_momentum; + float learning_rate = pre_learning_rate; bool check_convergence = false; for (int iter = 0; iter < max_iter; iter++) { @@ -102,7 +102,7 @@ void Exact_TSNE( if (iter == exaggeration_iter) { momentum = post_momentum; // Divide perplexities - const value_t div = 1.0f / early_exaggeration; + const float div = 1.0f / early_exaggeration; raft::linalg::scalarMultiply(VAL, VAL, div, NNZ, stream); learning_rate = post_learning_rate; } @@ -120,7 +120,7 @@ void Exact_TSNE( df_power, recp_df, stream); // Apply / integrate forces - const value_t gradient_norm = TSNE::apply_forces( + const float gradient_norm = TSNE::apply_forces( Y, velocity.data(), attract.data(), repel.data(), means.data(), gains.data(), Z, learning_rate, C, momentum, dim, n, min_gain, gradient.data(), check_convergence, stream); diff --git a/cpp/src/tsne/tsne.cu b/cpp/src/tsne/tsne.cu index 488ab5065b..3acff8ca9c 100644 --- a/cpp/src/tsne/tsne.cu +++ b/cpp/src/tsne/tsne.cu @@ -22,13 +22,13 @@ namespace ML { template void _fit(const raft::handle_t &handle, tsne_input &input, knn_graph &k_graph, const value_idx dim, - const value_t theta, const value_t epssq, value_t perplexity, - const int perplexity_max_iter, const value_t perplexity_tol, - const value_t early_exaggeration, const int exaggeration_iter, - const value_t min_gain, const value_t pre_learning_rate, - const value_t post_learning_rate, const int max_iter, - const value_t min_grad_norm, const value_t pre_momentum, - const value_t post_momentum, const long long random_state, + const float theta, const float epssq, float perplexity, + const int perplexity_max_iter, const float perplexity_tol, + const float early_exaggeration, const int exaggeration_iter, + const float min_gain, const float pre_learning_rate, + const float post_learning_rate, const int max_iter, + const float min_grad_norm, const float pre_momentum, + const float post_momentum, const long long random_state, int verbosity, const bool initialize_embeddings, bool barnes_hut) { TSNE_runner runner( handle, input, k_graph, dim, theta, epssq, perplexity, perplexity_max_iter, diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index 63eecd8db9..ddcb66e5fb 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -184,20 +184,20 @@ class TSNE_runner { knn_graph &k_graph; const value_idx dim; int n_neighbors; - const value_t theta; - const value_t epssq; - value_t perplexity; + const float theta; + const float epssq; + float perplexity; const int perplexity_max_iter; - const value_t perplexity_tol; - const value_t early_exaggeration; + const float perplexity_tol; + const float early_exaggeration; const int exaggeration_iter; - const value_t min_gain; - const value_t pre_learning_rate; - const value_t post_learning_rate; + const float min_gain; + const float pre_learning_rate; + const float post_learning_rate; const int max_iter; - const value_t min_grad_norm; - const value_t pre_momentum; - const value_t post_momentum; + const float min_grad_norm; + const float pre_momentum; + const float post_momentum; const long long random_state; int verbosity; const bool initialize_embeddings; From 76228b68b655a16e03669a5e86dda2c458d97477 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 21 Dec 2020 19:40:40 -0600 Subject: [PATCH 21/22] removing constants from kernel launch bounds --- cpp/src/tsne/bh_kernels.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/tsne/bh_kernels.cuh b/cpp/src/tsne/bh_kernels.cuh index 241586ed0b..ffbd5ffc0b 100644 --- a/cpp/src/tsne/bh_kernels.cuh +++ b/cpp/src/tsne/bh_kernels.cuh @@ -82,7 +82,7 @@ __global__ void Find_Normalization(value_t *restrict Z_norm, * Figures the bounding boxes for every point in the embedding. */ template -__global__ __launch_bounds__(THREADS1, 2) void BoundingBoxKernel( +__global__ __launch_bounds__(THREADS1) void BoundingBoxKernel( value_idx *restrict startd, value_idx *restrict childd, value_t *restrict massd, value_t *restrict posxd, value_t *restrict posyd, value_t *restrict maxxd, value_t *restrict maxyd, value_t *restrict minxd, @@ -184,7 +184,7 @@ __global__ __launch_bounds__(1024, */ template __global__ __launch_bounds__( - THREADS2, 2) void TreeBuildingKernel(/* int *restrict errd, */ + THREADS2) void TreeBuildingKernel(/* int *restrict errd, */ value_idx *restrict childd, const value_t *restrict posxd, const value_t *restrict posyd, From 383eca420a22ecd96bbbee2ba5607d2773fb9986 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 21 Dec 2020 19:42:10 -0600 Subject: [PATCH 22/22] style check --- cpp/src/tsne/barnes_hut.cuh | 22 ++++++++++++---------- cpp/src/tsne/bh_kernels.cuh | 15 +++++++-------- cpp/src/tsne/exact_kernels.cuh | 8 +++----- cpp/src/tsne/exact_tsne.cuh | 21 +++++++++++---------- 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/cpp/src/tsne/barnes_hut.cuh b/cpp/src/tsne/barnes_hut.cuh index 27eadfb52d..be86affb90 100644 --- a/cpp/src/tsne/barnes_hut.cuh +++ b/cpp/src/tsne/barnes_hut.cuh @@ -49,16 +49,18 @@ namespace TSNE { * @param[in] initialize_embeddings: Whether to overwrite the current Y vector with random noise. */ template -void Barnes_Hut( - value_t *VAL, const value_idx *COL, const value_idx *ROW, const value_idx NNZ, - const raft::handle_t &handle, value_t *Y, const value_idx n, - const float theta = 0.5f, const float epssq = 0.0025, - const float early_exaggeration = 12.0f, const int exaggeration_iter = 250, - const float min_gain = 0.01f, const float pre_learning_rate = 200.0f, - const float post_learning_rate = 500.0f, const int max_iter = 1000, - const float min_grad_norm = 1e-7, const float pre_momentum = 0.5, - const float post_momentum = 0.8, const long long random_state = -1, - const bool initialize_embeddings = true) { +void Barnes_Hut(value_t *VAL, const value_idx *COL, const value_idx *ROW, + const value_idx NNZ, const raft::handle_t &handle, value_t *Y, + const value_idx n, const float theta = 0.5f, + const float epssq = 0.0025, + const float early_exaggeration = 12.0f, + const int exaggeration_iter = 250, const float min_gain = 0.01f, + const float pre_learning_rate = 200.0f, + const float post_learning_rate = 500.0f, + const int max_iter = 1000, const float min_grad_norm = 1e-7, + const float pre_momentum = 0.5, const float post_momentum = 0.8, + const long long random_state = -1, + const bool initialize_embeddings = true) { auto d_alloc = handle.get_device_allocator(); cudaStream_t stream = handle.get_stream(); diff --git a/cpp/src/tsne/bh_kernels.cuh b/cpp/src/tsne/bh_kernels.cuh index ffbd5ffc0b..2ec840c76f 100644 --- a/cpp/src/tsne/bh_kernels.cuh +++ b/cpp/src/tsne/bh_kernels.cuh @@ -185,14 +185,13 @@ __global__ __launch_bounds__(1024, template __global__ __launch_bounds__( THREADS2) void TreeBuildingKernel(/* int *restrict errd, */ - value_idx *restrict childd, - const value_t *restrict posxd, - const value_t *restrict posyd, - const value_idx NNODES, - const value_idx N, - value_idx *restrict maxdepthd, - value_idx *restrict bottomd, - const value_t *restrict radiusd) { + value_idx *restrict childd, + const value_t *restrict posxd, + const value_t *restrict posyd, + const value_idx NNODES, const value_idx N, + value_idx *restrict maxdepthd, + value_idx *restrict bottomd, + const value_t *restrict radiusd) { value_idx j, depth; value_t x, y, r; value_t px, py; diff --git a/cpp/src/tsne/exact_kernels.cuh b/cpp/src/tsne/exact_kernels.cuh index c5abd39556..95d18be827 100644 --- a/cpp/src/tsne/exact_kernels.cuh +++ b/cpp/src/tsne/exact_kernels.cuh @@ -33,8 +33,7 @@ template __global__ void sigmas_kernel(const value_t *restrict distances, value_t *restrict P, const float perplexity, const float desired_entropy, const int epochs, - const float tol, const value_idx n, - const int k) { + const float tol, const value_idx n, const int k) { // For every item in row const auto i = (blockIdx.x * blockDim.x) + threadIdx.x; if (i >= n) return; @@ -87,9 +86,8 @@ __global__ void sigmas_kernel(const value_t *restrict distances, template __global__ void sigmas_kernel_2d(const value_t *restrict distances, value_t *restrict P, const float perplexity, - const float desired_entropy, - const int epochs, const float tol, - const value_idx n) { + const float desired_entropy, const int epochs, + const float tol, const value_idx n) { // For every item in row const auto i = (blockIdx.x * blockDim.x) + threadIdx.x; if (i >= n) return; diff --git a/cpp/src/tsne/exact_tsne.cuh b/cpp/src/tsne/exact_tsne.cuh index c095900ba1..105da4b369 100644 --- a/cpp/src/tsne/exact_tsne.cuh +++ b/cpp/src/tsne/exact_tsne.cuh @@ -47,16 +47,17 @@ namespace TSNE { * @param[in] initialize_embeddings: Whether to overwrite the current Y vector with random noise. */ template -void Exact_TSNE( - value_t *VAL, const value_idx *COL, const value_idx *ROW, const value_idx NNZ, - const raft::handle_t &handle, value_t *Y, const value_idx n, - const value_idx dim, const float early_exaggeration = 12.0f, - const int exaggeration_iter = 250, const float min_gain = 0.01f, - const float pre_learning_rate = 200.0f, - const float post_learning_rate = 500.0f, const int max_iter = 1000, - const float min_grad_norm = 1e-7, const float pre_momentum = 0.5, - const float post_momentum = 0.8, const long long random_state = -1, - const bool initialize_embeddings = true) { +void Exact_TSNE(value_t *VAL, const value_idx *COL, const value_idx *ROW, + const value_idx NNZ, const raft::handle_t &handle, value_t *Y, + const value_idx n, const value_idx dim, + const float early_exaggeration = 12.0f, + const int exaggeration_iter = 250, const float min_gain = 0.01f, + const float pre_learning_rate = 200.0f, + const float post_learning_rate = 500.0f, + const int max_iter = 1000, const float min_grad_norm = 1e-7, + const float pre_momentum = 0.5, const float post_momentum = 0.8, + const long long random_state = -1, + const bool initialize_embeddings = true) { auto d_alloc = handle.get_device_allocator(); cudaStream_t stream = handle.get_stream();