From 45a5a5430685518371914dd0931ed00168bedbcc Mon Sep 17 00:00:00 2001 From: Venkat Date: Thu, 23 Sep 2021 02:39:50 +0530 Subject: [PATCH] RF: Add Poisson deviance impurity criterion (#4156) * Adds the poisson impurity criterion to RF, in parity with scikit learn's RF regressor [[here](https://scikit-learn.org/stable/modules/tree.html#regression-criteria)] EDIT: * Also adds C++ level testing for RF Objective function gains of Poisson and Gini. Authors: - Venkat (https://github.com/venkywonka) Approvers: - Rory Mitchell (https://github.com/RAMitchell) - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuml/pull/4156 --- cpp/include/cuml/tree/algo_helper.h | 1 + .../batched-levelalgo/builder.cuh | 4 +- .../batched-levelalgo/metrics.cuh | 356 +++++++++++------- cpp/src/decisiontree/decisiontree.cuh | 13 + .../sg/decisiontree_batchedlevel_unittest.cu | 2 +- cpp/test/sg/rf_test.cu | 229 ++++++++++- .../dask/ensemble/randomforestclassifier.py | 5 +- .../dask/ensemble/randomforestregressor.py | 6 +- python/cuml/ensemble/randomforest_common.pyx | 3 +- python/cuml/ensemble/randomforest_shared.pxd | 1 + .../cuml/ensemble/randomforestregressor.pyx | 5 +- python/cuml/test/test_random_forest.py | 40 +- 12 files changed, 514 insertions(+), 151 deletions(-) diff --git a/cpp/include/cuml/tree/algo_helper.h b/cpp/include/cuml/tree/algo_helper.h index 28b4ac0e5d..ae7aa9b9d1 100644 --- a/cpp/include/cuml/tree/algo_helper.h +++ b/cpp/include/cuml/tree/algo_helper.h @@ -22,6 +22,7 @@ enum CRITERION { ENTROPY, MSE, MAE, + POISSON, CRITERION_END, }; diff --git a/cpp/src/decisiontree/batched-levelalgo/builder.cuh b/cpp/src/decisiontree/batched-levelalgo/builder.cuh index 2620f8cb5b..19bebcfa93 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder.cuh @@ -422,7 +422,7 @@ struct Builder { int nHistBins = large_blocks * nbins * colBlks * nclasses; CUDA_CHECK(cudaMemsetAsync(hist, 0, sizeof(BinT) * nHistBins, handle.get_stream())); ML::PUSH_RANGE("computeSplitClassificationKernel @builder_base.cuh [batched-levelalgo]"); - ObjectiveT objective(input.numOutputs, params.min_impurity_decrease, params.min_samples_leaf); + ObjectiveT objective(input.numOutputs, params.min_samples_leaf); computeSplitKernel <<>>(hist, params.n_bins, @@ -456,7 +456,7 @@ struct Builder { rmm::device_uvector d_instance_ranges(max_batch_size, handle.get_stream()); rmm::device_uvector d_leaves(max_batch_size * input.numOutputs, handle.get_stream()); - ObjectiveT objective(input.numOutputs, params.min_impurity_decrease, params.min_samples_leaf); + ObjectiveT objective(input.numOutputs, params.min_samples_leaf); for (std::size_t batch_begin = 0; batch_begin < tree->sparsetree.size(); batch_begin += max_batch_size) { std::size_t batch_end = min(batch_begin + max_batch_size, tree->sparsetree.size()); diff --git a/cpp/src/decisiontree/batched-levelalgo/metrics.cuh b/cpp/src/decisiontree/batched-levelalgo/metrics.cuh index a9861f56fb..fdcf8c18df 100644 --- a/cpp/src/decisiontree/batched-levelalgo/metrics.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/metrics.cuh @@ -26,21 +26,54 @@ namespace ML { namespace DT { -struct IntBin { +struct CountBin { int x; + CountBin(CountBin const&) = default; + HDI CountBin(int x_) : x(x_) {} + HDI CountBin() : x(0) {} - DI static void IncrementHistogram(IntBin* hist, int nbins, int b, int label) + DI static void IncrementHistogram(CountBin* hist, int nbins, int b, int label) { auto offset = label * nbins + b; - IntBin::AtomicAdd(hist + offset, {1}); + CountBin::AtomicAdd(hist + offset, {1}); } - DI static void AtomicAdd(IntBin* address, IntBin val) { atomicAdd(&address->x, val.x); } - DI IntBin& operator+=(const IntBin& b) + DI static void AtomicAdd(CountBin* address, CountBin val) { atomicAdd(&address->x, val.x); } + HDI CountBin& operator+=(const CountBin& b) { x += b.x; return *this; } - DI IntBin operator+(IntBin b) const + HDI CountBin operator+(CountBin b) const + { + b += *this; + return b; + } +}; + +struct AggregateBin { + double label_sum; + int count; + + AggregateBin(AggregateBin const&) = default; + HDI AggregateBin() : label_sum(0.0), count(0) {} + HDI AggregateBin(double label_sum, int count) : label_sum(label_sum), count(count) {} + + DI static void IncrementHistogram(AggregateBin* hist, int nbins, int b, double label) + { + AggregateBin::AtomicAdd(hist + b, {label, 1}); + } + DI static void AtomicAdd(AggregateBin* address, AggregateBin val) + { + atomicAdd(&address->label_sum, val.label_sum); + atomicAdd(&address->count, val.count); + } + HDI AggregateBin& operator+=(const AggregateBin& b) + { + label_sum += b.label_sum; + count += b.count; + return *this; + } + HDI AggregateBin operator+(AggregateBin b) const { b += *this; return b; @@ -54,59 +87,63 @@ class GiniObjectiveFunction { using LabelT = LabelT_; using IdxT = IdxT_; IdxT nclasses; - DataT min_impurity_decrease; IdxT min_samples_leaf; public: - using BinT = IntBin; - GiniObjectiveFunction(IdxT nclasses, DataT min_impurity_decrease, IdxT min_samples_leaf) - : nclasses(nclasses), - min_impurity_decrease(min_impurity_decrease), - min_samples_leaf(min_samples_leaf) + using BinT = CountBin; + GiniObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf) + : nclasses(nclasses), min_samples_leaf(min_samples_leaf) { } DI IdxT NumClasses() const { return nclasses; } - DI Split Gain(BinT* scdf_labels, DataT* sbins, IdxT col, IdxT len, IdxT nbins) + + HDI DataT GainPerSplit(BinT* hist, IdxT i, IdxT nbins, IdxT len, IdxT nLeft) { - Split sp; + auto nRight = len - nLeft; constexpr DataT One = DataT(1.0); - DataT invlen = One / len; + auto invlen = One / len; + auto invLeft = One / nLeft; + auto invRight = One / nRight; + auto gain = DataT(0.0); + + // if there aren't enough samples in this split, don't bother! + if (nLeft < min_samples_leaf || nRight < min_samples_leaf) + return -std::numeric_limits::max(); + + for (IdxT j = 0; j < nclasses; ++j) { + int val_i = 0; + auto lval_i = hist[nbins * j + i].x; + auto lval = DataT(lval_i); + gain += lval * invLeft * lval * invlen; + + val_i += lval_i; + auto total_sum = hist[nbins * j + nbins - 1].x; + auto rval_i = total_sum - lval_i; + auto rval = DataT(rval_i); + gain += rval * invRight * rval * invlen; + + val_i += rval_i; + auto val = DataT(val_i) * invlen; + gain -= val * val; + } + + return gain; + } + + DI Split Gain(BinT* shist, DataT* sbins, IdxT col, IdxT len, IdxT nbins) + { + Split sp; for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { IdxT nLeft = 0; for (IdxT j = 0; j < nclasses; ++j) { - nLeft += scdf_labels[nbins * j + i].x; - } - auto nRight = len - nLeft; - auto gain = DataT(0.0); - // if there aren't enough samples in this split, don't bother! - if (nLeft < min_samples_leaf || nRight < min_samples_leaf) { - gain = -std::numeric_limits::max(); - } else { - auto invLeft = One / nLeft; - auto invRight = One / nRight; - for (IdxT j = 0; j < nclasses; ++j) { - int val_i = 0; - auto lval_i = scdf_labels[nbins * j + i].x; - auto lval = DataT(lval_i); - gain += lval * invLeft * lval * invlen; - - val_i += lval_i; - auto total_sum = scdf_labels[nbins * j + nbins - 1].x; - auto rval_i = total_sum - lval_i; - auto rval = DataT(rval_i); - gain += rval * invRight * rval * invlen; - - val_i += rval_i; - auto val = DataT(val_i) * invlen; - gain -= val * val; - } + nLeft += shist[nbins * j + i].x; } - sp.update({sbins[i], col, gain, nLeft}); + sp.update({sbins[i], col, GainPerSplit(shist, i, nbins, len, nLeft), nLeft}); } return sp; } - static DI void SetLeafVector(BinT* shist, int nclasses, DataT* out) + static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out) { // Output probability int total = 0; @@ -126,64 +163,67 @@ class EntropyObjectiveFunction { using LabelT = LabelT_; using IdxT = IdxT_; IdxT nclasses; - DataT min_impurity_decrease; IdxT min_samples_leaf; public: - using BinT = IntBin; - EntropyObjectiveFunction(IdxT nclasses, DataT min_impurity_decrease, IdxT min_samples_leaf) - : nclasses(nclasses), - min_impurity_decrease(min_impurity_decrease), - min_samples_leaf(min_samples_leaf) + using BinT = CountBin; + EntropyObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf) + : nclasses(nclasses), min_samples_leaf(min_samples_leaf) { } DI IdxT NumClasses() const { return nclasses; } + + HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT nbins, IdxT len, IdxT nLeft) + { + auto nRight{len - nLeft}; + auto gain{DataT(0.0)}; + // if there aren't enough samples in this split, don't bother! + if (nLeft < min_samples_leaf || nRight < min_samples_leaf) { + return -std::numeric_limits::max(); + } else { + auto invLeft{DataT(1.0) / nLeft}; + auto invRight{DataT(1.0) / nRight}; + auto invLen{DataT(1.0) / len}; + for (IdxT c = 0; c < nclasses; ++c) { + int val_i = 0; + auto lval_i = hist[nbins * c + i].x; + if (lval_i != 0) { + auto lval = DataT(lval_i); + gain += raft::myLog(lval * invLeft) / raft::myLog(DataT(2)) * lval * invLen; + } + + val_i += lval_i; + auto total_sum = hist[nbins * c + nbins - 1].x; + auto rval_i = total_sum - lval_i; + if (rval_i != 0) { + auto rval = DataT(rval_i); + gain += raft::myLog(rval * invRight) / raft::myLog(DataT(2)) * rval * invLen; + } + + val_i += rval_i; + if (val_i != 0) { + auto val = DataT(val_i) * invLen; + gain -= val * raft::myLog(val) / raft::myLog(DataT(2)); + } + } + + return gain; + } + } + DI Split Gain(BinT* scdf_labels, DataT* sbins, IdxT col, IdxT len, IdxT nbins) { Split sp; - constexpr DataT One = DataT(1.0); - DataT invlen = One / len; for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { IdxT nLeft = 0; for (IdxT j = 0; j < nclasses; ++j) { nLeft += scdf_labels[nbins * j + i].x; } - auto nRight = len - nLeft; - auto gain = DataT(0.0); - // if there aren't enough samples in this split, don't bother! - if (nLeft < min_samples_leaf || nRight < min_samples_leaf) { - gain = -std::numeric_limits::max(); - } else { - auto invLeft = One / nLeft; - auto invRight = One / nRight; - for (IdxT j = 0; j < nclasses; ++j) { - int val_i = 0; - auto lval_i = scdf_labels[nbins * j + i].x; - if (lval_i != 0) { - auto lval = DataT(lval_i); - gain += raft::myLog(lval * invLeft) / raft::myLog(DataT(2)) * lval * invlen; - } - - val_i += lval_i; - auto total_sum = scdf_labels[nbins * j + nbins - 1].x; - auto rval_i = total_sum - lval_i; - if (rval_i != 0) { - auto rval = DataT(rval_i); - gain += raft::myLog(rval * invRight) / raft::myLog(DataT(2)) * rval * invlen; - } - - val_i += rval_i; - if (val_i != 0) { - auto val = DataT(val_i) * invlen; - gain -= val * raft::myLog(val) / raft::myLog(DataT(2)); - } - } - } - sp.update({sbins[i], col, gain, nLeft}); + sp.update({sbins[i], col, GainPerSplit(scdf_labels, i, nbins, len, nLeft), nLeft}); } return sp; } - static DI void SetLeafVector(BinT* shist, int nclasses, DataT* out) + static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out) { // Output probability int total = 0; @@ -197,74 +237,130 @@ class EntropyObjectiveFunction { }; template -class MSEObjectiveFunction { +class PoissonObjectiveFunction { public: using DataT = DataT_; using LabelT = LabelT_; using IdxT = IdxT_; private: - DataT min_impurity_decrease; IdxT min_samples_leaf; public: - struct MSEBin { - double label_sum; - int count; + using BinT = AggregateBin; + static constexpr auto eps_ = 10 * std::numeric_limits::epsilon(); - DI static void IncrementHistogram(MSEBin* hist, int nbins, int b, double label) - { - MSEBin::AtomicAdd(hist + b, {label, 1}); - } - DI static void AtomicAdd(MSEBin* address, MSEBin val) - { - atomicAdd(&address->label_sum, val.label_sum); - atomicAdd(&address->count, val.count); - } - DI MSEBin& operator+=(const MSEBin& b) - { - label_sum += b.label_sum; - count += b.count; - return *this; + HDI PoissonObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf) + : min_samples_leaf(min_samples_leaf) + { + } + DI IdxT NumClasses() const { return 1; } + + /** + * @brief compute the poisson impurity reduction (or purity gain) for each split + * + * @note This method is used to speed up the search for the best split + * by calculating the gain using a proxy poisson half deviance reduction. + * It is a proxy quantity such that the split that maximizes this value + * also maximizes the impurity improvement. It neglects all constant terms + * of the impurity decrease for a given split. + * The Gain is the difference in the proxy impurities of the parent and the + * weighted sum of impurities of its children. + */ + HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT nbins, IdxT len, IdxT nLeft) + { + // get the lens' + auto nRight = len - nLeft; + + // if there aren't enough samples in this split, don't bother! + if (nLeft < min_samples_leaf || nRight < min_samples_leaf) + return -std::numeric_limits::max(); + + auto label_sum = hist[nbins - 1].label_sum; + auto left_label_sum = (hist[i].label_sum); + auto right_label_sum = (hist[nbins - 1].label_sum - hist[i].label_sum); + + // label sum cannot be non-positive + if (label_sum < eps_ || left_label_sum < eps_ || right_label_sum < eps_) + return -std::numeric_limits::max(); + + // compute the gain to be + DataT parent_obj = -label_sum * raft::myLog(label_sum / len); + DataT left_obj = -left_label_sum * raft::myLog(left_label_sum / nLeft); + DataT right_obj = -right_label_sum * raft::myLog(right_label_sum / nRight); + auto gain = parent_obj - (left_obj + right_obj); + gain = gain / len; + + return gain; + } + + DI Split Gain(BinT const* shist, DataT const* sbins, IdxT col, IdxT len, IdxT nbins) + { + Split sp; + for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { + auto nLeft = shist[i].count; + sp.update({sbins[i], col, GainPerSplit(shist, i, nbins, len, nLeft), nLeft}); } - DI MSEBin operator+(MSEBin b) const - { - b += *this; - return b; + return sp; + } + + static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out) + { + for (int i = 0; i < nclasses; i++) { + out[i] = shist[i].label_sum / shist[i].count; } - }; - using BinT = MSEBin; - HDI MSEObjectiveFunction(IdxT nclasses, DataT min_impurity_decrease, IdxT min_samples_leaf) - : min_impurity_decrease(min_impurity_decrease), min_samples_leaf(min_samples_leaf) + } +}; +template +class MSEObjectiveFunction { + public: + using DataT = DataT_; + using LabelT = LabelT_; + using IdxT = IdxT_; + + private: + IdxT min_samples_leaf; + + public: + using BinT = AggregateBin; + HDI MSEObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf) + : min_samples_leaf(min_samples_leaf) { } DI IdxT NumClasses() const { return 1; } - DI Split Gain(BinT* shist, DataT* sbins, IdxT col, IdxT len, IdxT nbins) + + HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT nbins, IdxT len, IdxT nLeft) + { + auto gain{DataT(0)}; + auto nRight{len - nLeft}; + auto invLen{DataT(1.0) / len}; + // if there aren't enough samples in this split, don't bother! + if (nLeft < min_samples_leaf || nRight < min_samples_leaf) { + return -std::numeric_limits::max(); + } else { + auto label_sum = hist[nbins - 1].label_sum; + DataT parent_obj = -label_sum * label_sum / len; + DataT left_obj = -(hist[i].label_sum * hist[i].label_sum) / nLeft; + DataT right_label_sum = hist[i].label_sum - label_sum; + DataT right_obj = -(right_label_sum * right_label_sum) / nRight; + gain = parent_obj - (left_obj + right_obj); + gain *= invLen; + + return gain; + } + } + + DI Split Gain(BinT const* shist, DataT const* sbins, IdxT col, IdxT len, IdxT nbins) { Split sp; - auto invlen = DataT(1.0) / len; for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { - auto nLeft = shist[i].count; - auto nRight = len - nLeft; - DataT gain; - // if there aren't enough samples in this split, don't bother! - if (nLeft < min_samples_leaf || nRight < min_samples_leaf) { - gain = -std::numeric_limits::max(); - } else { - auto label_sum = shist[nbins - 1].label_sum; - DataT parent_obj = -label_sum * label_sum / len; - DataT left_obj = -(shist[i].label_sum * shist[i].label_sum) / nLeft; - DataT right_label_sum = shist[i].label_sum - label_sum; - DataT right_obj = -(right_label_sum * right_label_sum) / nRight; - gain = parent_obj - (left_obj + right_obj); - gain *= invlen; - } - sp.update({sbins[i], col, gain, nLeft}); + auto nLeft = shist[i].count; + sp.update({sbins[i], col, GainPerSplit(shist, i, nbins, len, nLeft), nLeft}); } return sp; } - static DI void SetLeafVector(BinT* shist, int nclasses, DataT* out) + static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out) { for (int i = 0; i < nclasses; i++) { out[i] = shist[i].label_sum / shist[i].count; diff --git a/cpp/src/decisiontree/decisiontree.cuh b/cpp/src/decisiontree/decisiontree.cuh index ea63f5251f..e2284cc14a 100644 --- a/cpp/src/decisiontree/decisiontree.cuh +++ b/cpp/src/decisiontree/decisiontree.cuh @@ -290,6 +290,19 @@ class DecisionTree { unique_labels, quantiles) .train(); + } else if (params.split_criterion == CRITERION::POISSON) { + return Builder>(handle, + treeid, + seed, + params, + data, + labels, + nrows, + ncols, + rowids, + unique_labels, + quantiles) + .train(); } else { ASSERT(false, "Unknown split criterion."); } diff --git a/cpp/test/sg/decisiontree_batchedlevel_unittest.cu b/cpp/test/sg/decisiontree_batchedlevel_unittest.cu index 9c402bec2c..37b9519b8c 100644 --- a/cpp/test/sg/decisiontree_batchedlevel_unittest.cu +++ b/cpp/test/sg/decisiontree_batchedlevel_unittest.cu @@ -279,7 +279,7 @@ TEST_P(TestMetric, RegressionMetricGain) CRITERION split_criterion = GetParam(); - ObjectiveT obj(1, params.min_impurity_decrease, params.min_samples_leaf); + ObjectiveT obj(1, params.min_samples_leaf); size_t smemSize1 = n_bins * sizeof(ObjectiveT::BinT) + // shist size n_bins * sizeof(DataT) + // sbins size sizeof(int); // sDone size diff --git a/cpp/test/sg/rf_test.cu b/cpp/test/sg/rf_test.cu index ea12fc6969..4f20085a3d 100644 --- a/cpp/test/sg/rf_test.cu +++ b/cpp/test/sg/rf_test.cu @@ -332,6 +332,7 @@ class RfSpecialisedTest { raft::ceildiv(params.n_rows, params.min_samples_leaf)); } } + void TestMinImpurity() { for (int i = 0u; i < forest->rf_params.n_trees; i++) { @@ -340,11 +341,12 @@ class RfSpecialisedTest { } } } - void TestDeterminism() + void TestDeterminism() { // Regression models use floating point atomics, so are not bitwise reproducible - bool is_regression = params.split_criterion == MSE || params.split_criterion == MAE; + bool is_regression = params.split_criterion == MSE or params.split_criterion == MAE or + params.split_criterion == POISSON; if (is_regression) return; // Repeat training @@ -449,7 +451,8 @@ class RfTest : public ::testing::TestWithParam { void SetUp() override { RfTestParams params = ::testing::TestWithParam::GetParam(); - bool is_regression = params.split_criterion == MSE || params.split_criterion == MAE; + bool is_regression = params.split_criterion == MSE or params.split_criterion == MAE or + params.split_criterion == POISSON; if (params.double_precision) { if (is_regression) { RfSpecialisedTest test(params); @@ -482,10 +485,11 @@ std::vector min_samples_leaf = {1, 10, 30}; std::vector min_samples_split = {2, 10}; std::vector min_impurity_decrease = {0.0f, 1.0f, 10.0f}; std::vector n_streams = {1, 2, 10}; -std::vector split_criterion = {CRITERION::MSE, CRITERION::GINI, CRITERION::ENTROPY}; -std::vector seed = {0, 17}; -std::vector n_labels = {2, 10, 20}; -std::vector double_precision = {false, true}; +std::vector split_criterion = { + CRITERION::POISSON, CRITERION::MSE, CRITERION::GINI, CRITERION::ENTROPY}; +std::vector seed = {0, 17}; +std::vector n_labels = {2, 10, 20}; +std::vector double_precision = {false, true}; int n_tests = 100; @@ -511,6 +515,7 @@ INSTANTIATE_TEST_CASE_P(RfTests, n_labels, double_precision))); +//------------------------------------------------------------------------------------------------------------------------------------- struct QuantileTestParameters { int n_rows; int n_bins; @@ -582,7 +587,7 @@ class RFQuantileTest : public ::testing::TestWithParam { int min_items_per_bin = max_items_per_bin - 1; int total_items = 0; for (int b = 0; b < params.n_bins; b++) { - ASSERT_TRUE(h_histogram[b] == max_items_per_bin || h_histogram[b] == min_items_per_bin) + ASSERT_TRUE(h_histogram[b] == max_items_per_bin or h_histogram[b] == min_items_per_bin) << "No. samples in bin[" << b << "] = " << h_histogram[b] << " Expected " << max_items_per_bin << " or " << min_items_per_bin << std::endl; total_items += h_histogram[b]; @@ -593,7 +598,6 @@ class RFQuantileTest : public ::testing::TestWithParam { } }; -//------------------------------------------------------------------------------------------------------------------------------------- const std::vector inputs = {{1000, 16, 6078587519764079670LLU}, {1130, 32, 4884670006177930266LLU}, {1752, 67, 9175325892580481371LLU}, @@ -620,6 +624,8 @@ typedef RFQuantileBinsLowerBoundTest RFQuantileBinsLowerBoundTestD; TEST_P(RFQuantileBinsLowerBoundTestD, test) {} INSTANTIATE_TEST_CASE_P(RfTests, RFQuantileBinsLowerBoundTestD, ::testing::ValuesIn(inputs)); +//------------------------------------------------------------------------------------------------------ + TEST(RfTest, TextDump) { RF_params rf_params = set_rf_params(2, 2, 1.0, 2, 1, 2, 0.0, true, 1, 1.0, 0, GINI, 1, 128); @@ -655,4 +661,209 @@ Tree #0 EXPECT_EQ(get_rf_json(forest_ptr), expected_json); } +//------------------------------------------------------------------------------------------------------------------------------------- +namespace DT { + +struct ObjectiveTestParameters { + uint64_t seed; + int n_bins; + int n_classes; + int min_samples_leaf; + double tolerance; +}; + +template +class ObjectiveTest : public ::testing::TestWithParam { + typedef typename ObjectiveT::DataT DataT; + typedef typename ObjectiveT::LabelT LabelT; + typedef typename ObjectiveT::IdxT IdxT; + typedef typename ObjectiveT::BinT BinT; + + ObjectiveTestParameters params; + + public: + auto RandUnder(int const end = 10000) { return rand() % end; } + + auto GenHist() + { + std::vector cdf_hist, pdf_hist; + + for (auto c = 0; c < params.n_classes; ++c) { + for (auto b = 0; b < params.n_bins; ++b) { + if constexpr (std::is_same::value) + pdf_hist.emplace_back(RandUnder()); + else + pdf_hist.emplace_back(static_cast(RandUnder()), RandUnder()); + + auto cumulative = b > 0 ? cdf_hist.back() : BinT(); + + cdf_hist.emplace_back(pdf_hist.empty() ? BinT() : pdf_hist.back()); + + cdf_hist.back() += cumulative; + } + } + + return std::make_pair(cdf_hist, pdf_hist); + } + + auto PoissonHalfDeviance( + std::vector const& hist) // 1/n * sum(y_true * log(y_true/y_pred) + y_pred - y_true) + { + BinT aggregate{BinT()}; + aggregate = std::accumulate(hist.begin(), hist.end(), aggregate); + assert(aggregate.count > 0); + auto const y_mean = aggregate.label_sum / aggregate.count; + auto poisson_half_deviance{DataT(0.0)}; + + std::for_each(hist.begin(), hist.end(), [&](BinT const& h) { + auto log_y = raft::myLog(h.label_sum ? h.label_sum : DataT(1.0)); // we don't want nans + poisson_half_deviance += h.label_sum * (log_y - raft::myLog(y_mean)) + y_mean - h.label_sum; + }); + + poisson_half_deviance /= aggregate.count; + return std::make_tuple( + poisson_half_deviance, aggregate.label_sum, static_cast(aggregate.count)); + } + + auto PoissonGroundTruthGain(std::vector const& pdf_hist, std::size_t split_bin_index) + { + std::vector left_pdf_hist{pdf_hist.begin(), pdf_hist.begin() + split_bin_index + 1}; + std::vector right_pdf_hist{pdf_hist.begin() + split_bin_index + 1, pdf_hist.end()}; + + auto [parent_phd, label_sum, n] = PoissonHalfDeviance(pdf_hist); + auto [left_phd, label_sum_left, n_left] = PoissonHalfDeviance(left_pdf_hist); + auto [right_phd, label_sum_right, n_right] = PoissonHalfDeviance(right_pdf_hist); + + auto gain = parent_phd - ((n_left / n) * left_phd + + (n_right / n) * right_phd); // gain in long form without proxy + + // edge cases + if (n_left < params.min_samples_leaf or n_right < params.min_samples_leaf or + label_sum < ObjectiveT::eps_ or label_sum_right < ObjectiveT::eps_ or + label_sum_left < ObjectiveT::eps_) + return -std::numeric_limits::max(); + else + return gain; + } + + auto GiniImpurity(std::vector const& hist) + { // sum((n_c/n_total)(1-(n_c/n_total))) + auto gini{double(0)}; + auto n_bins = hist.size() / params.n_classes; + auto n_instances = std::accumulate(hist.begin(), hist.end(), BinT()).x; // total instances + for (auto c = 0; c < params.n_classes; ++c) { + auto begin_iter = hist.begin() + c * n_bins; + auto end_iter = hist.begin() + (c + 1) * n_bins; + double class_proba = std::accumulate(begin_iter, end_iter, BinT()).x; // instances of class c + class_proba /= n_instances; // probability of class c + gini += class_proba * (1 - class_proba); // adding gain + } + return std::make_pair(gini, double(n_instances)); + } + + auto GiniGroundTruthGain(std::vector const& pdf_hist, std::size_t const split_bin_index) + { + std::vector left_pdf_hist, right_pdf_hist; + + for (auto c = 0; c < params.n_classes; ++c) { // decompose the pdf_hist + auto start = pdf_hist.begin() + c * params.n_bins; + auto split = pdf_hist.begin() + c * params.n_bins + split_bin_index + 1; + auto end = pdf_hist.begin() + (c + 1) * params.n_bins; + + left_pdf_hist.insert(left_pdf_hist.end(), start, split); + right_pdf_hist.insert(right_pdf_hist.end(), split, end); + } + + auto [parent_gini, n] = GiniImpurity(pdf_hist); + auto [left_gini, left_n] = GiniImpurity(left_pdf_hist); + auto [right_gini, right_n] = GiniImpurity(right_pdf_hist); + + auto gain = parent_gini - ((left_n / n) * left_gini + (right_n / n) * right_gini); + + // edge cases + if (left_n < params.min_samples_leaf or right_n < params.min_samples_leaf) { + return -std::numeric_limits::max(); + } else { + return gain; + } + } + + auto GroundTruthGain(std::vector const& pdf_hist, std::size_t const split_bin_index) + { + if constexpr (std::is_same>::value) // poisson + { + return PoissonGroundTruthGain(pdf_hist, split_bin_index); + } else if constexpr (std::is_same>::value) // gini + { + return GiniGroundTruthGain(pdf_hist, split_bin_index); + } + return double(0.0); + } + + auto NumLeftOfBin(std::vector const& cdf_hist, IdxT idx) + { + auto count{IdxT(0)}; + for (auto c = 0; c < params.n_classes; ++c) { + if constexpr (std::is_same::value) // countbin + { + count += cdf_hist[params.n_bins * c + idx].x; + } else // aggregatebin + { + count += cdf_hist[params.n_bins * c + idx].count; + } + } + return count; + } + + void SetUp() override + { + srand(params.seed); + params = ::testing::TestWithParam::GetParam(); + ObjectiveT objective(params.n_classes, params.min_samples_leaf); + + auto [cdf_hist, pdf_hist] = GenHist(); + + auto split_bin_index = RandUnder(params.n_bins); + auto ground_truth_gain = GroundTruthGain(pdf_hist, split_bin_index); + + auto hypothesis_gain = objective.GainPerSplit(&cdf_hist[0], + split_bin_index, + params.n_bins, + NumLeftOfBin(cdf_hist, params.n_bins - 1), + NumLeftOfBin(cdf_hist, split_bin_index)); + + ASSERT_NEAR(ground_truth_gain, hypothesis_gain, params.tolerance); + } +}; + +const std::vector poisson_objective_test_parameters = { + {9507819643927052255LLU, 64, 1, 0, 0.00001}, + {9507819643927052259LLU, 128, 1, 1, 0.00001}, + {9507819643927052251LLU, 256, 1, 1, 0.00001}, + {9507819643927052258LLU, 512, 1, 5, 0.00001}, +}; +const std::vector gini_objective_test_parameters = { + {9507819643927052255LLU, 64, 2, 0, 0.00001}, + {9507819643927052256LLU, 128, 10, 1, 0.00001}, + {9507819643927052257LLU, 256, 100, 1, 0.00001}, + {9507819643927052258LLU, 512, 100, 5, 0.00001}, +}; + +// poisson objective test +typedef ObjectiveTest> PoissonObjectiveTestD; +TEST_P(PoissonObjectiveTestD, poissonObjectiveTest) {} +INSTANTIATE_TEST_CASE_P(RfTests, + PoissonObjectiveTestD, + ::testing::ValuesIn(poisson_objective_test_parameters)); + +// gini objective test +typedef ObjectiveTest> GiniObjectiveTestD; +TEST_P(GiniObjectiveTestD, giniObjectiveTest) {} +INSTANTIATE_TEST_CASE_P(RfTests, + GiniObjectiveTestD, + ::testing::ValuesIn(gini_objective_test_parameters)); + +} // end namespace DT } // end namespace ML diff --git a/python/cuml/dask/ensemble/randomforestclassifier.py b/python/cuml/dask/ensemble/randomforestclassifier.py index e19fe39da7..5cb7da3c98 100755 --- a/python/cuml/dask/ensemble/randomforestclassifier.py +++ b/python/cuml/dask/ensemble/randomforestclassifier.py @@ -77,8 +77,9 @@ class RandomForestClassifier(BaseRandomForestModel, DelayedPredictionMixin, split_criterion : int or string (default = 0 ('gini')) The criterion used to split nodes. 0 or 'gini' for GINI, 1 or 'entropy' for ENTROPY, - 2 or 'mse' for MSE - 2 or 'mse' not valid for classification + 2 or 'mse' for MSE, + 4 or 'poisson' for POISSON, + 2, 'mse', 4, 'poisson' not valid for classification bootstrap : boolean (default = True) Control bootstrapping. If set, each tree in the forest is built diff --git a/python/cuml/dask/ensemble/randomforestregressor.py b/python/cuml/dask/ensemble/randomforestregressor.py index 4e28dda9f7..846e1cc344 100755 --- a/python/cuml/dask/ensemble/randomforestregressor.py +++ b/python/cuml/dask/ensemble/randomforestregressor.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from cuml.dask.common.base import DelayedPredictionMixin from cuml.ensemble import RandomForestRegressor as cuRFR from cuml.dask.ensemble.base import \ @@ -71,8 +70,9 @@ class RandomForestRegressor(BaseRandomForestModel, DelayedPredictionMixin, split_criterion : int or string (default = 2 ('mse')) The criterion used to split nodes. 0 or 'gini' for GINI, 1 or 'entropy' for ENTROPY, - 2 or 'mse' for MSE - only 2 or 'mse' valid for regression + 2 or 'mse' for MSE, + 4 or 'poisson' for POISSON, + 0, 'gini', 1, 'entropy' not valid for regression bootstrap : boolean (default = True) Control bootstrapping. If set, each tree in the forest is built diff --git a/python/cuml/ensemble/randomforest_common.pyx b/python/cuml/ensemble/randomforest_common.pyx index 2226d5329e..264aafa084 100644 --- a/python/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/ensemble/randomforest_common.pyx @@ -57,7 +57,8 @@ class BaseRandomForestModel(Base): '1': ENTROPY, 'entropy': ENTROPY, '2': MSE, 'mse': MSE, '3': MAE, 'mae': MAE, - '4': CRITERION_END} + '4': POISSON, 'poisson': POISSON, + '5': CRITERION_END} classes_ = CumlArrayDescriptor() diff --git a/python/cuml/ensemble/randomforest_shared.pxd b/python/cuml/ensemble/randomforest_shared.pxd index 6e5de318d4..389eec5a45 100644 --- a/python/cuml/ensemble/randomforest_shared.pxd +++ b/python/cuml/ensemble/randomforest_shared.pxd @@ -42,6 +42,7 @@ cdef extern from "cuml/ensemble/randomforest.hpp" namespace "ML": ENTROPY, MSE, MAE, + POISSON, CRITERION_END cdef extern from "cuml/ensemble/randomforest.hpp" namespace "ML": diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index 11d1517eb8..fdb4c9f369 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -164,8 +164,9 @@ class RandomForestRegressor(BaseRandomForestModel, split_criterion : int or string (default = 2 ('mse')) The criterion used to split nodes. 0 or 'gini' for GINI, 1 or 'entropy' for ENTROPY, - 2 or 'mse' for MSE - only 2 or 'mse' valid for regression + 2 or 'mse' for MSE, + 4 or 'poisson' for POISSON, + 0, 'gini', 1, 'entropy' not valid for regression. bootstrap : boolean (default = True) Control bootstrapping. If True, each tree in the forest is built diff --git a/python/cuml/test/test_random_forest.py b/python/cuml/test/test_random_forest.py index db038466e6..9d1d7bb486 100644 --- a/python/cuml/test/test_random_forest.py +++ b/python/cuml/test/test_random_forest.py @@ -31,7 +31,8 @@ from sklearn.ensemble import RandomForestClassifier as skrfc from sklearn.ensemble import RandomForestRegressor as skrfr -from sklearn.metrics import accuracy_score, mean_squared_error +from sklearn.metrics import accuracy_score, mean_squared_error, \ + mean_poisson_deviance from sklearn.datasets import fetch_california_housing, \ make_classification, make_regression, load_iris, load_breast_cancer, \ load_boston @@ -186,6 +187,43 @@ def special_reg(request): return X, y +@pytest.mark.parametrize("lam", [0.01, 0.1]) +@pytest.mark.parametrize("max_depth", [2, 4]) +def test_poisson_convergence(lam, max_depth): + np.random.seed(33) + bootstrap = None + max_features = 1.0 + n_estimators = 1 + min_impurity_decrease = 1e-5 + n_datapoints = 100000 + # generating random poisson dataset + X = np.random.random((n_datapoints, 4)).astype(np.float32) + y = np.random.poisson(lam=lam, size=n_datapoints).astype(np.float32) + + poisson_preds = curfr( + split_criterion=4, + max_depth=max_depth, + n_estimators=n_estimators, + bootstrap=bootstrap, + max_features=max_features, + min_impurity_decrease=min_impurity_decrease).fit(X, y).predict(X) + mse_preds = curfr( + split_criterion=2, + max_depth=max_depth, + n_estimators=n_estimators, + bootstrap=bootstrap, + max_features=max_features, + min_impurity_decrease=min_impurity_decrease).fit(X, y).predict(X) + # y should not be non-positive for mean_poisson_deviance + mask = mse_preds > 0 + mse_mpd = mean_poisson_deviance(y[mask], mse_preds[mask]) + poisson_mpd = mean_poisson_deviance(y, poisson_preds) + + # model trained on poisson data with + # poisson criterion must perform better on poisson loss + assert mse_mpd >= poisson_mpd + + @pytest.mark.parametrize( "max_samples", [unit_param(1.0), quality_param(0.90), stress_param(0.95)] )