Skip to content

Commit

Permalink
Merge pull request #70 from jinlow/bug/saved-stats
Browse files Browse the repository at this point in the history
Fix bug where missing leaf gain wasn't considered.
  • Loading branch information
jinlow authored Sep 12, 2023
2 parents f07d187 + b8b097e commit b82bcb0
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "forust-ml"
version = "0.2.23"
version = "0.2.24"
edition = "2021"
authors = ["James Inlow <[email protected]>"]
homepage = "https://github.com/jinlow/forust"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pip install forust

To use in a rust project add the following to your Cargo.toml file.
```toml
forust-ml = "0.2.23"
forust-ml = "0.2.24"
```

## Usage
Expand Down
4 changes: 2 additions & 2 deletions py-forust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "py-forust"
version = "0.2.23"
version = "0.2.24"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -10,7 +10,7 @@ crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.19.0", features = ["extension-module"] }
forust-ml = { version = "0.2.23", path = "../" }
forust-ml = { version = "0.2.24", path = "../" }
numpy = "0.19.0"
ndarray = "0.15.1"
serde_plain = { version = "1.0" }
Expand Down
2 changes: 1 addition & 1 deletion rs-example.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
To run this example, add the following code to your `Cargo.toml` file.
```toml
[dependencies]
forust-ml = "0.2.23"
forust-ml = "0.2.24"
polars = "0.28"
reqwest = { version = "0.11", features = ["blocking"] }
```
Expand Down
28 changes: 21 additions & 7 deletions src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,31 @@ impl SplittableNode {
self.left_child = left_child;
self.right_child = right_child;
self.split_feature = split_info.split_feature;
let missing_split_gain = match &split_info.missing_node {
MissingInfo::Branch(ni) => ni.gain,
_ => 0.,
};
self.split_gain =
split_info.left_node.gain + split_info.right_node.gain + missing_split_gain
- self.gain_value;
self.split_gain = self.get_split_gain(
&split_info.left_node,
&split_info.right_node,
&split_info.missing_node,
0.0,
);
self.split_value = split_info.split_value;
self.missing_node = missing_child;
self.is_leaf = false;
}

pub fn get_split_gain(
&self,
left_node_info: &NodeInfo,
right_node_info: &NodeInfo,
missing_node_info: &MissingInfo,
gamma: f32,
) -> f32 {
let missing_split_gain = match &missing_node_info {
MissingInfo::Branch(ni) | MissingInfo::Leaf(ni) => ni.gain,
_ => 0.,
};
left_node_info.gain + right_node_info.gain + missing_split_gain - self.gain_value - gamma
}

pub fn as_node(&self, learning_rate: f32) -> Node {
Node {
num: self.num,
Expand Down
17 changes: 6 additions & 11 deletions src/splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,12 @@ pub trait Splitter {
Some(v) => v,
};

// TODO!
// Should we be doing this?
// or should missing gain not factor in at
// all to the split gain?
let missing_gain = match &missing_info {
MissingInfo::Branch(v) | MissingInfo::Leaf(v) => v.gain,
_ => 0.0,
};
let split_gain = (left_node_info.gain + right_node_info.gain + missing_gain
- node.gain_value)
- self.get_gamma();
let split_gain = node.get_split_gain(
&left_node_info,
&right_node_info,
&missing_info,
self.get_gamma(),
);

// Check monotonicity holds
let split_gain = cull_gain(
Expand Down

0 comments on commit b82bcb0

Please sign in to comment.