Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] tSNE Lock up #2565

Merged
merged 17 commits into from
Jul 29, 2020
Merged

[BUG] tSNE Lock up #2565

merged 17 commits into from
Jul 29, 2020

Conversation

drobison00
Copy link
Contributor

Prevents race behavior related to cell index selection in Barnes_Hut algorithm. Failures in this case could allow for loop injection within the tree structure, resulting in infinite looping/lockup behavior.

In certain cases, randomly generated test data can result in divide by zero / NaN errors when attractive force is calculated. Check for them and exit early if detected.

Updated BH code to use device_buffer instead of direct allocation.

Closes #2358

@drobison00 drobison00 requested a review from a team as a code owner July 15, 2020 22:08
@GPUtester
Copy link
Contributor

Please update the changelog in order to start CI tests.

View the gpuCI docs here.

1 similar comment
@GPUtester
Copy link
Contributor

Please update the changelog in order to start CI tests.

View the gpuCI docs here.

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The device buffers definitely clean this PR up quite a bit! I really have just two little things. One is very minor.

cpp/src/tsne/barnes_hut.cuh Outdated Show resolved Hide resolved
random_vector(YY, -0.0001f, 0.0001f, (nnodes + 1) * 2, stream, random_state);
ASSERT(YY != NULL && rep_forces != NULL, "[ERROR] Possibly no more memory");
device_buffer<float> YY(d_alloc, stream, (nnodes + 1) * 2);
device_buffer<float> YY_prev(d_alloc, stream, (nnodes + 1) * 2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this problem is so dataset dependent and doesn't seem to occur on most real-world datasets, it would be unfortunate to have to double the amount of embedding memory by default. While it's not nearly the same as duplicating the input data, training 50M vertices still requires 400mb of extra memory just for this feature.

What do you think about making this feature optional and maybe mentioning the option in the warning? If the option is disabled, we just return the embedding the way it is w/ the NaN values. The option can be enabled and use a little extra memory if users still want the embeddings knowing that training wasn't able to complete successfully.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable. Are you thinking the option would be exposed at the Python level?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just a simple flag exposed through the TSNE constructor would be fine.

"Non-finite result detected during Barnes Hut iteration: %d, returning last "
"known good positions.",
iter);
MLCommon::copy(YY.data(), YY_prev.data(), (nnodes + 1) * 2, stream);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you been able to visually inspect the output embeddings of any of the datasets that are causing this failure? It would be nice to know if they have been reasonably embedded by the time the rollback occurs or if they are just garbage.

@zbjornson
Copy link
Contributor

I don't know if there are multiple ways to cause lockups, but I've been investigating one of them that is due to rounding error in the equation below when using 32-bit precision. That's the (or a?) source of inf/nan that this PR detects.

const float PQ = __fdividef(
VAL[index],
norm_add1[i] + norm[j] - 2.0f * (Y1[i] * Y1[j] + Y2[i] * Y2[j])); // P*Q

Example inputs from a lockup:

  const float PQ = __fdividef(
    /* doesn't matter */,
    10689151 + 10689150 - 2.0f * (500.782989501953125 * 500.782989501953125 + -3230.845947265625 * -3230.845947265625));

The correct denominator is ~2.72493 but it computes to 0 with float. I don't see a rearrangement (even considering FMA) that would universally avoid that problem. Rearranging those particular inputs, I can get it to compute to 2.0, which is sorta close.

Rather than storing+copying Y_prev, testing the entire embeddings array for non-finite numbers and bailing, what do you think of any of these options:

  1. Use mixed/double-precision arithmetic always. Con: Roughly halves the performance for this kernel (much worse on T4 and P4 arch).
  2. Use mixed/double-precision if the denominator is 0. Con: There might be other scenarios besides denom==0 when FP error is significant here.
  3. Use an adaptive equation to keep the intermediate magnitudes lower. Con: In the above example this only gets the result to 2.0 vs. 2.7, but one would hope that's unimportant in later tSNE iterations. I think this can be done with predicates to preserve uniform thread execution, so might not impact perf as much as option 1.

@drobison00
Copy link
Contributor Author

@zbjornson I'd thought the division error was because of nan's being propagated down from a higher level, but this makes more sense, now that I see your description. Allowing NaN's into the tree builder code breaks the minimum radius check for the quad tree, and causes an infinite loop.

I like #2 as an intermediate solution. If were willing to use extra memory, we could also save the values, do the computation, and roll-back / repeat at double-precision if a nans are detected. @cjnolet

The other lock up was caused by a race condition in the code that decrements the bottom offset when allocating new cells after a collision. Also looks like I messed up my last merge from upstream. Fixing now.

@zbjornson
Copy link
Contributor

zbjornson commented Jul 22, 2020

@drobison00 thanks for the fast response!

If were willing to use extra memory, we could also save the values, do the computation, and roll-back / repeat at double-precision if a nans are detected

I was thinking of doing this in the kernel on only the affected inputs, like:

float denominator = ...
if (denominator == 0) {
  double dbl_denominator = /* repeat with double precision */
  result = __fdividef(x, static_cast<float>(dbl_denominator));
} else {
  result = __fdividef(x, denominator);
}

so we wouldn't need any extra device memory, just regs. Do you think we would have to use double precision uniformly if any results are invalid, though?

But... this rearrangement might be pretty good, too:

float V = fma(-2.0f, C * D, A) + fma(-2.0f, E * F, B);
// i.e.
float V = fma(-2.0f, Y1[i] * Y1[j], norm_add1[i]) + fma(-2.0f, Y2[i], Y2[j], norm[j]);

That computes to 2.0 in this particular case, which might be close enough.

@drobison00
Copy link
Contributor Author

@zbjornson Good point; yes, I definitely like your idea more, only redoing affected inputs / (re)storing with registers. There might be some corner case where it makes sense to redo the all the calculations, but without some motivating example its probably not worth the performance hit.

Also have no problem with the fma formulation if its more stable; we're already working with an approximate solution. I'd still probably do the 0 denominator check though.

@zbjornson
Copy link
Contributor

@drobison00 sounds good. Are you planning to make that change in this PR, or should I open a PR?

@drobison00
Copy link
Contributor Author

@zbjornson I can add it to this one if that works. I need to remove my other changes anyway.

@drobison00
Copy link
Contributor Author

@zbjornson Looks like we can still run into a problem with double precision. Using doubles resolves a quite a few cases, but not all.

Some examples:
__fma_rn(-2.0f, -2826.654785 * -2826.654785, 9650204.000000) + __fma_rn(-2.0f, 1288.497559 * 1288.497559, 9650203.000000)
__fma_rn(-2.0f, -2841.154053 * -2841.154053, 9749921.000000) + __fma_rn(-2.0f, 1295.285278 * 1295.285278, 9749920.000000)

We can do something like whats below, to always avoid locking up; but likely risk situations where we generate garbage data. What if we do this, and emit some kind of warning/error message in cases where double precision still doesn't work?

  float denominator = __fmaf_rn(-2.0f, (Y1[i] * Y1[j]), norm_add1[i]) + __fmaf_rn(-2.0f, (Y2[i] * Y2[j]), norm[j]);

  if (denominator == 0) {
    /* repeat with double precision */
    double dbl_denominator = __fma_rn(-2.0f, Y1[i] * Y1[j], norm_add1[i]) + __fma_rn(-2.0f, Y2[i] * Y2[j], norm[j]);
    denominator = (dbl_denominator != 0) ? static_cast<float>(dbl_denominator) : FLT_EPSILON;
  }
  PQ = __fdividef(VAL[index], denominator);

@zbjornson
Copy link
Contributor

zbjornson commented Jul 23, 2020

@drobison00 I was just working through more cases and slowly drafting another comment. There are definitely potential issues with:

  1. How we decide the value is wrong. Testing if it's zero might not catch all instances of significantly wrong denominators. Should we test if it's less than 0.1? 0.5?

  2. What we do when it's wrong. Using FLT_EPSILON as the denominator is going to yield some huge number that is likely wrong too, as you said. Should we instead just cap the value of PQ? It seems reasonable that the force in a single iteration isn't allowed to be too huge, but establishing that bound could be tricky.

Hope to have more thoughts on this later today.

(PS: If Y1[i] is always equal to Y1[j], we should just use one and save loads from global memory.)

@zbjornson
Copy link
Contributor

@drobison00 both of those examples work (reasonable-looking nonzero values) if you do the intermediate Y1[i] * Y1[j] with double precision though:

double dbl_denominator = __fma_rn(-2.0, static_cast<double>(Y1[i]) * Y1[j], norm_add1[i]) + __fma_rn(-2.0, static_cast<double>(Y2[i]) * Y2[j], norm[j]);

@drobison00
Copy link
Contributor Author

@zbjornson I'd originally tried pre-computing Y1[i] * Y1[j], and Y2[i] * Y2[j], but was still seeing situations where that didn't work. I'll try to reproduce (will update shortly). Using either method, I don't think we can trust that dbl_denominator never comes out as 0.

As you say, clamping PQ could be tricky. Perhaps we could, instead, establish some simple order of magnitude bounds on the denominator in the double precision case, and clamp to that, instead of FLT_EPSILON, if the direct computation fails? That seems fairly straight forward. I'll take a look this evening.

Copy link
Contributor

@zbjornson zbjornson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can trust that dbl_denominator never comes out as 0.

Agreed, it's essentially free to include that safeguard. nvcc reorders branches based on __builtin_expect, too.

I ran ~90,000 stress-tests with the code below. The double-precision block kicked in 422 times (0.468%). None of those double-precision calcs resulted in zeros, so whatever we decide to do whendbl_denom is 0, it doesn't seem likely to matter too often.

I also did ~10k tests with the original equation. While the rearrangement doesn't seem to reduce the number of tSNE invocations that encounter zeros, it does substantially decrease the number of attraction kernel invocations that encounter zeros (by like 3x).

(Caveat: my input values were all in the [0, 1] range, so this might not be representative. I logged the random seeds for every run, so we can go back to all of those cases if needed.)

Perhaps we could, instead, establish some simple order of magnitude bounds on the denominator in the double precision case, and clamp to that, instead of FLT_EPSILON, if the direct computation fails?

I was hoping to be able to measure if there was a correlation between FP error and the true denominator, but per above couldn't collect any datapoints for it. Short of that, I'd pick something sorta close to 1 (maybe 0.1 or 0.01) so that the change in this iteration isn't too big, and hope that subsequent iterations resolve it.

How we decide the value is wrong. Testing if it's zero might not catch all instances of significantly wrong denominators.

(I didn't get to do this tonight, and I'm not sure it will be an important difference either. I basically want to do the 32-bit and 64-bit calculation for every point for a bunch of inputs and measure the RMS or something.)

Code I used for stress test

  // Try single precision compute first
  float denominator = __fmaf_rn(-2.0f, Y1[i] * Y1[j], norm_add1[i]) + __fmaf_rn(-2.0f, Y2[i] * Y2[j], norm[j]);

  if (__builtin_expect(denominator == 0, false)) {
    // repeat with double precision
    double y1i_y1j = static_cast<double>(Y1[i]) * Y1[j];
    double y2i_y2j = static_cast<double>(Y2[i]) * Y2[j];
    double dbl_denominator = __fma_rn(-2.0, y1i_y1j, norm_add1[i]) + __fma_rn(-2.0, y2i_y2j, norm[j]);
    denominator = dbl_denominator == 0 ? FLT_EPSILON : static_cast<float>(dbl_denominator);
  }
  PQ = __fdividef(VAL[index], denominator);

  // Apply forces
  atomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j]));
  atomicAdd(&attract2[i], PQ * (Y2[i] - Y2[j]));

Added a few tiny comments inline.

cpp/src/tsne/barnes_hut.cuh Outdated Show resolved Hide resolved
cpp/src/tsne/barnes_hut.cuh Outdated Show resolved Hide resolved
@zbjornson
Copy link
Contributor

zbjornson commented Jul 24, 2020

Alright, this is mostly clear to me now...

I'd pick something sorta close to 1 (maybe 0.1 or 0.01) so that the change in this iteration isn't too big, and hope that subsequent iterations resolve it.

  • It looks like normal denominator values are actually ~0.99999 or greater (into the millions). I tried this with a larger input domain than my previous [0, 1] and it remains true. I think we should limit the denominator to fmaxf(denom, 1.0f). edit and then we can dispense entirely with the dbl_denominator stuff. If someone knows that this should be the case theoretically, that would be helpful to have confirmation.
  • When this kernel produces a zero with 32-bit precision, the 64-bit result is often negative or <1. The result with arbitrary-precision arithmetic coincides very well with the double result; good. However, the results for adjacent points in the array look like they're always huge, e.g. they're ~10M but this point with 64-bit arithmetic yields 0.8 or -15. So I'm pretty confident that there's another source of FP error upstream. Limiting to 1.0 per above will help address this until we find those errors.

Here's a case that encountered "bad" denominators (y axis is double-precision result):
image

(Each of those plateaus is a tSNE iteration, I think.)

Here's a case that didn't:
image

Notice:

  1. There are no outliers below 1 in the "good" case.
  2. The Y axis ranges are vastly different. If attractive forces can genuinely get that high, risk of FP error is significant. (No pun intended.)

How we decide the value is wrong. Testing if it's zero might not catch all instances of significantly wrong denominators.

I'm also pretty confident that there are cases when the f32 result is 1.0 or 2.0 (exactly) and is wrong, e.g. below. I think these cases will be addressed by finding the upstream FP errors later. #2605 fixes one such error (or "the" error): norm[i] + 1 can't always be stored in 32 bits.

      f32                         f64
 10041905.00000000000000000  10041905.52043326944112778
 19477112.00000000000000000  19477111.94677329063415527
 10623784.00000000000000000  10623784.48056359961628914
        2.00000000000000000         0.79013678431510925
 24529614.00000000000000000  24529614.29719012230634689
 21355498.00000000000000000  21355499.83539685606956482
694822976.00000000000000000 694823029.55961430072784424
  2988024.25000000000000000   2988024.44322569016367197
 11997983.00000000000000000  11997983.20281551778316498
  8853240.00000000000000000   8853239.85142864286899567
 28865950.00000000000000000  28865950.54173412919044495
  1653692.75000000000000000   1653692.54413871094584465
767872768.00000000000000000 767872821.30707406997680664
  2846091.25000000000000000   2846091.27287304634228349
        1.00000000000000000         1.00772753171622753
  8347519.00000000000000000   8347518.65371377766132355
  2927162.00000000000000000   2927161.94909193459898233
 24529614.00000000000000000  24529614.29719012230634689
 28865950.00000000000000000  28865950.54173412919044495
  3168804.75000000000000000   3168804.76349345082417130
 11997983.00000000000000000  11997983.20281551778316498

@drobison00
Copy link
Contributor Author

@zbjornson Since the upstream source of the instability isn't entirely clear, going to add the change's discussed so far, which avoid the lock up behavior, notify users when we've encountered a problem, and return early. We can also open another bug to track down the source of the problems.

@drobison00 drobison00 requested a review from cjnolet July 28, 2020 01:23
@cjnolet cjnolet added 4 - Waiting on Reviewer Waiting for reviewer to review or respond bug Something isn't working CUDA / C++ CUDA issue labels Jul 28, 2020
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code changes look great and this is almost ready to go. What remains is very minor.

cpp/src/tsne/barnes_hut.cuh Outdated Show resolved Hide resolved
cpp/src/tsne/barnes_hut.cuh Outdated Show resolved Hide resolved
cpp/src/tsne/bh_kernels.cuh Outdated Show resolved Hide resolved
cpp/src/tsne/bh_kernels.cuh Outdated Show resolved Hide resolved
cpp/src/tsne/bh_kernels.cuh Outdated Show resolved Hide resolved
double _Y1 = static_cast<double>(Y1[i] * Y1[j]);
double _Y2 = static_cast<double>(Y2[i] * Y2[j]);
double dbl_denominator =
__fma_rn(-2.0f, _Y1, norm[i] + 1.0f) + __fma_rn(-2.0f, _Y2, norm[j]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be helpful to open a Github issue for further investigating this detail and reference it here for future eyes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zbjornson Can you link the PR you mention in your other comment?

Copy link
Contributor

@zbjornson zbjornson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got most of a PR ready that eliminates the upstream source of error. So long as that's accepted, that would mean we can remove the logging and early bail, etc.

BTW Martin Burtscher's BH kernel is now on version 4.5 (this code is based on 3.1). Unfortunately there's no version control on it so it's hard to see what changed and why, but it might be worth looking at his changes for remedying the tree-building lockups too. I noticed he doesn't have a radius error factor anymore (this), so it looks like a source of error was eliminated.
https://userweb.cs.txstate.edu/~burtscher/research/ECL-BH/ECL-BH_45.cu
The license on v4.5 is also more permissive.


if (__builtin_expect(denominator == 0, false)) {
double _Y1 = static_cast<double>(Y1[i] * Y1[j]);
double _Y2 = static_cast<double>(Y2[i] * Y2[j]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still does the calculation with single-precision. At least one of the operands needs to be cast to double before the multiplication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will update to cast both variables before the multiply.

The second code link simply fails the entire process when the cell race condition is hit.

            cell = atomicSub((int*)&bottomd, 1) - 1;
            if (cell <= nbodiesd) {printf("ERROR: out of cell memory\n"); asm("trap;");}

This PR leverages the fallback path for closely packed points that exceed the minimum radius.

It also looks like they've moved to an oct-tree formulation instead of the current quad tree. Definitely worth looking at, but it will need to be a later update.

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cjnolet cjnolet merged commit 877cb66 into rapidsai:branch-0.15 Jul 29, 2020
zbjornson added a commit to zbjornson/cuml that referenced this pull request Jul 29, 2020
The denominator equation is just the squared Euclidean distance. Instead of computing the norms in one kernel and storing it, we can (a) save nRow * sizeof(float) bytes of memory, (b) save global loads/stores, and (c) eliminate a source of FP error that's causing lockups (see linked issues).

Per code comment, this still includes a guard in case there are other sources of NaNs upstream. This compiles to just one `setp.ltu` instruction so is essentially free.

Ref rapidsai#2358
Ref rapidsai#2565
@drobison00 drobison00 deleted the bug-tsne-lockup branch September 8, 2020 16:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
4 - Waiting on Reviewer Waiting for reviewer to review or respond bug Something isn't working CUDA / C++ CUDA issue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] T-SNE freezing in benchmarks
4 participants