Skip to content

Commit

Permalink
Improve linfa-trees documentation and examples (#86)
Browse files Browse the repository at this point in the history
* added some docs and examples

* extend documentation, add examples

* extend documentation, run clippy

* Added feature names to decision trees for printing

* keep weights in split with ratio view and in view

* fix weights splitting for empy weights

* use legend & complete in tikz, remove max_classes

* Moved to SmallRng

* dark mode svg outline

* address review comments
  • Loading branch information
Sauro98 authored Feb 16, 2021
1 parent f911065 commit 9d71f60
Show file tree
Hide file tree
Showing 11 changed files with 548 additions and 159 deletions.
2 changes: 1 addition & 1 deletion linfa-trees/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ndarray-rand = "0.11"
linfa = { version = "0.3.0", path = ".." }

[dev-dependencies]
rand_isaac = "0.2.0"
rand = { version = "0.7", features = ["small_rng"] }
criterion = "0.3"
approx = "0.3"

Expand Down
7 changes: 4 additions & 3 deletions linfa-trees/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ Decision Trees (DTs) are a non-parametric supervised learning method used for cl

## Examples

There is an example in the `examples/` directory how to use decision trees. To run, use:
There is an example in the `examples/` directory showing how to use decision trees. To run, use:

```bash
$ cargo run --release --example decision_tree --features linfa/intel-mkl-system
$ cargo run --release --example decision_tree --features linfa/intel-mkl-system
```

This generates the following tree:
This generates the following tree:

<p align="center">
<img src="./iris-decisiontree.svg">
</p>
Expand Down
18 changes: 8 additions & 10 deletions linfa-trees/benches/decision_tree.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use linfa::prelude::*;
use linfa_trees::DecisionTree;
use ndarray::{stack, Array, Array2, Axis};
use ndarray::{stack, Array, Array1, Array2, Axis};
use ndarray_rand::rand::SeedableRng;
use ndarray_rand::rand_distr::{StandardNormal, Uniform};
use ndarray_rand::RandomExt;
use rand_isaac::Isaac64Rng;
use std::iter::FromIterator;
use rand::rngs::SmallRng;

fn generate_blobs(means: &Array2<f64>, samples: usize, mut rng: &mut Isaac64Rng) -> Array2<f64> {
fn generate_blobs(means: &Array2<f64>, samples: usize, mut rng: &mut SmallRng) -> Array2<f64> {
let out = means
.axis_iter(Axis(0))
.map(|mean| Array::random_using((samples, 4), StandardNormal, &mut rng) + mean)
Expand All @@ -19,7 +18,7 @@ fn generate_blobs(means: &Array2<f64>, samples: usize, mut rng: &mut Isaac64Rng)
}

fn decision_tree_bench(c: &mut Criterion) {
let mut rng = Isaac64Rng::seed_from_u64(42);
let mut rng = SmallRng::seed_from_u64(42);

// Controls how many samples for each class are generated
let training_set_sizes = vec![100, 1000, 10000, 100000];
Expand All @@ -39,11 +38,10 @@ fn decision_tree_bench(c: &mut Criterion) {
Array2::random_using((n_classes, n_features), Uniform::new(-30., 30.), &mut rng);

let train_x = generate_blobs(&centroids, *n, &mut rng);
let train_y = Array::from_iter(
(0..n_classes)
.map(|x| std::iter::repeat(x).take(*n).collect::<Vec<usize>>())
.flatten(),
);
let train_y: Array1<usize> = (0..n_classes)
.map(|x| std::iter::repeat(x).take(*n).collect::<Vec<usize>>())
.flatten()
.collect();
let dataset = DatasetBase::new(train_x, train_y);

group.bench_with_input(BenchmarkId::from_parameter(n), &dataset, |b, d| {
Expand Down
24 changes: 17 additions & 7 deletions linfa-trees/examples/decision_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use std::fs::File;
use std::io::Write;

use ndarray_rand::rand::SeedableRng;
use rand_isaac::Isaac64Rng;
use rand::rngs::SmallRng;

use linfa::prelude::*;
use linfa_trees::{DecisionTree, SplitQuality};

fn main() {
// load Iris dataset
let mut rng = Isaac64Rng::seed_from_u64(42);
let mut rng = SmallRng::seed_from_u64(42);

let (train, test) = linfa_datasets::iris()
.shuffle(&mut rng)
.split_with_ratio(0.8);
Expand All @@ -22,7 +23,7 @@ fn main() {
.min_weight_leaf(1.0)
.fit(&train);

let gini_pred_y = gini_model.predict(test.records().view());
let gini_pred_y = gini_model.predict(&test);
let cm = gini_pred_y.confusion_matrix(&test);

println!("{:?}", cm);
Expand All @@ -32,6 +33,9 @@ fn main() {
100.0 * cm.accuracy()
);

let feats = gini_model.features();
println!("Features trained in this tree {:?}", feats);

println!("Training model with entropy criterion ...");
let entropy_model = DecisionTree::params()
.split_quality(SplitQuality::Entropy)
Expand All @@ -40,7 +44,7 @@ fn main() {
.min_weight_leaf(10.0)
.fit(&train);

let entropy_pred_y = gini_model.predict(test.records().view());
let entropy_pred_y = gini_model.predict(&test);
let cm = entropy_pred_y.confusion_matrix(&test);

println!("{:?}", cm);
Expand All @@ -54,7 +58,13 @@ fn main() {
println!("Features trained in this tree {:?}", feats);

let mut tikz = File::create("decision_tree_example.tex").unwrap();
tikz.write(gini_model.export_to_tikz().to_string().as_bytes())
.unwrap();
println!(" => generate tree description with `latex decision_tree_example.tex`!");
tikz.write_all(
gini_model
.export_to_tikz()
.with_legend()
.to_string()
.as_bytes(),
)
.unwrap();
println!(" => generate Gini tree description with `latex decision_tree_example.tex`!");
}
Loading

0 comments on commit 9d71f60

Please sign in to comment.