Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] RF: Add Gamma and Inverse Gaussian loss criteria #4216

Merged
merged 42 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
ceee023
add poisson deviance loss
venkywonka Aug 11, 2021
a40c323
sign bug fix
venkywonka Aug 12, 2021
8cd1ce1
modify proxy impurity, refactor tests, clang fix
venkywonka Aug 19, 2021
c185c80
Merge branch 'branch-21.10' of https://github.com/rapidsai/cuml into …
venkywonka Aug 24, 2021
dca32f9
add tests for poisson & gini objectives, bug fixes and other refactors
venkywonka Aug 31, 2021
6039045
Merge branch 'branch-21.10' of https://github.com/rapidsai/cuml into …
venkywonka Aug 31, 2021
925116d
FIX clang format
venkywonka Aug 31, 2021
3142caf
FIX clang format
venkywonka Aug 31, 2021
9676818
remove debug code
venkywonka Aug 31, 2021
c52c29f
address review comments
venkywonka Sep 2, 2021
36615c3
Merge branch 'branch-21.10' of https://github.com/rapidsai/cuml into …
venkywonka Sep 3, 2021
c0c5948
Merge branch 'branch-21.10' of https://github.com/rapidsai/cuml into …
venkywonka Sep 7, 2021
79f00b8
add python level test
venkywonka Sep 11, 2021
13c3386
FIX clang format
venkywonka Sep 13, 2021
0332cc6
flake fix, reduce test load
venkywonka Sep 13, 2021
0a5d52a
fix tests, remove artifacts
venkywonka Sep 13, 2021
3255323
Merge branch 'branch-21.10' of https://github.com/rapidsai/cuml into …
venkywonka Sep 13, 2021
959ee2c
purge artifacts
venkywonka Sep 13, 2021
5a5410e
decrease tolerance
venkywonka Sep 13, 2021
59caf11
remove min_impurity_decrease member
venkywonka Sep 16, 2021
fd42fb7
fix accuracy bug and dask docstring duplication
venkywonka Sep 17, 2021
9247988
Merge branch 'branch-21.10' of https://github.com/rapidsai/cuml into …
venkywonka Sep 17, 2021
a31512d
fix doctring slip
venkywonka Sep 17, 2021
493f847
merge resolution
venkywonka Sep 17, 2021
aec9d26
merge with poisson branch
venkywonka Sep 20, 2021
db09e0f
add tweedie losses
venkywonka Sep 21, 2021
e63754a
refactor unit tests
venkywonka Sep 22, 2021
2e14991
Merge branch 'branch-21.10' of https://github.com/rapidsai/cuml into …
venkywonka Sep 23, 2021
78b0ffd
add tests for entropy and mse
venkywonka Sep 24, 2021
1fbff95
Merge branch 'branch-21.10' of https://github.com/rapidsai/cuml into …
venkywonka Sep 30, 2021
5f22047
Merge branch 'branch-21.10' of https://github.com/rapidsai/cuml into …
venkywonka Sep 30, 2021
43e5b71
Merge branch 'branch-21.12' of https://github.com/rapidsai/cuml into …
venkywonka Oct 4, 2021
11b2f4e
add python tests and refactor objectives
venkywonka Oct 4, 2021
2fa43d7
FIX clang format
venkywonka Oct 4, 2021
87395ff
reduce division operations
venkywonka Oct 5, 2021
8464628
flake fix and change criterion_dict
venkywonka Oct 5, 2021
d764562
make objective data members private
venkywonka Oct 5, 2021
68ecabb
refactor declaration
venkywonka Oct 6, 2021
6eeeac0
Merge branch 'branch-21.12' of https://github.com/rapidsai/cuml into …
venkywonka Oct 6, 2021
b1be698
fix improper merge
venkywonka Oct 6, 2021
57dddaf
update datapoints and args in pytest
venkywonka Oct 7, 2021
651d827
add documentation for other GainPerSplits
venkywonka Oct 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/include/cuml/tree/algo_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ enum CRITERION {
MSE,
MAE,
POISSON,
GAMMA,
INVERSE_GAUSSIAN,
CRITERION_END,
};

Expand Down
241 changes: 201 additions & 40 deletions cpp/src/decisiontree/batched-levelalgo/metrics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class GiniObjectiveFunction {
using DataT = DataT_;
using LabelT = LabelT_;
using IdxT = IdxT_;

private:
IdxT nclasses;
IdxT min_samples_leaf;

Expand All @@ -100,9 +102,9 @@ class GiniObjectiveFunction {

HDI DataT GainPerSplit(BinT* hist, IdxT i, IdxT nbins, IdxT len, IdxT nLeft)
{
auto nRight = len - nLeft;
IdxT nRight = len - nLeft;
constexpr DataT One = DataT(1.0);
auto invlen = One / len;
auto invLen = One / len;
auto invLeft = One / nLeft;
auto invRight = One / nRight;
auto gain = DataT(0.0);
Expand All @@ -115,16 +117,16 @@ class GiniObjectiveFunction {
int val_i = 0;
auto lval_i = hist[nbins * j + i].x;
auto lval = DataT(lval_i);
gain += lval * invLeft * lval * invlen;
gain += lval * invLeft * lval * invLen;

val_i += lval_i;
auto total_sum = hist[nbins * j + nbins - 1].x;
auto rval_i = total_sum - lval_i;
auto rval = DataT(rval_i);
gain += rval * invRight * rval * invlen;
gain += rval * invRight * rval * invLen;

val_i += rval_i;
auto val = DataT(val_i) * invlen;
auto val = DataT(val_i) * invLen;
gain -= val * val;
}

Expand Down Expand Up @@ -162,6 +164,8 @@ class EntropyObjectiveFunction {
using DataT = DataT_;
using LabelT = LabelT_;
using IdxT = IdxT_;

private:
IdxT nclasses;
IdxT min_samples_leaf;

Expand All @@ -175,7 +179,7 @@ class EntropyObjectiveFunction {

HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT nbins, IdxT len, IdxT nLeft)
{
auto nRight{len - nLeft};
IdxT nRight{len - nLeft};
auto gain{DataT(0.0)};
// if there aren't enough samples in this split, don't bother!
if (nLeft < min_samples_leaf || nRight < min_samples_leaf) {
Expand Down Expand Up @@ -236,25 +240,83 @@ class EntropyObjectiveFunction {
}
};

template <typename DataT_, typename LabelT_, typename IdxT_>
class MSEObjectiveFunction {
public:
using DataT = DataT_;
using LabelT = LabelT_;
using IdxT = IdxT_;
using BinT = AggregateBin;

private:
IdxT min_samples_leaf;

public:
HDI MSEObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf)
: min_samples_leaf(min_samples_leaf)
{
}

HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT nbins, IdxT len, IdxT nLeft) const
wphicks marked this conversation as resolved.
Show resolved Hide resolved
{
auto gain{DataT(0)};
IdxT nRight{len - nLeft};
auto invLen = DataT(1.0) / len;
// if there aren't enough samples in this split, don't bother!
if (nLeft < min_samples_leaf || nRight < min_samples_leaf) {
return -std::numeric_limits<DataT>::max();
} else {
auto label_sum = hist[nbins - 1].label_sum;
DataT parent_obj = -label_sum * label_sum * invLen;
DataT left_obj = -(hist[i].label_sum * hist[i].label_sum) / nLeft;
DataT right_label_sum = hist[i].label_sum - label_sum;
DataT right_obj = -(right_label_sum * right_label_sum) / nRight;
gain = parent_obj - (left_obj + right_obj);
gain *= DataT(0.5) * invLen;

return gain;
}
}

DI Split<DataT, IdxT> Gain(
BinT const* shist, DataT const* sbins, IdxT col, IdxT len, IdxT nbins) const
{
Split<DataT, IdxT> sp;
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
auto nLeft = shist[i].count;
sp.update({sbins[i], col, GainPerSplit(shist, i, nbins, len, nLeft), nLeft});
}
return sp;
}

DI IdxT NumClasses() const { return 1; }

static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out)
{
for (int i = 0; i < nclasses; i++) {
out[i] = shist[i].label_sum / shist[i].count;
}
}
};

template <typename DataT_, typename LabelT_, typename IdxT_>
class PoissonObjectiveFunction {
public:
using DataT = DataT_;
using LabelT = LabelT_;
using IdxT = IdxT_;
using BinT = AggregateBin;

private:
IdxT min_samples_leaf;

public:
using BinT = AggregateBin;
static constexpr auto eps_ = 10 * std::numeric_limits<DataT>::epsilon();

HDI PoissonObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf)
: min_samples_leaf(min_samples_leaf)
{
}
DI IdxT NumClasses() const { return 1; }

/**
* @brief compute the poisson impurity reduction (or purity gain) for each split
Expand All @@ -267,10 +329,11 @@ class PoissonObjectiveFunction {
* The Gain is the difference in the proxy impurities of the parent and the
* weighted sum of impurities of its children.
*/
HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT nbins, IdxT len, IdxT nLeft)
HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT nbins, IdxT len, IdxT nLeft) const
{
// get the lens'
auto nRight = len - nLeft;
IdxT nRight = len - nLeft;
auto invLen = DataT(1) / len;

// if there aren't enough samples in this split, don't bother!
if (nLeft < min_samples_leaf || nRight < min_samples_leaf)
Expand All @@ -285,16 +348,17 @@ class PoissonObjectiveFunction {
return -std::numeric_limits<DataT>::max();

// compute the gain to be
DataT parent_obj = -label_sum * raft::myLog(label_sum / len);
DataT parent_obj = -label_sum * raft::myLog(label_sum * invLen);
DataT left_obj = -left_label_sum * raft::myLog(left_label_sum / nLeft);
DataT right_obj = -right_label_sum * raft::myLog(right_label_sum / nRight);
auto gain = parent_obj - (left_obj + right_obj);
gain = gain / len;
DataT gain = parent_obj - (left_obj + right_obj);
gain = gain * invLen;

return gain;
}

DI Split<DataT, IdxT> Gain(BinT const* shist, DataT const* sbins, IdxT col, IdxT len, IdxT nbins)
DI Split<DataT, IdxT> Gain(
BinT const* shist, DataT const* sbins, IdxT col, IdxT len, IdxT nbins) const
{
Split<DataT, IdxT> sp;
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
Expand All @@ -304,53 +368,74 @@ class PoissonObjectiveFunction {
return sp;
}

DI IdxT NumClasses() const { return 1; }

static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out)
{
for (int i = 0; i < nclasses; i++) {
out[i] = shist[i].label_sum / shist[i].count;
}
}
};

template <typename DataT_, typename LabelT_, typename IdxT_>
class MSEObjectiveFunction {
class GammaObjectiveFunction {
public:
using DataT = DataT_;
using LabelT = LabelT_;
using IdxT = IdxT_;
using DataT = DataT_;
using LabelT = LabelT_;
using IdxT = IdxT_;
using BinT = AggregateBin;
static constexpr auto eps_ = 10 * std::numeric_limits<DataT>::epsilon();

private:
IdxT min_samples_leaf;

public:
using BinT = AggregateBin;
HDI MSEObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf)
: min_samples_leaf(min_samples_leaf)
HDI GammaObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf)
: min_samples_leaf{min_samples_leaf}
{
}
DI IdxT NumClasses() const { return 1; }

HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT nbins, IdxT len, IdxT nLeft)
/**
* @brief compute the gamma impurity reduction (or purity gain) for each split
*
* @note This method is used to speed up the search for the best split
* by calculating the gain using a proxy gamma half deviance reduction.
* It is a proxy quantity such that the split that maximizes this value
* also maximizes the impurity improvement. It neglects all constant terms
* of the impurity decrease for a given split.
* The Gain is the difference in the proxy impurities of the parent and the
* weighted sum of impurities of its children.
*/
HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT nbins, IdxT len, IdxT nLeft) const
{
auto gain{DataT(0)};
auto nRight{len - nLeft};
auto invLen{DataT(1.0) / len};
IdxT nRight = len - nLeft;
auto invLen = DataT(1) / len;

// if there aren't enough samples in this split, don't bother!
if (nLeft < min_samples_leaf || nRight < min_samples_leaf) {
if (nLeft < min_samples_leaf || nRight < min_samples_leaf)
return -std::numeric_limits<DataT>::max();
} else {
auto label_sum = hist[nbins - 1].label_sum;
DataT parent_obj = -label_sum * label_sum / len;
DataT left_obj = -(hist[i].label_sum * hist[i].label_sum) / nLeft;
DataT right_label_sum = hist[i].label_sum - label_sum;
DataT right_obj = -(right_label_sum * right_label_sum) / nRight;
gain = parent_obj - (left_obj + right_obj);
gain *= invLen;

return gain;
}
DataT label_sum = hist[nbins - 1].label_sum;
DataT left_label_sum = (hist[i].label_sum);
DataT right_label_sum = (hist[nbins - 1].label_sum - hist[i].label_sum);

// label sum cannot be non-positive
if (label_sum < eps_ || left_label_sum < eps_ || right_label_sum < eps_)
return -std::numeric_limits<DataT>::max();

// compute the gain to be
DataT parent_obj = len * raft::myLog(label_sum * invLen);
DataT left_obj = nLeft * raft::myLog(left_label_sum / nLeft);
DataT right_obj = nRight * raft::myLog(right_label_sum / nRight);
DataT gain = parent_obj - (left_obj + right_obj);
gain = gain * invLen;

return gain;
}

DI Split<DataT, IdxT> Gain(BinT const* shist, DataT const* sbins, IdxT col, IdxT len, IdxT nbins)
DI Split<DataT, IdxT> Gain(
BinT const* shist, DataT const* sbins, IdxT col, IdxT len, IdxT nbins) const
{
Split<DataT, IdxT> sp;
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
Expand All @@ -359,6 +444,7 @@ class MSEObjectiveFunction {
}
return sp;
}
DI IdxT NumClasses() const { return 1; }

static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out)
{
Expand All @@ -368,5 +454,80 @@ class MSEObjectiveFunction {
}
};

} // namespace DT
} // namespace ML
template <typename DataT_, typename LabelT_, typename IdxT_>
class InverseGaussianObjectiveFunction {
public:
using DataT = DataT_;
using LabelT = LabelT_;
using IdxT = IdxT_;
using BinT = AggregateBin;
static constexpr auto eps_ = 10 * std::numeric_limits<DataT>::epsilon();

private:
IdxT min_samples_leaf;

public:
HDI InverseGaussianObjectiveFunction(IdxT nclasses, IdxT min_samples_leaf)
: min_samples_leaf{min_samples_leaf}
{
}

/**
* @brief compute the inverse gaussian impurity reduction (or purity gain) for each split
*
* @note This method is used to speed up the search for the best split
* by calculating the gain using a proxy inverse gaussian half deviance reduction.
* It is a proxy quantity such that the split that maximizes this value
* also maximizes the impurity improvement. It neglects all constant terms
* of the impurity decrease for a given split.
* The Gain is the difference in the proxy impurities of the parent and the
* weighted sum of impurities of its children.
*/
HDI DataT GainPerSplit(BinT const* hist, IdxT i, IdxT nbins, IdxT len, IdxT nLeft) const
{
// get the lens'
IdxT nRight = len - nLeft;

// if there aren't enough samples in this split, don't bother!
if (nLeft < min_samples_leaf || nRight < min_samples_leaf)
return -std::numeric_limits<DataT>::max();

auto label_sum = hist[nbins - 1].label_sum;
auto left_label_sum = (hist[i].label_sum);
auto right_label_sum = (hist[nbins - 1].label_sum - hist[i].label_sum);

// label sum cannot be non-positive
if (label_sum < eps_ || left_label_sum < eps_ || right_label_sum < eps_)
return -std::numeric_limits<DataT>::max();

// compute the gain to be
DataT parent_obj = -DataT(len) * DataT(len) / label_sum;
DataT left_obj = -DataT(nLeft) * DataT(nLeft) / left_label_sum;
DataT right_obj = -DataT(nRight) * DataT(nRight) / right_label_sum;
DataT gain = parent_obj - (left_obj + right_obj);
gain = gain / (2 * len);

return gain;
}

DI Split<DataT, IdxT> Gain(
BinT const* shist, DataT const* sbins, IdxT col, IdxT len, IdxT nbins) const
{
Split<DataT, IdxT> sp;
for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) {
auto nLeft = shist[i].count;
sp.update({sbins[i], col, GainPerSplit(shist, i, nbins, len, nLeft), nLeft});
}
return sp;
}
DI IdxT NumClasses() const { return 1; }

static DI void SetLeafVector(BinT const* shist, int nclasses, DataT* out)
{
for (int i = 0; i < nclasses; i++) {
out[i] = shist[i].label_sum / shist[i].count;
}
}
};
} // end namespace DT
} // end namespace ML
Loading