From 2219a752d220841ded96aca6f96155d1614df68c Mon Sep 17 00:00:00 2001 From: John Zedlewski <904524+JohnZed@users.noreply.github.com> Date: Wed, 11 Nov 2020 16:40:35 -0800 Subject: [PATCH] [REVIEW] Fix experimental RF backend crashes and add tests (#3117) * Patch and test for RF crash #3107 * Cleanups of RF regression fixes * Add failing tests to RF regression * Expand experimental backend testing and align pointers * Expand python RF regression test * Updates based on review feedback * Update changelog * Add classification tests * Review comments and style fixes for RF --- CHANGELOG.md | 1 + .../batched-levelalgo/builder_base.cuh | 12 ++++++- .../batched-levelalgo/kernels.cuh | 32 +++++++++++++------ cpp/test/sg/rf_batched_classification_test.cu | 21 ++++++------ cpp/test/sg/rf_batched_regression_test.cu | 20 +++++++++--- python/cuml/test/test_random_forest.py | 19 ++++++++--- 6 files changed, 75 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 46f0bd1ab0..5f7447a469 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ - PR #3084: Fix artifacts in t-SNE results - PR #3086: Reverting FIL Notebook Testing - PR #3114: Fixed a typo in SVC's predict_proba AttributeError +- PR #3117: Fix two crashes in experimental RF backend - PR #3119: Fix memset args for benchmark - PR #3130: Return Python string from `dump_as_json()` of RF diff --git a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh index 40a3720915..d43636bb6d 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh @@ -394,6 +394,10 @@ struct ClsTraits { dim3 grid(b.n_blks_for_rows, colBlks, batchSize); size_t smemSize = sizeof(int) * binSize + sizeof(DataT) * nbins; smemSize += sizeof(int); + + // Extra room for alignment (see alignPointer in computeSplitClassificationKernel) + smemSize += 2 * sizeof(DataT) + 1 * sizeof(int); + CUDA_CHECK(cudaMemsetAsync(b.hist, 0, sizeof(int) * b.nHistBins, s)); computeSplitClassificationKernel <<>>( @@ -451,11 +455,17 @@ struct RegTraits { static void computeSplit(Builder>& b, IdxT col, IdxT batchSize, CRITERION splitType, cudaStream_t s) { - auto n_col_blks = b.n_blks_for_cols; + auto n_col_blks = std::min(b.n_blks_for_cols, b.input.nSampledCols - col); + dim3 grid(b.n_blks_for_rows, n_col_blks, batchSize); auto nbins = b.params.n_bins; size_t smemSize = 7 * nbins * sizeof(DataT) + nbins * sizeof(int); smemSize += sizeof(int); + + // Room for alignment in worst case (see alignPointer in + // computeSplitRegressionKernel) + smemSize += 5 * sizeof(DataT) + 2 * sizeof(int); + CUDA_CHECK( cudaMemsetAsync(b.pred, 0, sizeof(DataT) * b.nPredCounts * 2, s)); if (splitType == CRITERION::MAE) { diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels.cuh b/cpp/src/decisiontree/batched-levelalgo/kernels.cuh index 7c8051ccdb..e78459c2fc 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/kernels.cuh @@ -250,6 +250,13 @@ __global__ void nodeSplitKernel(IdxT max_depth, IdxT min_rows_per_node, total_nodes, (char*)smem); } +/* Returns 'input' rounded up to a correctly-aligned pointer of type OutT* */ +template +__device__ OutT* alignPointer(InT input) { + return reinterpret_cast( + raft::alignTo(reinterpret_cast(input), sizeof(OutT))); +} + template __global__ void computeSplitClassificationKernel( int* hist, IdxT nbins, IdxT max_depth, IdxT min_rows_per_node, @@ -269,9 +276,9 @@ __global__ void computeSplitClassificationKernel( auto end = range_start + range_len; auto nclasses = input.nclasses; auto len = nbins * 2 * nclasses; - auto* shist = reinterpret_cast(smem); - auto* sbins = reinterpret_cast(shist + len); - auto* sDone = reinterpret_cast(sbins + nbins); + auto* shist = alignPointer(smem); + auto* sbins = alignPointer(shist + len); + auto* sDone = alignPointer(sbins + nbins); IdxT stride = blockDim.x * gridDim.x; IdxT tid = threadIdx.x + blockIdx.x * blockDim.x; auto col = input.colids[colStart + blockIdx.y]; @@ -338,14 +345,15 @@ __global__ void computeSplitRegressionKernel( } auto end = range_start + range_len; auto len = nbins * 2; - auto* spred = reinterpret_cast(smem); - auto* scount = reinterpret_cast(spred + len); - auto* sbins = reinterpret_cast(scount + nbins); + auto* spred = alignPointer(smem); + auto* scount = alignPointer(spred + len); + auto* sbins = alignPointer(scount + nbins); + // used only for MAE criterion - auto* spred2 = reinterpret_cast(sbins + nbins); - auto* spred2P = reinterpret_cast(spred2 + len); - auto* spredP = reinterpret_cast(spred2P + nbins); - auto* sDone = reinterpret_cast(spredP + nbins); + auto* spred2 = alignPointer(sbins + nbins); + auto* spred2P = alignPointer(spred2 + len); + auto* spredP = alignPointer(spred2P + nbins); + auto* sDone = alignPointer(spredP + nbins); IdxT stride = blockDim.x * gridDim.x; IdxT tid = threadIdx.x + blockIdx.x * blockDim.x; auto col = input.colids[colStart + blockIdx.y]; @@ -354,10 +362,13 @@ __global__ void computeSplitRegressionKernel( } for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { scount[i] = 0; + // printf("indexing from sbins: %p to %p, sizeof: %d (spred: %p)\n", sbins, + // &sbins[i], (int)sizeof(DataT*), spred); sbins[i] = input.quantiles[col * nbins + i]; } __syncthreads(); auto coloffset = col * input.M; + // compute prediction averages for all bins in shared mem for (auto i = range_start + tid; i < end; i += stride) { auto row = input.rowids[i]; @@ -440,6 +451,7 @@ __global__ void computeSplitRegressionKernel( last = MLCommon::signalDone(done_count + nid * gridDim.y + blockIdx.y, gridDim.x, blockIdx.x == 0, sDone); } + if (!last) return; // last block computes the final gain Split sp; diff --git a/cpp/test/sg/rf_batched_classification_test.cu b/cpp/test/sg/rf_batched_classification_test.cu index e811c5787a..0c6816ca37 100644 --- a/cpp/test/sg/rf_batched_classification_test.cu +++ b/cpp/test/sg/rf_batched_classification_test.cu @@ -42,6 +42,7 @@ struct RfInputs { float min_impurity_decrease; int n_streams; CRITERION split_criterion; + float min_expected_acc; }; template @@ -143,6 +144,14 @@ class RFBatchedClsTest : public ::testing::TestWithParam { //------------------------------------------------------------------------------------------------------------------------------------- const std::vector inputsf2_clf = { + // Simple non-crash tests with small datasets + {100, 59, 1, 1.0f, 0.4f, 16, -1, true, false, 10, SPLIT_ALGO::GLOBAL_QUANTILE, + 2, 0.0, 2, CRITERION::GINI, 0.0f}, + {101, 59, 2, 1.0f, 0.4f, 10, -1, true, false, 13, SPLIT_ALGO::GLOBAL_QUANTILE, + 2, 0.0, 2, CRITERION::GINI, 0.0f}, + {100, 1, 2, 1.0f, 0.4f, 10, -1, true, false, 15, SPLIT_ALGO::GLOBAL_QUANTILE, + 2, 0.0, 2, CRITERION::GINI, 0.0f}, + // Simple accuracy tests {20000, 10, 25, 1.0f, 0.4f, 16, -1, true, false, 10, SPLIT_ALGO::GLOBAL_QUANTILE, 2, 0.0, 2, CRITERION::GINI}, {20000, 10, 5, 1.0f, 0.4f, 14, -1, true, false, 10, @@ -150,11 +159,7 @@ const std::vector inputsf2_clf = { typedef RFBatchedClsTest RFBatchedClsTestF; TEST_P(RFBatchedClsTestF, Fit) { - if (!params.bootstrap && (params.max_features == 1.0f)) { - ASSERT_TRUE(accuracy == 1.0f); - } else { - ASSERT_TRUE(accuracy >= 0.75f); // Empirically derived accuracy range - } + ASSERT_TRUE(accuracy >= params.min_expected_acc); } INSTANTIATE_TEST_CASE_P(RFBatchedClsTests, RFBatchedClsTestF, @@ -162,11 +167,7 @@ INSTANTIATE_TEST_CASE_P(RFBatchedClsTests, RFBatchedClsTestF, typedef RFBatchedClsTest RFBatchedClsTestD; TEST_P(RFBatchedClsTestD, Fit) { - if (!params.bootstrap && (params.max_features == 1.0f)) { - ASSERT_TRUE(accuracy == 1.0f); - } else { - ASSERT_TRUE(accuracy >= 0.75f); // Empirically derived accuracy range - } + ASSERT_TRUE(accuracy >= params.min_expected_acc); } INSTANTIATE_TEST_CASE_P(RFBatchedClsTests, RFBatchedClsTestD, diff --git a/cpp/test/sg/rf_batched_regression_test.cu b/cpp/test/sg/rf_batched_regression_test.cu index a1dd3104bf..ec6dd7a879 100644 --- a/cpp/test/sg/rf_batched_regression_test.cu +++ b/cpp/test/sg/rf_batched_regression_test.cu @@ -44,6 +44,7 @@ struct RfInputs { float min_impurity_decrease; int n_streams; CRITERION split_criterion; + float min_expected_acc; }; template @@ -120,19 +121,30 @@ class RFBatchedRegTest : public ::testing::TestWithParam { //------------------------------------------------------------------------------------------------------------------------------------- const std::vector inputs = { + // Small datasets to repro corner cases as in #3107 (test for crash) + {100, 29, 1, 1.0f, 1.0f, 2, -1, false, false, 16, SPLIT_ALGO::GLOBAL_QUANTILE, + 2, 0.0, 2, CRITERION::MAE, -10.0}, + {100, 57, 2, 1.0f, 1.0f, 2, -1, false, false, 16, SPLIT_ALGO::GLOBAL_QUANTILE, + 2, 0.0, 2, CRITERION::MAE, -10.0}, + {101, 57, 2, 1.0f, 1.0f, 2, -1, false, false, 13, SPLIT_ALGO::GLOBAL_QUANTILE, + 2, 0.0, 2, CRITERION::MSE, -10.0}, + {100, 1, 2, 1.0f, 1.0f, 2, -1, false, false, 13, SPLIT_ALGO::GLOBAL_QUANTILE, + 2, 0.0, 2, CRITERION::MAE, -10.0}, + + // Larger datasets for accuracy {1000, 10, 10, 1.0f, 1.0f, 12, -1, true, false, 10, - SPLIT_ALGO::GLOBAL_QUANTILE, 2, 0.0, 2, CRITERION::MAE}, + SPLIT_ALGO::GLOBAL_QUANTILE, 2, 0.0, 2, CRITERION::MAE, 0.7f}, {2000, 20, 20, 1.0f, 0.6f, 13, -1, true, false, 10, - SPLIT_ALGO::GLOBAL_QUANTILE, 2, 0.0, 2, CRITERION::MSE}}; + SPLIT_ALGO::GLOBAL_QUANTILE, 2, 0.0, 2, CRITERION::MSE, 0.7f}}; typedef RFBatchedRegTest RFBatchedRegTestF; -TEST_P(RFBatchedRegTestF, Fit) { ASSERT_TRUE(accuracy >= 0.7f); } +TEST_P(RFBatchedRegTestF, Fit) { ASSERT_GT(accuracy, params.min_expected_acc); } INSTANTIATE_TEST_CASE_P(RFBatchedRegTests, RFBatchedRegTestF, ::testing::ValuesIn(inputs)); typedef RFBatchedRegTest RFBatchedRegTestD; -TEST_P(RFBatchedRegTestD, Fit) { ASSERT_TRUE(accuracy >= 0.7f); } +TEST_P(RFBatchedRegTestD, Fit) { ASSERT_GT(accuracy, params.min_expected_acc); } INSTANTIATE_TEST_CASE_P(RFBatchedRegTests, RFBatchedRegTestD, ::testing::ValuesIn(inputs)); diff --git a/python/cuml/test/test_random_forest.py b/python/cuml/test/test_random_forest.py index abadee7378..686a557149 100644 --- a/python/cuml/test/test_random_forest.py +++ b/python/cuml/test/test_random_forest.py @@ -212,10 +212,18 @@ def test_rf_classification(small_clf, datatype, split_algo, @pytest.mark.parametrize('rows_sample', [unit_param(1.0), quality_param(0.90), stress_param(0.95)]) @pytest.mark.parametrize('datatype', [np.float32]) -@pytest.mark.parametrize('split_algo', [0, 1]) -@pytest.mark.parametrize('max_features', [1.0, 'auto', 'log2', 'sqrt']) +@pytest.mark.parametrize( + 'split_algo,max_features,use_experimental_backend,n_bins', + [(0, 1.0, False, 16), + (1, 1.0, False, 11), + (0, 'auto', False, 128), + (1, 'log2', False, 100), + (1, 'sqrt', False, 100), + (1, 1.0, True, 17), + (1, 1.0, True, 32), + ]) def test_rf_regression(special_reg, datatype, split_algo, max_features, - rows_sample): + rows_sample, use_experimental_backend, n_bins): use_handle = True @@ -230,10 +238,11 @@ def test_rf_regression(special_reg, datatype, split_algo, max_features, # Initialize and fit using cuML's random forest regression model cuml_model = curfr(max_features=max_features, rows_sample=rows_sample, - n_bins=16, split_algo=split_algo, split_criterion=2, + n_bins=n_bins, split_algo=split_algo, split_criterion=2, min_rows_per_node=2, random_state=123, n_streams=1, n_estimators=50, handle=handle, max_leaves=-1, - max_depth=16, accuracy_metric='mse') + max_depth=16, accuracy_metric='mse', + use_experimental_backend=use_experimental_backend) cuml_model.fit(X_train, y_train) # predict using FIL fil_preds = cuml_model.predict(X_test, predict_model="GPU")