Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gpuCI] Forward-merge branch-21.10 to branch-21.12 [skip gpuci] #4349

Merged
merged 7 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/cpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ gpuci_logger "Activate conda env"
conda activate rapids

# Remove rapidsai-nightly channel if it is stable build
if [ "${IS_STABLE_BUILD}" != "true" ]; then
if [ "${IS_STABLE_BUILD}" = "true" ]; then
conda config --system --remove channels rapidsai-nightly
fi

Expand Down
2 changes: 1 addition & 1 deletion cpp/cmake/modules/ConfigureCUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ list(APPEND CUML_CUDA_FLAGS --expt-extended-lambda --expt-relaxed-constexpr)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.2.0)
list(APPEND CUML_CUDA_FLAGS -Werror=all-warnings)
endif()
list(APPEND CUML_CUDA_FLAGS -Xcompiler=-Wall,-Werror,-Wno-error=deprecated-declarations)
list(APPEND CUML_CUDA_FLAGS -Xcompiler=-Wall,-Werror,-Wno-error=deprecated-declarations,-Wno-error=sign-compare)

if(DISABLE_DEPRECATION_WARNING)
list(APPEND CUML_CXX_FLAGS -Wno-deprecated-declarations)
Expand Down
27 changes: 17 additions & 10 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,8 @@ inline std::size_t bit_pool_size(const tl::Tree<T, L>& tree, const categorical_s
int node_id = stack.top();
stack.pop();
while (!tree.IsLeaf(node_id)) {
if (tree.SplitType(node_id) == tl::SplitFeatureType::kCategorical) {
if (tree.SplitType(node_id) == tl::SplitFeatureType::kCategorical &&
tree.MatchingCategories(node_id).size() > 0) {
int fid = tree.SplitIndex(node_id);
size += cat_sets.sizeof_mask(fid);
}
Expand Down Expand Up @@ -802,8 +803,9 @@ conversion_state<fil_node_t> tl2fil_inner_node(int fil_left_child,
int tl_left = tree.LeftChild(tl_node_id), tl_right = tree.RightChild(tl_node_id);
val_t split = {.f = NAN}; // yes there's a default initializer already
int feature_id = tree.SplitIndex(tl_node_id);
bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical;
bool default_left = tree.DefaultLeft(tl_node_id);
bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical &&
tree.MatchingCategories(tl_node_id).size() > 0;
bool default_left = tree.DefaultLeft(tl_node_id);
if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kNumerical) {
split.f = static_cast<float>(tree.Threshold(tl_node_id));
adjust_threshold(&split.f, &tl_left, &tl_right, &default_left, tree.ComparisonOp(tl_node_id));
Expand All @@ -813,13 +815,18 @@ conversion_state<fil_node_t> tl2fil_inner_node(int fil_left_child,
std::swap(tl_left, tl_right);
default_left = !default_left;
}
int sizeof_mask = cat_sets->accessor().sizeof_mask(feature_id);
split.idx = *bit_pool_offset;
*bit_pool_offset += sizeof_mask;
// cat_sets->bits have been zero-initialized
uint8_t* bits = &cat_sets->bits[split.idx];
for (std::uint32_t category : tree.MatchingCategories(tl_node_id)) {
bits[category / BITS_PER_BYTE] |= 1 << (category % BITS_PER_BYTE);
if (tree.MatchingCategories(tl_node_id).size() > 0) {
int sizeof_mask = cat_sets->accessor().sizeof_mask(feature_id);
split.idx = *bit_pool_offset;
*bit_pool_offset += sizeof_mask;
// cat_sets->bits have been zero-initialized
uint8_t* bits = &cat_sets->bits[split.idx];
for (std::uint32_t category : tree.MatchingCategories(tl_node_id)) {
bits[category / BITS_PER_BYTE] |= 1 << (category % BITS_PER_BYTE);
}
} else {
// always branch left in FIL. Already accounted for Treelite branching direction above.
split.f = NAN;
}
} else {
ASSERT(false, "only numerical and categorical split nodes are supported");
Expand Down
20 changes: 15 additions & 5 deletions cpp/src/fil/internal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,10 @@ struct forest_params_t {
/// FIL_TPB is the number of threads per block to use with FIL kernels
const int FIL_TPB = 256;

constexpr std::int32_t MAX_PRECISE_INT_FLOAT = 1 << 24; // 16'777'216
// as far as FIL is concerned, 16'777'214 is the most we can do.
constexpr std::int32_t MAX_PRECISE_INT_FLOAT = (1 << 24) - 2;

__host__ __device__ __forceinline__ int fetch_bit(const uint8_t* array, int bit)
__host__ __device__ __forceinline__ int fetch_bit(const uint8_t* array, uint32_t bit)
{
return (array[bit / BITS_PER_BYTE] >> (bit % BITS_PER_BYTE)) & 1;
}
Expand Down Expand Up @@ -337,15 +338,24 @@ struct categorical_sets {

// set count is due to tree_idx + node_within_tree_idx are both ints, hence uint32_t result
template <typename node_t>
__host__ __device__ __forceinline__ int category_matches(node_t node, int category) const
__host__ __device__ __forceinline__ int category_matches(node_t node, float category) const
{
// standard boolean packing. This layout has better ILP
// node.set() is global across feature IDs and is an offset (as opposed
// to set number). If we run out of uint32_t and we have hundreds of
// features with similar categorical feature count, we may consider
// storing node ID within nodes with same feature ID and look up
// {.max_matching, .first_node_offset} = ...[feature_id]
return category <= max_matching[node.fid()] && fetch_bit(bits + node.set(), category);

/* category < 0.0f or category > INT_MAX is equivalent to out-of-dictionary category
(not matching, branch left). -0.0f represents category 0.
If (float)(int)category != category, we will discard the fractional part.
E.g. 3.8f represents category 3 regardless of max_matching value.
FIL will reject a model where an integer within [0, max_matching + 1] cannot be represented
precisely as a 32-bit float.
*/
return category < static_cast<float>(max_matching[node.fid()] + 1) && category >= 0.0f &&
fetch_bit(bits + node.set(), static_cast<int>(category));
}
static int sizeof_mask_from_max_matching(int max_matching)
{
Expand All @@ -372,7 +382,7 @@ struct tree_base {
if (isnan(val)) {
cond = !node.def_left();
} else if (CATS_SUPPORTED && node.is_categorical()) {
cond = cat_sets.category_matches(node, static_cast<int>(val));
cond = cat_sets.category_matches(node, val);
} else {
cond = val >= node.thresh();
}
Expand Down
63 changes: 56 additions & 7 deletions cpp/test/sg/fil_child_index_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,56 @@ std::vector<ChildIndexTestParams> params = {
CHILD_INDEX_TEST_PARAMS(parent_node_idx = 4, input = NAN, correct = 10), // !def_left
CHILD_INDEX_TEST_PARAMS(
node = NODE(def_left = true), input = NAN, parent_node_idx = 4, correct = 9), // !def_left
// cannot match ( > max_matching)
// cannot match ( < 0 and realistic max_matching)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {},
cso.max_matching = {-1},
input = 0,
cso.max_matching = {10},
input = -5,
correct = 1),
// Skipping category < 0 and dummy categorical node: max_matching == -1. Prevented by FIL import.
// cannot match ( > INT_MAX)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b1111'1111},
cso.max_matching = {7},
input = (float)(1ll << 33ll),
correct = 1),
// cannot match ( > max_matching and integer)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b1111'1111},
cso.max_matching = {1},
input = 2,
correct = 1),
// matches ( > max_matching only due to fractional part)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b1111'1111},
cso.max_matching = {1},
input = 1.8f,
correct = 2),
// cannot match ( > max_matching not only due to fractional part)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b1111'1111},
cso.max_matching = {1},
input = 2.1f,
correct = 1),
// cannot match ( > max_matching not only due to fractional part)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b1111'1111},
cso.max_matching = {1},
input = 2.8f,
correct = 1),
// does not match (bits[category] == 0, category == 0)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b0000'0000},
cso.max_matching = {0},
input = 0,
correct = 1),
// matches
// matches (negative zero)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b0000'0001},
cso.max_matching = {0},
input = -0.0f,
correct = 2),
// matches (positive zero)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b0000'0001},
cso.max_matching = {0},
Expand All @@ -228,7 +265,7 @@ std::vector<ChildIndexTestParams> params = {
// matches
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b0000'0101},
cso.max_matching = {2, -1},
cso.max_matching = {2, 0},
input = 2,
correct = 2),
// does not match (bits[category] == 0, category > 0)
Expand All @@ -237,13 +274,25 @@ std::vector<ChildIndexTestParams> params = {
cso.max_matching = {2},
input = 1,
correct = 1),
// cannot match (max_matching[fid=1] == -1)
// cannot match (max_matching[fid=1] < input)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
node.fid = 1,
cso.bits = {0b0000'0101},
cso.max_matching = {2, -1},
cso.max_matching = {2, 0},
input = 2,
correct = 1),
// default left
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true, def_left = true),
cso.bits = {0b0000'0101},
cso.max_matching = {2},
input = NAN,
correct = 1),
// default right
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true, def_left = false),
cso.bits = {0b0000'0101},
cso.max_matching = {2},
input = NAN,
correct = 2),
};

TEST_P(ChildIndexTestDense, Predict) { check(); }
Expand Down
7 changes: 4 additions & 3 deletions cpp/test/sg/fil_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ struct replace_some_floating_with_categorical {
{
int max_matching_cat = max_matching_cat_d[data_idx % num_cols];
if (max_matching_cat == -1) return data;
return roundf((data * 0.5f + 0.5f) * max_matching_cat);
// also test invalid (negative) categories
return roundf((data * 0.5f + 0.5f) * max_matching_cat - 1.0);
}
};

Expand Down Expand Up @@ -305,8 +306,8 @@ class BaseFilTest : public testing::TestWithParam<FilTestParams> {
for (int fid = 0; fid < ps.num_cols; ++fid) {
feature_categorical[fid] = fc(gen);
if (feature_categorical[fid]) {
// even for some categorical features, we will have no matching categories
float mm = pow(10, mmc(gen)) - 1.0f;
// categorical features will never have max_matching == -1
float mm = pow(10, mmc(gen));
ASSERT(mm < INT_MAX,
"internal error: max_magnitude_of_matching_cat %f is too large",
ps.max_magnitude_of_matching_cat);
Expand Down
12 changes: 11 additions & 1 deletion python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,17 @@ cdef class ForestInference_impl():
Parameters
----------
X : float32 array-like (device or host) shape = (n_samples, n_features)
For optimal performance, pass a device array with C-style layout
For optimal performance, pass a device array with C-style layout.
For categorical features: category < 0.0 or category > 16'777'214
is equivalent to out-of-dictionary category (not matching).
-0.0 represents category 0.
If float(int(category)) != category, we will discard the
fractional part. E.g. 3.8 represents category 3 regardless of
max_matching value. FIL will reject a model where an integer
within [0, max_matching + 1] cannot be represented precisely
as a float32.
NANs work the same between numerical and categorical inputs:
they are missing values and follow Treelite's DefaultLeft.
preds : float32 device array, shape = n_samples
predict_proba : bool, whether to output class probabilities(vs classes)
Supported only for binary classification. output format
Expand Down
38 changes: 27 additions & 11 deletions python/cuml/test/test_fil.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import os
import pandas as pd
from random import sample, seed
from math import ceil

from cuml import ForestInference
from cuml.test.utils import array_equal, unit_param, \
Expand Down Expand Up @@ -491,17 +491,33 @@ def test_output_args(small_classifier_and_preds):
assert array_equal(fil_preds, xgb_preds, 1e-3)


def to_categorical(features, n_categorical):
def to_categorical(features, n_categorical, invalid_pct, rng):
""" returns data in two formats: pandas (for LightGBM) and numpy (for FIL)
"""
# the main bottleneck (>80%) of to_categorical() is the pandas operations
n_features = features.shape[1]
df_cols = {}
# all categorical columns
cat_cols = features[:, :n_categorical]
cat_cols = cat_cols - cat_cols.min(axis=1, keepdims=True) # range [0, ?]
cat_cols /= cat_cols.max(axis=1, keepdims=True) # range [0, 1]
cat_cols = cat_cols - cat_cols.min(axis=0, keepdims=True) # range [0, ?]
cat_cols /= cat_cols.max(axis=0, keepdims=True) # range [0, 1]
rough_n_categories = 100
# round into rough_n_categories bins
cat_cols = (cat_cols * rough_n_categories).astype(int)
# randomly inject invalid categories
invalid_idx = rng.choice(
a=cat_cols.size,
size=ceil(cat_cols.size * invalid_pct / 100),
replace=False,
shuffle=False)
cat_cols.flat[invalid_idx] += rough_n_categories

new_features = features.copy()
new_features[:, :n_categorical] = cat_cols

# shuffle the columns around
new_idx = rng.choice(n_features, n_features, replace=False, shuffle=True)
new_matrix = new_features[:, new_idx]
for icol in range(n_categorical):
col = cat_cols[:, icol]
df_cols[icol] = pd.Series(pd.Categorical(col,
Expand All @@ -510,11 +526,9 @@ def to_categorical(features, n_categorical):
for icol in range(n_categorical, n_features):
df_cols[icol] = pd.Series(features[:, icol])
# shuffle the columns around
seed(42)
new_idx = sample(range(n_features), k=n_features)
df_cols = {i: df_cols[new_idx[i]] for i in range(n_features)}

return pd.DataFrame(df_cols)
return pd.DataFrame(df_cols), new_matrix


@pytest.mark.parametrize('num_classes', [2, 5])
Expand All @@ -532,14 +546,16 @@ def test_lightgbm(tmp_path, num_classes, n_categorical):
n_rows = 500
n_informative = 'auto'

state = np.random.RandomState(43210)
X, y = simulate_data(n_rows,
n_features,
num_classes,
n_informative=n_informative,
random_state=43210,
random_state=state,
classification=True)
rng = np.random.default_rng(hash(state))
if n_categorical > 0:
X_fit = to_categorical(X, n_categorical)
X_fit, X = to_categorical(X, n_categorical, 10, rng)
else:
X_fit = X

Expand All @@ -560,7 +576,7 @@ def test_lightgbm(tmp_path, num_classes, n_categorical):
# binary classification
gbm_proba = bst.predict(X)
fil_proba = fm.predict_proba(X)[:, 1]
gbm_preds = (gbm_proba > 0.5)
gbm_preds = (gbm_proba > 0.5).astype(int)
fil_preds = fm.predict(X)
assert array_equal(gbm_preds, fil_preds)
np.testing.assert_allclose(gbm_proba, fil_proba,
Expand All @@ -572,11 +588,11 @@ def test_lightgbm(tmp_path, num_classes, n_categorical):
n_estimators=num_round)
lgm.fit(X_fit, y)
lgm.booster_.save_model(model_path)
lgm_preds = lgm.predict(X).astype(int)
fm = ForestInference.load(model_path,
algo='TREE_REORG',
output_class=True,
model_type="lightgbm")
lgm_preds = lgm.predict(X)
assert array_equal(lgm.booster_.predict(X).argmax(axis=1), lgm_preds)
assert array_equal(lgm_preds, fm.predict(X))
# lightgbm uses float64 thresholds, while FIL uses float32
Expand Down