Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse TSNE #3293

Merged
merged 25 commits into from
Jan 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions cpp/include/cuml/manifold/tsne.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ namespace ML {
* approach is available in their article t-SNE-CUDA: GPU-Accelerated t-SNE and
* its Applications to Modern Data (https://arxiv.org/abs/1807.11824).
*/
void TSNE_fit(const raft::handle_t &handle, const float *X, float *Y,
const int n, const int p, int64_t *knn_indices, float *knn_dists,
const int dim = 2, int n_neighbors = 1023,
const float theta = 0.5f, const float epssq = 0.0025,
float perplexity = 50.0f, const int perplexity_max_iter = 100,
void TSNE_fit(const raft::handle_t &handle, float *X, float *Y, int n, int p,
int64_t *knn_indices, float *knn_dists, const int dim = 2,
int n_neighbors = 1023, const float theta = 0.5f,
const float epssq = 0.0025, float perplexity = 50.0f,
const int perplexity_max_iter = 100,
const float perplexity_tol = 1e-5,
const float early_exaggeration = 12.0f,
const int exaggeration_iter = 250, const float min_gain = 0.01f,
Expand All @@ -89,4 +89,73 @@ void TSNE_fit(const raft::handle_t &handle, const float *X, float *Y,
int verbosity = CUML_LEVEL_INFO,
const bool initialize_embeddings = true, bool barnes_hut = true);

/**
* @brief Dimensionality reduction via TSNE using either Barnes Hut O(NlogN)
* or brute force O(N^2).
*
* @param[in] handle The GPU handle.
* @param[in] indptr indptr of CSR dataset.
* @param[in] indices indices of CSR dataset.
* @param[in] data data of CSR dataset.
* @param[out] Y The final embedding.
* @param[in] nnz The number of non-zero entries in the CSR.
* @param[in] n Number of rows in data X.
* @param[in] p Number of columns in data X.
* @param[in] knn_indices Array containing nearest neighors indices.
* @param[in] knn_dists Array containing nearest neighors distances.
* @param[in] dim Number of output dimensions for embeddings Y.
* @param[in] n_neighbors Number of nearest neighbors used.
* @param[in] theta Float between 0 and 1. Tradeoff for speed (0)
* vs accuracy (1) for Barnes Hut only.
* @param[in] epssq A tiny jitter to promote numerical stability.
* @param[in] perplexity How many nearest neighbors are used during
* construction of Pij.
* @param[in] perplexity_max_iter Number of iterations used to construct Pij.
* @param[in] perplexity_tol The small tolerance used for Pij to ensure
* numerical stability.
* @param[in] early_exaggeration How much early pressure you want the clusters
* in TSNE to spread out more.
* @param[in] exaggeration_iter How many iterations you want the early
* pressure to run for.
* @param[in] min_gain Rounds up small gradient updates.
* @param[in] pre_learning_rate The learning rate during exaggeration phase.
* @param[in] post_learning_rate The learning rate after exaggeration phase.
* @param[in] max_iter The maximum number of iterations TSNE should
* run for.
* @param[in] min_grad_norm The smallest gradient norm TSNE should
* terminate on.
* @param[in] pre_momentum The momentum used during the exaggeration
* phase.
* @param[in] post_momentum The momentum used after the exaggeration
* phase.
* @param[in] random_state Set this to -1 for pure random intializations
* or >= 0 for reproducible outputs.
* @param[in] verbosity verbosity level for logging messages during
* execution
* @param[in] initialize_embeddings Whether to overwrite the current Y vector
* with random noise.
* @param[in] barnes_hut Whether to use the fast Barnes Hut or use the
* slower exact version.
*
* The CUDA implementation is derived from the excellent CannyLabs open source
* implementation here: https://github.com/CannyLab/tsne-cuda/. The CannyLabs
* code is licensed according to the conditions in
* cuml/cpp/src/tsne/cannylabs_tsne_license.txt. A full description of their
* approach is available in their article t-SNE-CUDA: GPU-Accelerated t-SNE and
* its Applications to Modern Data (https://arxiv.org/abs/1807.11824).
*/
void TSNE_fit_sparse(
const raft::handle_t &handle, int *indptr, int *indices, float *data,
float *Y, int nnz, int n, int p, int *knn_indices, float *knn_dists,
const int dim = 2, int n_neighbors = 1023, const float theta = 0.5f,
const float epssq = 0.0025, float perplexity = 50.0f,
const int perplexity_max_iter = 100, const float perplexity_tol = 1e-5,
const float early_exaggeration = 12.0f, const int exaggeration_iter = 250,
const float min_gain = 0.01f, const float pre_learning_rate = 200.0f,
const float post_learning_rate = 500.0f, const int max_iter = 1000,
const float min_grad_norm = 1e-7, const float pre_momentum = 0.5,
const float post_momentum = 0.8, const long long random_state = -1,
int verbosity = CUML_LEVEL_INFO, const bool initialize_embeddings = true,
bool barnes_hut = true);

} // namespace ML
99 changes: 53 additions & 46 deletions cpp/src/tsne/barnes_hut.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ namespace TSNE {
* @param[in] random_state: Set this to -1 for pure random intializations or >= 0 for reproducible outputs.
* @param[in] initialize_embeddings: Whether to overwrite the current Y vector with random noise.
*/
void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
const raft::handle_t &handle, float *Y, const int n,
const float theta = 0.5f, const float epssq = 0.0025,
template <typename value_idx, typename value_t>
void Barnes_Hut(value_t *VAL, const value_idx *COL, const value_idx *ROW,
const value_idx NNZ, const raft::handle_t &handle, value_t *Y,
const value_idx n, const float theta = 0.5f,
const float epssq = 0.0025,
const float early_exaggeration = 12.0f,
const int exaggeration_iter = 250, const float min_gain = 0.01f,
const float pre_learning_rate = 200.0f,
Expand All @@ -66,7 +68,7 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
//---------------------------------------------------
const int blocks = raft::getMultiProcessorCount();

int nnodes = n * 2;
auto nnodes = n * 2;
if (nnodes < 1024 * blocks) nnodes = 1024 * blocks;
while ((nnodes & (32 - 1)) != 0) nnodes++;
nnodes--;
Expand All @@ -75,64 +77,65 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
// Allocate more space
// MLCommon::device_buffer<unsigned> errl(d_alloc, stream, 1);
MLCommon::device_buffer<unsigned> limiter(d_alloc, stream, 1);
MLCommon::device_buffer<int> maxdepthd(d_alloc, stream, 1);
MLCommon::device_buffer<int> bottomd(d_alloc, stream, 1);
MLCommon::device_buffer<float> radiusd(d_alloc, stream, 1);
MLCommon::device_buffer<value_idx> maxdepthd(d_alloc, stream, 1);
MLCommon::device_buffer<value_idx> bottomd(d_alloc, stream, 1);
MLCommon::device_buffer<value_t> radiusd(d_alloc, stream, 1);

TSNE::InitializationKernel<<<1, 1, 0, stream>>>(/*errl.data(),*/
limiter.data(),
maxdepthd.data(),
radiusd.data());
CUDA_CHECK(cudaPeekAtLastError());

const int FOUR_NNODES = 4 * nnodes;
const int FOUR_N = 4 * n;
const value_idx FOUR_NNODES = 4 * nnodes;
const value_idx FOUR_N = 4 * n;
const float theta_squared = theta * theta;
const int NNODES = nnodes;
const value_idx NNODES = nnodes;

// Actual allocations
MLCommon::device_buffer<int> startl(d_alloc, stream, nnodes + 1);
MLCommon::device_buffer<int> childl(d_alloc, stream, (nnodes + 1) * 4);
MLCommon::device_buffer<float> massl(d_alloc, stream, nnodes + 1);
MLCommon::device_buffer<value_idx> startl(d_alloc, stream, nnodes + 1);
MLCommon::device_buffer<value_idx> childl(d_alloc, stream, (nnodes + 1) * 4);
MLCommon::device_buffer<value_t> massl(d_alloc, stream, nnodes + 1);

thrust::device_ptr<float> begin_massl =
thrust::device_ptr<value_t> begin_massl =
thrust::device_pointer_cast(massl.data());
thrust::fill(thrust::cuda::par.on(stream), begin_massl,
begin_massl + (nnodes + 1), 1.0f);

MLCommon::device_buffer<float> maxxl(d_alloc, stream, blocks * FACTOR1);
MLCommon::device_buffer<float> maxyl(d_alloc, stream, blocks * FACTOR1);
MLCommon::device_buffer<float> minxl(d_alloc, stream, blocks * FACTOR1);
MLCommon::device_buffer<float> minyl(d_alloc, stream, blocks * FACTOR1);
MLCommon::device_buffer<value_t> maxxl(d_alloc, stream, blocks * FACTOR1);
MLCommon::device_buffer<value_t> maxyl(d_alloc, stream, blocks * FACTOR1);
MLCommon::device_buffer<value_t> minxl(d_alloc, stream, blocks * FACTOR1);
MLCommon::device_buffer<value_t> minyl(d_alloc, stream, blocks * FACTOR1);

// SummarizationKernel
MLCommon::device_buffer<int> countl(d_alloc, stream, nnodes + 1);
MLCommon::device_buffer<value_idx> countl(d_alloc, stream, nnodes + 1);

// SortKernel
MLCommon::device_buffer<int> sortl(d_alloc, stream, nnodes + 1);
MLCommon::device_buffer<value_idx> sortl(d_alloc, stream, nnodes + 1);

// RepulsionKernel
MLCommon::device_buffer<float> rep_forces(d_alloc, stream, (nnodes + 1) * 2);
MLCommon::device_buffer<float> attr_forces(
MLCommon::device_buffer<value_t> rep_forces(d_alloc, stream,
(nnodes + 1) * 2);
MLCommon::device_buffer<value_t> attr_forces(
d_alloc, stream, n * 2); // n*2 double for reduction sum

MLCommon::device_buffer<float> Z_norm(d_alloc, stream, 1);
MLCommon::device_buffer<value_t> Z_norm(d_alloc, stream, 1);

MLCommon::device_buffer<float> radiusd_squared(d_alloc, stream, 1);
MLCommon::device_buffer<value_t> radiusd_squared(d_alloc, stream, 1);

// Apply
MLCommon::device_buffer<float> gains_bh(d_alloc, stream, n * 2);
MLCommon::device_buffer<value_t> gains_bh(d_alloc, stream, n * 2);

thrust::device_ptr<float> begin_gains_bh =
thrust::device_ptr<value_t> begin_gains_bh =
thrust::device_pointer_cast(gains_bh.data());
thrust::fill(thrust::cuda::par.on(stream), begin_gains_bh,
begin_gains_bh + (n * 2), 1.0f);

MLCommon::device_buffer<float> old_forces(d_alloc, stream, n * 2);
MLCommon::device_buffer<value_t> old_forces(d_alloc, stream, n * 2);
CUDA_CHECK(
cudaMemsetAsync(old_forces.data(), 0, sizeof(float) * n * 2, stream));
cudaMemsetAsync(old_forces.data(), 0, sizeof(value_t) * n * 2, stream));

MLCommon::device_buffer<float> YY(d_alloc, stream, (nnodes + 1) * 2);
MLCommon::device_buffer<value_t> YY(d_alloc, stream, (nnodes + 1) * 2);
if (initialize_embeddings) {
random_vector(YY.data(), -0.0001f, 0.0001f, (nnodes + 1) * 2, stream,
random_state);
Expand All @@ -143,27 +146,30 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,

// Set cache levels for faster algorithm execution
//---------------------------------------------------
CUDA_CHECK(
cudaFuncSetCacheConfig(TSNE::BoundingBoxKernel, cudaFuncCachePreferShared));
CUDA_CHECK(
cudaFuncSetCacheConfig(TSNE::TreeBuildingKernel, cudaFuncCachePreferL1));
CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::ClearKernel1, cudaFuncCachePreferL1));
CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::ClearKernel2, cudaFuncCachePreferL1));
CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::SummarizationKernel,
CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::BoundingBoxKernel<value_idx, value_t>,
cudaFuncCachePreferShared));
CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::SortKernel, cudaFuncCachePreferL1));
CUDA_CHECK(
cudaFuncSetCacheConfig(TSNE::RepulsionKernel, cudaFuncCachePreferL1));
CUDA_CHECK(
cudaFuncSetCacheConfig(TSNE::attractive_kernel_bh, cudaFuncCachePreferL1));
CUDA_CHECK(cudaFuncSetCacheConfig(
TSNE::TreeBuildingKernel<value_idx, value_t>, cudaFuncCachePreferL1));
CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::ClearKernel1<value_idx>,
cudaFuncCachePreferL1));
CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::ClearKernel2<value_idx, value_t>,
cudaFuncCachePreferL1));
CUDA_CHECK(cudaFuncSetCacheConfig(
TSNE::SummarizationKernel<value_idx, value_t>, cudaFuncCachePreferShared));
CUDA_CHECK(
cudaFuncSetCacheConfig(TSNE::IntegrationKernel, cudaFuncCachePreferL1));
cudaFuncSetCacheConfig(TSNE::SortKernel<value_idx>, cudaFuncCachePreferL1));
CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::RepulsionKernel<value_idx, value_t>,
cudaFuncCachePreferL1));
CUDA_CHECK(cudaFuncSetCacheConfig(
TSNE::attractive_kernel_bh<value_idx, value_t>, cudaFuncCachePreferL1));
CUDA_CHECK(cudaFuncSetCacheConfig(TSNE::IntegrationKernel<value_idx, value_t>,
cudaFuncCachePreferL1));
// Do gradient updates
//---------------------------------------------------
CUML_LOG_DEBUG("Start gradient updates!");

float momentum = pre_momentum;
float learning_rate = pre_learning_rate;
value_t momentum = pre_momentum;
value_t learning_rate = pre_learning_rate;

for (int iter = 0; iter < max_iter; iter++) {
CUDA_CHECK(cudaMemsetAsync(static_cast<void *>(rep_forces.data()), 0,
Expand All @@ -181,7 +187,7 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
if (iter == exaggeration_iter) {
momentum = post_momentum;
// Divide perplexities
const float div = 1.0f / early_exaggeration;
const value_t div = 1.0f / early_exaggeration;
raft::linalg::scalarMultiply(VAL, VAL, div, NNZ, stream);
learning_rate = post_learning_rate;
}
Expand Down Expand Up @@ -252,7 +258,8 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
START_TIMER;
// TODO: Calculate Kullback-Leibler divergence
// For general embedding dimensions
TSNE::attractive_kernel_bh<<<raft::ceildiv(NNZ, 1024), 1024, 0, stream>>>(
TSNE::attractive_kernel_bh<<<raft::ceildiv(NNZ, (value_idx)1024), 1024, 0,
stream>>>(
VAL, COL, ROW, YY.data(), YY.data() + nnodes + 1, attr_forces.data(),
attr_forces.data() + n, NNZ);
CUDA_CHECK(cudaPeekAtLastError());
Expand Down
Loading