Skip to content

Commit

Permalink
Fix RF classification irreproducibility (#3785)
Browse files Browse the repository at this point in the history
This PR fixes of RF irreproducibility issue with following two changes

- Make the `splits` argument to `computeSplitClassificationKernel / computeSplitRegressionKernel` `volatile`. This is necessary as `splits` is read and written by multiple threads in the same kernel.
- Change the update logic to include ti-break based on `quesval` when `best_metric_val` and `colid` matches.

Note: The fact that this changes fixes the irreproducibility in classification kernel means the kernel was selecting a suboptimal split in a very rare scenario due to L1 caching.

Authors:
  - Vinay Deshpande (https://github.com/vinaydes)

Approvers:
  - Thejaswi. N. S (https://github.com/teju85)

URL: #3785
  • Loading branch information
vinaydes authored Apr 23, 2021
1 parent 8cccaf2 commit 2f0af34
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
7 changes: 4 additions & 3 deletions cpp/src/decisiontree/batched-levelalgo/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ __global__ void computeSplitClassificationKernel(
IdxT min_samples_leaf, DataT min_impurity_decrease, IdxT max_leaves,
Input<DataT, LabelT, IdxT> input, const Node<DataT, LabelT, IdxT>* nodes,
IdxT colStart, int* done_count, int* mutex, const IdxT* n_leaves,
Split<DataT, IdxT>* splits, CRITERION splitType, IdxT treeid, uint64_t seed) {
volatile Split<DataT, IdxT>* splits, CRITERION splitType, IdxT treeid,
uint64_t seed) {
extern __shared__ char smem[];
IdxT nid = blockIdx.z;
auto node = nodes[nid];
Expand Down Expand Up @@ -514,8 +515,8 @@ __global__ void computeSplitRegressionKernel(
DataT min_impurity_decrease, IdxT max_leaves,
Input<DataT, LabelT, IdxT> input, const Node<DataT, LabelT, IdxT>* nodes,
IdxT colStart, int* done_count, int* mutex, const IdxT* n_leaves,
Split<DataT, IdxT>* splits, void* workspace, CRITERION splitType, IdxT treeid,
uint64_t seed) {
volatile Split<DataT, IdxT>* splits, void* workspace, CRITERION splitType,
IdxT treeid, uint64_t seed) {
extern __shared__ char smem[];
IdxT nid = blockIdx.z;
auto node = nodes[nid];
Expand Down
18 changes: 12 additions & 6 deletions cpp/src/decisiontree/batched-levelalgo/split.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct Split {
*
* @return the reference to the copied object (typically useful for chaining)
*/
DI SplitT& operator=(const SplitT& other) {
DI volatile SplitT& operator=(const SplitT& other) volatile {
quesval = other.quesval;
colid = other.colid;
best_metric_val = other.best_metric_val;
Expand All @@ -73,11 +73,17 @@ struct Split {
/**
* @brief updates the current split if the input gain is better
*/
DI void update(const SplitT& other) {
if (other.best_metric_val == best_metric_val) {
if (other.colid < colid) *this = other;
} else if (other.best_metric_val > best_metric_val) {
DI void update(const SplitT& other) volatile {
if (other.best_metric_val > best_metric_val) {
*this = other;
} else if (other.best_metric_val == best_metric_val) {
if (other.colid > colid) {
*this = other;
} else if (other.colid == colid) {
if (other.quesval > quesval) {
*this = other;
}
}
}
}

Expand Down Expand Up @@ -107,7 +113,7 @@ struct Split {
* @note all threads in the block must enter this function together. At the
* end thread0 will contain the best split.
*/
DI void evalBestSplit(void* smem, SplitT* split, int* mutex) {
DI void evalBestSplit(void* smem, volatile SplitT* split, int* mutex) {
auto* sbest = reinterpret_cast<SplitT*>(smem);
warpReduce();
auto warp = threadIdx.x / raft::WarpSize;
Expand Down

0 comments on commit 2f0af34

Please sign in to comment.