diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index cff10c863..89a927b18 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -7,7 +7,7 @@ // except according to those terms. use std::ops::{Add, Div, Mul}; -use libnum::{self, One, Zero, Float, FromPrimitive}; +use libnum::{self, Zero, Float, FromPrimitive}; use itertools::free::enumerate; use imp_prelude::*; @@ -123,8 +123,9 @@ impl ArrayBase /// Return mean along `axis`. /// - /// **Panics** if `axis` is out of bounds or if the length of the axis is - /// zero and division by zero panics for type `A`. + /// **Panics** if `axis` is out of bounds, if the length of the axis is + /// zero and division by zero panics for type `A`, or if `A::from_usize()` + /// fails for the axis length. /// /// ``` /// use ndarray::{aview1, arr2, Axis}; @@ -137,16 +138,12 @@ impl ArrayBase /// ); /// ``` pub fn mean_axis(&self, axis: Axis) -> Array - where A: Clone + Zero + One + Add + Div, + where A: Clone + Zero + FromPrimitive + Add + Div, D: RemoveAxis, { - let n = self.len_of(axis); + let n = A::from_usize(self.len_of(axis)).expect("Converting axis length to `A` must not fail."); let sum = self.sum_axis(axis); - let mut cnt = A::zero(); - for _ in 0..n { - cnt = cnt + A::one(); - } - sum / &aview0(&cnt) + sum / &aview0(&n) } /// Return variance along `axis`.