Skip to content

Commit

Permalink
Store data frequencies in tree nodes of RF (#3647)
Browse files Browse the repository at this point in the history
The RF model should store the number of data points associated with each tree node. This information is useful in many ways, including:
* Visualizing the trees
* Debugging performance problem
* Computing SHAP values using the TreeSHAP algorithm

To that end, this PR does the following:
* Add the `instance_count` field in the `SparseTreeNode` structure
* Expose the `instance_count` field in the JSON dump
* Add a unit test to ensure that the counts in the JSON dump are correct.

Note that this feature will work with the new backend only. If the old backend is used, `instance_count` field will be absent in the JSON dump.

Closes #3131

Authors:
  - Philip Hyunsu Cho (@hcho3)

Approvers:
  - John Zedlewski (@JohnZed)

URL: #3647
  • Loading branch information
hcho3 authored Mar 25, 2021
1 parent 5a82fdb commit c48c081
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 5 deletions.
1 change: 1 addition & 0 deletions cpp/include/cuml/tree/flatnode.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct SparseTreeNode {
DataT best_metric_val;
IdxT left_child_id = IdxT(-1);
uint32_t unique_id = UINT32_MAX;
uint32_t instance_count = UINT32_MAX; // UINT32_MAX indicates n/a
};

template <typename T, typename L>
Expand Down
1 change: 1 addition & 0 deletions cpp/src/decisiontree/batched-levelalgo/builder.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void convertToSparse(const Builder<Traits>& b,
for (IdxT i = 0; i < b.h_total_nodes; ++i) {
const auto& hnode = h_nodes[i].info;
sparsetree[i + len] = hnode;
sparsetree[i + len].instance_count = h_nodes[i].count;
if (hnode.left_child_id != -1) sparsetree[i + len].left_child_id += len;
}
}
Expand Down
14 changes: 10 additions & 4 deletions cpp/src/decisiontree/decisiontree_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,11 @@ std::string get_node_json(const std::string &prefix,
oss << prefix << "{\"nodeid\": " << idx
<< ", \"split_feature\": " << node.colid
<< ", \"split_threshold\": " << to_string_high_precision(node.quesval)
<< ", \"gain\": " << to_string_high_precision(node.best_metric_val)
<< ", \"yes\": " << node.left_child_id
<< ", \"gain\": " << to_string_high_precision(node.best_metric_val);
if (node.instance_count != UINT32_MAX) {
oss << ", \"instance_count\": " << node.instance_count;
}
oss << ", \"yes\": " << node.left_child_id
<< ", \"no\": " << (node.left_child_id + 1) << ", \"children\": [\n";
// enter the next tree level - left and right branch
oss << get_node_json(prefix + " ", sparsetree, node.left_child_id) << ",\n"
Expand All @@ -128,8 +131,11 @@ std::string get_node_json(const std::string &prefix,
<< prefix << "]}";
} else {
oss << prefix << "{\"nodeid\": " << idx
<< ", \"leaf_value\": " << to_string_high_precision(node.prediction)
<< "}";
<< ", \"leaf_value\": " << to_string_high_precision(node.prediction);
if (node.instance_count != UINT32_MAX) {
oss << ", \"instance_count\": " << node.instance_count;
}
oss << "}";
}
return oss.str();
}
Expand Down
43 changes: 43 additions & 0 deletions python/cuml/test/dask/test_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,49 @@ def predict_with_json_rf_regressor(rf, x):
np.testing.assert_almost_equal(pred, expected_pred, decimal=6)


@pytest.mark.parametrize('max_depth', [1, 2, 3, 5, 10, 15, 20])
@pytest.mark.parametrize('n_estimators', [5, 10, 20])
def test_rf_instance_count(client, max_depth, n_estimators):
n_workers = len(client.scheduler_info()['workers'])
if n_estimators < n_workers:
err_msg = "n_estimators cannot be lower than number of dask workers"
pytest.xfail(err_msg)

X, y = make_classification(n_samples=350, n_features=20,
n_clusters_per_class=1, n_informative=10,
random_state=123, n_classes=2)
X = X.astype(np.float32)
cu_rf_mg = cuRFC_mg(max_features=1.0, max_samples=1.0,
n_bins=16, split_algo=1, split_criterion=0,
min_samples_leaf=2, seed=23707, n_streams=1,
n_estimators=n_estimators, max_leaves=-1,
max_depth=max_depth, use_experimental_backend=True)
y = y.astype(np.int32)

X_dask, y_dask = _prep_training_data(client, X, y, partitions_per_worker=2)
cu_rf_mg.fit(X_dask, y_dask)
json_out = cu_rf_mg.get_json()
json_obj = json.loads(json_out)

# The instance count of each node must be equal to the sum of
# the instance counts of its children
def check_instance_count_for_non_leaf(tree):
assert 'instance_count' in tree
if 'children' not in tree:
return
assert 'instance_count' in tree['children'][0]
assert 'instance_count' in tree['children'][1]
assert (tree['instance_count'] == tree['children'][0]['instance_count']
+ tree['children'][1]['instance_count'])
check_instance_count_for_non_leaf(tree['children'][0])
check_instance_count_for_non_leaf(tree['children'][1])

for tree in json_obj:
check_instance_count_for_non_leaf(tree)
# The root's count should be equal to the number of rows in the data
assert tree['instance_count'] == X.shape[0]


@pytest.mark.parametrize('estimator_type', ['regression', 'classification'])
def test_rf_get_combined_model_right_aftter_fit(client, estimator_type):
max_depth = 3
Expand Down
57 changes: 56 additions & 1 deletion python/cuml/test/test_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,58 @@ def predict_with_json_rf_regressor(rf, x):
np.testing.assert_almost_equal(pred, expected_pred, decimal=6)


@pytest.mark.parametrize('max_depth', [1, 2, 3, 5, 10, 15, 20])
@pytest.mark.parametrize('n_estimators', [5, 10, 20])
@pytest.mark.parametrize('use_experimental_backend', [True, False])
def test_rf_instance_count(max_depth, n_estimators, use_experimental_backend):
X, y = make_classification(n_samples=350, n_features=20,
n_clusters_per_class=1, n_informative=10,
random_state=123, n_classes=2)
X = X.astype(np.float32)
cuml_model = curfc(max_features=1.0, max_samples=1.0,
n_bins=16, split_algo=1, split_criterion=0,
min_samples_leaf=2, seed=23707, n_streams=1,
n_estimators=n_estimators, max_leaves=-1,
max_depth=max_depth,
use_experimental_backend=use_experimental_backend)
y = y.astype(np.int32)

# Train model on the data
cuml_model.fit(X, y)

json_out = cuml_model.get_json()
json_obj = json.loads(json_out)

# The instance count of each node must be equal to the sum of
# the instance counts of its children. Note that the instance count
# is only available with the new backend.
if use_experimental_backend:
def check_instance_count_for_non_leaf(tree):
assert 'instance_count' in tree
if 'children' not in tree:
return
assert 'instance_count' in tree['children'][0]
assert 'instance_count' in tree['children'][1]
assert (tree['instance_count']
== tree['children'][0]['instance_count']
+ tree['children'][1]['instance_count'])
check_instance_count_for_non_leaf(tree['children'][0])
check_instance_count_for_non_leaf(tree['children'][1])
for tree in json_obj:
check_instance_count_for_non_leaf(tree)
# The root's count must be equal to the number of rows in the data
assert tree['instance_count'] == X.shape[0]
else:
def assert_instance_count_absent(tree):
assert 'instance_count' not in tree
if 'children' not in tree:
return
assert_instance_count_absent(tree['children'][0])
assert_instance_count_absent(tree['children'][1])
for tree in json_obj:
assert_instance_count_absent(tree)


@pytest.mark.memleak
@pytest.mark.parametrize('estimator_type', ['classification'])
def test_rf_host_memory_leak(large_clf, estimator_type):
Expand Down Expand Up @@ -987,4 +1039,7 @@ def test_rf_regression_with_identical_labels(split_criterion,
clf.fit(X, y)
model_dump = json.loads(clf.get_json())
assert len(model_dump) == 1
assert model_dump[0] == {'nodeid': 0, 'leaf_value': 1.0}
expected_dump = {'nodeid': 0, 'leaf_value': 1.0}
if use_experimental_backend:
expected_dump['instance_count'] = 5
assert model_dump[0] == expected_dump

0 comments on commit c48c081

Please sign in to comment.