diff --git a/src/gamma.rs b/src/gamma.rs index ea06539e..64dc1612 100644 --- a/src/gamma.rs +++ b/src/gamma.rs @@ -7,6 +7,9 @@ pub trait Gamma where Self: Sized, { + /// Compute the gamma function. + fn gamma(self) -> Self; + /// Compute the real-valued digamma function. /// /// The formula is as follows: @@ -35,8 +38,12 @@ where /// inference. University of London, 2003, pp. 265–266. fn digamma(self) -> Self; - /// Compute the gamma function. - fn gamma(self) -> Self; + /// Compute the trigamma function. + /// + /// The code is based on a [Julia implementation][1]. + /// + /// [1]: https://github.com/JuliaMath/SpecialFunctions.jl + fn trigamma(&self) -> Self; /// Compute the regularized lower incomplete gamma function. /// @@ -61,13 +68,6 @@ where /// Compute the natural logarithm of the gamma function. fn ln_gamma(self) -> (Self, i32); - - /// Compute the trigamma function. - /// - /// The code is based on a [Julia implementation][1]. - /// - /// [1]: https://github.com/JuliaMath/SpecialFunctions.jl - fn trigamma(&self) -> Self; } macro_rules! evaluate_polynomial( @@ -78,6 +78,11 @@ macro_rules! evaluate_polynomial( #[rustfmt::skip] macro_rules! implement { ($kind:ty) => { impl Gamma for $kind { + #[inline] + fn gamma(self) -> Self { + self.tgamma() + } + fn digamma(self) -> Self { let p = self; if p <= 8.0 { @@ -102,9 +107,40 @@ macro_rules! implement { ($kind:ty) => { impl Gamma for $kind { ) } - #[inline] - fn gamma(self) -> Self { - self.tgamma() + fn trigamma(&self) -> Self { + let mut x: $kind = *self; + if x <= 0.0 { + return (<$kind>::PI * (<$kind>::PI * x).sin().recip()).powi(2) + - (1.0 - x).trigamma(); + } + + let mut psi: $kind = 0.0; + if x < 8.0 { + let n = (8.0 - x.floor()) as usize; + psi += x.recip().powi(2); + for v in 1..n { + psi += (x + (v as $kind)).recip().powi(2); + } + x += n as $kind; + } + let t = x.recip(); + let w = t * t; + psi += t + 0.5 * w; + psi + t + * w + * evaluate_polynomial!( + w, + [ + 0.16666666666666666, + -0.03333333333333333, + 0.023809523809523808, + -0.03333333333333333, + 0.07575757575757576, + -0.2531135531135531, + 1.1666666666666667, + -7.092156862745098, + ] + ) } fn inc_gamma(self, p: Self) -> Self { @@ -205,42 +241,6 @@ macro_rules! implement { ($kind:ty) => { impl Gamma for $kind { fn ln_gamma(self) -> (Self, i32) { self.lgamma() } - - fn trigamma(&self) -> Self { - let mut x: $kind = *self; - if x <= 0.0 { - return (<$kind>::PI * (<$kind>::PI * x).sin().recip()).powi(2) - - (1.0 - x).trigamma(); - } - - let mut psi: $kind = 0.0; - if x < 8.0 { - let n = (8.0 - x.floor()) as usize; - psi += x.recip().powi(2); - for v in 1..n { - psi += (x + (v as $kind)).recip().powi(2); - } - x += n as $kind; - } - let t = x.recip(); - let w = t * t; - psi += t + 0.5 * w; - psi + t - * w - * evaluate_polynomial!( - w, - [ - 0.16666666666666666, - -0.03333333333333333, - 0.023809523809523808, - -0.03333333333333333, - 0.07575757575757576, - -0.2531135531135531, - 1.1666666666666667, - -7.092156862745098, - ] - ) - } }}} implement!(f32); @@ -263,6 +263,51 @@ mod tests { assert_eq!(-FRAC_PI_2 - 3.0 * LN_2 - EULER_MASCHERONI, 0.25.digamma()); } + #[test] + fn trigamma() { + #[cfg(feature = "no_std")] + use core::f64::consts::PI; + #[cfg(not(feature = "no_std"))] + use std::f64::consts::PI; + let x = vec![ + 0.1, + 0.5, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 10.0, + -PI, + -2.0 * PI, + -3.0 * PI, + ]; + let y = vec![ + 101.43329915079276, + 4.93480220054468, + 1.6449340668482262, + 0.6449340668482261, + 0.39493406684822613, + 0.28382295573711497, + 0.221322955737115, + 0.18132295573711496, + 0.1535451779593372, + 0.13313701469403108, + 0.11751201469403139, + 0.10516633568168575, + 53.030438740085536, + 16.206759250472963, + 10.341296000533267, + ]; + + let z = x.iter().map(|&x| x.trigamma()).collect::>(); + assert::close(&z, &y, 1e-12); + } + #[test] fn inc_gamma_small_p() { let p = 4.2; @@ -330,49 +375,4 @@ mod tests { let z = x.iter().map(|&x| x.inc_gamma(p)).collect::>(); assert::close(&z, &y, 1e-12); } - - #[test] - fn trigamma() { - #[cfg(feature = "no_std")] - use core::f64::consts::PI; - #[cfg(not(feature = "no_std"))] - use std::f64::consts::PI; - let x = vec![ - 0.1, - 0.5, - 1.0, - 2.0, - 3.0, - 4.0, - 5.0, - 6.0, - 7.0, - 8.0, - 9.0, - 10.0, - -PI, - -2.0 * PI, - -3.0 * PI, - ]; - let y = vec![ - 101.43329915079276, - 4.93480220054468, - 1.6449340668482262, - 0.6449340668482261, - 0.39493406684822613, - 0.28382295573711497, - 0.221322955737115, - 0.18132295573711496, - 0.1535451779593372, - 0.13313701469403108, - 0.11751201469403139, - 0.10516633568168575, - 53.030438740085536, - 16.206759250472963, - 10.341296000533267, - ]; - - let z = x.iter().map(|&x| x.trigamma()).collect::>(); - assert::close(&z, &y, 1e-12); - } }