From b0af80546f898bc95b746e10dbc1761a5268decb Mon Sep 17 00:00:00 2001 From: Ivano Donadi Date: Mon, 1 Feb 2021 22:09:56 +0100 Subject: [PATCH] Iter fold (#77) * fold_fit poc removed fit dataset lifetime, finished fold_fit tentative trait impl for kernel adjusted svm for iter_fold * updated elasticnet to new interface * fmt * removed fittable type from iter_fold definition * fmt * removed commented parameters --- linfa-bayes/src/gaussian_nb.rs | 2 +- linfa-elasticnet/src/algorithm.rs | 2 +- linfa-hierarchical/examples/irisflower.rs | 2 +- linfa-hierarchical/src/lib.rs | 23 +- linfa-kernel/Cargo.toml | 2 +- linfa-kernel/src/inner.rs | 143 +++++++ linfa-kernel/src/lib.rs | 402 ++++++++++-------- linfa-linear/src/ols.rs | 2 +- linfa-reduction/examples/diffusion_map.rs | 2 +- .../src/diffusion_map/algorithms.rs | 8 +- linfa-svm/Cargo.toml | 2 +- linfa-svm/examples/winequality.rs | 5 +- linfa-svm/src/classification.rs | 250 +++++++---- linfa-svm/src/lib.rs | 95 ++++- linfa-svm/src/permutable_kernel.rs | 94 ++-- linfa-svm/src/regression.rs | 136 ++++-- linfa-svm/src/solver_smo.rs | 226 +++++----- src/dataset/impl_dataset.rs | 136 ++++++ src/dataset/mod.rs | 100 +++++ src/traits.rs | 2 +- 20 files changed, 1148 insertions(+), 486 deletions(-) create mode 100644 linfa-kernel/src/inner.rs diff --git a/linfa-bayes/src/gaussian_nb.rs b/linfa-bayes/src/gaussian_nb.rs index 7190b04a8..e86714eac 100644 --- a/linfa-bayes/src/gaussian_nb.rs +++ b/linfa-bayes/src/gaussian_nb.rs @@ -78,7 +78,7 @@ where /// # Ok(()) /// # } /// ``` - fn fit(&self, dataset: &'a DatasetBase, L>) -> Self::Object { + fn fit(&self, dataset: &DatasetBase, L>) -> Self::Object { // We extract the unique classes in sorted order let mut unique_classes = dataset.targets.labels(); unique_classes.sort_unstable(); diff --git a/linfa-elasticnet/src/algorithm.rs b/linfa-elasticnet/src/algorithm.rs index 0bd3c454a..a061072c5 100644 --- a/linfa-elasticnet/src/algorithm.rs +++ b/linfa-elasticnet/src/algorithm.rs @@ -24,7 +24,7 @@ impl<'a, F: Float + AbsDiffEq + Lapack, D: Data, D2: Data> /// for new feature values. fn fit( &self, - dataset: &'a DatasetBase, ArrayBase>, + dataset: &DatasetBase, ArrayBase>, ) -> Result> { self.validate_params()?; diff --git a/linfa-hierarchical/examples/irisflower.rs b/linfa-hierarchical/examples/irisflower.rs index 2a4bfa7b2..517d65b9d 100644 --- a/linfa-hierarchical/examples/irisflower.rs +++ b/linfa-hierarchical/examples/irisflower.rs @@ -10,7 +10,7 @@ fn main() -> Result<(), Box> { let kernel = Kernel::params() .method(KernelMethod::Gaussian(1.0)) - .transform(dataset.records()); + .transform(dataset.records().view()); let kernel = HierarchicalCluster::default() .num_clusters(3) diff --git a/linfa-hierarchical/src/lib.rs b/linfa-hierarchical/src/lib.rs index 5ca24eaf2..3ed79be24 100644 --- a/linfa-hierarchical/src/lib.rs +++ b/linfa-hierarchical/src/lib.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use kodama::linkage; pub use kodama::Method; -use ndarray::ArrayView2; use linfa::dataset::{DatasetBase, Targets}; use linfa::traits::Transformer; @@ -58,17 +57,13 @@ impl HierarchicalCluster { } } -impl<'b: 'a, 'a, F: Float> - Transformer>, DatasetBase>, Vec>> +impl<'b: 'a, 'a, F: Float> Transformer, DatasetBase, Vec>> for HierarchicalCluster { /// Perform hierarchical clustering of a similarity matrix /// /// Returns the class id for each data point - fn transform( - &self, - kernel: Kernel>, - ) -> DatasetBase>, Vec> { + fn transform(&self, kernel: Kernel<'a, F>) -> DatasetBase, Vec> { // ignore all similarities below this value let threshold = F::from(1e-6).unwrap(); @@ -135,18 +130,16 @@ impl<'b: 'a, 'a, F: Float> } impl<'a, F: Float, T: Targets> - Transformer< - DatasetBase>, T>, - DatasetBase>, Vec>, - > for HierarchicalCluster + Transformer, T>, DatasetBase, Vec>> + for HierarchicalCluster { /// Perform hierarchical clustering of a similarity matrix /// /// Returns the class id for each data point fn transform( &self, - dataset: DatasetBase>, T>, - ) -> DatasetBase>, Vec> { + dataset: DatasetBase, T>, + ) -> DatasetBase, Vec> { //let Dataset { records, .. } = dataset; self.transform(dataset.records) } @@ -188,7 +181,7 @@ mod tests { let kernel = Kernel::params() .method(KernelMethod::Gaussian(5.0)) - .transform(&entries); + .transform(entries.view()); let kernel = HierarchicalCluster::default() .max_distance(0.1) @@ -243,7 +236,7 @@ mod tests { let kernel = Kernel::params() .method(KernelMethod::Linear) - .transform(&data); + .transform(data.view()); dbg!(&kernel.to_upper_triangle()); let predictions = HierarchicalCluster::default() diff --git a/linfa-kernel/Cargo.toml b/linfa-kernel/Cargo.toml index c287ae959..2d2c5a83f 100644 --- a/linfa-kernel/Cargo.toml +++ b/linfa-kernel/Cargo.toml @@ -25,7 +25,7 @@ features = ["std", "derive"] [dependencies] ndarray = "0.13" -sprs = { version = "0.9", default-features = false } +sprs = { version = "0.9.3", default-features = false } hnsw = "0.6" space = "0.10" diff --git a/linfa-kernel/src/inner.rs b/linfa-kernel/src/inner.rs new file mode 100644 index 000000000..fc94dfe2a --- /dev/null +++ b/linfa-kernel/src/inner.rs @@ -0,0 +1,143 @@ +use linfa::Float; +use ndarray::prelude::*; +use ndarray::Data; +#[cfg(feature = "serde")] +use serde_crate::{Deserialize, Serialize}; +use sprs::{CsMat, CsMatView}; +use std::ops::Mul; + +pub trait Inner { + type Elem: Float; + + fn dot(&self, rhs: &ArrayView2) -> Array2; + fn sum(&self) -> Array1; + fn size(&self) -> usize; + fn column(&self, i: usize) -> Vec; + fn to_upper_triangle(&self) -> Vec; + fn is_dense(&self) -> bool; + fn diagonal(&self) -> Array1; +} + +pub enum KernelInner { + Dense(K1), + Sparse(K2), +} + +impl> Inner for ArrayBase { + type Elem = F; + + fn dot(&self, rhs: &ArrayView2) -> Array2 { + self.dot(rhs) + } + fn sum(&self) -> Array1 { + self.sum_axis(Axis(1)) + } + fn size(&self) -> usize { + self.ncols() + } + fn column(&self, i: usize) -> Vec { + self.column(i).to_vec() + } + fn to_upper_triangle(&self) -> Vec { + self.indexed_iter() + .filter(|((row, col), _)| col > row) + .map(|(_, val)| *val) + .collect() + } + + fn diagonal(&self) -> Array1 { + self.diag().to_owned() + } + + fn is_dense(&self) -> bool { + true + } +} + +impl Inner for CsMat { + type Elem = F; + + fn dot(&self, rhs: &ArrayView2) -> Array2 { + self.mul(rhs) + } + fn sum(&self) -> Array1 { + let mut sum = Array1::zeros(self.cols()); + for (val, i) in self.iter() { + let (_, col) = i; + sum[col] += *val; + } + + sum + } + fn size(&self) -> usize { + self.cols() + } + fn column(&self, i: usize) -> Vec { + (0..self.size()) + .map(|j| *self.get(j, i).unwrap_or(&F::neg_zero())) + .collect::>() + } + fn to_upper_triangle(&self) -> Vec { + let mat = self.to_dense(); + mat.indexed_iter() + .filter(|((row, col), _)| col > row) + .map(|(_, val)| *val) + .collect() + } + + fn diagonal(&self) -> Array1 { + let diag_sprs = self.diag(); + let mut diag = Array1::zeros(diag_sprs.dim()); + for (sparse_i, sparse_elem) in diag_sprs.iter() { + diag[sparse_i] = *sparse_elem; + } + diag + } + + fn is_dense(&self) -> bool { + false + } +} + +impl<'a, F: Float> Inner for CsMatView<'a, F> { + type Elem = F; + + fn dot(&self, rhs: &ArrayView2) -> Array2 { + self.mul(rhs) + } + fn sum(&self) -> Array1 { + let mut sum = Array1::zeros(self.cols()); + for (val, i) in self.iter() { + let (_, col) = i; + sum[col] += *val; + } + + sum + } + fn size(&self) -> usize { + self.cols() + } + fn column(&self, i: usize) -> Vec { + (0..self.size()) + .map(|j| *self.get(j, i).unwrap_or(&F::neg_zero())) + .collect::>() + } + fn to_upper_triangle(&self) -> Vec { + let mat = self.to_dense(); + mat.indexed_iter() + .filter(|((row, col), _)| col > row) + .map(|(_, val)| *val) + .collect() + } + fn diagonal(&self) -> Array1 { + let diag_sprs = self.diag(); + let mut diag = Array1::zeros(diag_sprs.dim()); + for (sparse_i, sparse_elem) in diag_sprs.iter() { + diag[sparse_i] = *sparse_elem; + } + diag + } + fn is_dense(&self) -> bool { + false + } +} diff --git a/linfa-kernel/src/lib.rs b/linfa-kernel/src/lib.rs index aac0561ce..dc713e2a1 100644 --- a/linfa-kernel/src/lib.rs +++ b/linfa-kernel/src/lib.rs @@ -1,12 +1,14 @@ //! Kernel methods //! +pub mod inner; mod sparse; +pub use inner::{Inner, KernelInner}; use ndarray::prelude::*; use ndarray::Data; #[cfg(feature = "serde")] use serde_crate::{Deserialize, Serialize}; -use sprs::CsMat; +use sprs::{CsMat, CsMatView}; use std::ops::Mul; use linfa::{dataset::DatasetBase, dataset::Records, dataset::Targets, traits::Transformer, Float}; @@ -20,18 +22,6 @@ pub enum KernelType { Sparse(usize), } -/// Storage for the kernel matrix -#[cfg_attr( - feature = "serde", - derive(Serialize, Deserialize), - serde(crate = "serde_crate") -)] -#[derive(Debug)] -pub enum KernelInner { - Dense(Array2), - Sparse(CsMat), -} - /// A generic kernel /// /// @@ -40,18 +30,20 @@ pub enum KernelInner { derive(Serialize, Deserialize), serde(crate = "serde_crate") )] -pub struct Kernel +pub struct KernelBase where R::Elem: Float, + K1::Elem: Float, + K2::Elem: Float, { #[cfg_attr( feature = "serde", serde(bound( - serialize = "KernelInner: Serialize", - deserialize = "KernelInner: Deserialize<'de>" + serialize = "KernelInner: Serialize", + deserialize = "KernelInner: Deserialize<'de>" )) )] - pub inner: KernelInner, + pub inner: KernelInner, #[cfg_attr( feature = "serde", serde(bound( @@ -63,21 +55,46 @@ where pub dataset: R, } -impl<'a, F: Float> Kernel> { - pub fn new( - dataset: ArrayView2<'a, F>, - method: KernelMethod, - kind: KernelType, - ) -> Kernel> { - let inner = match kind { - KernelType::Dense => KernelInner::Dense(dense_from_fn(&dataset, &method)), - KernelType::Sparse(k) => KernelInner::Sparse(sparse_from_fn(&dataset, k, &method)), - }; +pub type KernelArrayBase = KernelBase, Array2, CsMat>; +pub type KernelOwned = KernelBase, Array2, CsMat>; +pub type Kernel<'a, F> = KernelBase, Array2, CsMat>; +pub type KernelView<'a, F> = KernelBase, ArrayView2<'a, F>, CsMatView<'a, F>>; - Kernel { - inner, - method, - dataset, +impl, K1: Inner, K2: Inner> + KernelBase +{ + /// Wheter the kernel is a linear kernel + /// + /// ## Returns + /// + /// - `true`: if the kernel is linear + /// - `false`: otherwise + pub fn is_linear(&self) -> bool { + self.method.is_linear() + } + + /// Generates the default set of parameters for building a kernel. + /// Use this to initialize a set of parameters to be customized using `KernelParams`'s methods + /// + /// ## Example + /// + /// ```rust + /// + /// use linfa_kernel::Kernel; + /// use linfa::traits::Transformer; + /// use ndarray::Array2; + /// + /// let data = Array2::from_shape_vec((3,2), vec![1., 2., 3., 4., 5., 6.,]).unwrap(); + /// + /// // Build a kernel from `data` with the defaul parameters + /// let params = Kernel::params(); + /// let kernel = params.transform(data); + /// + /// ``` + pub fn params() -> KernelParams { + KernelParams { + kind: KernelType::Dense, + method: KernelMethod::Gaussian(F::from(0.5).unwrap()), } } @@ -98,8 +115,8 @@ impl<'a, F: Float> Kernel> { /// If the shapes of kernel and `rhs` are not compatible for multiplication pub fn dot(&self, rhs: &ArrayView2) -> Array2 { match &self.inner { - KernelInner::Dense(mat) => mat.dot(rhs), - KernelInner::Sparse(mat) => mat.mul(rhs), + KernelInner::Dense(inn) => inn.dot(rhs), + KernelInner::Sparse(inn) => inn.dot(rhs), } } @@ -110,24 +127,36 @@ impl<'a, F: Float> Kernel> { /// A new array with the sum of all the elements in each row pub fn sum(&self) -> Array1 { match &self.inner { - KernelInner::Dense(mat) => mat.sum_axis(Axis(1)), - KernelInner::Sparse(mat) => { - let mut sum = Array1::zeros(mat.cols()); - for (val, i) in mat.iter() { - let (_, col) = i; - sum[col] += *val; - } - - sum - } + KernelInner::Dense(inn) => inn.sum(), + KernelInner::Sparse(inn) => inn.sum(), } } /// Gives the size of the side of the square kernel matrix pub fn size(&self) -> usize { match &self.inner { - KernelInner::Dense(mat) => mat.ncols(), - KernelInner::Sparse(mat) => mat.cols(), + KernelInner::Dense(inn) => inn.size(), + KernelInner::Sparse(inn) => inn.size(), + } + } + + /// Getter for a column of the kernel matrix + /// + /// ## Params + /// + /// - `i`: the index of the column + /// + /// ## Returns + /// + /// The i-th column of the kernel matrix, stored as a `Vec` + /// + /// ## Panics + /// + /// If `i` is out of bounds + pub fn column(&self, i: usize) -> Vec { + match &self.inner { + KernelInner::Dense(inn) => inn.column(i), + KernelInner::Sparse(inn) => inn.column(i), } } @@ -140,18 +169,8 @@ impl<'a, F: Float> Kernel> { /// matrix, stored in a `Vec` pub fn to_upper_triangle(&self) -> Vec { match &self.inner { - KernelInner::Dense(mat) => mat - .indexed_iter() - .filter(|((row, col), _)| col > row) - .map(|(_, val)| *val) - .collect(), - KernelInner::Sparse(mat) => { - let mat = mat.to_dense(); - mat.indexed_iter() - .filter(|((row, col), _)| col > row) - .map(|(_, val)| *val) - .collect() - } + KernelInner::Dense(inn) => inn.to_upper_triangle(), + KernelInner::Sparse(inn) => inn.to_upper_triangle(), } } @@ -163,37 +182,13 @@ impl<'a, F: Float> Kernel> { /// the kernel matrix pub fn diagonal(&self) -> Array1 { match &self.inner { - KernelInner::Dense(mat) => mat.diag().to_owned(), - KernelInner::Sparse(_) => self - .dataset - .outer_iter() - .map(|x| self.method.distance(x.view(), x.view())) - .collect(), - } - } - - /// Getter for a column of the kernel matrix - /// - /// ## Params - /// - /// - `i`: the index of the column - /// - /// ## Returns - /// - /// The i-th column of the kernel matrix, stored as a `Vec` - /// - /// ## Panics - /// - /// If `i` is out of bounds - pub fn column(&self, i: usize) -> Vec { - match &self.inner { - KernelInner::Dense(mat) => mat.column(i).to_vec(), - KernelInner::Sparse(mat) => (0..self.size()) - .map(|j| *mat.get(j, i).unwrap_or(&F::neg_zero())) - .collect::>(), + KernelInner::Dense(inn) => inn.diagonal(), + KernelInner::Sparse(inn) => inn.diagonal(), } } +} +impl> KernelArrayBase { /// Sums the inner product of `sample` and every one of the samples /// used to generate the kernel /// @@ -219,43 +214,100 @@ impl<'a, F: Float> Kernel> { .sum() } - /// Wheter the kernel is a linear kernel - /// - /// ## Returns - /// - /// - `true`: if the kernel is linear - /// - `false`: otherwise - pub fn is_linear(&self) -> bool { - self.method.is_linear() + /// Gives a KernelView which has a view on the original kernel's inner matrix and dataset + pub fn view<'a>(&'a self) -> KernelView<'a, F> { + KernelView { + inner: match &self.inner { + KernelInner::Dense(inn) => KernelInner::Dense(inn.view()), + KernelInner::Sparse(inn) => KernelInner::Sparse(inn.view()), + }, + method: self.method.clone(), + dataset: self.dataset.view(), + } } +} - /// Generates the default set of parameters for building a kernel. - /// Use this to initialize a set of parameters to be customized using `KernelParams`'s methods - /// - /// ## Example - /// - /// ```rust - /// - /// use linfa_kernel::Kernel; - /// use linfa::traits::Transformer; - /// use ndarray::Array2; - /// - /// let data = Array2::from_shape_vec((3,2), vec![1., 2., 3., 4., 5., 6.,]).unwrap(); - /// - /// // Build a kernel from `data` with the defaul parameters - /// let params = Kernel::params(); - /// let kernel = params.transform(&data); - /// - /// ``` - pub fn params() -> KernelParams { - KernelParams { - kind: KernelType::Dense, - method: KernelMethod::Gaussian(F::from(0.5).unwrap()), +// This particular implementation is created with the idea of using for training models +// when k-folding, for all other uses it would be best to use the `ArrayView` impl. +// For this reason all other kernel methods are not implemented because for that limited +// scope it would be better to call them on `kernel.view()` +impl KernelOwned { + /// Generates a new kernel which will owns its dataset + pub fn new(dataset: Array2, method: KernelMethod, kind: KernelType) -> KernelOwned { + match kind { + KernelType::Dense => KernelOwned { + inner: KernelInner::Dense(dense_from_fn(&dataset, &method)), + method: method, + dataset: dataset, + }, + KernelType::Sparse(k) => KernelOwned { + inner: KernelInner::Sparse(sparse_from_fn(&dataset, k, &method)), + method: method, + dataset: dataset, + }, + } + } +} + +impl std::clone::Clone for KernelOwned { + fn clone(&self) -> KernelOwned { + KernelOwned { + inner: match &self.inner { + KernelInner::Dense(inn) => KernelInner::Dense(inn.clone()), + KernelInner::Sparse(inn) => KernelInner::Sparse(inn.clone()), + }, + method: self.method.clone(), + dataset: self.dataset.clone(), + } + } +} + +impl<'a, F: Float> Kernel<'a, F> { + pub fn new( + dataset: ArrayView2<'a, F>, + method: KernelMethod, + kind: KernelType, + ) -> Kernel<'a, F> { + let inner = match kind { + KernelType::Dense => KernelInner::Dense(dense_from_fn(&dataset, &method)), + KernelType::Sparse(k) => KernelInner::Sparse(sparse_from_fn(&dataset, k, &method)), + }; + + Kernel { + inner, + method, + dataset, + } + } + + pub fn to_owned(&self) -> KernelOwned { + KernelOwned { + inner: match &self.inner { + KernelInner::Dense(inn) => KernelInner::Dense(inn.clone()), + KernelInner::Sparse(inn) => KernelInner::Sparse(inn.clone()), + }, + method: self.method.clone(), + dataset: self.dataset.to_owned(), } } } -impl<'a, F: Float> Records for Kernel> { +impl<'a, F: Float> KernelView<'a, F> { + pub fn to_owned(&self) -> KernelOwned { + KernelOwned { + inner: match &self.inner { + KernelInner::Dense(inn) => KernelInner::Dense(inn.to_owned()), + KernelInner::Sparse(inn) => KernelInner::Sparse(inn.to_owned()), + }, + method: self.method.clone(), + dataset: self.dataset.to_owned(), + } + } +} + +impl, K1: Inner, K2: Inner> Records + for KernelBase +{ type Elem = F; fn observations(&self) -> usize { @@ -342,7 +394,7 @@ impl KernelParams { /// // Build a kernel from `data` with the defaul parameters /// // and then set the preferred method /// let params = Kernel::params().method(KernelMethod::Linear); - /// let kernel = params.transform(&data); + /// let kernel = params.transform(data); /// ``` pub fn method(mut self, method: KernelMethod) -> KernelParams { self.method = method; @@ -372,7 +424,7 @@ impl KernelParams { /// // Build a kernel from `data` with the defaul parameters /// // and then set the preferred kind /// let params = Kernel::params().kind(KernelType::Dense); - /// let kernel = params.transform(&data); + /// let kernel = params.transform(data); /// ``` pub fn kind(mut self, kind: KernelType) -> KernelParams { self.kind = kind; @@ -380,15 +432,13 @@ impl KernelParams { } } -impl<'a, F: Float> Transformer<&'a Array2, Kernel>> for KernelParams { - /// Builds a kernel from the input data without copying it. +impl<'a, F: Float> Transformer, KernelOwned> for KernelParams { + /// Builds a kernel from the input data and takes ownership of it. /// - /// A reference to the input data will be kept by the kernel - /// through an `ArrayView` /// /// ## Parameters /// - /// - `x`: matrix of records (##records, ##features) in input + /// - `x`: matrix of records (#records, #features) in input /// /// ## Returns /// @@ -398,13 +448,13 @@ impl<'a, F: Float> Transformer<&'a Array2, Kernel>> for Ker /// ## Panics /// /// If the kernel type is `Sparse` and the number of neighbors specified is - /// not between 1 and ##records-1 - fn transform(&self, x: &'a Array2) -> Kernel> { - Kernel::new(x.view(), self.method.clone(), self.kind.clone()) + /// not between 1 and #records-1 + fn transform(&self, x: Array2) -> KernelOwned { + KernelOwned::new(x, self.method.clone(), self.kind.clone()) } } -impl<'a, F: Float> Transformer, Kernel>> for KernelParams { +impl<'a, F: Float> Transformer, Kernel<'a, F>> for KernelParams { /// Builds a kernel from a view of the input data. /// /// A reference to the input data will be kept by the kernel @@ -412,7 +462,7 @@ impl<'a, F: Float> Transformer, Kernel>> for /// /// ## Parameters /// - /// - `x`: view of a matrix of records (##records, ##features) + /// - `x`: view of a matrix of records (#records, #features) /// /// A kernel build from `x` according to the parameters on which /// this method is called @@ -420,24 +470,22 @@ impl<'a, F: Float> Transformer, Kernel>> for /// ## Panics /// /// If the kernel type is `Sparse` and the number of neighbors specified is - /// not between 1 and ##records-1 - fn transform(&self, x: ArrayView2<'a, F>) -> Kernel> { + /// not between 1 and #records-1 + fn transform(&self, x: ArrayView2<'a, F>) -> Kernel<'a, F> { Kernel::new(x, self.method.clone(), self.kind.clone()) } } impl<'a, F: Float, T: Targets> - Transformer<&'a DatasetBase, T>, DatasetBase>, &'a T>> - for KernelParams + Transformer, T>, DatasetBase, T>> for KernelParams { /// Builds a new Dataset with the kernel as the records and the same targets as the input one. /// - /// A reference to the input records will be kept by the kernel - /// through an `ArrayView` + /// It takes ownership of the original database. /// /// ## Parameters /// - /// - `x`: A dataset with a matrix of records (##records, ##features) and any targets + /// - `x`: A dataset with a matrix of records (#records, #features) and any targets /// /// ## Returns /// @@ -449,22 +497,19 @@ impl<'a, F: Float, T: Targets> /// ## Panics /// /// If the kernel type is `Sparse` and the number of neighbors specified is - /// not between 1 and ##records-1 - fn transform( - &self, - x: &'a DatasetBase, T>, - ) -> DatasetBase>, &'a T> { - let kernel = Kernel::new(x.records.view(), self.method.clone(), self.kind.clone()); + /// not between 1 and #records-1 + fn transform(&self, x: DatasetBase, T>) -> DatasetBase, T> { + let kernel = KernelOwned::new(x.records, self.method.clone(), self.kind.clone()); - DatasetBase::new(kernel, &x.targets) + DatasetBase::new(kernel, x.targets) } } -impl<'a, F: Float, T: Targets> - Transformer< - &'a DatasetBase, T>, - DatasetBase>, &'a [T::Elem]>, - > for KernelParams +// lifetime 'b allows the kernel to borrow the underlying data +// for a possibly shorter time than 'a, useful in fold_fit +impl<'a, 'b, F: Float, T: Targets> + Transformer<&'b DatasetBase, T>, DatasetBase, &'b [T::Elem]>> + for KernelParams { /// Builds a new Dataset with the kernel as the records and the same targets as the input one. /// @@ -488,9 +533,9 @@ impl<'a, F: Float, T: Targets> /// not between 1 and ##records-1 fn transform( &self, - x: &'a DatasetBase, T>, - ) -> DatasetBase>, &'a [T::Elem]> { - let kernel = Kernel::new(x.records, self.method.clone(), self.kind.clone()); + x: &'b DatasetBase, T>, + ) -> DatasetBase, &'b [T::Elem]> { + let kernel = Kernel::new(x.records.view(), self.method.clone(), self.kind.clone()); DatasetBase::new(kernel, x.targets.as_slice()) } @@ -696,19 +741,19 @@ mod tests { // dense kernel dot let mul_mat = dense_from_fn(&input_arr, &KernelMethod::Linear).dot(&to_multiply); - let kernel = Kernel::params() + let kernel = KernelView::params() .kind(KernelType::Dense) .method(KernelMethod::Linear) - .transform(&input_arr); + .transform(input_arr.view()); let mul_ker = kernel.dot(&to_multiply.view()); assert!(kernels_almost_equal(mul_mat.view(), mul_ker.view())); // sparse kernel dot let mul_mat = sparse_from_fn(&input_arr, 3, &KernelMethod::Linear).mul(&to_multiply.view()); - let kernel = Kernel::params() + let kernel = KernelView::params() .kind(KernelType::Sparse(3)) .method(KernelMethod::Linear) - .transform(&input_arr); + .transform(input_arr.view()); let mul_ker = kernel.dot(&to_multiply.view()); assert!(kernels_almost_equal(mul_mat.view(), mul_ker.view())); } @@ -723,12 +768,12 @@ mod tests { let input_arr = ndarray::stack(Axis(0), &[input_arr_1.view(), input_arr_2.view()]).unwrap(); for kind in vec![KernelType::Dense, KernelType::Sparse(1)] { - let kernel = Kernel::params() + let kernel = KernelView::params() .kind(kind) // Such a value for eps brings to zero the inner product // between any two points that are not equal .method(KernelMethod::Gaussian(1e-5)) - .transform(&input_arr); + .transform(input_arr.view()); let mut kernel_upper_triang = kernel.to_upper_triangle(); assert_eq!(kernel_upper_triang.len(), 45); //so that i can use pop() @@ -752,12 +797,12 @@ mod tests { let input_arr = Array2::from_shape_vec((10, 10), input_vec).unwrap(); let weights = [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]; for kind in vec![KernelType::Dense, KernelType::Sparse(1)] { - let kernel = Kernel::params() + let kernel = KernelView::params() .kind(kind) // Such a value for eps brings to zero the inner product // between any two points that are not equal .method(KernelMethod::Gaussian(1e-5)) - .transform(&input_arr); + .transform(input_arr.view()); for (sample, w) in input_arr.outer_iter().zip(&weights) { // with that kernel, only the input samples have non // zero inner product with the samples used to generate the matrix. @@ -778,10 +823,10 @@ mod tests { // dense kernel sum let cols_sum = dense_from_fn(&input_arr, &method).sum_axis(Axis(1)); - let kernel = Kernel::params() + let kernel = KernelView::params() .kind(KernelType::Dense) .method(method.clone()) - .transform(&input_arr); + .transform(input_arr.view()); let kers_sum = kernel.sum(); assert!(arrays_almost_equal(cols_sum.view(), kers_sum.view())); @@ -789,10 +834,10 @@ mod tests { let cols_sum = sparse_from_fn(&input_arr, 3, &method) .to_dense() .sum_axis(Axis(1)); - let kernel = Kernel::params() + let kernel = KernelView::params() .kind(KernelType::Sparse(3)) .method(method) - .transform(&input_arr); + .transform(input_arr.view()); let kers_sum = kernel.sum(); assert!(arrays_almost_equal(cols_sum.view(), kers_sum.view())); } @@ -804,29 +849,29 @@ mod tests { let method = KernelMethod::Linear; - // dense kernel sum + // dense kernel diag let input_diagonal = dense_from_fn(&input_arr, &method).diag().into_owned(); - let kernel = Kernel::params() + let kernel = KernelView::params() .kind(KernelType::Dense) .method(method.clone()) - .transform(&input_arr); + .transform(input_arr.view()); let kers_diagonal = kernel.diagonal(); assert!(arrays_almost_equal( input_diagonal.view(), kers_diagonal.view() )); - // sparse kernel sum + // sparse kernel diag let input_diagonal: Vec<_> = sparse_from_fn(&input_arr, 3, &method) .outer_iterator() .enumerate() .map(|(i, row)| *row.get(i).unwrap()) .collect(); let input_diagonal = Array1::from_shape_vec(10, input_diagonal).unwrap(); - let kernel = Kernel::params() + let kernel = KernelView::params() .kind(KernelType::Sparse(3)) .method(method) - .transform(&input_arr); + .transform(input_arr.view()); let kers_diagonal = kernel.diagonal(); assert!(arrays_almost_equal( input_diagonal.view(), @@ -876,14 +921,16 @@ mod tests { KernelMethod::Polynomial(1., 2.), ]; for method in methods { - let kernel_ref = Kernel::new(input.records().view(), method.clone(), k_type.clone()); - let kernel_tr = Kernel::params() + let cloned_dataset = (input.records.clone(), input.targets().clone()).into(); + let kernel_ref = + KernelOwned::new(input.records().clone(), method.clone(), k_type.clone()); + let kernel_tr: DatasetBase, _> = Kernel::params() .kind(k_type.clone()) .method(method.clone()) - .transform(input); + .transform(cloned_dataset); assert!(kernels_almost_equal( - kernel_ref.dataset, - kernel_tr.records.dataset + kernel_ref.dataset.view(), + kernel_tr.records.dataset.view() )); } } @@ -917,12 +964,15 @@ mod tests { KernelMethod::Polynomial(1., 2.), ]; for method in methods { - let kernel_ref = Kernel::new(input.view(), method.clone(), k_type.clone()); + let kernel_ref = KernelOwned::new(input.clone(), method.clone(), k_type.clone()); let kernel_tr = Kernel::params() .kind(k_type.clone()) .method(method.clone()) - .transform(input); - assert!(kernels_almost_equal(kernel_ref.dataset, kernel_tr.dataset)); + .transform(input.clone()); + assert!(kernels_almost_equal( + kernel_ref.dataset.view(), + kernel_tr.dataset.view() + )); } } diff --git a/linfa-linear/src/ols.rs b/linfa-linear/src/ols.rs index f4488c44d..9c175f1fd 100644 --- a/linfa-linear/src/ols.rs +++ b/linfa-linear/src/ols.rs @@ -143,7 +143,7 @@ impl<'a, F: Float, D: Data, D2: Data> /// for new feature values. fn fit( &self, - dataset: &'a DatasetBase, ArrayBase>, + dataset: &DatasetBase, ArrayBase>, ) -> Result, String> { let X = dataset.records(); let y = dataset.targets(); diff --git a/linfa-reduction/examples/diffusion_map.rs b/linfa-reduction/examples/diffusion_map.rs index cd6fe1513..ee0c53c05 100644 --- a/linfa-reduction/examples/diffusion_map.rs +++ b/linfa-reduction/examples/diffusion_map.rs @@ -23,7 +23,7 @@ fn main() { .kind(KernelType::Sparse(15)) .method(KernelMethod::Gaussian(2.0)) //.kind(KernelType::Dense) - .transform(&dataset); + .transform(dataset.view()); let embedding = DiffusionMap::::params(2) .steps(1) diff --git a/linfa-reduction/src/diffusion_map/algorithms.rs b/linfa-reduction/src/diffusion_map/algorithms.rs index 8ec5eaa27..f06acd7af 100644 --- a/linfa-reduction/src/diffusion_map/algorithms.rs +++ b/linfa-reduction/src/diffusion_map/algorithms.rs @@ -1,4 +1,4 @@ -use ndarray::{Array1, Array2, ArrayView2}; +use ndarray::{Array1, Array2}; use ndarray_linalg::{ eigh::EighInto, lobpcg, lobpcg::LobpcgResult, Lapack, Scalar, TruncatedOrder, UPLO, }; @@ -16,10 +16,10 @@ pub struct DiffusionMap { eigvals: Array1, } -impl<'a, F: Float + Lapack> Transformer<&'a Kernel>, DiffusionMap> +impl<'a, F: Float + Lapack> Transformer<&'a Kernel<'a, F>, DiffusionMap> for DiffusionMapHyperParams { - fn transform(&self, kernel: &'a Kernel>) -> DiffusionMap { + fn transform(&self, kernel: &'a Kernel<'a, F>) -> DiffusionMap { // compute spectral embedding with diffusion map let (embedding, eigvals) = compute_diffusion_map(kernel, self.steps(), 0.0, self.embedding_size(), None); @@ -49,7 +49,7 @@ impl DiffusionMap { } fn compute_diffusion_map<'b, F: Float + Lapack>( - kernel: &'b Kernel>, + kernel: &'b Kernel<'b, F>, steps: usize, alpha: f32, embedding_size: usize, diff --git a/linfa-svm/Cargo.toml b/linfa-svm/Cargo.toml index b9df426ef..4b35dde25 100644 --- a/linfa-svm/Cargo.toml +++ b/linfa-svm/Cargo.toml @@ -32,5 +32,5 @@ linfa = { version = "0.3.0", path = ".." } linfa-kernel = { version = "0.3.0", path = "../linfa-kernel" } [dev-dependencies] -linfa-datasets = { version = "0.3.0", path = "../datasets", features = ["winequality"] } +linfa-datasets = { version = "0.3.0", path = "../datasets", features = ["winequality", "diabetes"] } rand_isaac = "0.2" diff --git a/linfa-svm/examples/winequality.rs b/linfa-svm/examples/winequality.rs index 588e827dc..a8b7f82ad 100644 --- a/linfa-svm/examples/winequality.rs +++ b/linfa-svm/examples/winequality.rs @@ -7,11 +7,12 @@ fn main() { let (train, valid) = linfa_datasets::winequality() .map_targets(|x| *x > 6) .split_with_ratio(0.9); + let train_view = train.view(); // transform with RBF kernel let train_kernel = Kernel::params() - .method(KernelMethod::Gaussian(30.0)) - .transform(&train); + .method(KernelMethod::Gaussian(80.0)) + .transform(&train_view); println!( "Fit SVM classifier with #{} training points", diff --git a/linfa-svm/src/classification.rs b/linfa-svm/src/classification.rs index d2a6e46a8..bb39d5d79 100644 --- a/linfa-svm/src/classification.rs +++ b/linfa-svm/src/classification.rs @@ -1,12 +1,13 @@ use linfa::{dataset::DatasetBase, dataset::Pr, dataset::Targets, traits::Fit, traits::Predict}; -use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix2}; +use ndarray::{Array1, ArrayBase, ArrayView1, ArrayView2, Data, Ix2}; use std::cmp::Ordering; use std::ops::Mul; -use super::permutable_kernel::{Kernel, PermutableKernel, PermutableKernelOneClass}; +use super::permutable_kernel::{PermutableKernel, PermutableKernelOneClass}; use super::solver_smo::SolverState; use super::SolverParams; use super::{Float, Svm, SvmParams}; +use linfa_kernel::{Kernel, KernelOwned, KernelView}; /// Support Vector Classification with C-penalizing parameter /// @@ -24,13 +25,13 @@ use super::{Float, Svm, SvmParams}; /// * `targets` - the ground truth targets `y_i` /// * `cpos` - C for positive targets /// * `cneg` - C for negative targets -pub fn fit_c<'a, A: Float>( - params: SolverParams, - kernel: &'a Kernel<'a, A>, - targets: &'a [bool], - cpos: A, - cneg: A, -) -> Svm<'a, A, Pr> { +pub fn fit_c( + params: SolverParams, + kernel: KernelOwned, + targets: &[bool], + cpos: F, + cneg: F, +) -> Svm { let bounds = targets .iter() .map(|x| if *x { cpos } else { cneg }) @@ -39,8 +40,8 @@ pub fn fit_c<'a, A: Float>( let kernel = PermutableKernel::new(kernel, targets.to_vec()); let solver = SolverState::new( - vec![A::zero(); targets.len()], - vec![-A::one(); targets.len()], + vec![F::zero(); targets.len()], + vec![-F::one(); targets.len()], targets.to_vec(), kernel, bounds, @@ -75,23 +76,23 @@ pub fn fit_c<'a, A: Float>( /// * `kernel` - the kernel matrix `Q` /// * `targets` - the ground truth targets `y_i` /// * `nu` - Nu penalizing term -pub fn fit_nu<'a, A: Float>( - params: SolverParams, - kernel: &'a Kernel<'a, A>, - targets: &'a [bool], - nu: A, -) -> Svm<'a, A, Pr> { - let mut sum_pos = nu * A::from(targets.len()).unwrap() / A::from(2.0).unwrap(); - let mut sum_neg = nu * A::from(targets.len()).unwrap() / A::from(2.0).unwrap(); +pub fn fit_nu( + params: SolverParams, + kernel: KernelOwned, + targets: &[bool], + nu: F, +) -> Svm { + let mut sum_pos = nu * F::from(targets.len()).unwrap() / F::from(2.0).unwrap(); + let mut sum_neg = nu * F::from(targets.len()).unwrap() / F::from(2.0).unwrap(); let init_alpha = targets .iter() .map(|x| { if *x { - let val = A::min(A::one(), sum_pos); + let val = F::min(F::one(), sum_pos); sum_pos -= val; val } else { - let val = A::min(A::one(), sum_neg); + let val = F::min(F::one(), sum_neg); sum_neg -= val; val } @@ -102,10 +103,10 @@ pub fn fit_nu<'a, A: Float>( let solver = SolverState::new( init_alpha, - vec![A::zero(); targets.len()], + vec![F::zero(); targets.len()], targets.to_vec(), kernel, - vec![A::one(); targets.len()], + vec![F::one(); targets.len()], params, true, ); @@ -137,19 +138,19 @@ pub fn fit_nu<'a, A: Float>( /// * `params` - Solver parameters (threshold etc.) /// * `kernel` - the kernel matrix `Q` /// * `nu` - Nu penalizing term -pub fn fit_one_class<'a, A: Float + num_traits::ToPrimitive>( - params: SolverParams, - kernel: &'a Kernel<'a, A>, - nu: A, -) -> Svm<'a, A, Pr> { +pub fn fit_one_class( + params: SolverParams, + kernel: KernelOwned, + nu: F, +) -> Svm { let size = kernel.size(); - let n = (nu * A::from(size).unwrap()).to_usize().unwrap(); + let n = (nu * F::from(size).unwrap()).to_usize().unwrap(); let init_alpha = (0..size) .map(|x| match x.cmp(&n) { - Ordering::Less => A::one(), - Ordering::Greater => A::zero(), - Ordering::Equal => nu * A::from(size).unwrap() - A::from(x).unwrap(), + Ordering::Less => F::one(), + Ordering::Greater => F::zero(), + Ordering::Equal => nu * F::from(size).unwrap() - F::from(x).unwrap(), }) .collect::>(); @@ -157,10 +158,10 @@ pub fn fit_one_class<'a, A: Float + num_traits::ToPrimitive>( let solver = SolverState::new( init_alpha, - vec![A::zero(); size], + vec![F::zero(); size], vec![true; size], kernel, - vec![A::one(); size], + vec![F::one(); size], params, false, ); @@ -175,21 +176,21 @@ pub fn fit_one_class<'a, A: Float + num_traits::ToPrimitive>( /// For a given dataset with kernel matrix as records and two class problem as targets this fits /// a optimal hyperplane to the problem and returns the solution as a model. The model predicts /// probabilities for whether a sample belongs to the first or second class. -impl<'a, F: Float> Fit<'a, Kernel<'a, F>, &Array1> for SvmParams { - type Object = Svm<'a, F, Pr>; +impl<'a, F: Float> Fit<'a, KernelOwned, &Array1> for SvmParams { + type Object = Svm; - fn fit(&self, dataset: &'a DatasetBase, &Array1>) -> Self::Object { + fn fit(&self, dataset: &DatasetBase, &Array1>) -> Self::Object { match (self.c, self.nu) { (Some((c_p, c_n)), _) => fit_c( self.solver_params.clone(), - &dataset.records, + dataset.records.clone(), dataset.targets().as_slice(), c_p, c_n, ), (None, Some((nu, _))) => fit_nu( self.solver_params.clone(), - &dataset.records, + dataset.records.clone(), dataset.targets().as_slice(), nu, ), @@ -198,21 +199,67 @@ impl<'a, F: Float> Fit<'a, Kernel<'a, F>, &Array1> for SvmParams { } } +impl<'a, F: Float> Fit<'a, KernelOwned, Array1> for SvmParams { + type Object = Svm; + + fn fit(&self, dataset: &DatasetBase, Array1>) -> Self::Object { + match (self.c, self.nu) { + (Some((c_p, c_n)), _) => fit_c( + self.solver_params.clone(), + dataset.records.clone(), + dataset.targets().as_slice().unwrap(), + c_p, + c_n, + ), + (None, Some((nu, _))) => fit_nu( + self.solver_params.clone(), + dataset.records.clone(), + dataset.targets().as_slice().unwrap(), + nu, + ), + _ => panic!("Set either C value or Nu value"), + } + } +} + +impl<'a, F: Float> Fit<'a, KernelOwned, ArrayView1<'a, bool>> for SvmParams { + type Object = Svm; + + fn fit(&self, dataset: &DatasetBase, ArrayView1<'a, bool>>) -> Self::Object { + match (self.c, self.nu) { + (Some((c_p, c_n)), _) => fit_c( + self.solver_params.clone(), + dataset.records.clone(), + dataset.targets().as_slice().unwrap(), + c_p, + c_n, + ), + (None, Some((nu, _))) => fit_nu( + self.solver_params.clone(), + dataset.records.clone(), + dataset.targets().as_slice().unwrap(), + nu, + ), + _ => panic!("Set either C value or Nu value"), + } + } +} + impl<'a, F: Float> Fit<'a, Kernel<'a, F>, ArrayView1<'a, bool>> for SvmParams { - type Object = Svm<'a, F, Pr>; + type Object = Svm; - fn fit(&self, dataset: &'a DatasetBase, ArrayView1<'a, bool>>) -> Self::Object { + fn fit(&self, dataset: &DatasetBase, ArrayView1<'a, bool>>) -> Self::Object { match (self.c, self.nu) { (Some((c_p, c_n)), _) => fit_c( self.solver_params.clone(), - &dataset.records, + dataset.records.to_owned(), dataset.targets().as_slice().unwrap(), c_p, c_n, ), (None, Some((nu, _))) => fit_nu( self.solver_params.clone(), - &dataset.records, + dataset.records.to_owned(), dataset.targets().as_slice().unwrap(), nu, ), @@ -222,20 +269,20 @@ impl<'a, F: Float> Fit<'a, Kernel<'a, F>, ArrayView1<'a, bool>> for SvmParams Fit<'a, Kernel<'a, F>, &[bool]> for SvmParams { - type Object = Svm<'a, F, Pr>; + type Object = Svm; - fn fit(&self, dataset: &'a DatasetBase, &[bool]>) -> Self::Object { + fn fit(&self, dataset: &DatasetBase, &[bool]>) -> Self::Object { match (self.c, self.nu) { (Some((c_p, c_n)), _) => fit_c( self.solver_params.clone(), - &dataset.records, + dataset.records.to_owned(), dataset.targets(), c_p, c_n, ), (None, Some((nu, _))) => fit_nu( self.solver_params.clone(), - &dataset.records, + dataset.records.to_owned(), dataset.targets(), nu, ), @@ -249,18 +296,59 @@ impl<'a, F: Float> Fit<'a, Kernel<'a, F>, &[bool]> for SvmParams { /// This fits a SVM model to a dataset with only positive samples and uses the one-class /// implementation of SVM. impl<'a, F: Float> Fit<'a, Kernel<'a, F>, &()> for SvmParams { - type Object = Svm<'a, F, Pr>; + type Object = Svm; - fn fit(&self, dataset: &'a DatasetBase, &()>) -> Self::Object { + fn fit(&self, dataset: &DatasetBase, &()>) -> Self::Object { match self.nu { - Some((nu, _)) => fit_one_class(self.solver_params.clone(), &dataset.records, nu), + Some((nu, _)) => { + fit_one_class(self.solver_params.clone(), dataset.records.to_owned(), nu) + } + None => panic!("One class needs Nu value"), + } + } +} + +impl<'a, F: Float> Fit<'a, Kernel<'a, F>, &[()]> for SvmParams { + type Object = Svm; + + fn fit(&self, dataset: &DatasetBase, &[()]>) -> Self::Object { + match self.nu { + Some((nu, _)) => { + fit_one_class(self.solver_params.clone(), dataset.records.to_owned(), nu) + } + None => panic!("One class needs Nu value"), + } + } +} + +impl<'a, F: Float> Fit<'a, KernelView<'a, F>, &()> for SvmParams { + type Object = Svm; + + fn fit(&self, dataset: &DatasetBase, &()>) -> Self::Object { + match self.nu { + Some((nu, _)) => { + fit_one_class(self.solver_params.clone(), dataset.records.to_owned(), nu) + } + None => panic!("One class needs Nu value"), + } + } +} + +impl<'a, F: Float> Fit<'a, KernelView<'a, F>, &[()]> for SvmParams { + type Object = Svm; + + fn fit(&self, dataset: &DatasetBase, &[()]>) -> Self::Object { + match self.nu { + Some((nu, _)) => { + fit_one_class(self.solver_params.clone(), dataset.records.to_owned(), nu) + } None => panic!("One class needs Nu value"), } } } /// Predict a probability with a feature vector -impl<'a, F: Float> Predict, Pr> for Svm<'a, F, Pr> { +impl<'a, F: Float> Predict, Pr> for Svm { fn predict(&self, data: Array1) -> Pr { let val = match self.linear_decision { Some(ref x) => x.mul(&data).sum() - self.rho, @@ -272,11 +360,24 @@ impl<'a, F: Float> Predict, Pr> for Svm<'a, F, Pr> { } } +/// Predict a probability with a feature vector +impl<'a, F: Float> Predict, Pr> for Svm { + fn predict(&self, data: ArrayView1<'a, F>) -> Pr { + let val = match self.linear_decision { + Some(ref x) => x.mul(&data).sum() - self.rho, + None => self.kernel.weighted_sum(&self.alpha, data) - self.rho, + }; + + // this is safe because `F` is only implemented for `f32` and `f64` + Pr(val.to_f32().unwrap()) + } +} + /// Classify observations /// /// This function takes a number of features and predicts target probabilities that they belong to /// the positive class. -impl<'a, F: Float, D: Data> Predict, Array1> for Svm<'a, F, Pr> { +impl<'a, F: Float, D: Data> Predict, Array1> for Svm { fn predict(&self, data: ArrayBase) -> Array1 { data.outer_iter() .map(|data| { @@ -292,10 +393,14 @@ impl<'a, F: Float, D: Data> Predict, Array1> for } } -impl<'a, F: Float, T: Targets> - Predict, T>, DatasetBase, Array1>> for Svm<'a, F, Pr> +impl<'a, F: Float, D: Data, T: Targets> + Predict, T>, DatasetBase, Array1>> + for Svm { - fn predict(&self, data: DatasetBase, T>) -> DatasetBase, Array1> { + fn predict( + &self, + data: DatasetBase, T>, + ) -> DatasetBase, Array1> { let DatasetBase { records, .. } = data; let predicted = self.predict(records.view()); @@ -305,7 +410,7 @@ impl<'a, F: Float, T: Targets> impl<'a, F: Float, T: Targets, D: Data> Predict<&'a DatasetBase, T>, DatasetBase, Array1>> - for Svm<'a, F, Pr> + for Svm { fn predict( &self, @@ -320,10 +425,10 @@ impl<'a, F: Float, T: Targets, D: Data> #[cfg(test)] mod tests { use super::Svm; - use linfa::dataset::DatasetBase; + use linfa::dataset::{Dataset, DatasetBase}; use linfa::prelude::ToConfusionMatrix; use linfa::traits::{Fit, Predict, Transformer}; - use linfa_kernel::{Kernel, KernelMethod}; + use linfa_kernel::{Kernel, KernelMethod, KernelView}; use ndarray::{Array, Array1, Array2, Axis}; use ndarray_rand::rand::SeedableRng; @@ -352,7 +457,7 @@ mod tests { #[test] fn test_linear_classification() { - let entries = ndarray::stack( + let entries: Array2 = ndarray::stack( Axis(0), &[ Array::random((10, 2), Uniform::new(-1., -0.5)).view(), @@ -361,11 +466,12 @@ mod tests { ) .unwrap(); let targets = (0..20).map(|x| x < 10).collect::>(); - let dataset = DatasetBase::new(entries.clone(), targets); + let dataset = Dataset::new(entries.clone(), targets); + let dataset_view = dataset.view(); let dataset = Kernel::params() .method(KernelMethod::Linear) - .transform(&dataset); + .transform(&dataset_view); // train model with positive and negative weight let model = Svm::params().pos_neg_weights(1.0, 1.0).fit(&dataset); @@ -380,7 +486,7 @@ mod tests { // train model with Nu parameter let model = Svm::params().nu_weight(0.05).fit(&dataset); - let valid = model.predict(valid).map_targets(|x| **x > 0.0); + let valid = model.predict(&valid).map_targets(|x| **x > 0.0); let cm = valid.confusion_matrix(&dataset); assert_eq!(cm.accuracy(), 1.0); @@ -392,11 +498,12 @@ mod tests { // construct parabolica and classify middle area as positive and borders as negative let records = Array::random_using((40, 1), Uniform::new(-2f64, 2.), &mut rng); let targets = records.map_axis(Axis(1), |x| x[0] * x[0] < 0.5); - let dataset = DatasetBase::new(records.clone(), targets); + let dataset = Dataset::new(records.clone(), targets); + let dataset_view = dataset.view(); - let dataset = Kernel::params() + let dataset = KernelView::params() .method(KernelMethod::Polynomial(0.0, 2.0)) - .transform(&dataset); + .transform(&dataset_view); // train model with positive and negative weight let model = Svm::params().pos_neg_weights(1.0, 1.0).fit(&dataset); @@ -415,11 +522,12 @@ mod tests { fn test_convoluted_rings_classification() { let records = generate_convoluted_rings(10); let targets = (0..20).map(|x| x < 10).collect::>(); - let dataset = DatasetBase::new(records.clone(), targets); + let dataset = Dataset::new(records.clone(), targets.clone()); + let dataset_view = dataset.view(); - let dataset = Kernel::params() + let dataset = KernelView::params() .method(KernelMethod::Gaussian(50.0)) - .transform(&dataset); + .transform(&dataset_view); // train model with positive and negative weight let model = Svm::params().pos_neg_weights(1.0, 1.0).fit(&dataset); @@ -434,7 +542,7 @@ mod tests { // train model with Nu parameter let model = Svm::params().nu_weight(0.01).fit(&dataset); - let valid = model.predict(valid).map_targets(|x| **x > 0.0); + let valid = model.predict(&valid).map_targets(|x| **x > 0.0); let cm = valid.confusion_matrix(&dataset); assert!(cm.accuracy() > 0.9); @@ -444,9 +552,9 @@ mod tests { fn test_reject_classification() { // generate two clusters with 100 samples each let entries = Array::random((100, 2), Uniform::new(-4., 4.)); - let dataset = DatasetBase::new(entries.clone(), ()); + let dataset = DatasetBase::new(entries.view(), ()); - let dataset = Kernel::params() + let dataset = KernelView::params() .method(KernelMethod::Gaussian(100.0)) .transform(&dataset); diff --git a/linfa-svm/src/lib.rs b/linfa-svm/src/lib.rs index 0a9d7197d..921321f18 100644 --- a/linfa-svm/src/lib.rs +++ b/linfa-svm/src/lib.rs @@ -81,7 +81,7 @@ mod permutable_kernel; mod regression; pub mod solver_smo; -use permutable_kernel::Kernel; +use linfa_kernel::KernelOwned; pub use solver_smo::SolverParams; /// SVM Hyperparameters @@ -191,38 +191,38 @@ pub enum ExitReason { derive(Serialize, Deserialize), serde(crate = "serde_crate") )] -pub struct Svm<'a, A: Float, T> { - pub alpha: Vec, - pub rho: A, - r: Option, +pub struct Svm { + pub alpha: Vec, + pub rho: F, + r: Option, exit_reason: ExitReason, iterations: usize, - obj: A, + obj: F, #[cfg_attr( feature = "serde", serde(bound( - serialize = "&'a Kernel<'a, A>: Serialize", - deserialize = "&'a Kernel<'a, A>: Deserialize<'de>" + serialize = "&'a Kernel<'a, F>: Serialize", + deserialize = "&'a Kernel<'a, F>: Deserialize<'de>" )) )] - kernel: &'a Kernel<'a, A>, - linear_decision: Option>, + kernel: KernelOwned, + linear_decision: Option>, phantom: PhantomData, } -impl<'a, A: Float, T> Svm<'a, A, T> { - /// Create hyper parameter set - /// - /// This creates a `SvmParams` and sets it to the default values: - /// * C values of (1, 1) - /// * Eps of 1e-7 - /// * No shrinking - pub fn params() -> SvmParams { +/// Create hyper parameter set +/// +/// This creates a `SvmParams` and sets it to the default values: +/// * C values of (1, 1) +/// * Eps of 1e-7 +/// * No shrinking +impl Svm { + pub fn params() -> SvmParams { SvmParams { - c: Some((A::one(), A::one())), + c: Some((F::one(), F::one())), nu: None, solver_params: SolverParams { - eps: A::from(1e-7).unwrap(), + eps: F::from(1e-7).unwrap(), shrinking: false, }, phantom: PhantomData, @@ -236,10 +236,10 @@ impl<'a, A: Float, T> Svm<'a, A, T> { pub fn nsupport(&self) -> usize { self.alpha .iter() - .filter(|x| x.abs() > A::from(1e-5).unwrap()) + .filter(|x| x.abs() > F::from(1e-5).unwrap()) .count() } - pub(crate) fn with_phantom(self) -> Svm<'a, A, S> { + pub(crate) fn with_phantom(self) -> Svm { Svm { alpha: self.alpha, rho: self.rho, @@ -258,7 +258,7 @@ impl<'a, A: Float, T> Svm<'a, A, T> { /// /// In order to understand the solution of the SMO solver the objective, number of iterations and /// required support vectors are printed here. -impl<'a, A: Float, T> fmt::Display for Svm<'a, A, T> { +impl<'a, F: Float, T> fmt::Display for Svm { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.exit_reason { ExitReason::ReachedThreshold => write!( @@ -278,3 +278,52 @@ impl<'a, A: Float, T> fmt::Display for Svm<'a, A, T> { } } } + +#[cfg(test)] +mod tests { + use crate::Svm; + use linfa::dataset::Dataset; + use linfa::prelude::*; + use linfa_kernel::{Kernel, KernelMethod}; + use ndarray::Array1; + #[test] + fn test_iter_folding_for_classification() { + let mut dataset = linfa_datasets::winequality().map_targets(|x| *x > 6); + let params = Svm::params().pos_neg_weights(7., 0.6); + + let avg_acc = dataset + .iter_fold(4, |training_set| { + let train_kernel = Kernel::params() + .method(KernelMethod::Gaussian(80.0)) + .transform(&training_set); + params.fit(&train_kernel) + }) + .map(|(model, valid)| { + model + .predict(&valid) + .map_targets(|x| **x > 0.0) + .confusion_matrix(&valid) + .accuracy() + }) + .sum::() + / 4_f32; + assert!(avg_acc >= 0.5) + } + + #[test] + fn test_iter_folding_for_regression() { + let mut dataset: Dataset = linfa_datasets::diabetes(); + let params = Svm::params().c_eps(10., 0.01); + + let _avg_acc = dataset + .iter_fold(4, |training_set| { + let train_kernel = Kernel::params() + .method(KernelMethod::Linear) + .transform(&training_set); + params.fit(&train_kernel) + }) + .map(|(model, valid)| Array1::from(model.predict(valid.records())).r2(valid.targets())) + .sum::() + / 4_f64; + } +} diff --git a/linfa-svm/src/permutable_kernel.rs b/linfa-svm/src/permutable_kernel.rs index 22ae04ce3..67c7f0215 100644 --- a/linfa-svm/src/permutable_kernel.rs +++ b/linfa-svm/src/permutable_kernel.rs @@ -1,29 +1,28 @@ use crate::Float; -use linfa_kernel::Kernel as LinfaKernel; -use ndarray::{Array1, ArrayView2}; +use linfa_kernel::KernelOwned; +use ndarray::Array1; -pub type Kernel<'a, A> = LinfaKernel>; - -pub trait Permutable<'a, A: Float> { +pub trait Permutable { fn swap_indices(&mut self, i: usize, j: usize); - fn distances(&self, idx: usize, length: usize) -> Vec; - fn self_distance(&self, idx: usize) -> A; - fn inner(&self) -> &'a Kernel<'a, A>; + fn distances(&self, idx: usize, length: usize) -> Vec; + fn self_distance(&self, idx: usize) -> F; + fn inner(&self) -> &KernelOwned; + fn to_inner(self) -> KernelOwned; } -/// Kernel matrix with permutable columns +/// KernelView matrix with permutable columns /// /// This struct wraps a kernel matrix with access indices. The working set can shrink during the /// optimization and it is therefore necessary to reorder entries. -pub struct PermutableKernel<'a, A: Float> { - kernel: &'a Kernel<'a, A>, - kernel_diag: Array1, +pub struct PermutableKernel { + kernel: KernelOwned, + kernel_diag: Array1, kernel_indices: Vec, targets: Vec, } -impl<'a, A: Float> PermutableKernel<'a, A> { - pub fn new(kernel: &'a Kernel<'a, A>, targets: Vec) -> PermutableKernel<'a, A> { +impl PermutableKernel { + pub fn new(kernel: KernelOwned, targets: Vec) -> PermutableKernel { let kernel_diag = kernel.diagonal(); let kernel_indices = (0..kernel.size()).collect::>(); @@ -36,14 +35,14 @@ impl<'a, A: Float> PermutableKernel<'a, A> { } } -impl<'a, A: Float> Permutable<'a, A> for PermutableKernel<'a, A> { +impl Permutable for PermutableKernel { /// Swap two indices fn swap_indices(&mut self, i: usize, j: usize) { self.kernel_indices.swap(i, j); } /// Return distances from node `idx` to all other nodes - fn distances(&self, idx: usize, length: usize) -> Vec { + fn distances(&self, idx: usize, length: usize) -> Vec { let idx = self.kernel_indices[idx]; let kernel = self.kernel.column(idx); @@ -65,26 +64,31 @@ impl<'a, A: Float> Permutable<'a, A> for PermutableKernel<'a, A> { } /// Return internal kernel - fn inner(&self) -> &'a Kernel<'a, A> { + fn inner(&self) -> &KernelOwned { + &self.kernel + } + + /// Return internal kernel + fn to_inner(self) -> KernelOwned { self.kernel } /// Return distance to itself - fn self_distance(&self, idx: usize) -> A { + fn self_distance(&self, idx: usize) -> F { let idx = self.kernel_indices[idx]; self.kernel_diag[idx] } } -pub struct PermutableKernelOneClass<'a, A: Float> { - kernel: &'a Kernel<'a, A>, - kernel_diag: Array1, +pub struct PermutableKernelOneClass { + kernel: KernelOwned, + kernel_diag: Array1, kernel_indices: Vec, } -impl<'a, A: Float> PermutableKernelOneClass<'a, A> { - pub fn new(kernel: &'a Kernel<'a, A>) -> PermutableKernelOneClass<'a, A> { +impl PermutableKernelOneClass { + pub fn new(kernel: KernelOwned) -> PermutableKernelOneClass { let kernel_diag = kernel.diagonal(); let kernel_indices = (0..kernel.size()).collect::>(); @@ -96,14 +100,14 @@ impl<'a, A: Float> PermutableKernelOneClass<'a, A> { } } -impl<'a, A: Float> Permutable<'a, A> for PermutableKernelOneClass<'a, A> { +impl Permutable for PermutableKernelOneClass { /// Swap two indices fn swap_indices(&mut self, i: usize, j: usize) { self.kernel_indices.swap(i, j); } /// Return distances from node `idx` to all other nodes - fn distances(&self, idx: usize, length: usize) -> Vec { + fn distances(&self, idx: usize, length: usize) -> Vec { let idx = self.kernel_indices[idx]; let kernel = self.kernel.column(idx); @@ -115,27 +119,32 @@ impl<'a, A: Float> Permutable<'a, A> for PermutableKernelOneClass<'a, A> { } /// Return internal kernel - fn inner(&self) -> &'a Kernel<'a, A> { + fn inner(&self) -> &KernelOwned { + &self.kernel + } + + /// Return internal kernel + fn to_inner(self) -> KernelOwned { self.kernel } /// Return distance to itself - fn self_distance(&self, idx: usize) -> A { + fn self_distance(&self, idx: usize) -> F { let idx = self.kernel_indices[idx]; self.kernel_diag[idx] } } -pub struct PermutableKernelRegression<'a, A: Float> { - kernel: &'a Kernel<'a, A>, - kernel_diag: Array1, +pub struct PermutableKernelRegression { + kernel: KernelOwned, + kernel_diag: Array1, kernel_indices: Vec, signs: Vec, } -impl<'a, A: Float> PermutableKernelRegression<'a, A> { - pub fn new(kernel: &'a Kernel<'a, A>) -> PermutableKernelRegression<'a, A> { +impl<'a, F: Float> PermutableKernelRegression { + pub fn new(kernel: KernelOwned) -> PermutableKernelRegression { let kernel_diag = kernel.diagonal(); let kernel_indices = (0..2 * kernel.size()) .map(|x| { @@ -159,7 +168,7 @@ impl<'a, A: Float> PermutableKernelRegression<'a, A> { } } -impl<'a, A: Float> Permutable<'a, A> for PermutableKernelRegression<'a, A> { +impl<'a, F: Float> Permutable for PermutableKernelRegression { /// Swap two indices fn swap_indices(&mut self, i: usize, j: usize) { self.kernel_indices.swap(i, j); @@ -167,7 +176,7 @@ impl<'a, A: Float> Permutable<'a, A> for PermutableKernelRegression<'a, A> { } /// Return distances from node `idx` to all other nodes - fn distances(&self, idx: usize, length: usize) -> Vec { + fn distances(&self, idx: usize, length: usize) -> Vec { let kernel = self.kernel.column(self.kernel_indices[idx]); // reorder entries @@ -187,12 +196,17 @@ impl<'a, A: Float> Permutable<'a, A> for PermutableKernelRegression<'a, A> { } /// Return internal kernel - fn inner(&self) -> &'a Kernel<'a, A> { + fn inner(&self) -> &KernelOwned { + &self.kernel + } + + /// Return internal kernel + fn to_inner(self) -> KernelOwned { self.kernel } /// Return distance to itself - fn self_distance(&self, idx: usize) -> A { + fn self_distance(&self, idx: usize) -> F { let idx = self.kernel_indices[idx]; self.kernel_diag[idx] @@ -202,20 +216,20 @@ impl<'a, A: Float> Permutable<'a, A> for PermutableKernelRegression<'a, A> { #[cfg(test)] mod tests { use super::{Permutable, PermutableKernel}; - use linfa_kernel::{Kernel, KernelInner, KernelMethod}; + use linfa_kernel::{KernelInner, KernelMethod, KernelOwned}; use ndarray::array; #[test] fn test_permutable_kernel() { let dist = array![[1.0, 0.3, 0.1], [0.3, 1.0, 0.5], [0.1, 0.5, 1.0]]; let targets = vec![true, true, true]; - let dist = Kernel { + let dist = KernelOwned { inner: KernelInner::Dense(dist.clone()), method: KernelMethod::Linear, - dataset: dist.view(), + dataset: dist, }; - let mut kernel = PermutableKernel::new(&dist, targets); + let mut kernel = PermutableKernel::new(dist, targets); assert_eq!(kernel.distances(0, 3), &[1.0, 0.3, 0.1]); assert_eq!(kernel.distances(1, 3), &[0.3, 1.0, 0.5]); diff --git a/linfa-svm/src/regression.rs b/linfa-svm/src/regression.rs index d136b5a1c..63ada73b2 100644 --- a/linfa-svm/src/regression.rs +++ b/linfa-svm/src/regression.rs @@ -1,9 +1,10 @@ //! Support Vector Regression use linfa::{dataset::DatasetBase, traits::Fit, traits::Predict}; +use linfa_kernel::{Kernel, KernelOwned, KernelView}; use ndarray::{Array1, ArrayBase, ArrayView1, Data, Ix2}; use std::ops::Mul; -use super::permutable_kernel::{Kernel, PermutableKernelRegression}; +use super::permutable_kernel::PermutableKernelRegression; use super::solver_smo::SolverState; use super::SolverParams; use super::{Float, Svm, SvmParams}; @@ -19,14 +20,14 @@ use super::{Float, Svm, SvmParams}; /// * `targets` - the continuous targets `y_i` /// * `c` - C value for all targets /// * `p` - epsilon value for all targets -pub fn fit_epsilon<'a, A: Float>( - params: SolverParams, - kernel: &'a Kernel<'a, A>, - target: &'a [A], - c: A, - p: A, -) -> Svm<'a, A, A> { - let mut linear_term = vec![A::zero(); 2 * target.len()]; +pub fn fit_epsilon( + params: SolverParams, + kernel: KernelOwned, + target: &[F], + c: F, + p: F, +) -> Svm { + let mut linear_term = vec![F::zero(); 2 * target.len()]; let mut targets = vec![true; 2 * target.len()]; for i in 0..target.len() { @@ -39,7 +40,7 @@ pub fn fit_epsilon<'a, A: Float>( let kernel = PermutableKernelRegression::new(kernel); let solver = SolverState::new( - vec![A::zero(); 2 * target.len()], + vec![F::zero(); 2 * target.len()], linear_term, targets.to_vec(), kernel, @@ -70,21 +71,21 @@ pub fn fit_epsilon<'a, A: Float>( /// * `targets` - the continuous targets `y_i` /// * `c` - C value for all targets /// * `nu` - nu value for all targets -pub fn fit_nu<'a, A: Float>( - params: SolverParams, - kernel: &'a Kernel<'a, A>, - target: &'a [A], - c: A, - nu: A, -) -> Svm<'a, A, A> { - let mut alpha = vec![A::zero(); 2 * target.len()]; - let mut linear_term = vec![A::zero(); 2 * target.len()]; +pub fn fit_nu( + params: SolverParams, + kernel: KernelOwned, + target: &[F], + c: F, + nu: F, +) -> Svm { + let mut alpha = vec![F::zero(); 2 * target.len()]; + let mut linear_term = vec![F::zero(); 2 * target.len()]; let mut targets = vec![true; 2 * target.len()]; - let mut sum = c * nu * A::from(target.len()).unwrap() / A::from(2.0).unwrap(); + let mut sum = c * nu * F::from(target.len()).unwrap() / F::from(2.0).unwrap(); for i in 0..target.len() { - alpha[i] = A::min(sum, c); - alpha[i + target.len()] = A::min(sum, c); + alpha[i] = F::min(sum, c); + alpha[i + target.len()] = F::min(sum, c); sum -= alpha[i]; linear_term[i] = -target[i]; @@ -119,21 +120,21 @@ pub fn fit_nu<'a, A: Float>( /// Regress obserations /// /// Take a number of observations and project them to optimal continuous targets. -impl<'a, F: Float> Fit<'a, Kernel<'a, F>, &Array1> for SvmParams { - type Object = Svm<'a, F, F>; +impl<'a, F: Float> Fit<'a, KernelOwned, &Array1> for SvmParams { + type Object = Svm; - fn fit(&self, dataset: &'a DatasetBase, &Array1>) -> Self::Object { + fn fit(&self, dataset: &DatasetBase, &Array1>) -> Self::Object { match (self.c, self.nu) { (Some((c, eps)), _) => fit_epsilon( self.solver_params.clone(), - &dataset.records, + dataset.records.clone(), dataset.targets().as_slice().unwrap(), c, eps, ), (None, Some((nu, eps))) => fit_nu( self.solver_params.clone(), - &dataset.records, + dataset.records.clone(), dataset.targets().as_slice().unwrap(), nu, eps, @@ -144,20 +145,20 @@ impl<'a, F: Float> Fit<'a, Kernel<'a, F>, &Array1> for SvmParams { } impl<'a, F: Float> Fit<'a, Kernel<'a, F>, ArrayView1<'a, F>> for SvmParams { - type Object = Svm<'a, F, F>; + type Object = Svm; - fn fit(&self, dataset: &'a DatasetBase, ArrayView1<'a, F>>) -> Self::Object { + fn fit(&self, dataset: &DatasetBase, ArrayView1<'a, F>>) -> Self::Object { match (self.c, self.nu) { (Some((c, eps)), _) => fit_epsilon( self.solver_params.clone(), - &dataset.records, + dataset.records.to_owned(), dataset.targets().as_slice().unwrap(), c, eps, ), (None, Some((nu, eps))) => fit_nu( self.solver_params.clone(), - &dataset.records, + dataset.records.to_owned(), dataset.targets().as_slice().unwrap(), nu, eps, @@ -167,7 +168,56 @@ impl<'a, F: Float> Fit<'a, Kernel<'a, F>, ArrayView1<'a, F>> for SvmParams } } -impl<'a, D: Data> Predict, Vec> for Svm<'a, f64, f64> { +impl<'a, F: Float> Fit<'a, Kernel<'a, F>, &'a [F]> for SvmParams { + type Object = Svm; + + fn fit(&self, dataset: &DatasetBase, &'a [F]>) -> Self::Object { + match (self.c, self.nu) { + (Some((c, eps)), _) => fit_epsilon( + self.solver_params.clone(), + dataset.records.to_owned(), + dataset.targets(), + c, + eps, + ), + (None, Some((nu, eps))) => fit_nu( + self.solver_params.clone(), + dataset.records.to_owned(), + dataset.targets(), + nu, + eps, + ), + _ => panic!("Set either C value or Nu value"), + } + } +} + +impl<'a, F: Float> Fit<'a, KernelView<'a, F>, ArrayView1<'a, F>> for SvmParams { + type Object = Svm; + + fn fit(&self, dataset: &DatasetBase, ArrayView1<'a, F>>) -> Self::Object { + match (self.c, self.nu) { + (Some((c, eps)), _) => fit_epsilon( + self.solver_params.clone(), + dataset.records.to_owned(), + dataset.targets().as_slice().unwrap(), + c, + eps, + ), + (None, Some((nu, eps))) => fit_nu( + self.solver_params.clone(), + dataset.records.to_owned(), + dataset.targets().as_slice().unwrap(), + nu, + eps, + ), + _ => panic!("Set either C value or Nu value"), + } + } +} + +/// Predict a probability with a set of observations +impl> Predict, Vec> for Svm { fn predict(&self, data: ArrayBase) -> Vec { data.outer_iter() .map(|data| { @@ -182,6 +232,24 @@ impl<'a, D: Data> Predict, Vec> for Svm<'a, f .collect() } } + +/// Predict a probability with a set of observations +impl> Predict<&ArrayBase, Vec> for Svm { + fn predict(&self, data: &ArrayBase) -> Vec { + data.outer_iter() + .map(|data| { + let val = match self.linear_decision { + Some(ref x) => x.mul(&data).sum() - self.rho, + None => self.kernel.weighted_sum(&self.alpha, data.view()) - self.rho, + }; + + // this is safe because `F` is only implemented for `f32` and `f64` + val + }) + .collect() + } +} + #[cfg(test)] pub mod tests { use super::Svm; @@ -202,7 +270,7 @@ pub mod tests { let kernel = Kernel::params() .method(KernelMethod::Gaussian(50.)) - .transform(&sin_curve); + .transform(sin_curve.view()); let dataset = DatasetBase::new(kernel, target.view()); @@ -224,7 +292,7 @@ pub mod tests { let kernel = Kernel::params() .method(KernelMethod::Gaussian(50.)) - .transform(&sin_curve); + .transform(sin_curve.view()); let dataset = DatasetBase::new(kernel, target.view()); diff --git a/linfa-svm/src/solver_smo.rs b/linfa-svm/src/solver_smo.rs index c29611257..425798d61 100644 --- a/linfa-svm/src/solver_smo.rs +++ b/linfa-svm/src/solver_smo.rs @@ -6,22 +6,22 @@ use std::marker::PhantomData; /// Parameters of the solver routine #[derive(Clone)] -pub struct SolverParams { +pub struct SolverParams { /// Stopping condition - pub eps: A, + pub eps: F, /// Should we shrink, e.g. ignore bounded alphas pub shrinking: bool, } /// Status of alpha variables of the solver #[derive(Debug)] -struct Alpha { - value: A, - upper_bound: A, +struct Alpha { + value: F, + upper_bound: F, } -impl Alpha { - pub fn from(value: A, upper_bound: A) -> Alpha { +impl Alpha { + pub fn from(value: F, upper_bound: F) -> Alpha { Alpha { value, upper_bound } } @@ -30,14 +30,14 @@ impl Alpha { } pub fn free_floating(&self) -> bool { - self.value < self.upper_bound && self.value > A::zero() + self.value < self.upper_bound && self.value > F::zero() } pub fn reached_lower(&self) -> bool { - self.value == A::zero() + self.value == F::zero() } - pub fn val(&self) -> A { + pub fn val(&self) -> F { self.value } } @@ -47,50 +47,50 @@ impl Alpha { /// We are solving the dual problem with linear constraints /// min_a f(a), s.t. y^Ta = d, 0 <= a_t < C, t = 1, ..., l /// where f(a) = a^T Q a / 2 + p^T a -pub struct SolverState<'a, A: Float, K: Permutable<'a, A>> { +pub struct SolverState<'a, F: Float, K: Permutable> { /// Gradient of each variable - gradient: Vec, + gradient: Vec, /// Cached gradient because most of the variables are constant - gradient_fixed: Vec, + gradient_fixed: Vec, /// Current value of each variable and in respect to bounds - alpha: Vec>, + alpha: Vec>, /// Active set of variables active_set: Vec, /// Number of active variables nactive: usize, unshrink: bool, nu_constraint: bool, - r: A, + r: F, /// Quadratic term of the problem kernel: K, /// Linear term of the problem - p: Vec, + p: Vec, /// Targets we want to predict targets: Vec, /// Bounds per alpha - bounds: Vec, + bounds: Vec, /// Parameters, e.g. stopping condition etc. - params: SolverParams, + params: SolverParams, phantom: PhantomData<&'a K>, } #[allow(clippy::needless_range_loop)] -impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { +impl<'a, F: Float, K: 'a + Permutable> SolverState<'a, F, K> { /// Initialize a solver state /// /// This is bounded by the lifetime of the kernel matrix, because it can quite large pub fn new( - alpha: Vec, - p: Vec, + alpha: Vec, + p: Vec, targets: Vec, kernel: K, - bounds: Vec, - params: SolverParams, + bounds: Vec, + params: SolverParams, nu_constraint: bool, - ) -> SolverState<'a, A, K> { + ) -> SolverState<'a, F, K> { // initialize alpha status according to bound let alpha = alpha .into_iter() @@ -103,10 +103,10 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { // initialize gradient let mut gradient = p.clone(); - let mut gradient_fixed = vec![A::zero(); alpha.len()]; + let mut gradient_fixed = vec![F::zero(); alpha.len()]; for i in 0..alpha.len() { - // when we have reached alpha = A::zero(), then d(a) = p + // when we have reached alpha = F::zero(), then d(a) = p if !alpha[i].reached_lower() { let dist_i = kernel.distances(i, alpha.len()); let alpha_i = alpha[i].val(); @@ -138,7 +138,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { bounds, params, nu_constraint, - r: A::zero(), + r: F::zero(), phantom: PhantomData, } } @@ -154,16 +154,16 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { } /// Return target as positive/negative indicator - pub fn target(&self, idx: usize) -> A { + pub fn target(&self, idx: usize) -> F { if self.targets[idx] { - A::one() + F::one() } else { - -A::one() + -F::one() } } /// Return the k-th bound - pub fn bound(&self, idx: usize) -> A { + pub fn bound(&self, idx: usize) -> F { self.bounds[idx] } @@ -234,9 +234,9 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { if self.targets[i] != self.targets[j] { let mut quad_coef = self.kernel.self_distance(i) + self.kernel.self_distance(j) - + (A::one() + A::one()) * dist_i[j]; - if quad_coef <= A::zero() { - quad_coef = A::from(1e-10).unwrap(); + + (F::one() + F::one()) * dist_i[j]; + if quad_coef <= F::zero() { + quad_coef = F::from(1e-10).unwrap(); } let delta = -(self.gradient[i] + self.gradient[j]) / quad_coef; @@ -247,13 +247,13 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { self.alpha[j].value += delta; // bound to feasible solution - if diff > A::zero() { - if self.alpha[j].val() < A::zero() { - self.alpha[j].value = A::zero(); + if diff > F::zero() { + if self.alpha[j].val() < F::zero() { + self.alpha[j].value = F::zero(); self.alpha[i].value = diff; } - } else if self.alpha[i].val() < A::zero() { - self.alpha[i].value = A::zero(); + } else if self.alpha[i].val() < F::zero() { + self.alpha[i].value = F::zero(); self.alpha[j].value = -diff; } @@ -267,11 +267,11 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { self.alpha[i].value = bound_j + diff; } } else { - //dbg!(self.kernel.self_distance(i), self.kernel.self_distance(j), A::from(2.0).unwrap() * dist_i[j]); + //dbg!(self.kernel.self_distance(i), self.kernel.self_distance(j), F::from(2.0).unwrap() * dist_i[j]); let mut quad_coef = self.kernel.self_distance(i) + self.kernel.self_distance(j) - - A::from(2.0).unwrap() * dist_i[j]; - if quad_coef <= A::zero() { - quad_coef = A::from(1e-10).unwrap(); + - F::from(2.0).unwrap() * dist_i[j]; + if quad_coef <= F::zero() { + quad_coef = F::from(1e-10).unwrap(); } let delta = (self.gradient[i] - self.gradient[j]) / quad_coef; @@ -287,8 +287,8 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { self.alpha[i].value = bound_i; self.alpha[j].value = sum - bound_i; } - } else if self.alpha[j].val() < A::zero() { - self.alpha[j].value = A::zero(); + } else if self.alpha[j].val() < F::zero() { + self.alpha[j].value = F::zero(); self.alpha[i].value = sum; } if sum > bound_j { @@ -296,20 +296,20 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { self.alpha[j].value = bound_j; self.alpha[i].value = sum - bound_j; } - } else if self.alpha[i].val() < A::zero() { - self.alpha[i].value = A::zero(); + } else if self.alpha[i].val() < F::zero() { + self.alpha[i].value = F::zero(); self.alpha[j].value = sum; } /*if self.alpha[i].val() > bound_i { self.alpha[i].value = bound_i; - } else if self.alpha[i].val() < A::zero() { - self.alpha[i].value = A::zero(); + } else if self.alpha[i].val() < F::zero() { + self.alpha[i].value = F::zero(); } if self.alpha[j].val() > bound_j { self.alpha[j].value = bound_j; - } else if self.alpha[j].val() < A::zero() { - self.alpha[j].value = A::zero(); + } else if self.alpha[j].val() < F::zero() { + self.alpha[j].value = F::zero(); }*/ } @@ -360,11 +360,11 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { } /// Return max and min gradients of free variables - pub fn max_violating_pair(&self) -> ((A, isize), (A, isize)) { + pub fn max_violating_pair(&self) -> ((F, isize), (F, isize)) { // max { -y_i * grad(f)_i \i in I_up(\alpha) } - let mut gmax1 = (-A::infinity(), -1); + let mut gmax1 = (-F::infinity(), -1); // max { y_i * grad(f)_i \i in U_low(\alpha) } - let mut gmax2 = (-A::infinity(), -1); + let mut gmax2 = (-F::infinity(), -1); for i in 0..self.nactive() { if self.targets[i] { @@ -388,11 +388,11 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { } #[allow(clippy::type_complexity)] - pub fn max_violating_pair_nu(&self) -> ((A, isize), (A, isize), (A, isize), (A, isize)) { - let mut gmax1 = (-A::infinity(), -1); - let mut gmax2 = (-A::infinity(), -1); - let mut gmax3 = (-A::infinity(), -1); - let mut gmax4 = (-A::infinity(), -1); + pub fn max_violating_pair_nu(&self) -> ((F, isize), (F, isize), (F, isize), (F, isize)) { + let mut gmax1 = (-F::infinity(), -1); + let mut gmax2 = (-F::infinity(), -1); + let mut gmax3 = (-F::infinity(), -1); + let mut gmax4 = (-F::infinity(), -1); for i in 0..self.nactive() { if self.targets[i] { @@ -428,7 +428,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { let (gmax, gmax2) = self.max_violating_pair(); - let mut obj_diff_min = (A::infinity(), -1); + let mut obj_diff_min = (F::infinity(), -1); if gmax.1 != -1 { let dist_i = self.kernel.distances(gmax.1 as usize, self.ntotal()); @@ -437,18 +437,18 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { if self.targets[j] { if !self.alpha[j].reached_lower() { let grad_diff = gmax.0 + self.gradient[j]; - if grad_diff > A::zero() { + if grad_diff > F::zero() { // this is possible, because op_i is some let i = gmax.1 as usize; let quad_coef = self.kernel.self_distance(i) + self.kernel.self_distance(j) - - A::from(2.0).unwrap() * self.target(i) * dist_ij; + - F::from(2.0).unwrap() * self.target(i) * dist_ij; - let obj_diff = if quad_coef > A::zero() { + let obj_diff = if quad_coef > F::zero() { -(grad_diff * grad_diff) / quad_coef } else { - -(grad_diff * grad_diff) / A::from(1e-10).unwrap() + -(grad_diff * grad_diff) / F::from(1e-10).unwrap() }; if obj_diff <= obj_diff_min.0 { @@ -458,18 +458,18 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { } } else if !self.alpha[j].reached_upper() { let grad_diff = gmax.0 - self.gradient[j]; - if grad_diff > A::zero() { + if grad_diff > F::zero() { // this is possible, because op_i is `Some` let i = gmax.1 as usize; let quad_coef = self.kernel.self_distance(i) + self.kernel.self_distance(j) - + A::from(2.0).unwrap() * self.target(i) * dist_ij; + + F::from(2.0).unwrap() * self.target(i) * dist_ij; - let obj_diff = if quad_coef > A::zero() { + let obj_diff = if quad_coef > F::zero() { -(grad_diff * grad_diff) / quad_coef } else { - -(grad_diff * grad_diff) / A::from(1e-10).unwrap() + -(grad_diff * grad_diff) / F::from(1e-10).unwrap() }; if obj_diff <= obj_diff_min.0 { obj_diff_min = (obj_diff, j as isize); @@ -495,7 +495,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { pub fn select_working_set_nu(&self) -> (usize, usize, bool) { let (gmaxp1, gmaxn1, gmaxp2, gmaxn2) = self.max_violating_pair_nu(); - let mut obj_diff_min = (A::infinity(), -1); + let mut obj_diff_min = (F::infinity(), -1); let dist_i_p = if gmaxp1.1 != -1 { Some(self.kernel.distances(gmaxp1.1 as usize, self.ntotal())) @@ -513,7 +513,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { if self.targets[j] { if !self.alpha[j].reached_lower() { let grad_diff = gmaxp1.0 + self.gradient[j]; - if grad_diff > A::zero() { + if grad_diff > F::zero() { let dist_i_p = match dist_i_p { Some(ref x) => x, None => continue, @@ -523,12 +523,12 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { let i = gmaxp1.1 as usize; let quad_coef = self.kernel.self_distance(i) + self.kernel.self_distance(j) - - A::from(2.0).unwrap() * dist_i_p[j]; + - F::from(2.0).unwrap() * dist_i_p[j]; - let obj_diff = if quad_coef > A::zero() { + let obj_diff = if quad_coef > F::zero() { -(grad_diff * grad_diff) / quad_coef } else { - -(grad_diff * grad_diff) / A::from(1e-10).unwrap() + -(grad_diff * grad_diff) / F::from(1e-10).unwrap() }; if obj_diff <= obj_diff_min.0 { @@ -538,7 +538,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { } } else if !self.alpha[j].reached_upper() { let grad_diff = gmaxn1.0 - self.gradient[j]; - if grad_diff > A::zero() { + if grad_diff > F::zero() { let dist_i_n = match dist_i_n { Some(ref x) => x, None => continue, @@ -548,12 +548,12 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { let i = gmaxn1.1 as usize; let quad_coef = self.kernel.self_distance(i) + self.kernel.self_distance(j) - - A::from(2.0).unwrap() * dist_i_n[j]; + - F::from(2.0).unwrap() * dist_i_n[j]; - let obj_diff = if quad_coef > A::zero() { + let obj_diff = if quad_coef > F::zero() { -(grad_diff * grad_diff) / quad_coef } else { - -(grad_diff * grad_diff) / A::from(1e-10).unwrap() + -(grad_diff * grad_diff) / F::from(1e-10).unwrap() }; if obj_diff <= obj_diff_min.0 { obj_diff_min = (obj_diff, j as isize); @@ -562,7 +562,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { } } - if A::max(gmaxp1.0 + gmaxp2.0, gmaxn1.0 + gmaxn2.0) < self.params.eps + if F::max(gmaxp1.0 + gmaxp2.0, gmaxn1.0 + gmaxn2.0) < self.params.eps || obj_diff_min.1 == -1 { (0, 0, true) @@ -578,7 +578,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { } } - pub fn should_shrunk(&self, i: usize, gmax1: A, gmax2: A) -> bool { + pub fn should_shrunk(&self, i: usize, gmax1: F, gmax2: F) -> bool { if self.alpha[i].reached_upper() { if self.targets[i] { -self.gradient[i] > gmax1 @@ -596,7 +596,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { } } - pub fn should_shrunk_nu(&self, i: usize, gmax1: A, gmax2: A, gmax3: A, gmax4: A) -> bool { + pub fn should_shrunk_nu(&self, i: usize, gmax1: F, gmax2: F, gmax3: F, gmax4: F) -> bool { if self.alpha[i].reached_upper() { if self.targets[i] { -self.gradient[i] > gmax1 @@ -624,7 +624,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { let (gmax1, gmax2) = (gmax1.0, gmax2.0); // work on all variables when 10*eps is reached - if !self.unshrink && gmax1 + gmax2 <= self.params.eps * A::from(10.0).unwrap() { + if !self.unshrink && gmax1 + gmax2 <= self.params.eps * F::from(10.0).unwrap() { self.unshrink = true; self.reconstruct_gradient(); self.nactive = self.ntotal(); @@ -652,7 +652,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { // work on all variables when 10*eps is reached if !self.unshrink - && A::max(gmax1 + gmax2, gmax3 + gmax4) <= self.params.eps * A::from(10.0).unwrap() + && F::max(gmax1 + gmax2, gmax3 + gmax4) <= self.params.eps * F::from(10.0).unwrap() { self.unshrink = true; self.reconstruct_gradient(); @@ -675,31 +675,31 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { } } - pub fn calculate_rho(&mut self) -> A { + pub fn calculate_rho(&mut self) -> F { // with additional constraint call the other function if self.nu_constraint { return self.calculate_rho_nu(); } let mut nfree = 0; - let mut sum_free = A::zero(); - let mut ub = A::infinity(); - let mut lb = -A::infinity(); + let mut sum_free = F::zero(); + let mut ub = F::infinity(); + let mut lb = -F::infinity(); for i in 0..self.nactive() { let yg = self.target(i) * self.gradient[i]; if self.alpha[i].reached_upper() { if self.targets[i] { - lb = A::max(lb, yg); + lb = F::max(lb, yg); } else { - ub = A::min(ub, yg); + ub = F::min(ub, yg); } } else if self.alpha[i].reached_lower() { if self.targets[i] { - ub = A::min(ub, yg); + ub = F::min(ub, yg); } else { - lb = A::max(lb, yg); + lb = F::max(lb, yg); } } else { nfree += 1; @@ -708,24 +708,24 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { } if nfree > 0 { - sum_free / A::from(nfree).unwrap() + sum_free / F::from(nfree).unwrap() } else { - (ub + lb) / A::from(2.0).unwrap() + (ub + lb) / F::from(2.0).unwrap() } } - pub fn calculate_rho_nu(&mut self) -> A { + pub fn calculate_rho_nu(&mut self) -> F { let (mut nfree1, mut nfree2) = (0, 0); - let (mut sum_free1, mut sum_free2) = (A::zero(), A::zero()); - let (mut ub1, mut ub2) = (A::infinity(), A::infinity()); - let (mut lb1, mut lb2) = (-A::infinity(), -A::infinity()); + let (mut sum_free1, mut sum_free2) = (F::zero(), F::zero()); + let (mut ub1, mut ub2) = (F::infinity(), F::infinity()); + let (mut lb1, mut lb2) = (-F::infinity(), -F::infinity()); for i in 0..self.nactive() { if self.targets[i] { if self.alpha[i].reached_upper() { - lb1 = A::max(lb1, self.gradient[i]); + lb1 = F::max(lb1, self.gradient[i]); } else if self.alpha[i].reached_lower() { - ub1 = A::max(ub1, self.gradient[i]); + ub1 = F::max(ub1, self.gradient[i]); } else { nfree1 += 1; sum_free1 += self.gradient[i]; @@ -734,9 +734,9 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { if !self.targets[i] { if self.alpha[i].reached_upper() { - lb2 = A::max(lb2, self.gradient[i]); + lb2 = F::max(lb2, self.gradient[i]); } else if self.alpha[i].reached_lower() { - ub2 = A::max(ub2, self.gradient[i]); + ub2 = F::max(ub2, self.gradient[i]); } else { nfree2 += 1; sum_free2 += self.gradient[i]; @@ -745,22 +745,22 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { } let r1 = if nfree1 > 0 { - sum_free1 / A::from(nfree1).unwrap() + sum_free1 / F::from(nfree1).unwrap() } else { - (ub1 + lb1) / A::from(2.0).unwrap() + (ub1 + lb1) / F::from(2.0).unwrap() }; let r2 = if nfree2 > 0 { - sum_free2 / A::from(nfree2).unwrap() + sum_free2 / F::from(nfree2).unwrap() } else { - (ub2 + lb2) / A::from(2.0).unwrap() + (ub2 + lb2) / F::from(2.0).unwrap() }; - self.r = (r1 + r2) / A::from(2.0).unwrap(); + self.r = (r1 + r2) / F::from(2.0).unwrap(); - (r1 - r2) / A::from(2.0).unwrap() + (r1 - r2) / F::from(2.0).unwrap() } - pub fn solve(mut self) -> Svm<'a, A, A> { + pub fn solve(mut self) -> Svm { let mut iter = 0; let max_iter = if self.targets.len() > std::usize::MAX / 100 { std::usize::MAX @@ -812,11 +812,11 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { }; // calculate object function - let mut v = A::zero(); + let mut v = F::zero(); for i in 0..self.targets.len() { v += self.alpha[i].val() * (self.gradient[i] + self.p[i]); } - let obj = v / A::from(2.0).unwrap(); + let obj = v / F::from(2.0).unwrap(); let exit_reason = if max_iter == iter { ExitReason::ReachedIterations @@ -825,7 +825,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { }; // put back the solution - let alpha: Vec = (0..self.ntotal()) + let alpha: Vec = (0..self.ntotal()) .map(|i| self.alpha[self.active_set[i]].val()) .collect(); @@ -848,7 +848,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { exit_reason, obj, iterations: iter, - kernel: self.kernel.inner(), + kernel: self.kernel.to_inner(), linear_decision, phantom: PhantomData, } @@ -859,7 +859,7 @@ impl<'a, A: Float, K: 'a + Permutable<'a, A>> SolverState<'a, A, K> { #[cfg(test)] mod tests { use crate::permutable_kernel::PermutableKernel; - use super::{SolverState, SolverParams, Svm}; + use super::{SolverState, SolverParams, SvmBase}; use ndarray::array; use linfa_kernel::{Kernel, KernelInner}; @@ -883,7 +883,7 @@ mod tests { let solver = SolverState::new(vec![1.0, 1.0], p, targets, kernel, vec![1000.0; 2], ¶ms, false); - let res: Svm = solver.solve(); + let res: SvmBase = solver.solve(); println!("{:?}", res.alpha); println!("{}", res); diff --git a/src/dataset/impl_dataset.rs b/src/dataset/impl_dataset.rs index a43e5cd59..fa1fa4614 100644 --- a/src/dataset/impl_dataset.rs +++ b/src/dataset/impl_dataset.rs @@ -307,6 +307,138 @@ impl Dataset { pub fn fold(&self, k: usize) -> Vec<(Dataset, Dataset)> { self.view().fold(k) } + + pub fn axis_chunks_iter( + &self, + axis: Axis, + chunk_size: usize, + ) -> impl Iterator> { + self.records() + .axis_chunks_iter(axis, chunk_size) + .zip(self.targets().axis_chunks_iter(axis, chunk_size)) + .map(|(rec, tar)| (rec, tar).into()) + } + + /// Allows to perform k-folding cross validation on fittable algorithms. + /// + /// Given in input a dataset, a value of k and the desired params for the fittable + /// algorithm, returns an iterator over the k trained models and the + /// associated validation set. + /// + /// The models are trained according to a closure specified + /// as an input. + /// + /// ## Parameters + /// + /// - `k`: the number of folds to apply to the dataset + /// - `params`: the desired parameters for the fittable algorithm at hand + /// - `fit_closure`: a closure of the type `(params, training_data) -> fitted_model` + /// that will be used to produce the trained model for each fold. The training data given in input + /// won't outlive the closure. + /// + /// ## Returns + /// + /// An iterator over couples `(trained_model, validation_set)`. + /// + /// ## Panics + /// + /// This method will panic for any of the following three reasons: + /// + /// - The value of `k` provided is not positive; + /// - The value of `k` provided is greater than the total number of samples in the dataset; + /// - The dataset's data is not stored contiguously and in standard order; + /// + /// ## Example + /// ```rust + /// use linfa::traits::Fit; + /// use linfa::dataset::{Dataset, DatasetView}; + /// use ndarray::{array, ArrayView1, ArrayView2}; + /// + /// struct MockFittable {} + /// + /// struct MockFittableResult { + /// mock_var: usize, + /// } + /// + /// impl<'a> Fit<'a, ArrayView2<'a, f64>, ArrayView1<'a, f64>> for MockFittable { + /// type Object = MockFittableResult; + /// + /// fn fit(&self, training_data: &DatasetView) -> Self::Object { + /// MockFittableResult { mock_var: training_data.targets().dim()} + /// } + /// } + /// + /// let records = array![[1.,1.], [2.,2.], [3.,3.], [4.,4.], [5.,5.]]; + /// let targets = array![1.,2.,3.,4.,5.]; + /// let mut dataset: Dataset = (records, targets).into(); + /// let params = MockFittable {}; + /// + ///for (model,validation_set) in dataset.iter_fold(5, |v| params.fit(&v)){ + /// // Here you can use `model` and `validation_set` to + /// // assert the performance of the chosen algorithm + /// } + /// ``` + pub fn iter_fold<'a, O: 'a, C: 'a + Fn(DatasetView) -> O>( + &'a mut self, + k: usize, + fit_closure: C, + ) -> impl Iterator)> + 'a { + assert!(k > 0); + assert!(k <= self.targets.len()); + let samples_count = self.targets().len(); + let fold_size = samples_count / k; + + let features = self.records.dim().1; + + let mut records_sl = self.records.as_slice_mut().unwrap(); + let mut targets_sl = self.targets.as_slice_mut().unwrap(); + + let mut objs: Vec = Vec::new(); + + for i in 0..k { + assist_swap_array2(&mut records_sl, i, fold_size, features); + assist_swap_array1(&mut targets_sl, i, fold_size); + + let train = DatasetView::new( + ArrayView2::from_shape( + (samples_count - fold_size, features), + records_sl.split_at(fold_size * features).1, + ) + .unwrap(), + ArrayView1::from_shape(samples_count - fold_size, targets_sl.split_at(fold_size).1) + .unwrap(), + ); + + let obj = fit_closure(train); + objs.push(obj); + + assist_swap_array2(&mut records_sl, i, fold_size, features); + assist_swap_array1(&mut targets_sl, i, fold_size); + } + objs.into_iter() + .zip(self.axis_chunks_iter(Axis(0), fold_size)) + } +} + +fn assist_swap_array1(slice: &mut [E], index: usize, fold_size: usize) { + if index == 0 { + return; + } + let start = fold_size * index; + let (first_s, second_s) = slice.split_at_mut(start); + let (mut fold, _) = second_s.split_at_mut(fold_size); + first_s[..fold_size].swap_with_slice(&mut fold); +} + +fn assist_swap_array2(slice: &mut [F], index: usize, fold_size: usize, features: usize) { + if index == 0 { + return; + } + let adj_fold_size = fold_size * features; + let start = adj_fold_size * index; + let (first_s, second_s) = slice.split_at_mut(start); + let (mut fold, _) = second_s.split_at_mut(adj_fold_size); + first_s[..fold_size * features].swap_with_slice(&mut fold); } impl<'a, F: Float, E: Copy> DatasetView<'a, F, E> { @@ -422,4 +554,8 @@ impl<'a, F: Float, E: Copy> DatasetView<'a, F, E> { } res } + + pub fn to_owned(&self) -> Dataset { + (self.records().to_owned(), self.targets.to_owned()).into() + } } diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index 8a6765be1..c2db2aa48 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -292,4 +292,104 @@ mod tests { } } } + + struct MockFittable {} + + struct MockFittableResult { + mock_var: usize, + } + + use crate::traits::Fit; + use ndarray::{ArrayView1, ArrayView2}; + + impl<'a> Fit<'a, ArrayView2<'a, f64>, ArrayView1<'a, f64>> for MockFittable { + type Object = MockFittableResult; + + fn fit(&self, training_data: &DatasetView) -> Self::Object { + MockFittableResult { + mock_var: training_data.targets().dim(), + } + } + } + + #[test] + fn test_iter_fold() { + let records = + Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap(); + let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap(); + let mut dataset: Dataset = (records, targets).into(); + let params = MockFittable {}; + + for (i, (model, validation_set)) in dataset.iter_fold(5, |v| params.fit(&v)).enumerate() { + assert_eq!(model.mock_var, 4); + assert_eq!(validation_set.records().row(0)[0] as usize, i + 1); + assert_eq!(validation_set.records().row(0)[1] as usize, i + 1); + assert_eq!(validation_set.targets()[0] as usize, i + 1); + assert_eq!(validation_set.records().dim(), (1, 2)); + assert_eq!(validation_set.targets().dim(), 1); + } + } + + #[test] + fn test_iter_fold_uneven_folds() { + let records = + Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap(); + let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap(); + let mut dataset: Dataset = (records, targets).into(); + let params = MockFittable {}; + + // If we request three folds from a dataset with 5 samples it will cut the + // last two samples from the folds and always add them as a tail of the training + // data + for (i, (model, validation_set)) in dataset.iter_fold(3, |v| params.fit(&v)).enumerate() { + assert_eq!(model.mock_var, 4); + assert_eq!(validation_set.records().row(0)[0] as usize, i + 1); + assert_eq!(validation_set.records().row(0)[1] as usize, i + 1); + assert_eq!(validation_set.targets()[0] as usize, i + 1); + assert_eq!(validation_set.records().dim(), (1, 2)); + assert_eq!(validation_set.targets().dim(), 1); + assert!(i < 3); + } + + // the same goes for the last sample if we choose 4 folds + for (i, (model, validation_set)) in dataset.iter_fold(4, |v| params.fit(&v)).enumerate() { + assert_eq!(model.mock_var, 4); + assert_eq!(validation_set.records().row(0)[0] as usize, i + 1); + assert_eq!(validation_set.records().row(0)[1] as usize, i + 1); + assert_eq!(validation_set.targets()[0] as usize, i + 1); + assert_eq!(validation_set.records().dim(), (1, 2)); + assert_eq!(validation_set.targets().dim(), 1); + assert!(i < 4); + } + + // if we choose 2 folds then again the last sample will be only + // used for trainig + for (i, (model, validation_set)) in dataset.iter_fold(2, |v| params.fit(&v)).enumerate() { + assert_eq!(model.mock_var, 3); + assert_eq!(validation_set.targets().dim(), 2); + assert!(i < 2); + } + } + + #[test] + #[should_panic] + fn iter_fold_panics_k_0() { + let records = + Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap(); + let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap(); + let mut dataset: Dataset = (records, targets).into(); + let params = MockFittable {}; + let _ = dataset.iter_fold(0, |v| params.fit(&v)).enumerate(); + } + + #[test] + #[should_panic] + fn iter_fold_panics_k_more_than_samples() { + let records = + Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap(); + let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap(); + let mut dataset: Dataset = (records, targets).into(); + let params = MockFittable {}; + let _ = dataset.iter_fold(6, |v| params.fit(&v)).enumerate(); + } } diff --git a/src/traits.rs b/src/traits.rs index 9a1a6ff39..196519967 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -23,7 +23,7 @@ pub trait Transformer { pub trait Fit<'a, R: Records, T: Targets> { type Object: 'a; - fn fit(&self, dataset: &'a DatasetBase) -> Self::Object; + fn fit(&self, dataset: &DatasetBase) -> Self::Object; } /// Incremental algorithms