diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index c8d12f2b8c..b9330e5215 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -28,6 +28,8 @@ namespace ML { enum TSNE_ALGORITHM { EXACT, BARNES_HUT, FFT }; +enum TSNE_INIT { RANDOM, PCA }; + struct TSNEParams { // Number of output dimensions for embeddings Y. int dim = 2; @@ -94,8 +96,8 @@ struct TSNEParams { // verbosity level for logging messages during execution int verbosity = CUML_LEVEL_INFO; - // Whether to overwrite the current Y vector with random noise. - bool initialize_embeddings = true; + // Embedding initializer algorithm + TSNE_INIT init = TSNE_INIT::RANDOM; // When this is set to true, the distances from the knn graph will // always be squared before computing conditional probabilities, even if diff --git a/cpp/src/tsne/barnes_hut_tsne.cuh b/cpp/src/tsne/barnes_hut_tsne.cuh index ef63473060..23bb233350 100644 --- a/cpp/src/tsne/barnes_hut_tsne.cuh +++ b/cpp/src/tsne/barnes_hut_tsne.cuh @@ -126,7 +126,8 @@ value_t Barnes_Hut(value_t* VAL, RAFT_CUDA_TRY(cudaMemsetAsync(old_forces.data(), 0, sizeof(value_t) * n * 2, stream)); rmm::device_uvector YY((nnodes + 1) * 2, stream); - if (params.initialize_embeddings) { + + if (params.init == TSNE_INIT::RANDOM) { random_vector(YY.data(), -0.0001f, 0.0001f, (nnodes + 1) * 2, stream, params.random_state); } else { raft::copy(YY.data(), Y, n, stream); diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index a221d70820..4963f42eee 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -91,8 +91,8 @@ void get_distances(const raft::handle_t& handle, k_graph.knn_indices, k_graph.knn_dists, k_graph.n_neighbors, - true, - true, + false, + false, nullptr, metric, p); diff --git a/cpp/src/tsne/exact_tsne.cuh b/cpp/src/tsne/exact_tsne.cuh index 2b236574b6..680c3200e8 100644 --- a/cpp/src/tsne/exact_tsne.cuh +++ b/cpp/src/tsne/exact_tsne.cuh @@ -54,9 +54,6 @@ value_t Exact_TSNE(value_t* VAL, value_t kl_div = 0; const value_idx dim = params.dim; - if (params.initialize_embeddings) - random_vector(Y, -0.0001f, 0.0001f, n * dim, stream, params.random_state); - // Allocate space //--------------------------------------------------- CUML_LOG_DEBUG("Now allocating memory for TSNE."); diff --git a/cpp/src/tsne/fft_tsne.cuh b/cpp/src/tsne/fft_tsne.cuh index bd9e9e73b8..cb5dedf932 100644 --- a/cpp/src/tsne/fft_tsne.cuh +++ b/cpp/src/tsne/fft_tsne.cuh @@ -340,10 +340,6 @@ value_t FFT_TSNE(value_t* VAL, value_t learning_rate = params.pre_learning_rate; value_t exaggeration = params.early_exaggeration; - if (params.initialize_embeddings) { - random_vector(Y, 0.0000f, 0.0001f, n * 2, stream, params.random_state); - } - value_t kl_div = 0; for (int iter = 0; iter < params.max_iter; iter++) { // Compute charges Q_ij diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index 8e9fdb0df5..96df2a12ca 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -27,14 +27,25 @@ #include #include +#include +#include +#include #include #include #include +#include + namespace ML { +template class U> +inline constexpr bool is_instance_of = std::false_type{}; + +template