Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of NaNs in categorial splits of LightGBM models #304

Merged
merged 4 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions src/frontend/lightgbm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -554,25 +554,19 @@ inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi) {

tree.AddChilds(new_id);
if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
// categorical
// categorical split
const int cat_idx = static_cast<int>(lgb_tree.threshold[old_id]);
const std::vector<uint32_t> left_categories
= BitsetToList(lgb_tree.cat_threshold.data()
+ lgb_tree.cat_boundaries[cat_idx],
lgb_tree.cat_boundaries[cat_idx + 1]
- lgb_tree.cat_boundaries[cat_idx]);
const bool missing_value_to_zero = missing_type != MissingType::kNaN;
// For categorical splits, we ignore the missing type field. NaNs always get mapped to
// the right child node.
bool default_left = false;
if (missing_value_to_zero) {
// If missing_value_to_zero flag is true, all missing values get mapped to 0.0, so
// we need to override the default_left flag
default_left
= (std::find(left_categories.begin(), left_categories.end(),
static_cast<uint32_t>(0)) != left_categories.end());
}
tree.SetCategoricalSplit(new_id, split_index, default_left, left_categories, false);
} else {
// numerical
// numerical split
const auto threshold = static_cast<double>(lgb_tree.threshold[old_id]);
bool default_left
= GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-0.036716360109071977
-0.040451733276673375
-0.088410677452890038
0.21990291345117738
5 changes: 4 additions & 1 deletion tests/python/test_gtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,5 +340,8 @@ def test_lightgbm_sparse_categorical_model():
X, _ = load_svmlight_file(dataset_db[dataset].dtest, zero_based=True,
n_features=tl_model.num_feature)
expected_pred = load_txt(dataset_db[dataset].expected_margin)
out_pred = treelite.gtil.predict(tl_model, X.toarray(), pred_margin=True)
# GTIL doesn't yet support sparse matrix; so use NaN to represent missing values
Xa = X.toarray()
Xa[Xa == 0] = 'nan'
out_pred = treelite.gtil.predict(tl_model, Xa, pred_margin=True)
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)
26 changes: 26 additions & 0 deletions tests/python/test_lightgbm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,29 @@ def test_constant_tree():
model_path = _qualify_path('lightgbm_constant_tree', 'model_with_constant_tree.txt')
model = treelite.Model.load(model_path, model_format='lightgbm')
assert model.num_tree == 2


@pytest.mark.parametrize('toolchain', os_compatible_toolchains())
def test_nan_handling_with_categorical_splits(tmpdir, toolchain):
"""Test that NaN inputs are handled correctly in categorical splits"""

# Test case taken from https://github.com/dmlc/treelite/issues/277
X = np.array(30 * [[1]] + 30 * [[2]] + 30 * [[0]])
y = np.array(60 * [5] + 30 * [10])
train_data = lightgbm.Dataset(X, label=y, categorical_feature=[0])
bst = lightgbm.train({}, train_data, 1)

model_path = os.path.join(tmpdir, 'dummy_categorical.txt')
libpath = os.path.join(tmpdir, 'dummy_categorical_lgb' + _libext())

input_with_nan = np.array([[np.NaN], [0.0]])

lgb_pred = bst.predict(input_with_nan)
bst.save_model(model_path)

model = treelite.Model.load(model_path, model_format='lightgbm')
model.export_lib(toolchain=toolchain, libpath=libpath)
predictor = treelite_runtime.Predictor(libpath)
dmat = treelite_runtime.DMatrix(input_with_nan)
tl_pred = predictor.predict(dmat)
np.testing.assert_almost_equal(tl_pred, lgb_pred)