Skip to content

Commit

Permalink
Introduce multi targets and probability matrices (#88)
Browse files Browse the repository at this point in the history
 * rename `Targets` to `AsTargets` and introduce `AsTargetsMut` for mutable borrowing and `FromTargetArray` for counted labels construction
 * default to multi-target datasets in `pub type Dataset..`
 * re-name `observations` to `nsamples` and add `ntargets`
 * rename `frequencies_with_mask` to `label_frequencies_with_mask`
 * add `sample_iter`, `target_iter`, `feature_iter` iterators
 * rename `axis_chunks_iter` to `sample_chunks`
 * rename `bootstrap` to `bootstrap_samples` and add `bootstrap` (simultaneously sample features and samples) and `bootstrap_features`
 * add a `try_single_target -> Result<..>` which may return an `Error::MultipleTargets` for use in ROC/SVM/etc.
 * add `FromTarget` in order to accept `Array1` in `Dataset::new` and use `Targets == Array2<()>` for empty targets
 * add `PredictRef` which is used by a "proxy" trait `Predict` to derive all possible input/output combinations
 * add tests to `linfa-datasets`
 * add error types with `thiserror` where possible
  • Loading branch information
bytesnake authored Feb 23, 2021
1 parent 9d71f60 commit b231118
Show file tree
Hide file tree
Showing 55 changed files with 2,133 additions and 1,429 deletions.
23 changes: 23 additions & 0 deletions CONTRIBUTE.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,29 @@ fn div_capped<F: Float>(num: F) {
}
```

## Implement prediction traits

There are two different traits for predictions, `Predict` and `PredictRef`. `PredictRef` takes a reference to the records and produces a new set of targets. This should be implemented by a new algorithms, for example:
```rust
impl<F: Float, D: Data<Elem = F>> PredictRef<ArrayBase<D, Ix2>, Array1<F>> for Svm<F, F> {
fn predict_ref<'a>(&'a self, data: &ArrayBase<D; Ix2>) -> Array1<F> {
data.outer_iter().map(|data| {
self.normal.dot(&data) - self.rho
})
.collect()
}
}
```

This implementation is then used by `Predict` to provide the following `records` and `targets` combinations:

* `Dataset` -> `Dataset`
* `&Dataset` -> `Array1`
* `Array2` -> `Dataset`
* `&Array2` -> `Array1`

and should be imported by the user.

## Make serde optionally

If you want to implement `Serialize` and `Deserialize` for your parameters, please do that behind a feature flag. You can add to your cargo manifest
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ blas = ["ndarray/blas"]

[dependencies]
num-traits = "0.2"
thiserror = "1"
rand = { version = "0.7", features = ["small_rng"] }
ndarray = { version = "0.13", default-features = false, features = ["approx"] }
ndarray-linalg = { version = "0.12.1", optional = true }
Expand Down
3 changes: 3 additions & 0 deletions datasets/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ ndarray-csv = "0.4"
csv = "1.1"
flate2 = "1.0"

[dev-dependencies]
approx = { version = "0.3", default-features = false, features = ["std"] }

[features]
default = []
diabetes = []
Expand Down
109 changes: 109 additions & 0 deletions datasets/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,112 @@ pub fn winequality() -> Dataset<f64, usize> {
.map_targets(|x| *x as usize)
.with_feature_names(feature_names)
}

#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use linfa::prelude::*;

#[cfg(feature = "iris")]
#[test]
fn test_iris() {
let ds = iris();

// check that we have the right amount of data
assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (150, 4, 1));

// check for feature names
assert_eq!(
ds.feature_names(),
&["sepal length", "sepal width", "petal length", "petal width"]
);

// check label frequency
assert_eq!(
ds.label_frequencies()
.into_iter()
.map(|b| b.1)
.collect::<Vec<_>>(),
&[50., 50., 50.]
);

// perform correlation analysis and assert that petal length and width are correlated
let pcc = ds.pearson_correlation_with_p_value(100);
assert_abs_diff_eq!(pcc.get_p_values().unwrap()[5], 0.04, epsilon = 0.04);

// get the mean per feature
let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
assert_abs_diff_eq!(
mean_features,
array![5.84, 3.05, 3.75, 1.20],
epsilon = 0.01
);
}

#[cfg(feature = "diabetes")]
#[test]
fn test_diabetes() {
let ds = diabetes();

// check that we have the right amount of data
assert_eq!((ds.nsamples(), ds.nfeatures(), ds.ntargets()), (441, 10, 1));

// perform correlation analysis and assert that T-Cells and low-density lipoproteins are
// correlated
let pcc = ds.pearson_correlation_with_p_value(100);
assert_abs_diff_eq!(pcc.get_p_values().unwrap()[30], 0.02, epsilon = 0.02);

// get the mean per feature, the data should be normalized
let mean_features = ds.records().mean_axis(Axis(0)).unwrap();
assert_abs_diff_eq!(mean_features, Array1::zeros(10), epsilon = 0.005);
}

#[cfg(feature = "winequality")]
#[test]
fn test_winequality() {
let ds = winequality();

// check that we have the right amount of data
assert_eq!(
(ds.nsamples(), ds.nfeatures(), ds.ntargets()),
(1599, 11, 1)
);

// check for feature names
let feature_names = vec![
"fixed acidity",
"volatile acidity",
"citric acid",
"residual sugar",
"chlorides",
"free sulfur dioxide",
"total sulfur dioxide",
"density",
"pH",
"sulphates",
"alcohol",
];
assert_eq!(ds.feature_names(), feature_names);

// check label frequency
let compare_to = vec![
(5, 681.0),
(7, 199.0),
(6, 638.0),
(8, 18.0),
(3, 10.0),
(4, 53.0),
];

let freqs = ds.label_frequencies();
assert!(compare_to
.into_iter()
.all(|(key, val)| { freqs.get(&key).map(|x| *x == val).unwrap_or(false) }));

// perform correlation analysis and assert that fixed acidity and citric acid are
// correlated
let pcc = ds.pearson_correlation_with_p_value(100);
assert_abs_diff_eq!(pcc.get_p_values().unwrap()[1], 0.05, epsilon = 0.05);
}
}
2 changes: 2 additions & 0 deletions linfa-bayes/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ categories = ["algorithms", "mathematics", "science"]
[dependencies]
ndarray = { version = "0.13" , features = ["blas", "approx"]}
ndarray-stats = "0.3"
thiserror = "1"

linfa = { version = "0.3.0", path = ".." }

[dev-dependencies]
Expand Down
10 changes: 4 additions & 6 deletions linfa-bayes/examples/winequality.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::error::Error;

use linfa::metrics::ToConfusionMatrix;
use linfa::traits::{Fit, Predict};
use linfa_bayes::GaussianNbParams;
use linfa_bayes::{GaussianNbParams, Result};

fn main() -> Result<(), Box<dyn Error>> {
fn main() -> Result<()> {
// Read in the dataset and convert continuous target into categorical
let (train, valid) = linfa_datasets::winequality()
.map_targets(|x| if *x > 6 { 1 } else { 0 })
Expand All @@ -14,10 +12,10 @@ fn main() -> Result<(), Box<dyn Error>> {
let model = GaussianNbParams::params().fit(&train.view())?;

// Predict the validation dataset
let pred = model.predict(valid.records.view());
let pred = model.predict(&valid);

// Construct confusion matrix
let cm = pred.confusion_matrix(&valid);
let cm = pred.confusion_matrix(&valid)?;

// classes | 1 | 0
// 1 | 10 | 12
Expand Down
26 changes: 6 additions & 20 deletions linfa-bayes/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,13 @@
use std::fmt;

use ndarray_stats::errors::MinMaxError;
use thiserror::Error;

pub type Result<T> = std::result::Result<T, BayesError>;

#[derive(Debug)]
#[derive(Error, Debug)]
pub enum BayesError {
/// Error when performing Max operation on data
Stats(MinMaxError),
}

impl fmt::Display for BayesError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Stats(error) => write!(f, "Ndarray Stats Error: {}", error),
}
}
#[error("invalid statistical operation {0}")]
Stats(#[from] MinMaxError),
#[error(transparent)]
BaseCrate(#[from] linfa::Error),
}

impl From<MinMaxError> for BayesError {
fn from(error: MinMaxError) -> Self {
Self::Stats(error)
}
}

impl std::error::Error for BayesError {}
Loading

0 comments on commit b231118

Please sign in to comment.