Skip to content

Commit

Permalink
Merge pull request #67 from jinlow/bug/early-stopping
Browse files Browse the repository at this point in the history
Bug/early stopping
  • Loading branch information
jinlow authored Sep 6, 2023
2 parents ecafdaa + 5841176 commit ddb32f8
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "forust-ml"
version = "0.2.21"
version = "0.2.22"
edition = "2021"
authors = ["James Inlow <[email protected]>"]
homepage = "https://github.com/jinlow/forust"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pip install forust

To use in a rust project add the following to your Cargo.toml file.
```toml
forust-ml = "0.2.21"
forust-ml = "0.2.22"
```

## Usage
Expand Down
4 changes: 2 additions & 2 deletions py-forust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "py-forust"
version = "0.2.21"
version = "0.2.22"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -10,7 +10,7 @@ crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.19.0", features = ["extension-module"] }
forust-ml = { version = "0.2.21", path = "../" }
forust-ml = { version = "0.2.22", path = "../" }
numpy = "0.19.0"
ndarray = "0.15.1"
serde_plain = { version = "1.0" }
Expand Down
17 changes: 17 additions & 0 deletions py-forust/tests/test_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ def test_early_stopping_rounds(X_y, tmp_path):
best_iteration = fmod.get_best_iteration()
assert best_iteration is not None
assert best_iteration < history.shape[0]

fmod.set_prediction_iteration(4)
new_preds = fmod.predict(X)
assert not np.allclose(new_preds, preds)
Expand All @@ -689,6 +690,22 @@ def test_early_stopping_rounds(X_y, tmp_path):
assert np.allclose(loaded.predict(X), new_preds)


def test_early_stopping_with_dev(X_y):
X, y = X_y

val = y.index.to_series().isin(y.sample(frac=0.25, random_state=0))

model = GradientBooster(log_iterations=1, early_stopping_rounds=4, iterations=100)
model.fit(
X.loc[~val, :], y.loc[~val], evaluation_data=[(X.loc[val, :], y.loc[val])]
)

# Did we actually stop?
n_trees = json.loads(model.json_dump())["trees"]
assert len(n_trees) == model.get_best_iteration() + 4
assert model.get_best_iteration() < 99


def test_goss_sampling_method(X_y):
X, y = X_y
X = X
Expand Down
2 changes: 1 addition & 1 deletion rs-example.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
To run this example, add the following code to your `Cargo.toml` file.
```toml
[dependencies]
forust-ml = "0.2.21"
forust-ml = "0.2.22"
polars = "0.28"
reqwest = { version = "0.11", features = ["blocking"] }
```
Expand Down
8 changes: 7 additions & 1 deletion src/gradientbooster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,9 @@ impl GradientBooster {

self.update_predictions_inplace(&mut yhat, &tree, data);

// This will always be false, unless early stopping rounds are used.
let mut stop_early = false;

// Update Evaluation data, if it's needed.
if let Some(eval_sets) = &mut evaluation_sets {
if self.evaluation_history.is_none() {
Expand Down Expand Up @@ -568,7 +571,7 @@ impl GradientBooster {
if self.log_iterations > 0 {
info!("Stopping early at iteration {} with metric value {}", i, m)
}
break;
stop_early = true;
}
}
Some(v)
Expand All @@ -589,6 +592,9 @@ impl GradientBooster {
if let Some(history) = &mut self.evaluation_history {
history.append_row(metrics);
}
if stop_early {
break;
}
}
self.trees.push(tree);
(grad, hess) = calc_grad_hess(y, &yhat, sample_weight);
Expand Down

0 comments on commit ddb32f8

Please sign in to comment.