Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] tSNE Lock up #2565

Merged
merged 17 commits into from
Jul 29, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
- PR #2535: Fix issue with incorrect docker image being used in local build script
- PR #2542: Fix small memory leak in TSNE
- PR #2552: Fixed the length argument of updateDevice calls in RF test
- PR #2565: Fix cell allocation code to avoid loops in quad-tree. Prevent NaNs causing infinite descent

# cuML 0.14.0 (03 Jun 2020)

Expand Down
184 changes: 87 additions & 97 deletions cpp/src/tsne/barnes_hut.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include <common/cudart_utils.h>
#include <common/device_buffer.hpp>
#include <cuml/common/logger.hpp>
#include "bh_kernels.cuh"
#include "utils.cuh"
Expand Down Expand Up @@ -55,6 +56,7 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
const int max_iter = 1000, const float min_grad_norm = 1e-7,
const float pre_momentum = 0.5, const float post_momentum = 0.8,
const long long random_state = -1) {
using MLCommon::device_buffer;
drobison00 marked this conversation as resolved.
Show resolved Hide resolved
auto d_alloc = handle.getDeviceAllocator();
cudaStream_t stream = handle.getStream();

Expand All @@ -71,71 +73,70 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
// Allocate more space
//---------------------------------------------------
//int *errl = (int *)d_alloc->allocate(sizeof(int), stream);
unsigned *limiter = (unsigned *)d_alloc->allocate(sizeof(unsigned), stream);
int *maxdepthd = (int *)d_alloc->allocate(sizeof(int), stream);
int *bottomd = (int *)d_alloc->allocate(sizeof(int), stream);
float *radiusd = (float *)d_alloc->allocate(sizeof(float), stream);
device_buffer<unsigned> limiter(d_alloc, stream, 1);
device_buffer<int> maxdepthd(d_alloc, stream, 1);
device_buffer<int> bottomd(d_alloc, stream, 1);
device_buffer<float> radiusd(d_alloc, stream, 1);

TSNE::InitializationKernel<<<1, 1, 0, stream>>>(
/*errl,*/ limiter.data(), maxdepthd.data(), radiusd.data());

TSNE::InitializationKernel<<<1, 1, 0, stream>>>(/*errl,*/ limiter, maxdepthd,
radiusd);
CUDA_CHECK(cudaPeekAtLastError());

const int FOUR_NNODES = 4 * nnodes;
const int FOUR_N = 4 * n;
const float theta_squared = theta * theta;
const int NNODES = nnodes;
bool h_chk_finite;

// Actual mallocs
int *startl = (int *)d_alloc->allocate(sizeof(int) * (nnodes + 1), stream);
int *childl =
(int *)d_alloc->allocate(sizeof(int) * (nnodes + 1) * 4, stream);
float *massl =
(float *)d_alloc->allocate(sizeof(float) * (nnodes + 1), stream);
thrust::device_ptr<float> begin_massl = thrust::device_pointer_cast(massl);
device_buffer<int> startl(d_alloc, stream, nnodes + 1);
device_buffer<int> childl(d_alloc, stream, (nnodes + 1) * 4);
device_buffer<float> massl(d_alloc, stream, nnodes + 1);
thrust::device_ptr<float> begin_massl =
thrust::device_pointer_cast(massl.data());
thrust::fill(thrust::cuda::par.on(stream), begin_massl,
begin_massl + (nnodes + 1), 1.0f);

float *maxxl =
(float *)d_alloc->allocate(sizeof(float) * blocks * FACTOR1, stream);
float *maxyl =
(float *)d_alloc->allocate(sizeof(float) * blocks * FACTOR1, stream);
float *minxl =
(float *)d_alloc->allocate(sizeof(float) * blocks * FACTOR1, stream);
float *minyl =
(float *)d_alloc->allocate(sizeof(float) * blocks * FACTOR1, stream);
device_buffer<float> maxxl(d_alloc, stream, blocks * FACTOR1);
device_buffer<float> maxyl(d_alloc, stream, blocks * FACTOR1);
device_buffer<float> minxl(d_alloc, stream, blocks * FACTOR1);
device_buffer<float> minyl(d_alloc, stream, blocks * FACTOR1);

// SummarizationKernel
int *countl = (int *)d_alloc->allocate(sizeof(int) * (nnodes + 1), stream);
device_buffer<int> countl(d_alloc, stream, nnodes + 1);

// SortKernel
int *sortl = (int *)d_alloc->allocate(sizeof(int) * (nnodes + 1), stream);
device_buffer<int> sortl(d_alloc, stream, nnodes + 1);

// RepulsionKernel
float *rep_forces =
(float *)d_alloc->allocate(sizeof(float) * (nnodes + 1) * 2, stream);
float *attr_forces = (float *)d_alloc->allocate(
sizeof(float) * n * 2, stream); // n*2 double for reduction sum
device_buffer<float> rep_forces(d_alloc, stream, (nnodes + 1) * 2);
device_buffer<float> attr_forces(d_alloc, stream, n * 2);

float *norm_add1 = (float *)d_alloc->allocate(sizeof(float) * n, stream);
float *norm = (float *)d_alloc->allocate(sizeof(float) * n, stream);
float *Z_norm = (float *)d_alloc->allocate(sizeof(float), stream);
device_buffer<float> norm_add1(d_alloc, stream, n);
device_buffer<float> norm(d_alloc, stream, n);
device_buffer<float> Z_norm(d_alloc, stream, 1);

float *radiusd_squared = (float *)d_alloc->allocate(sizeof(float), stream);
device_buffer<float> radiusd_squared(d_alloc, stream, 1);

// Apply
float *gains_bh = (float *)d_alloc->allocate(sizeof(float) * n * 2, stream);
device_buffer<float> gains_bh(d_alloc, stream, n * 2);
thrust::device_ptr<float> begin_gains_bh =
thrust::device_pointer_cast(gains_bh);
thrust::device_pointer_cast(gains_bh.data());
thrust::fill(thrust::cuda::par.on(stream), begin_gains_bh,
begin_gains_bh + (n * 2), 1.0f);

float *old_forces = (float *)d_alloc->allocate(sizeof(float) * n * 2, stream);
CUDA_CHECK(cudaMemsetAsync(old_forces, 0, sizeof(float) * n * 2, stream));
device_buffer<float> old_forces(d_alloc, stream, n * 2);
CUDA_CHECK(
cudaMemsetAsync(old_forces.data(), 0, sizeof(float) * n * 2, stream));

float *YY =
(float *)d_alloc->allocate(sizeof(float) * (nnodes + 1) * 2, stream);
random_vector(YY, -0.0001f, 0.0001f, (nnodes + 1) * 2, stream, random_state);
ASSERT(YY != NULL && rep_forces != NULL, "[ERROR] Possibly no more memory");
device_buffer<float> YY(d_alloc, stream, (nnodes + 1) * 2);
device_buffer<float> YY_prev(d_alloc, stream, (nnodes + 1) * 2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this problem is so dataset dependent and doesn't seem to occur on most real-world datasets, it would be unfortunate to have to double the amount of embedding memory by default. While it's not nearly the same as duplicating the input data, training 50M vertices still requires 400mb of extra memory just for this feature.

What do you think about making this feature optional and maybe mentioning the option in the warning? If the option is disabled, we just return the embedding the way it is w/ the NaN values. The option can be enabled and use a little extra memory if users still want the embeddings knowing that training wasn't able to complete successfully.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable. Are you thinking the option would be exposed at the Python level?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just a simple flag exposed through the TSNE constructor would be fine.

random_vector(YY.data(), -0.0001f, 0.0001f, (nnodes + 1) * 2, stream,
random_state);
ASSERT(
YY.data() != NULL && YY_prev.data() != NULL && rep_forces.data() != NULL,
"[ERROR] Possibly no more memory");

// Set cache levels for faster algorithm execution
//---------------------------------------------------
Expand All @@ -148,7 +149,6 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
cudaFuncSetCacheConfig(TSNE::RepulsionKernel, cudaFuncCachePreferL1);
cudaFuncSetCacheConfig(TSNE::attractive_kernel_bh, cudaFuncCachePreferL1);
cudaFuncSetCacheConfig(TSNE::IntegrationKernel, cudaFuncCachePreferL1);

// Do gradient updates
//---------------------------------------------------
CUML_LOG_DEBUG("Start gradient updates!");
Expand All @@ -157,11 +157,15 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
float learning_rate = pre_learning_rate;

for (int iter = 0; iter < max_iter; iter++) {
CUDA_CHECK(cudaMemsetAsync(rep_forces.data(), 0,
sizeof(float) * (nnodes + 1) * 2, stream));
CUDA_CHECK(
cudaMemsetAsync(rep_forces, 0, sizeof(float) * (nnodes + 1) * 2, stream));
CUDA_CHECK(cudaMemsetAsync(attr_forces, 0, sizeof(float) * n * 2, stream));
TSNE::Reset_Normalization<<<1, 1, 0, stream>>>(Z_norm, radiusd_squared,
bottomd, NNODES, radiusd);
cudaMemsetAsync(attr_forces.data(), 0, sizeof(float) * n * 2, stream));

MLCommon::copy(YY_prev.data(), YY.data(), (nnodes + 1) * 2, stream);
TSNE::Reset_Normalization<<<1, 1, 0, stream>>>(
Z_norm.data(), radiusd_squared.data(), bottomd.data(), NNODES,
radiusd.data());
CUDA_CHECK(cudaPeekAtLastError());

if (iter == exaggeration_iter) {
Expand All @@ -173,128 +177,114 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,

START_TIMER;
TSNE::BoundingBoxKernel<<<blocks * FACTOR1, THREADS1, 0, stream>>>(
startl, childl, massl, YY, YY + nnodes + 1, maxxl, maxyl, minxl, minyl,
FOUR_NNODES, NNODES, n, limiter, radiusd);
startl.data(), childl.data(), massl.data(), YY.data(),
YY.data() + nnodes + 1, maxxl.data(), maxyl.data(), minxl.data(),
minyl.data(), FOUR_NNODES, NNODES, n, limiter.data(), radiusd.data());
CUDA_CHECK(cudaPeekAtLastError());

END_TIMER(BoundingBoxKernel_time);

START_TIMER;
TSNE::ClearKernel1<<<blocks, 1024, 0, stream>>>(childl, FOUR_NNODES,
TSNE::ClearKernel1<<<blocks, 1024, 0, stream>>>(childl.data(), FOUR_NNODES,
FOUR_N);
CUDA_CHECK(cudaPeekAtLastError());

END_TIMER(ClearKernel1_time);

START_TIMER;
TSNE::TreeBuildingKernel<<<blocks * FACTOR2, THREADS2, 0, stream>>>(
/*errl,*/ childl, YY, YY + nnodes + 1, NNODES, n, maxdepthd, bottomd,
radiusd);
/*errl,*/ childl.data(), YY.data(), YY.data() + nnodes + 1, NNODES, n,
maxdepthd.data(), bottomd.data(), radiusd.data());
CUDA_CHECK(cudaPeekAtLastError());

END_TIMER(TreeBuildingKernel_time);

START_TIMER;
TSNE::ClearKernel2<<<blocks * 1, 1024, 0, stream>>>(startl, massl, NNODES,
bottomd);
TSNE::ClearKernel2<<<blocks * 1, 1024, 0, stream>>>(
startl.data(), massl.data(), NNODES, bottomd.data());
CUDA_CHECK(cudaPeekAtLastError());

END_TIMER(ClearKernel2_time);

START_TIMER;
TSNE::SummarizationKernel<<<blocks * FACTOR3, THREADS3, 0, stream>>>(
countl, childl, massl, YY, YY + nnodes + 1, NNODES, n, bottomd);
countl.data(), childl.data(), massl.data(), YY.data(),
YY.data() + nnodes + 1, NNODES, n, bottomd.data());
CUDA_CHECK(cudaPeekAtLastError());

END_TIMER(SummarizationKernel_time);

START_TIMER;
TSNE::SortKernel<<<blocks * FACTOR4, THREADS4, 0, stream>>>(
sortl, countl, startl, childl, NNODES, n, bottomd);
sortl.data(), countl.data(), startl.data(), childl.data(), NNODES, n,
bottomd.data());
CUDA_CHECK(cudaPeekAtLastError());

END_TIMER(SortKernel_time);

START_TIMER;
TSNE::RepulsionKernel<<<blocks * FACTOR5, THREADS5, 0, stream>>>(
/*errl,*/ theta, epssq, sortl, childl, massl, YY, YY + nnodes + 1,
rep_forces, rep_forces + nnodes + 1, Z_norm, theta_squared, NNODES,
FOUR_NNODES, n, radiusd_squared, maxdepthd);
/*errl,*/ theta, epssq, sortl.data(), childl.data(), massl.data(),
YY.data(), YY.data() + nnodes + 1, rep_forces.data(),
rep_forces.data() + nnodes + 1, Z_norm.data(), theta_squared, NNODES,
FOUR_NNODES, n, radiusd_squared.data(), maxdepthd.data());
CUDA_CHECK(cudaPeekAtLastError());

END_TIMER(RepulsionTime);

START_TIMER;
TSNE::Find_Normalization<<<1, 1, 0, stream>>>(Z_norm, n);
TSNE::Find_Normalization<<<1, 1, 0, stream>>>(Z_norm.data(), n);
CUDA_CHECK(cudaPeekAtLastError());

END_TIMER(Reduction_time);

START_TIMER;
TSNE::get_norm<<<MLCommon::ceildiv(n, 1024), 1024, 0, stream>>>(
YY, YY + nnodes + 1, norm, norm_add1, n);
YY.data(), YY.data() + nnodes + 1, norm.data(), norm_add1.data(), n);
CUDA_CHECK(cudaPeekAtLastError());

// TODO: Calculate Kullback-Leibler divergence
// For general embedding dimensions
TSNE::
attractive_kernel_bh<<<MLCommon::ceildiv(NNZ, 1024), 1024, 0, stream>>>(
VAL, COL, ROW, YY, YY + nnodes + 1, norm, norm_add1, attr_forces,
attr_forces + n, NNZ);
VAL, COL, ROW, YY.data(), YY.data() + nnodes + 1, norm.data(),
norm_add1.data(), attr_forces.data(), attr_forces.data() + n, NNZ);
CUDA_CHECK(cudaPeekAtLastError());
END_TIMER(attractive_time);

START_TIMER;
TSNE::IntegrationKernel<<<blocks * FACTOR6, THREADS6, 0, stream>>>(
learning_rate, momentum, early_exaggeration, YY, YY + nnodes + 1,
attr_forces, attr_forces + n, rep_forces, rep_forces + nnodes + 1,
gains_bh, gains_bh + n, old_forces, old_forces + n, Z_norm, n);
learning_rate, momentum, early_exaggeration, YY.data(),
YY.data() + nnodes + 1, attr_forces.data(), attr_forces.data() + n,
rep_forces.data(), rep_forces.data() + nnodes + 1, gains_bh.data(),
gains_bh.data() + n, old_forces.data(), old_forces.data() + n,
Z_norm.data(), n);
CUDA_CHECK(cudaPeekAtLastError());

END_TIMER(IntegrationKernel_time);

h_chk_finite = thrust::transform_reduce(
thrust::cuda::par.on(stream), YY.data(), YY.data() + (nnodes + 1) * 2,
FiniteTestUnary(), 0, thrust::plus<bool>());
if (h_chk_finite) {
CUML_LOG_WARN(
"Non-finite result detected during Barnes Hut iteration: %d, returning last "
"known good positions.",
iter);
MLCommon::copy(YY.data(), YY_prev.data(), (nnodes + 1) * 2, stream);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you been able to visually inspect the output embeddings of any of the datasets that are causing this failure? It would be nice to know if they have been reasonably embedded by the time the rollback occurs or if they are just garbage.

break;
}
}
PRINT_TIMES;

// Copy final YY into true output Y
drobison00 marked this conversation as resolved.
Show resolved Hide resolved
thrust::device_ptr<float> Y_begin = thrust::device_pointer_cast(Y);
thrust::copy(thrust::cuda::par.on(stream), YY, YY + n, Y_begin);
drobison00 marked this conversation as resolved.
Show resolved Hide resolved
CUDA_CHECK(cudaPeekAtLastError());

thrust::copy(thrust::cuda::par.on(stream), YY + nnodes + 1,
YY + nnodes + 1 + n, Y_begin + n);
MLCommon::copy(Y, YY.data(), n, stream);
CUDA_CHECK(cudaPeekAtLastError());
drobison00 marked this conversation as resolved.
Show resolved Hide resolved

// Deallocate everything
//d_alloc->deallocate(errl, sizeof(int), stream);
d_alloc->deallocate(limiter, sizeof(unsigned), stream);
d_alloc->deallocate(maxdepthd, sizeof(int), stream);
d_alloc->deallocate(bottomd, sizeof(int), stream);
d_alloc->deallocate(radiusd, sizeof(float), stream);

d_alloc->deallocate(startl, sizeof(int) * (nnodes + 1), stream);
d_alloc->deallocate(childl, sizeof(int) * (nnodes + 1) * 4, stream);
d_alloc->deallocate(massl, sizeof(float) * (nnodes + 1), stream);

d_alloc->deallocate(maxxl, sizeof(float) * blocks * FACTOR1, stream);
d_alloc->deallocate(maxyl, sizeof(float) * blocks * FACTOR1, stream);
d_alloc->deallocate(minxl, sizeof(float) * blocks * FACTOR1, stream);
d_alloc->deallocate(minyl, sizeof(float) * blocks * FACTOR1, stream);

d_alloc->deallocate(countl, sizeof(int) * (nnodes + 1), stream);
d_alloc->deallocate(sortl, sizeof(int) * (nnodes + 1), stream);

d_alloc->deallocate(rep_forces, sizeof(float) * (nnodes + 1) * 2, stream);
d_alloc->deallocate(attr_forces, sizeof(float) * n * 2, stream);
d_alloc->deallocate(norm, sizeof(float) * n, stream);
d_alloc->deallocate(norm_add1, sizeof(float) * n, stream);

d_alloc->deallocate(Z_norm, sizeof(float), stream);
d_alloc->deallocate(radiusd_squared, sizeof(float), stream);

d_alloc->deallocate(gains_bh, sizeof(float) * n * 2, stream);
d_alloc->deallocate(old_forces, sizeof(float) * n * 2, stream);

d_alloc->deallocate(YY, sizeof(float) * (nnodes + 1) * 2, stream);
MLCommon::copy(Y + n, YY.data() + nnodes + 1, n, stream);
CUDA_CHECK(cudaPeekAtLastError());
}

} // namespace TSNE
Expand Down
Loading