diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ac892de71..993fca8d56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -105,6 +105,7 @@ - PR #2535: Fix issue with incorrect docker image being used in local build script - PR #2542: Fix small memory leak in TSNE - PR #2552: Fixed the length argument of updateDevice calls in RF test +- PR #2565: Fix cell allocation code to avoid loops in quad-tree. Prevent NaNs causing infinite descent - PR #2563: Update scipy call for arima gradient test - PR #2569: Fix for cuDF update - PR #2508: Use keyword parameters in sklearn.datasets.make_* functions diff --git a/cpp/src/tsne/barnes_hut.cuh b/cpp/src/tsne/barnes_hut.cuh index 1471e3495f..642339ccf4 100644 --- a/cpp/src/tsne/barnes_hut.cuh +++ b/cpp/src/tsne/barnes_hut.cuh @@ -62,6 +62,7 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, // Get device properites //--------------------------------------------------- const int blocks = MLCommon::getMultiProcessorCount(); + int h_flag; int nnodes = n * 2; if (nnodes < 1024 * blocks) nnodes = 1024 * blocks; @@ -70,17 +71,19 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, CUML_LOG_DEBUG("N_nodes = %d blocks = %d", nnodes, blocks); // Allocate more space - //--------------------------------------------------- // MLCommon::device_buffer errl(d_alloc, stream, 1); MLCommon::device_buffer limiter(d_alloc, stream, 1); MLCommon::device_buffer maxdepthd(d_alloc, stream, 1); MLCommon::device_buffer bottomd(d_alloc, stream, 1); MLCommon::device_buffer radiusd(d_alloc, stream, 1); + MLCommon::device_buffer flag_unstable_computation(d_alloc, stream, 1); TSNE::InitializationKernel<<<1, 1, 0, stream>>>(/*errl.data(),*/ limiter.data(), maxdepthd.data(), - radiusd.data()); + radiusd.data(), + flag_unstable_computation + .data()); CUDA_CHECK(cudaPeekAtLastError()); const int FOUR_NNODES = 4 * nnodes; @@ -121,6 +124,7 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, // Apply MLCommon::device_buffer gains_bh(d_alloc, stream, n * 2); + thrust::device_ptr begin_gains_bh = thrust::device_pointer_cast(gains_bh.data()); thrust::fill(thrust::cuda::par.on(stream), begin_gains_bh, @@ -146,7 +150,6 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, cudaFuncSetCacheConfig(TSNE::RepulsionKernel, cudaFuncCachePreferL1); cudaFuncSetCacheConfig(TSNE::attractive_kernel_bh, cudaFuncCachePreferL1); cudaFuncSetCacheConfig(TSNE::IntegrationKernel, cudaFuncCachePreferL1); - // Do gradient updates //--------------------------------------------------- CUML_LOG_DEBUG("Start gradient updates!"); @@ -161,6 +164,7 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, CUDA_CHECK(cudaMemsetAsync(static_cast(attr_forces.data()), 0, attr_forces.size() * sizeof(*attr_forces.data()), stream)); + TSNE::Reset_Normalization<<<1, 1, 0, stream>>>( Z_norm.data(), radiusd_squared.data(), bottomd.data(), NNODES, radiusd.data()); @@ -247,10 +251,24 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ, TSNE:: attractive_kernel_bh<<>>( VAL, COL, ROW, YY.data(), YY.data() + nnodes + 1, norm.data(), - attr_forces.data(), attr_forces.data() + n, NNZ); + attr_forces.data(), attr_forces.data() + n, NNZ, + flag_unstable_computation.data()); CUDA_CHECK(cudaPeekAtLastError()); END_TIMER(attractive_time); + MLCommon::copy(&h_flag, flag_unstable_computation.data(), 1, stream); + if (h_flag) { + CUML_LOG_ERROR( + "Detected zero divisor in attractive force kernel after '%d' " + "iterations;" + " returning early. Your final results may not be accurate. In some " + "cases" + " this error can be resolved by increasing perplexity, and n_neighbors;" + " if the problem persists, please use 'method=exact'.", + iter); + break; + } + START_TIMER; TSNE::IntegrationKernel<<>>( learning_rate, momentum, early_exaggeration, YY.data(), diff --git a/cpp/src/tsne/bh_kernels.cuh b/cpp/src/tsne/bh_kernels.cuh index 6f772bfd9c..c4ca3b657d 100644 --- a/cpp/src/tsne/bh_kernels.cuh +++ b/cpp/src/tsne/bh_kernels.cuh @@ -35,7 +35,6 @@ #include #include -#include namespace ML { namespace TSNE { @@ -46,11 +45,13 @@ namespace TSNE { __global__ void InitializationKernel(/*int *restrict errd, */ unsigned *restrict limiter, int *restrict maxdepthd, - float *restrict radiusd) { + float *restrict radiusd, + int *restrict flag_unstable_computation) { // errd[0] = 0; maxdepthd[0] = 1; limiter[0] = 0; radiusd[0] = 0.0f; + flag_unstable_computation[0] = 0; } /** @@ -171,7 +172,8 @@ __global__ __launch_bounds__(1024, 1) void ClearKernel1(int *restrict childd, } /** - * Build the actual KD Tree. + * Build the actual QuadTree. + * See: https://iss.oden.utexas.edu/Publications/Papers/burtscher11.pdf */ __global__ __launch_bounds__( THREADS2, FACTOR2) void TreeBuildingKernel(/* int *restrict errd, */ @@ -194,6 +196,7 @@ __global__ __launch_bounds__( int localmaxdepth = 1; int skip = 1; + const int inc = blockDim.x * gridDim.x; int i = threadIdx.x + blockIdx.x * blockDim.x; @@ -206,6 +209,11 @@ __global__ __launch_bounds__( depth = 1; r = radius * 0.5f; + /* Select child node 'j' + rootx < px rootx > px + * rooty < py 1 -> 3 0 -> 2 + * rooty > py 1 -> 1 0 -> 0 + */ x = rootx + ((rootx < (px = posxd[i])) ? (j = 1, r) : (j = 0, -r)); y = rooty + ((rooty < (py = posyd[i])) ? (j |= 2, r) : (-r)); @@ -217,17 +225,20 @@ __global__ __launch_bounds__( depth++; r *= 0.5f; - // determine which child to follow x += ((x < px) ? (j = 1, r) : (j = 0, -r)); y += ((y < py) ? (j |= 2, r) : (-r)); } + // (ch)ild will be '-1' (nullptr), '-2' (locked), or an Integer corresponding to a body offset + // in the lower [0, N) blocks of childd if (ch != -2) { - // skip if child pointer is locked and try again later + // skip if child pointer was locked when we examined it, and try again later. locked = n * 4 + j; + // store the locked position in case we need to patch in a cell later. if (ch == -1) { + // Child is a nullptr ('-1'), so we write our body index to the leaf, and move on to the next body. if (atomicCAS(&childd[locked], -1, i) == -1) { if (depth > localmaxdepth) localmaxdepth = depth; @@ -235,23 +246,26 @@ __global__ __launch_bounds__( skip = 1; } } else { + // Child node isn't empty, so we store the current value of the child, lock the leaf, and patch in a new cell if (ch == atomicCAS(&childd[locked], ch, -2)) { - // try to lock patch = -1; while (ch >= 0) { depth++; const int cell = atomicSub(bottomd, 1) - 1; - if (cell <= N) { - // atomicExch(errd, 1); + if (cell == N) { atomicExch(bottomd, NNODES); + } else if (cell < N) { + depth--; + continue; } if (patch != -1) childd[n * 4 + j] = cell; if (cell > patch) patch = cell; + // Insert migrated child node j = (x < posxd[ch]) ? 1 : 0; if (y < posyd[ch]) j |= 2; @@ -264,7 +278,9 @@ __global__ __launch_bounds__( y += ((y < py) ? (j |= 2, r) : (-r)); ch = childd[n * 4 + j]; - if (r <= 1e-10) break; + if (r <= 1e-10) { + break; + } } childd[n * 4 + j] = i; @@ -276,6 +292,7 @@ __global__ __launch_bounds__( } } } + __threadfence(); if (skip == 2) childd[locked] = patch; @@ -635,7 +652,8 @@ __global__ void attractive_kernel_bh( const float *restrict VAL, const int *restrict COL, const int *restrict ROW, const float *restrict Y1, const float *restrict Y2, const float *restrict norm, float *restrict attract1, - float *restrict attract2, const int NNZ) { + float *restrict attract2, const int NNZ, + int *restrict flag_unstable_computation) { const int index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= NNZ) return; const int i = ROW[index]; @@ -643,9 +661,24 @@ __global__ void attractive_kernel_bh( // TODO: Calculate Kullback-Leibler divergence // TODO: Convert attractive forces to CSR format - const float PQ = __fdividef( - VAL[index], - norm[i] + 1.0f + norm[j] - 2.0f * (Y1[i] * Y1[j] + Y2[i] * Y2[j])); // P*Q + // Try single precision compute first + float denominator = __fmaf_rn(-2.0f, (Y1[i] * Y1[j]), norm[i] + 1.0f) + + __fmaf_rn(-2.0f, (Y2[i] * Y2[j]), norm[j]); + + if (__builtin_expect(denominator == 0, false)) { + double _Y1 = static_cast(Y1[i]) * static_cast(Y1[j]); + double _Y2 = static_cast(Y2[i]) * static_cast(Y2[j]); + double dbl_denominator = + __fma_rn(-2.0f, _Y1, norm[i] + 1.0f) + __fma_rn(-2.0f, _Y2, norm[j]); + + if (__builtin_expect(dbl_denominator == 0, false)) { + dbl_denominator = 1.0f; + flag_unstable_computation[0] = 1; + } + + denominator = dbl_denominator; + } + const float PQ = __fdividef(VAL[index], denominator); // Apply forces atomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j])); diff --git a/cpp/src/tsne/exact_kernels.cuh b/cpp/src/tsne/exact_kernels.cuh index b16ade068f..df0117f334 100644 --- a/cpp/src/tsne/exact_kernels.cuh +++ b/cpp/src/tsne/exact_kernels.cuh @@ -26,7 +26,7 @@ namespace ML { namespace TSNE { /****************************************/ -/* Finds the best guassian bandwith for +/* Finds the best Gaussian bandwidth for each row in the dataset */ __global__ void sigmas_kernel(const float *restrict distances, float *restrict P, const float perplexity, @@ -45,7 +45,7 @@ __global__ void sigmas_kernel(const float *restrict distances, for (int step = 0; step < epochs; step++) { float sum_Pi = FLT_EPSILON; - // Exponentiate to get guassian + // Exponentiate to get Gaussian for (int j = 0; j < k; j++) { P[ik + j] = __expf(-distances[ik + j] * beta); sum_Pi += P[ik + j]; @@ -84,7 +84,7 @@ __global__ void sigmas_kernel(const float *restrict distances, } /****************************************/ -/* Finds the best guassian bandwith for +/* Finds the best Gaussian bandwith for each row in the dataset */ __global__ void sigmas_kernel_2d(const float *restrict distances, float *restrict P, const float perplexity, @@ -101,7 +101,7 @@ __global__ void sigmas_kernel_2d(const float *restrict distances, register const int ik = i * 2; for (int step = 0; step < epochs; step++) { - // Exponentiate to get guassian + // Exponentiate to get Gaussian P[ik] = __expf(-distances[ik] * beta); P[ik + 1] = __expf(-distances[ik + 1] * beta); const float sum_Pi = FLT_EPSILON + P[ik] + P[ik + 1];