Skip to content

Commit

Permalink
[FEA] PCA Initialization for TSNE (rapidsai#5897)
Browse files Browse the repository at this point in the history
Closes rapidsai#3458

Add PCA embedding initialization to C++ layer and expose it in Python API.
```python

from cuml.manifold import TSNE

tsne = TSNE(
    ...
    init="pca" # ("random" or "pca")
)
```

Authors:
  - Anupam (https://github.com/aamijar)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Micka (https://github.com/lowener)

Approvers:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#5897
  • Loading branch information
aamijar authored Jul 10, 2024
1 parent 5679372 commit 1e7de60
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 29 deletions.
6 changes: 4 additions & 2 deletions cpp/include/cuml/manifold/tsne.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace ML {

enum TSNE_ALGORITHM { EXACT, BARNES_HUT, FFT };

enum TSNE_INIT { RANDOM, PCA };

struct TSNEParams {
// Number of output dimensions for embeddings Y.
int dim = 2;
Expand Down Expand Up @@ -94,8 +96,8 @@ struct TSNEParams {
// verbosity level for logging messages during execution
int verbosity = CUML_LEVEL_INFO;

// Whether to overwrite the current Y vector with random noise.
bool initialize_embeddings = true;
// Embedding initializer algorithm
TSNE_INIT init = TSNE_INIT::RANDOM;

// When this is set to true, the distances from the knn graph will
// always be squared before computing conditional probabilities, even if
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/tsne/barnes_hut_tsne.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ value_t Barnes_Hut(value_t* VAL,
RAFT_CUDA_TRY(cudaMemsetAsync(old_forces.data(), 0, sizeof(value_t) * n * 2, stream));

rmm::device_uvector<value_t> YY((nnodes + 1) * 2, stream);
if (params.initialize_embeddings) {

if (params.init == TSNE_INIT::RANDOM) {
random_vector(YY.data(), -0.0001f, 0.0001f, (nnodes + 1) * 2, stream, params.random_state);
} else {
raft::copy(YY.data(), Y, n, stream);
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/tsne/distances.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ void get_distances(const raft::handle_t& handle,
k_graph.knn_indices,
k_graph.knn_dists,
k_graph.n_neighbors,
true,
true,
false,
false,
nullptr,
metric,
p);
Expand Down
3 changes: 0 additions & 3 deletions cpp/src/tsne/exact_tsne.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ value_t Exact_TSNE(value_t* VAL,
value_t kl_div = 0;
const value_idx dim = params.dim;

if (params.initialize_embeddings)
random_vector(Y, -0.0001f, 0.0001f, n * dim, stream, params.random_state);

// Allocate space
//---------------------------------------------------
CUML_LOG_DEBUG("Now allocating memory for TSNE.");
Expand Down
4 changes: 0 additions & 4 deletions cpp/src/tsne/fft_tsne.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,6 @@ value_t FFT_TSNE(value_t* VAL,
value_t learning_rate = params.pre_learning_rate;
value_t exaggeration = params.early_exaggeration;

if (params.initialize_embeddings) {
random_vector(Y, 0.0000f, 0.0001f, n * 2, stream, params.random_state);
}

value_t kl_div = 0;
for (int iter = 0; iter < params.max_iter; iter++) {
// Compute charges Q_ij
Expand Down
83 changes: 83 additions & 0 deletions cpp/src/tsne/tsne_runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,25 @@

#include <raft/core/handle.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/divide.cuh>
#include <raft/linalg/multiply.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_uvector.hpp>

#include <thrust/transform.h>

#include <pca/pca.cuh>

namespace ML {

template <class T, template <class> class U>
inline constexpr bool is_instance_of = std::false_type{};

template <template <class> class U, class V>
inline constexpr bool is_instance_of<U<V>, U> = std::true_type{};

template <typename tsne_input, typename value_idx, typename value_t>
class TSNE_runner {
public:
Expand Down Expand Up @@ -78,6 +89,78 @@ class TSNE_runner {
CUML_LOG_WARN(
"# of Nearest Neighbors should be at least 3 * perplexity. Your results"
" might be a bit strange...");

auto stream = handle.get_stream();
const value_idx dim = params.dim;

if (params.init == TSNE_INIT::RANDOM) {
random_vector(Y, -0.0001f, 0.0001f, n * dim, stream, params.random_state);
} else if (params.init == TSNE_INIT::PCA) {
auto components = raft::make_device_matrix<float>(handle, p, dim);
auto explained_var = raft::make_device_vector<float>(handle, dim);
auto explained_var_ratio = raft::make_device_vector<float>(handle, dim);
auto singular_vals = raft::make_device_vector<float>(handle, dim);
auto mu = raft::make_device_vector<float>(handle, p);
auto noise_vars = raft::make_device_scalar<float>(handle, 0);

paramsPCA prms;
prms.n_cols = p;
prms.n_rows = n;
prms.n_components = dim;
prms.whiten = false;
prms.algorithm = solver::COV_EIG_DQ;

if constexpr (!is_instance_of<tsne_input, manifold_dense_inputs_t>) {
throw std::runtime_error("The tsne_input must be of type manifold_dense_inputs_t");
} else {
pcaFitTransform(handle,
input.X,
Y,
components.data_handle(),
explained_var.data_handle(),
explained_var_ratio.data_handle(),
singular_vals.data_handle(),
mu.data_handle(),
noise_vars.data_handle(),
prms,
stream);

auto mean_result = raft::make_device_vector<float, int>(handle, dim);
auto stddev_result = raft::make_device_vector<float, int>(handle, dim);
const float multiplier = 1e-4;

auto Y_view = raft::make_device_matrix_view<float, int, raft::col_major>(Y, n, dim);
auto Y_view_const =
raft::make_device_matrix_view<const float, int, raft::col_major>(Y, n, dim);

auto mean_result_view = mean_result.view();
auto mean_result_view_const = raft::make_const_mdspan(mean_result.view());

auto stddev_result_view = stddev_result.view();

auto h_multiplier_view_const = raft::make_host_scalar_view<const float>(&multiplier);

raft::stats::mean(handle, Y_view_const, mean_result_view, false);
raft::stats::stddev(
handle, Y_view_const, mean_result_view_const, stddev_result_view, false);

divide_scalar_device(Y_view, Y_view_const, stddev_result_view);
raft::linalg::multiply_scalar(handle, Y_view_const, Y_view, h_multiplier_view_const);
}
}
}

void divide_scalar_device(
raft::device_matrix_view<float, int, raft::col_major>& Y_view,
raft::device_matrix_view<const float, int, raft::col_major>& Y_view_const,
raft::device_vector_view<float, int>& stddev_result_view)
{
raft::linalg::unary_op(handle,
Y_view_const,
Y_view,
[device_scalar = stddev_result_view.data_handle()] __device__(auto y) {
return y / *device_scalar;
});
}

value_t run()
Expand Down
31 changes: 27 additions & 4 deletions cpp/test/sg/tsne_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/core/handle.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/transpose.cuh>
#include <raft/util/cudart_utils.hpp>

#include <thrust/reduce.h>
Expand All @@ -47,6 +48,7 @@ using namespace ML::Metrics;
struct TSNEInput {
int n, p;
std::vector<float> dataset;
TSNE_INIT init;
double trustworthiness_threshold;
};

Expand Down Expand Up @@ -128,6 +130,11 @@ class TSNETest : public ::testing::TestWithParam<TSNEInput> {
// Allocate memory
rmm::device_uvector<float> X_d(n * p, stream);
raft::update_device(X_d.data(), dataset.data(), n * p, stream);

rmm::device_uvector<float> Xtranspose(n * p, stream);
raft::copy_async(Xtranspose.data(), X_d.data(), n * p, stream);
raft::linalg::transpose(handle, Xtranspose.data(), X_d.data(), p, n, stream);

rmm::device_uvector<float> Y_d(n * model_params.dim, stream);
rmm::device_uvector<int64_t> input_indices(0, stream);
rmm::device_uvector<float> input_dists(0, stream);
Expand Down Expand Up @@ -183,6 +190,9 @@ class TSNETest : public ::testing::TestWithParam<TSNEInput> {
handle.sync_stream(stream);
free(embeddings_h);

raft::copy_async(Xtranspose.data(), X_d.data(), n * p, stream);
raft::linalg::transpose(handle, Xtranspose.data(), X_d.data(), n, p, stream);

// Produce trustworthiness score
results.trustworthiness =
trustworthiness_score<float, raft::distance::DistanceType::L2SqrtUnexpanded>(
Expand Down Expand Up @@ -215,6 +225,7 @@ class TSNETest : public ::testing::TestWithParam<TSNEInput> {
p = params.p;
dataset = params.dataset;
trustworthiness_threshold = params.trustworthiness_threshold;
model_params.init = params.init;
basicTest();
}

Expand Down Expand Up @@ -242,10 +253,22 @@ class TSNETest : public ::testing::TestWithParam<TSNEInput> {
};

const std::vector<TSNEInput> inputs = {
{Digits::n_samples, Digits::n_features, Digits::digits, 0.98},
{Boston::n_samples, Boston::n_features, Boston::boston, 0.98},
{BreastCancer::n_samples, BreastCancer::n_features, BreastCancer::breast_cancer, 0.98},
{Diabetes::n_samples, Diabetes::n_features, Diabetes::diabetes, 0.90}};
{Digits::n_samples, Digits::n_features, Digits::digits, TSNE_INIT::RANDOM, 0.98},
{Boston::n_samples, Boston::n_features, Boston::boston, TSNE_INIT::RANDOM, 0.98},
{BreastCancer::n_samples,
BreastCancer::n_features,
BreastCancer::breast_cancer,
TSNE_INIT::RANDOM,
0.98},
{Diabetes::n_samples, Diabetes::n_features, Diabetes::diabetes, TSNE_INIT::RANDOM, 0.90},
{Digits::n_samples, Digits::n_features, Digits::digits, TSNE_INIT::PCA, 0.98},
{Boston::n_samples, Boston::n_features, Boston::boston, TSNE_INIT::PCA, 0.98},
{BreastCancer::n_samples,
BreastCancer::n_features,
BreastCancer::breast_cancer,
TSNE_INIT::PCA,
0.98},
{Diabetes::n_samples, Diabetes::n_features, Diabetes::diabetes, TSNE_INIT::PCA, 0.90}};

typedef TSNETest TSNETestF;
TEST_P(TSNETestF, Result)
Expand Down
28 changes: 17 additions & 11 deletions python/cuml/manifold/t_sne.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -56,6 +56,10 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML":
BARNES_HUT = 1,
FFT = 2

enum TSNE_INIT:
RANDOM = 0,
PCA = 1

cdef cppclass TSNEParams:
int dim,
int n_neighbors,
Expand All @@ -76,7 +80,7 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML":
float post_momentum,
long long random_state,
int verbosity,
bool initialize_embeddings,
TSNE_INIT init,
bool square_distances,
DistanceType metric,
float p,
Expand Down Expand Up @@ -156,8 +160,8 @@ class TSNE(Base,
Distance metric to use. Supported distances are ['l1, 'cityblock',
'manhattan', 'euclidean', 'l2', 'sqeuclidean', 'minkowski',
'chebyshev', 'cosine', 'correlation']
init : str 'random' (default 'random')
Currently supports random initialization.
init : str 'random' or 'pca' (default 'random')
Currently supports random or pca initialization.
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`.
See :ref:`verbosity-levels` for more info.
Expand Down Expand Up @@ -317,11 +321,9 @@ class TSNE(Base,
if n_iter <= 100:
warnings.warn("n_iter = {} might cause TSNE to output wrong "
"results. Set it higher.".format(n_iter))
if init.lower() != 'random':
# TODO https://github.com/rapidsai/cuml/issues/3458
warnings.warn("TSNE does not support {} but only random "
"initialization.".format(init))
init = 'random'
if init.lower() != 'random' and init.lower() != 'pca':
raise ValueError("TSNE does not support {} but only random and pca "
"initialization.".format(init))
if angle < 0 or angle > 1:
raise ValueError("angle = {} should be ≥ 0 and ≤ 1".format(angle))
if n_neighbors < 0:
Expand Down Expand Up @@ -437,7 +439,7 @@ class TSNE(Base,
# Handle dense inputs
else:
self.X_m, n, p, _ = \
input_to_cuml_array(X, order='C', check_dtype=np.float32,
input_to_cuml_array(X, order='F', check_dtype=np.float32,
convert_to_dtype=(np.float32
if convert_dtype
else None))
Expand Down Expand Up @@ -599,10 +601,14 @@ class TSNE(Base,
params.post_momentum = <float> self.post_momentum
params.random_state = <long long> seed
params.verbosity = <int> self.verbose
params.initialize_embeddings = <bool> True
params.square_distances = <bool> self.square_distances
params.algorithm = algo

if self.init.lower() == 'random':
params.init = TSNE_INIT.RANDOM
elif self.init.lower() == 'pca':
params.init = TSNE_INIT.PCA

# metric
metric_parsing = {
"l2": DistanceType.L2SqrtExpanded,
Expand Down
6 changes: 4 additions & 2 deletions python/cuml/tests/test_tsne.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -184,8 +184,9 @@ def test_tsne_precomputed_knn(precomputed_type, sparse_input):
assert trust >= 0.92


@pytest.mark.parametrize("init", ["random", "pca"])
@pytest.mark.parametrize("method", ["fft", "barnes_hut"])
def test_tsne(test_datasets, method):
def test_tsne(test_datasets, method, init):
"""
This tests how TSNE handles a lot of input data across time.
(1) Numpy arrays are passed in
Expand All @@ -205,6 +206,7 @@ def test_tsne(test_datasets, method):
method=method,
min_grad_norm=1e-12,
perplexity=DEFAULT_PERPLEXITY,
init=init,
)

Y = tsne.fit_transform(X)
Expand Down

0 comments on commit 1e7de60

Please sign in to comment.