From 5668104689e0cfabe8dea3866846ea34d94bfa13 Mon Sep 17 00:00:00 2001 From: jinlow Date: Sat, 22 Apr 2023 20:27:01 -0500 Subject: [PATCH 1/4] Sped up contribution calculation --- py-forust/forust/__init__.py | 40 +++++ py-forust/src/lib.rs | 16 ++ py-forust/tests/test_booster.py | 46 ++++++ src/data.rs | 15 ++ src/gradientbooster.rs | 25 ++++ src/node.rs | 107 +++++++------- src/partial_dependence.rs | 70 ++++----- src/splitter.rs | 34 +---- src/tree.rs | 251 ++++++++++++++++++++------------ src/utils.rs | 11 ++ 10 files changed, 405 insertions(+), 210 deletions(-) diff --git a/py-forust/forust/__init__.py b/py-forust/forust/__init__.py index bdf39ee..b15cc38 100644 --- a/py-forust/forust/__init__.py +++ b/py-forust/forust/__init__.py @@ -35,6 +35,15 @@ def predict( ) -> np.ndarray: raise NotImplementedError() + def predict_contributions( + self, + flat_data: np.ndarray, + rows: int, + cols: int, + parallel: bool = True, + ) -> np.ndarray: + raise NotImplementedError + def value_partial_dependence( self, feature: int, @@ -238,6 +247,37 @@ def predict(self, X: FrameLike, parallel: Union[bool, None] = None) -> np.ndarra parallel=parallel_, ) + def predict_contributions( + self, X: FrameLike, parallel: Union[bool, None] = None + ) -> np.ndarray: + """Predict with the fitted booster on new data, returning the feature + contribution matrix. The last column is the bias term. + + Args: + X (FrameLike): Either a pandas DataFrame, or a 2 dimensional numpy array. + parallel (Union[bool, None], optional): Optionally specify if the predict + function should run in parallel on multiple threads. If `None` is + passed, the `parallel` attribute of the booster will be used. + Defaults to `None`. + + Returns: + np.ndarray: Returns a numpy array of the predictions. + """ + X_ = X.to_numpy() if isinstance(X, pd.DataFrame) else X + if not np.issubdtype(X_.dtype, "float64"): + X_ = X_.astype(dtype="float64", copy=False) + + parallel_ = self.parallel if parallel is None else parallel + flat_data = X_.ravel(order="F") + rows, cols = X_.shape + contributions = self.booster.predict_contributions( + flat_data=flat_data, + rows=rows, + cols=cols, + parallel=parallel_, + ) + return np.reshape(contributions, (X_.shape[0], X_.shape[1] + 1)) + def partial_dependence(self, X: FrameLike, feature: Union[str, int]) -> np.ndarray: """Calculate the partial dependence values of a feature. For each unique value of the feature, this gives the estimate of the predicted value for that diff --git a/py-forust/src/lib.rs b/py-forust/src/lib.rs index e4ea916..b48ecda 100644 --- a/py-forust/src/lib.rs +++ b/py-forust/src/lib.rs @@ -112,6 +112,22 @@ impl GradientBooster { let parallel = parallel.unwrap_or(true); Ok(self.booster.predict(&data, parallel).into_pyarray(py)) } + pub fn predict_contributions<'py>( + &self, + py: Python<'py>, + flat_data: PyReadonlyArray1, + rows: usize, + cols: usize, + parallel: Option, + ) -> PyResult<&'py PyArray1> { + let flat_data = flat_data.as_slice()?; + let data = Matrix::new(flat_data, rows, cols); + let parallel = parallel.unwrap_or(true); + Ok(self + .booster + .predict_contributions(&data, parallel) + .into_pyarray(py)) + } pub fn value_partial_dependence(&self, feature: usize, value: f64) -> PyResult { Ok(self.booster.value_partial_dependence(feature, value)) diff --git a/py-forust/tests/test_booster.py b/py-forust/tests/test_booster.py index ee4a9fa..50162cb 100644 --- a/py-forust/tests/test_booster.py +++ b/py-forust/tests/test_booster.py @@ -263,3 +263,49 @@ def test_monotone_constraints(X_y): assert np.all(p_d[0:-1, 1] >= p_d[1:, 1]) else: assert np.all(p_d[0:-1, 1] <= p_d[1:, 1]) + + +def test_booster_to_xgboosts_with_contributions(X_y): + X, y = X_y + X = X + fmod = GradientBooster( + iterations=100, + learning_rate=0.3, + max_depth=5, + l2=1, + min_leaf_weight=1, + gamma=1, + objective_type="LogLoss", + nbins=500, + parallel=False, + base_score=0.0, + ) + fmod.fit(X, y=y) + fmod_preds = fmod.predict(X) + fmod_contribs = fmod.predict_contributions(X) + fmod_preds[~np.isclose(fmod_contribs.sum(1), fmod_preds, rtol=5)] + fmod_contribs.sum(1)[~np.isclose(fmod_contribs.sum(1), fmod_preds, rtol=5)] + assert fmod_contribs.shape[1] == X.shape[1] + 1 + assert np.allclose(fmod_contribs.sum(1), fmod_preds) + + xmod = XGBClassifier( + n_estimators=100, + learning_rate=0.3, + max_depth=5, + reg_lambda=1, + min_child_weight=1, + gamma=1, + objective="binary:logitraw", + eval_metric="auc", + tree_method="hist", + max_bin=10000, + base_score=0.0, + ) + xmod.fit(X, y) + xmod_preds = xmod.predict(X, output_margin=True) + import xgboost as xgb + + xmod_contribs = xmod.get_booster().predict( + xgb.DMatrix(X), approx_contribs=True, pred_contribs=True + ) + assert np.allclose(fmod_contribs, xmod_contribs, atol=0.000001) diff --git a/src/data.rs b/src/data.rs index 9a1c5ba..4f59fd3 100644 --- a/src/data.rs +++ b/src/data.rs @@ -93,6 +93,7 @@ pub struct Matrix<'a, T> { } impl<'a, T> Matrix<'a, T> { + // Defaults to column major pub fn new(data: &'a [T], rows: usize, cols: usize) -> Self { Matrix { data, @@ -137,6 +138,20 @@ impl<'a, T> Matrix<'a, T> { } } +/// A lightweight row major matrix, this is primarily +/// for returning data to the user, it is especially +/// suited for appending rows to, such as when building +/// up a matrix of contributions to return to the +/// user, the added benefit is it will be even +/// faster to return to numpy. +// pub struct RowMajorMatrix { +// pub data: Vec, +// pub rows: usize, +// pub cols: usize, +// stride1: usize, +// stride2: usize, +// } + impl<'a, T> fmt::Display for Matrix<'a, T> where T: FromStr + std::fmt::Display, diff --git a/src/gradientbooster.rs b/src/gradientbooster.rs index 12038d9..b00404b 100644 --- a/src/gradientbooster.rs +++ b/src/gradientbooster.rs @@ -5,6 +5,7 @@ use crate::errors::ForustError; use crate::objective::{gradient_hessian_callables, ObjectiveType}; use crate::splitter::MissingImputerSplitter; use crate::tree::Tree; +use rayon::prelude::*; use serde::{Deserialize, Serialize}; use std::fs; @@ -215,6 +216,28 @@ impl GradientBooster { init_preds } + /// Generate predictions on data using the gradient booster. + /// + /// * `data` - Either a pandas DataFrame, or a 2 dimensional numpy array. + pub fn predict_contributions(&self, data: &Matrix, parallel: bool) -> Vec { + let weights: Vec> = if parallel { + self.trees + .par_iter() + .map(|t| t.distribute_leaf_weights()) + .collect() + } else { + self.trees + .iter() + .map(|t| t.distribute_leaf_weights()) + .collect() + }; + let mut contribs = vec![0.; (data.cols + 1) * data.rows]; + self.trees.iter().zip(weights.iter()).for_each(|(t, w)| { + t.predict_contributions(data, &mut contribs, w); + }); + contribs + } + /// Given a value, return the partial dependence value of that value for that /// feature in the model. /// @@ -390,6 +413,8 @@ mod tests { let sample_weight = vec![1.; y.len()]; booster.fit(&data, &y, &sample_weight).unwrap(); let preds = booster.predict(&data, false); + let contribs = booster.predict_contributions(&data, false); + assert_eq!(contribs.len(), (data.cols + 1) * data.rows); println!("{}", booster.trees[0]); println!("{}", booster.trees[0].nodes.len()); println!("{}", booster.trees.last().unwrap().nodes.len()); diff --git a/src/node.rs b/src/node.rs index 0fec207..4c60b8d 100644 --- a/src/node.rs +++ b/src/node.rs @@ -23,10 +23,11 @@ pub struct SplittableNode { pub stop_idx: usize, pub lower_bound: f32, pub upper_bound: f32, + pub is_leaf: bool, } #[derive(Deserialize, Serialize)] -pub struct ParentNode { +pub struct Node { pub num: usize, pub weight_value: f32, pub hessian_sum: f32, @@ -37,6 +38,32 @@ pub struct ParentNode { pub missing_node: usize, pub left_child: usize, pub right_child: usize, + pub is_leaf: bool, +} + +impl Node { + /// Update all the info that is needed if this node is a + /// parent node, this consumes the SplitableNode. + pub fn make_parent_node(&mut self, split_node: SplittableNode) { + self.is_leaf = false; + self.missing_node = split_node.missing_node; + self.split_value = split_node.split_value; + self.split_feature = split_node.split_feature; + self.split_gain = split_node.split_gain; + self.left_child = split_node.left_child; + self.right_child = split_node.right_child; + } + /// Get the path that should be traveled down, given a value. + pub fn get_child_idx(&self, v: &f64) -> usize { + if v.is_nan() { + self.missing_node + } else if v < &self.split_value { + self.left_child + } else { + // if v >= &self.split_value + self.right_child + } + } } #[derive(Deserialize, Serialize)] @@ -47,13 +74,6 @@ pub struct LeafNode { pub depth: usize, } -#[derive(Deserialize, Serialize)] -pub enum TreeNode { - Parent(ParentNode), - Leaf(LeafNode), - Splittable(SplittableNode), -} - impl SplittableNode { pub fn from_node_info( num: usize, @@ -81,9 +101,12 @@ impl SplittableNode { stop_idx, lower_bound: node_info.bounds.0, upper_bound: node_info.bounds.1, + is_leaf: true, } } + /// Create a default splitable node, + /// we default to the node being a leaf. #[allow(clippy::too_many_arguments)] pub fn new( num: usize, @@ -116,6 +139,7 @@ impl SplittableNode { stop_idx, lower_bound, upper_bound, + is_leaf: true, } } @@ -136,17 +160,10 @@ impl SplittableNode { MissingInfo::Branch(_) => todo!(), MissingInfo::EmptyBranch => todo!(), }; + self.is_leaf = false; } - pub fn as_leaf_node(&self) -> TreeNode { - TreeNode::Leaf(LeafNode { - num: self.num, - weight_value: self.weight_value, - hessian_sum: self.hessian_sum, - depth: self.depth, - }) - } - pub fn as_parent_node(&self) -> TreeNode { - TreeNode::Parent(ParentNode { + pub fn as_node(&self) -> Node { + Node { num: self.num, weight_value: self.weight_value, hessian_sum: self.hessian_sum, @@ -157,47 +174,33 @@ impl SplittableNode { split_gain: self.split_gain, left_child: self.left_child, right_child: self.right_child, - }) + is_leaf: self.is_leaf, + } } } -impl fmt::Display for TreeNode { +impl fmt::Display for Node { // This trait requires `fmt` with this exact signature. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - TreeNode::Leaf(leaf) => write!( + if self.is_leaf { + write!( f, "{}:leaf={},cover={}", - leaf.num, leaf.weight_value, leaf.hessian_sum - ), - TreeNode::Parent(parent) => { - write!( - f, - "{}:[{} < {}] yes={},no={},missing={},gain={},cover={}", - parent.num, - parent.split_feature, - parent.split_value, - parent.left_child, - parent.right_child, - parent.missing_node, - parent.split_gain, - parent.hessian_sum - ) - } - TreeNode::Splittable(node) => { - write!( - f, - "SPLITTABLE - {}:[{} < {}] yes={},no={},missing={},gain={},cover={}", - node.num, - node.split_feature, - node.split_value, - node.missing_node, - node.left_child, - node.right_child, - node.split_gain, - node.hessian_sum - ) - } + self.num, self.weight_value, self.hessian_sum + ) + } else { + write!( + f, + "{}:[{} < {}] yes={},no={},missing={},gain={},cover={}", + self.num, + self.split_feature, + self.split_value, + self.left_child, + self.right_child, + self.missing_node, + self.split_gain, + self.hessian_sum + ) } } } diff --git a/src/partial_dependence.rs b/src/partial_dependence.rs index e731cbc..3fc0158 100644 --- a/src/partial_dependence.rs +++ b/src/partial_dependence.rs @@ -1,4 +1,3 @@ -use crate::node::TreeNode; use crate::tree::Tree; /// Partial Dependence Calculator @@ -10,11 +9,7 @@ use crate::tree::Tree; // } fn get_node_cover(tree: &Tree, node_idx: usize) -> f32 { - match &tree.nodes[node_idx] { - TreeNode::Leaf(n) => n.hessian_sum, - TreeNode::Parent(n) => n.hessian_sum, - TreeNode::Splittable(_) => unreachable!(), - } + tree.nodes[node_idx].hessian_sum } pub fn tree_partial_dependence( @@ -24,40 +19,35 @@ pub fn tree_partial_dependence( value: f64, proportion: f32, ) -> f64 { - let node = &tree.nodes[node_idx]; - match node { - TreeNode::Leaf(n) => f64::from(proportion * n.weight_value), - TreeNode::Parent(n) => { - if n.split_feature == feature { - let child = if value.is_nan() { - n.missing_node - } else if value < n.split_value { - n.left_child - } else { - n.right_child - }; - tree_partial_dependence(tree, child, feature, value, proportion) - } else { - let left_cover = get_node_cover(tree, n.left_child); - let right_cover = get_node_cover(tree, n.right_child); - let total_cover = left_cover + right_cover; - - tree_partial_dependence( - tree, - n.left_child, - feature, - value, - proportion * (left_cover / total_cover), - ) + tree_partial_dependence( - tree, - n.right_child, - feature, - value, - proportion * (right_cover / total_cover), - ) - } - } - TreeNode::Splittable(_) => unreachable!(), + let n = &tree.nodes[node_idx]; + if n.is_leaf { + f64::from(proportion * n.weight_value) + } else if n.split_feature == feature { + let child = if value.is_nan() { + n.missing_node + } else if value < n.split_value { + n.left_child + } else { + n.right_child + }; + tree_partial_dependence(tree, child, feature, value, proportion) + } else { + let left_cover = get_node_cover(tree, n.left_child); + let right_cover = get_node_cover(tree, n.right_child); + let total_cover = left_cover + right_cover; + tree_partial_dependence( + tree, + n.left_child, + feature, + value, + proportion * (left_cover / total_cover), + ) + tree_partial_dependence( + tree, + n.right_child, + feature, + value, + proportion * (right_cover / total_cover), + ) } } diff --git a/src/splitter.rs b/src/splitter.rs index 36d8f90..5562805 100644 --- a/src/splitter.rs +++ b/src/splitter.rs @@ -1,10 +1,7 @@ -use std::collections::VecDeque; - use crate::constraints::{Constraint, ConstraintMap}; use crate::data::{JaggedMatrix, Matrix}; use crate::histogram::HistogramMatrix; use crate::node::SplittableNode; -use crate::node::TreeNode; use crate::utils::{constrained_weight, cull_gain, gain_given_weight, pivot_on_split, weight}; #[derive(Debug)] @@ -186,8 +183,7 @@ pub trait Splitter { grad: &[f32], hess: &[f32], parallel: bool, - growable_buffer: &mut VecDeque, - ) -> Vec; + ) -> Vec; #[allow(clippy::too_many_arguments)] fn split_node( @@ -200,20 +196,10 @@ pub trait Splitter { grad: &[f32], hess: &[f32], parallel: bool, - growable_buffer: &mut VecDeque, - ) -> Vec { + ) -> Vec { match self.best_split(node) { Some(split_info) => self.handle_split_info( - split_info, - n_nodes, - node, - index, - data, - cuts, - grad, - hess, - parallel, - growable_buffer, + split_info, n_nodes, node, index, data, cuts, grad, hess, parallel, ), None => Vec::new(), } @@ -354,8 +340,7 @@ impl Splitter for MissingBranchSplitter { grad: &[f32], hess: &[f32], parallel: bool, - growable_buffer: &mut VecDeque, - ) -> Vec { + ) -> Vec { todo!() } } @@ -560,8 +545,7 @@ impl Splitter for MissingImputerSplitter { grad: &[f32], hess: &[f32], parallel: bool, - growable_buffer: &mut VecDeque, - ) -> Vec { + ) -> Vec { let left_idx = *n_nodes; let right_idx = left_idx + 1; @@ -641,13 +625,7 @@ impl Splitter for MissingImputerSplitter { node.stop_idx, split_info.right_node, ); - growable_buffer.push_front(left_idx); - growable_buffer.push_front(right_idx); - // It has children, so we know it's going to be a parent node - vec![ - TreeNode::Splittable(left_node), - TreeNode::Splittable(right_node), - ] + vec![left_node, right_node] } fn get_constraint(&self, feature: &usize) -> Option<&Constraint> { diff --git a/src/tree.rs b/src/tree.rs index 25bf3b3..40b1ae7 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -1,6 +1,6 @@ use crate::data::{JaggedMatrix, Matrix}; use crate::histogram::HistogramMatrix; -use crate::node::{SplittableNode, TreeNode}; +use crate::node::{Node, SplittableNode}; use crate::partial_dependence::tree_partial_dependence; use crate::splitter::Splitter; use crate::utils::fast_f64_sum; @@ -12,7 +12,7 @@ use std::fmt::{self, Display}; #[derive(Deserialize, Serialize)] pub struct Tree { - pub nodes: Vec, + pub nodes: Vec, } impl Default for Tree { @@ -63,94 +63,119 @@ impl Tree { f32::INFINITY, ); // Add the first node to the tree nodes. - self.nodes.push(TreeNode::Splittable(root_node)); + self.nodes.push(root_node.as_node()); let mut n_leaves = 1; - let mut growable = VecDeque::new(); - growable.push_front(0); + let mut growable: VecDeque = VecDeque::new(); + growable.push_front(root_node); while !growable.is_empty() { if n_leaves >= max_leaves { - // Clear the rest of the node idxs that - // are not needed. - for i in growable.iter() { - let n = &self.nodes[*i]; - if let TreeNode::Splittable(node) = n { - self.nodes[*i] = node.as_leaf_node(); - } - } break; } // We know there is a value here, because of how the // while loop is setup. - let n_idx = growable + // Grab a splitable node from the stack + // If we can split it, and update the corresponding + // tree nodes children. + let mut node = growable .pop_back() .expect("Growable buffer should not be empty."); - - let n = self.nodes.get_mut(n_idx); + let n_idx = node.num; // This will only be splittable nodes - if let Some(TreeNode::Splittable(node)) = n { - let depth = node.depth + 1; - - // If we have hit max depth, skip this node - // but keep going, because there may be other - // valid shallower nodes. - if depth > max_depth { - self.nodes[n_idx] = node.as_leaf_node(); - continue; - } - // For max_leaves, subtract 1 from the n_leaves - // every time we pop from the growable stack - // then, if we can add two children, add two to - // n_leaves. If we can't split the node any - // more, then just add 1 back to n_leaves - n_leaves -= 1; - - let new_nodes = splitter.split_node( - &n_nodes, - node, - &mut index, - data, - cuts, - grad, - hess, - parallel, - &mut growable, - ); - - let n_new_nodes = new_nodes.len(); - if n_new_nodes == 0 { - n_leaves += 1; - self.nodes[n_idx] = node.as_leaf_node(); - } else { - self.nodes[n_idx] = node.as_parent_node(); - n_leaves += n_new_nodes; - n_nodes += n_new_nodes; - self.nodes.extend(new_nodes); + let depth = node.depth + 1; + + // If we have hit max depth, skip this node + // but keep going, because there may be other + // valid shallower nodes. + if depth > max_depth { + // self.nodes[n_idx] = node.as_leaf_node(); + continue; + } + + // For max_leaves, subtract 1 from the n_leaves + // every time we pop from the growable stack + // then, if we can add two children, add two to + // n_leaves. If we can't split the node any + // more, then just add 1 back to n_leaves + n_leaves -= 1; + + let new_nodes = splitter.split_node( + &n_nodes, &mut node, &mut index, data, cuts, grad, hess, parallel, + ); + + let n_new_nodes = new_nodes.len(); + if n_new_nodes == 0 { + n_leaves += 1; + } else { + self.nodes[n_idx].make_parent_node(node); + n_leaves += n_new_nodes; + n_nodes += n_new_nodes; + for n in new_nodes { + self.nodes.push(n.as_node()); + growable.push_front(n) } } } } + pub fn predict_contributions_row( + &self, + data: &Matrix, + row: usize, + contribs: &mut [f64], + weights: &[f64], + ) { + // Add the bias term first... + contribs[data.cols] += weights[0]; + let mut node_idx = 0; + loop { + let node = &self.nodes[node_idx]; + if node.is_leaf { + break; + } + // Get change of weight given child's weight. + let child_idx = node.get_child_idx(data.get(row, node.split_feature)); + let node_weight = weights[node_idx]; + let child_weight = weights[child_idx]; + let delta = child_weight - node_weight; + contribs[node.split_feature] += delta; + node_idx = child_idx + } + } - pub fn predict_row(&self, data: &Matrix, row: usize) -> f64 { + fn predict_contributions_single_threaded( + &self, + data: &Matrix, + contribs: &mut [f64], + weights: &[f64], + ) { + // There needs to always be at least 2 trees + data.index + .iter() + .zip(contribs.chunks_mut(data.cols + 1)) + .for_each(|(row, contribs)| { + self.predict_contributions_row(data, *row, contribs, weights) + }) + } + // fn predict_contributions_parallel(&self, data: &Matrix, + // contribs: &mut [f64],weights: &[f64]) { + // data.index.par_iter().zip(contribs.chunks_mut(data.cols + 1)).for_each( + // |(row, contribs)| self.predict_contributions_row(data, *row, contribs, weights) + // ) + // } + + pub fn predict_contributions(&self, data: &Matrix, contribs: &mut [f64], weights: &[f64]) { + self.predict_contributions_single_threaded(data, contribs, weights) + } + + fn predict_row(&self, data: &Matrix, row: usize) -> f64 { let mut node_idx = 0; loop { - let n = &self.nodes[node_idx]; - match n { - TreeNode::Leaf(node) => { - return node.weight_value as f64; - } - TreeNode::Parent(node) => { - let v = data.get(row, node.split_feature); - if v.is_nan() { - node_idx = node.missing_node; - } else if v < &node.split_value { - node_idx = node.left_child; - } else if v >= &node.split_value { - node_idx = node.right_child; - } - } - _ => unreachable!(), + let node = &self.nodes[node_idx]; + if node.is_leaf { + return node.weight_value as f64; + } else { + node_idx = node.get_child_idx(data.get(row, node.split_feature)); } } } @@ -180,6 +205,26 @@ impl Tree { pub fn value_partial_dependence(&self, feature: usize, value: f64) -> f64 { tree_partial_dependence(self, 0, feature, value, 1.0) } + fn distribute_node_leaf_weights(&self, i: usize, weights: &mut [f64]) -> f64 { + let node = &self.nodes[i]; + let mut w = node.weight_value as f64; + if !node.is_leaf { + let left_node = &self.nodes[node.left_child]; + let right_node = &self.nodes[node.right_child]; + w = ((left_node.hessian_sum as f64 + * self.distribute_node_leaf_weights(node.left_child, weights)) + + (right_node.hessian_sum as f64 + * self.distribute_node_leaf_weights(node.right_child, weights))) + / (node.hessian_sum as f64); + } + weights[i] = w; + w + } + pub fn distribute_leaf_weights(&self) -> Vec { + let mut weights = vec![0.; self.nodes.len()]; + self.distribute_node_leaf_weights(0, &mut weights); + weights + } } impl Display for Tree { @@ -191,21 +236,13 @@ impl Display for Tree { // This will always be populated, because we confirm // that the buffer is not empty. let idx = print_buffer.pop().unwrap(); - let n = &self.nodes[idx]; - match n { - TreeNode::Leaf(node) => { - r += format!("{}{}\n", " ".repeat(node.depth).as_str(), n).as_str(); - } - TreeNode::Parent(node) => { - r += format!("{}{}\n", " ".repeat(node.depth).as_str(), n).as_str(); - print_buffer.push(node.right_child); - print_buffer.push(node.left_child); - } - TreeNode::Splittable(node) => { - r += format!("{}{}\n", " ".repeat(node.depth).as_str(), n).as_str(); - print_buffer.push(node.right_child); - print_buffer.push(node.left_child); - } + let node = &self.nodes[idx]; + if node.is_leaf { + r += format!("{}{}\n", " ".repeat(node.depth).as_str(), node).as_str(); + } else { + r += format!("{}{}\n", " ".repeat(node.depth).as_str(), node).as_str(); + print_buffer.push(node.right_child); + print_buffer.push(node.left_child); } } write!(f, "{}", r) @@ -219,6 +256,7 @@ mod tests { use crate::constraints::{Constraint, ConstraintMap}; use crate::objective::{LogLoss, ObjectiveFunction}; use crate::splitter::MissingImputerSplitter; + use crate::utils::precision_round; use std::fs; #[test] fn test_tree_fit() { @@ -249,10 +287,28 @@ mod tests { tree.fit(&bdata, &b.cuts, &g, &h, &splitter, usize::MAX, 5, true); - println!("{}", tree); - let preds = tree.predict(&data, false); - println!("{:?}", &preds[0..10]); - assert_eq!(25, tree.nodes.len()) + // println!("{}", tree); + // let preds = tree.predict(&data, false); + // println!("{:?}", &preds[0..10]); + assert_eq!(25, tree.nodes.len()); + // Test contributions prediction... + let weights = tree.distribute_leaf_weights(); + let mut contribs = vec![0.; (data.cols + 1) * data.rows]; + tree.predict_contributions(&data, &mut contribs, &weights); + let full_preds = tree.predict(&data, true); + assert_eq!(contribs.len(), (data.cols + 1) * data.rows); + + let contribs_preds: Vec = contribs + .chunks(data.cols + 1) + .map(|i| i.iter().sum()) + .collect(); + println!("{:?}", &contribs[0..10]); + println!("{:?}", &contribs_preds[0..10]); + + assert_eq!(contribs_preds.len(), full_preds.len()); + for (i, j) in full_preds.iter().zip(contribs_preds) { + assert_eq!(precision_round(*i, 7), precision_round(j, 7)); + } } #[test] @@ -295,5 +351,20 @@ mod tests { let preds = tree.predict(&pred_data, false); let increasing = preds.windows(2).all(|a| a[0] >= a[1]); assert!(increasing); + + let weights = tree.distribute_leaf_weights(); + + let mut contribs = vec![0.; (data.cols + 1) * data.rows]; + tree.predict_contributions(&data, &mut contribs, &weights); + let full_preds = tree.predict(&data, true); + assert_eq!(contribs.len(), (data.cols + 1) * data.rows); + let contribs_preds: Vec = contribs + .chunks(data.cols + 1) + .map(|i| i.iter().sum()) + .collect(); + assert_eq!(contribs_preds.len(), full_preds.len()); + for (i, j) in full_preds.iter().zip(contribs_preds) { + assert_eq!(precision_round(*i, 7), precision_round(j, 7)); + } } } diff --git a/src/utils.rs b/src/utils.rs index fd66371..afcd340 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -355,6 +355,12 @@ pub fn missing_compare(split_value: &u16, cmp_value: u16, missing_right: bool) - } } +#[inline] +pub fn precision_round(n: f64, precision: i32) -> f64 { + let p = (10.0_f64).powi(precision); + (n * p).round() / p +} + #[cfg(test)] mod tests { use super::*; @@ -363,6 +369,11 @@ mod tests { use rand::Rng; use rand::SeedableRng; #[test] + fn test_round() { + assert_eq!(0.3, precision_round(0.3333, 1)); + assert_eq!(0.2343, precision_round(0.2343123123123, 4)); + } + #[test] fn test_percentiles() { let v = vec![4., 5., 6., 1., 2., 3., 7., 8., 9., 10.]; let w = vec![1.; v.len()]; From 2ddd926b8d153979921f1f841a51ee10a5d89674 Mon Sep 17 00:00:00 2001 From: jinlow Date: Sat, 22 Apr 2023 21:02:47 -0500 Subject: [PATCH 2/4] Made it even faster --- src/gradientbooster.rs | 2 +- src/tree.rs | 40 +++++++++++++++++++++++++++++----------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/src/gradientbooster.rs b/src/gradientbooster.rs index b00404b..3057bfe 100644 --- a/src/gradientbooster.rs +++ b/src/gradientbooster.rs @@ -233,7 +233,7 @@ impl GradientBooster { }; let mut contribs = vec![0.; (data.cols + 1) * data.rows]; self.trees.iter().zip(weights.iter()).for_each(|(t, w)| { - t.predict_contributions(data, &mut contribs, w); + t.predict_contributions(data, &mut contribs, w, parallel); }); contribs } diff --git a/src/tree.rs b/src/tree.rs index 40b1ae7..3c07097 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -157,15 +157,33 @@ impl Tree { self.predict_contributions_row(data, *row, contribs, weights) }) } - // fn predict_contributions_parallel(&self, data: &Matrix, - // contribs: &mut [f64],weights: &[f64]) { - // data.index.par_iter().zip(contribs.chunks_mut(data.cols + 1)).for_each( - // |(row, contribs)| self.predict_contributions_row(data, *row, contribs, weights) - // ) - // } - - pub fn predict_contributions(&self, data: &Matrix, contribs: &mut [f64], weights: &[f64]) { - self.predict_contributions_single_threaded(data, contribs, weights) + fn predict_contributions_parallel( + &self, + data: &Matrix, + contribs: &mut [f64], + weights: &[f64], + ) { + // There needs to always be at least 2 trees + data.index + .par_iter() + .zip(contribs.par_chunks_mut(data.cols + 1)) + .for_each(|(row, contribs)| { + self.predict_contributions_row(data, *row, contribs, weights) + }) + } + + pub fn predict_contributions( + &self, + data: &Matrix, + contribs: &mut [f64], + weights: &[f64], + parallel: bool, + ) { + if parallel { + self.predict_contributions_parallel(data, contribs, weights) + } else { + self.predict_contributions_single_threaded(data, contribs, weights) + } } fn predict_row(&self, data: &Matrix, row: usize) -> f64 { @@ -294,7 +312,7 @@ mod tests { // Test contributions prediction... let weights = tree.distribute_leaf_weights(); let mut contribs = vec![0.; (data.cols + 1) * data.rows]; - tree.predict_contributions(&data, &mut contribs, &weights); + tree.predict_contributions(&data, &mut contribs, &weights, false); let full_preds = tree.predict(&data, true); assert_eq!(contribs.len(), (data.cols + 1) * data.rows); @@ -355,7 +373,7 @@ mod tests { let weights = tree.distribute_leaf_weights(); let mut contribs = vec![0.; (data.cols + 1) * data.rows]; - tree.predict_contributions(&data, &mut contribs, &weights); + tree.predict_contributions(&data, &mut contribs, &weights, false); let full_preds = tree.predict(&data, true); assert_eq!(contribs.len(), (data.cols + 1) * data.rows); let contribs_preds: Vec = contribs From 3a68b300fc46cc2f0bc73a3facfb548c775a171d Mon Sep 17 00:00:00 2001 From: jinlow Date: Sat, 22 Apr 2023 21:05:39 -0500 Subject: [PATCH 3/4] update readme --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index 690ca72..7d23fc6 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,13 @@ model.fit(X, y) # Predict on data model.predict(X.head()) # array([-1.94919663, 2.25863229, 0.32963671, 2.48732194, -3.00371813]) + +# predict contributions +model.predict_contributions(X.head()) +# array([[-0.63014213, 0.33880048, -0.16520798, -0.07798772, -0.85083578, +# -1.07720813], +# [ 1.05406709, 0.08825999, 0.21662544, -0.12083538, 0.35209258, +# -1.07720813], ``` The `fit` method accepts the following arguments. @@ -102,6 +109,13 @@ The predict method accepts the following arguments. passed, the `parallel` attribute of the booster will be used. Defaults to `None`. +The `predict_contributions` method will predict with the fitted booster on new data, returning the feature contribution matrix. The last column is the bias term. + - `X` ***(FrameLike)***: Either a pandas DataFrame, or a 2 dimensional numpy array, with numeric data. + - `parallel` ***(Optional[bool], optional)***: Optionally specify if the predict + function should run in parallel on multiple threads. If `None` is + passed, the `parallel` attribute of the booster will be used. + Defaults to `None`. + ### Inspecting the Model Once the booster has been fit, each individual tree structure can be retrieved in text form, using the `text_dump` method. This method returns a list, the same length as the number of trees in the model. From 38268bbd0cec0c3b74b60eed495d06f2f442da5c Mon Sep 17 00:00:00 2001 From: jinlow Date: Sat, 22 Apr 2023 21:06:28 -0500 Subject: [PATCH 4/4] Update version --- Cargo.toml | 2 +- README.md | 2 +- py-forust/Cargo.toml | 4 ++-- rs-example.md | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ed6fa07..70cde8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "forust-ml" -version = "0.2.0" +version = "0.2.1" edition = "2021" authors = ["James Inlow "] homepage = "https://github.com/jinlow/forust" diff --git a/README.md b/README.md index 7d23fc6..b9181bf 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ pip install forust To use in a rust project add the following to your Cargo.toml file. ```toml -forust-ml = "0.2.0" +forust-ml = "0.2.1" ``` ## Usage diff --git a/py-forust/Cargo.toml b/py-forust/Cargo.toml index 3dec2d3..7356c4c 100644 --- a/py-forust/Cargo.toml +++ b/py-forust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-forust" -version = "0.2.0" +version = "0.2.1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -10,6 +10,6 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.17", features = ["extension-module"] } -forust-ml = { version="0.2.0", path="../" } +forust-ml = { version="0.2.1", path="../" } numpy = "0.17.2" ndarray = "0.15.1" diff --git a/rs-example.md b/rs-example.md index dfc6aff..c527ab5 100644 --- a/rs-example.md +++ b/rs-example.md @@ -3,7 +3,7 @@ To run this example, add the following code to your `Cargo.toml` file. ```toml [dependencies] -forust-ml = "0.2.0" +forust-ml = "0.2.1" polars = "0.24" reqwest = { version = "0.11", features = ["blocking"] } ```