diff --git a/cpp/src/decisiontree/levelalgo/levelhelper_classifier.cuh b/cpp/src/decisiontree/levelalgo/levelhelper_classifier.cuh index 038ccbdf29..4f0cde99f6 100644 --- a/cpp/src/decisiontree/levelalgo/levelhelper_classifier.cuh +++ b/cpp/src/decisiontree/levelalgo/levelhelper_classifier.cuh @@ -297,7 +297,10 @@ void leaf_eval_classification( sparse_nodelist.clear(); int non_leaf_counter = 0; - bool condition_global = (curr_depth == max_depth); + // XXX: This line fixes the inaccuracy in max_depth==1 tree leaf predictions + // I still don't fully understand its role, but it seems interesting + bool condition_global = + curr_depth >= max_depth - 1; // XXX removed: (curr_depth == max_depth); if (max_leaves != -1) condition_global = condition_global || (tree_leaf_cnt >= max_leaves);