Skip to content

Commit

Permalink
RF: fix the bug in pdf_to_cdf device function that causes hang when…
Browse files Browse the repository at this point in the history
… `n_bins > TPB && n_bins % TPB != 0` (#3921)

* This mini-(but important)-PR fixes the bug in `pdf_to_cdf` device function that causes hang when `n_bins > TPB && n_bins % TPB != 0`
* This closes #3919

Authors:
  - Venkat (https://github.com/venkywonka)

Approvers:
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #3921
  • Loading branch information
venkywonka authored Jun 1, 2021
1 parent 64fb1f4 commit 6cb800a
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions cpp/src/decisiontree/batched-levelalgo/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ DI DataT pdf_to_cdf(DataT* pdf_shist, DataT* cdf_shist, IdxT nbins) {
// variable to accumulate aggregate of sumscans of previous iterations
DataT total_aggregate = DataT(0);

for (IdxT tix = threadIdx.x; tix < max(TPB, nbins); tix += blockDim.x) {
for (IdxT tix = threadIdx.x; tix < raft::ceildiv(nbins, TPB) * TPB;
tix += blockDim.x) {
DataT result;
DataT block_aggregate;
// getting the scanning element from pdf shist only
Expand All @@ -380,8 +381,8 @@ DI DataT pdf_to_cdf(DataT* pdf_shist, DataT* cdf_shist, IdxT nbins) {
// store the result in cdf shist
if (tix < nbins) {
cdf_shist[tix] = result + total_aggregate;
total_aggregate += block_aggregate;
}
total_aggregate += block_aggregate;
}
// return the total sum
return total_aggregate;
Expand Down

0 comments on commit 6cb800a

Please sign in to comment.