Skip to content

Commit

Permalink
Patch and test for RF crash rapidsai#3107
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnZed committed Nov 5, 2020
1 parent 63fc249 commit 0c7b7a6
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 1 deletion.
3 changes: 2 additions & 1 deletion cpp/src/decisiontree/batched-levelalgo/builder_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,8 @@ struct RegTraits {
static void computeSplit(Builder<RegTraits<DataT, IdxT>>& b, IdxT col,
IdxT batchSize, CRITERION splitType,
cudaStream_t s) {
auto n_col_blks = b.n_blks_for_cols;
auto n_col_blks = std::min(b.n_blks_for_cols, b.input.nSampledCols - col);

dim3 grid(b.n_blks_for_rows, n_col_blks, batchSize);
auto nbins = b.params.n_bins;
size_t smemSize = 7 * nbins * sizeof(DataT) + nbins * sizeof(int);
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/decisiontree/batched-levelalgo/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,17 @@ __global__ void computeSplitRegressionKernel(
}
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
scount[i] = 0;
if (col < 0) {
printf(
"indexing q with %d, col: %d, nbins: %d, i: %d, from colStart: %d + "
"blockIdx.y: %d\n",
col * nbins + i, col, nbins, i, colStart, blockIdx.y);
}
sbins[i] = input.quantiles[col * nbins + i];
}
__syncthreads();
auto coloffset = col * input.M;

// compute prediction averages for all bins in shared mem
for (auto i = range_start + tid; i < end; i += stride) {
auto row = input.rowids[i];
Expand Down Expand Up @@ -440,6 +447,7 @@ __global__ void computeSplitRegressionKernel(
last = MLCommon::signalDone(done_count + nid * gridDim.y + blockIdx.y,
gridDim.x, blockIdx.x == 0, sDone);
}

if (!last) return;
// last block computes the final gain
Split<DataT, IdxT> sp;
Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ if(BUILD_CUML_TESTS)
add_dependencies(ml cutlass)

target_include_directories(ml PRIVATE ${CUML_TEST_INCLUDE_DIRS})
target_compile_options(ml PRIVATE $<$<COMPILE_LANGUAGE:CUDA>: --generate-line-info>)

target_link_libraries(ml ${CUML_TEST_LINK_LIBRARIES})

Expand Down
5 changes: 5 additions & 0 deletions cpp/test/sg/rf_batched_regression_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ class RFBatchedRegTest : public ::testing::TestWithParam<RfInputs> {

//-------------------------------------------------------------------------------------------------------------------------------------
const std::vector<RfInputs> inputs = {
// First two will FAIL (small data, small ensemble) but in old code will crash as well
{100, 29, 1, 1.0f, 1.0f, 2, -1, false, false, 16, SPLIT_ALGO::GLOBAL_QUANTILE,
2, 0.0, 2, CRITERION::MAE},
{100, 57, 1, 1.0f, 1.0f, 2, -1, false, false, 16, SPLIT_ALGO::GLOBAL_QUANTILE,
2, 0.0, 2, CRITERION::MAE},
{1000, 10, 10, 1.0f, 1.0f, 12, -1, true, false, 10,
SPLIT_ALGO::GLOBAL_QUANTILE, 2, 0.0, 2, CRITERION::MAE},
{2000, 20, 20, 1.0f, 0.6f, 13, -1, true, false, 10,
Expand Down

0 comments on commit 0c7b7a6

Please sign in to comment.