diff --git a/src/number.rs b/src/number.rs index c39c0334..04c24706 100644 --- a/src/number.rs +++ b/src/number.rs @@ -62,6 +62,39 @@ impl Identities for i32 { } } +/// Test if a number is a float NaN ("Not a number") value. +pub trait IsNaN { + /// Return true if the current value is a NaN. See [`f32::is_nan`]. + /// + /// This is always false for integer types. + #[allow(clippy::wrong_self_convention)] // Match `f32::is_nan` etc. + fn is_nan(self) -> bool; +} + +macro_rules! impl_isnan_float { + ($type:ty) => { + impl IsNaN for $type { + fn is_nan(self) -> bool { + <$type>::is_nan(self) + } + } + }; +} +macro_rules! impl_isnan_int { + ($type:ty) => { + impl IsNaN for $type { + fn is_nan(self) -> bool { + false + } + } + }; +} + +impl_isnan_float!(f32); +impl_isnan_int!(i32); +impl_isnan_int!(i8); +impl_isnan_int!(u8); + /// Convert between a primitive type and an array of bytes in little-endian /// order. pub trait LeBytes { diff --git a/src/ops/gather.rs b/src/ops/gather.rs index 6e414429..aa25956d 100644 --- a/src/ops/gather.rs +++ b/src/ops/gather.rs @@ -2,6 +2,7 @@ use rten_tensor::prelude::*; use rten_tensor::{to_slice_items, NdTensorView, SliceItem, Tensor, TensorView, TensorViewMut}; use smallvec::SmallVec; +use crate::number::IsNaN; use crate::ops::reduce::{cmp_nan_greater, cmp_nan_less}; use crate::ops::{ resolve_axis, resolve_index, Input, InputList, IntoOpResult, OpError, Operator, OutputList, @@ -392,7 +393,9 @@ pub enum ScatterReduction { Max, } -fn scatter_reduce + std::ops::Mul>( +fn scatter_reduce< + T: Copy + PartialOrd + std::ops::Add + std::ops::Mul + IsNaN, +>( current: T, update: T, reduction: Option, @@ -416,7 +419,7 @@ fn scatter_reduce + std::ops::M } pub fn scatter_elements< - T: Copy + Default + PartialOrd + std::ops::Add + std::ops::Mul, + T: Copy + Default + PartialOrd + std::ops::Add + std::ops::Mul + IsNaN, >( pool: &TensorPool, data: TensorView, @@ -499,7 +502,7 @@ impl Operator for ScatterElements { } pub fn scatter_nd< - T: Copy + Default + PartialOrd + std::ops::Add + std::ops::Mul, + T: Copy + Default + PartialOrd + std::ops::Add + std::ops::Mul + IsNaN, >( pool: &TensorPool, data: TensorView, diff --git a/src/ops/operators.rs b/src/ops/operators.rs index 11ce19a1..2631076c 100644 --- a/src/ops/operators.rs +++ b/src/ops/operators.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use rten_tensor::prelude::*; use rten_tensor::{MutLayout, NdTensorView, Storage, Tensor, TensorBase, TensorView}; -use crate::number::{Identities, IsInt}; +use crate::number::{Identities, IsInt, IsNaN}; use crate::ops::OpError; use crate::ops::{ arg_max, div, matmul, mul, pad, reduce_l2, reduce_max, reduce_mean, reduce_min, reduce_sum, @@ -22,7 +22,7 @@ pub trait Operators { fn arg_max(&self, axis: isize, keep_dims: bool) -> Result, OpError> where - Self::Elem: Copy + PartialOrd; + Self::Elem: Copy + PartialOrd + IsNaN; fn div(&self, other: TensorView) -> Result, OpError> where @@ -44,7 +44,7 @@ pub trait Operators { keep_dims: bool, ) -> Result, OpError> where - Self::Elem: Copy + PartialOrd; + Self::Elem: Copy + PartialOrd + IsNaN; fn reduce_min( &self, @@ -52,7 +52,7 @@ pub trait Operators { keep_dims: bool, ) -> Result, OpError> where - Self::Elem: Copy + PartialOrd; + Self::Elem: Copy + PartialOrd + IsNaN; fn reduce_sum( &self, @@ -78,7 +78,7 @@ pub trait Operators { sorted: bool, ) -> Result<(Tensor, Tensor), OpError> where - Self::Elem: Copy + Default + PartialOrd; + Self::Elem: Copy + Default + PartialOrd + IsNaN; } /// Trait which exposes ONNX operators as methods of tensors. @@ -112,7 +112,7 @@ impl, L: MutLayout> Operators for TensorBase fn arg_max(&self, axis: isize, keep_dims: bool) -> Result, OpError> where - T: Copy + PartialOrd, + T: Copy + PartialOrd + IsNaN, { let view = self.as_dyn(); use_thread_pool(|| arg_max(&TensorPool::new(), view, axis, keep_dims)) @@ -142,7 +142,7 @@ impl, L: MutLayout> Operators for TensorBase fn reduce_max(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result, OpError> where - T: Copy + PartialOrd, + T: Copy + PartialOrd + IsNaN, { let view = self.as_dyn(); use_thread_pool(|| reduce_max(&TensorPool::new(), view, axes, keep_dims)) @@ -150,7 +150,7 @@ impl, L: MutLayout> Operators for TensorBase fn reduce_min(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result, OpError> where - T: Copy + PartialOrd, + T: Copy + PartialOrd + IsNaN, { let view = self.as_dyn(); use_thread_pool(|| reduce_min(&TensorPool::new(), view, axes, keep_dims)) @@ -184,7 +184,7 @@ impl, L: MutLayout> Operators for TensorBase sorted: bool, ) -> Result<(Tensor, Tensor), OpError> where - T: Copy + Default + PartialOrd, + T: Copy + Default + PartialOrd + IsNaN, { let view = self.as_dyn(); use_thread_pool(|| topk(&TensorPool::new(), view, k, axis, largest, sorted)) diff --git a/src/ops/reduce.rs b/src/ops/reduce.rs index a56ee825..6eafacf8 100644 --- a/src/ops/reduce.rs +++ b/src/ops/reduce.rs @@ -5,7 +5,7 @@ use rten_tensor; use rten_tensor::prelude::*; use rten_tensor::{DynIndices, NdTensor, NdTensorView, SliceItem, Tensor, TensorView}; -use crate::number::Identities; +use crate::number::{Identities, IsNaN}; use crate::ops::layout::squeeze_in_place; use crate::ops::{ resolve_axes, resolve_axis, Input, InputList, IntoOpResult, OpError, Operator, OutputList, @@ -46,15 +46,15 @@ fn select_max_index std::cmp::Ordering>( } if !input.is_empty() { - for lane in input.lanes(resolved_axis) { + reduced_data.extend(input.lanes(resolved_axis).map(|lane| { let index = if let Some(slice) = lane.as_slice() { // Fast path for contiguous lanes. max_position_by(slice.iter(), &compare) } else { max_position_by(lane, &compare) }; - reduced_data.push(index as i32); - } + index as i32 + })); } let mut reduced = Tensor::::from_data(&reduced_shape, reduced_data); @@ -109,7 +109,7 @@ macro_rules! dispatch_single_axis_reduce_op { /// Return the index of the maximum value along a given axis. /// /// NaN values are propagated by treating NaNs as greater than other values. -pub fn arg_max( +pub fn arg_max( pool: &TensorPool, input: TensorView, axis: isize, @@ -138,7 +138,7 @@ impl Operator for ArgMax { /// Return the index of the minimum value along a given axis. /// /// NaN values are propagated by treating NaNs as smaller than other values. -pub fn arg_min( +pub fn arg_min( pool: &TensorPool, input: TensorView, axis: isize, @@ -147,7 +147,7 @@ pub fn arg_min( select_max_index(pool, input, axis, keep_dims, |a, b| { match a.partial_cmp(b) { Some(ordering) => ordering.reverse(), - None => cmp_nan_greater(a, b), + None => cmp_nan_greater(*a, *b), } }) } @@ -506,16 +506,12 @@ impl Operator for ReduceL2 { } } -fn is_nan(a: &T) -> bool { - a.partial_cmp(a).is_none() -} - /// Compare `a` and `b`, treating all NaN values as greater than non-NaN values. -pub fn cmp_nan_greater(a: T, b: T) -> std::cmp::Ordering { +pub fn cmp_nan_greater(a: T, b: T) -> std::cmp::Ordering { match a.partial_cmp(&b) { Some(ordering) => ordering, None => { - if is_nan(&a) { + if a.is_nan() { std::cmp::Ordering::Greater } else { std::cmp::Ordering::Less @@ -525,11 +521,11 @@ pub fn cmp_nan_greater(a: T, b: T) -> std::cmp::Ordering { } /// Compare `a` and `b`, treating all NaN values as less than non-NaN values. -pub fn cmp_nan_less(a: T, b: T) -> std::cmp::Ordering { +pub fn cmp_nan_less(a: T, b: T) -> std::cmp::Ordering { match a.partial_cmp(&b) { Some(ordering) => ordering, None => { - if is_nan(&a) { + if a.is_nan() { std::cmp::Ordering::Less } else { std::cmp::Ordering::Greater @@ -538,7 +534,7 @@ pub fn cmp_nan_less(a: T, b: T) -> std::cmp::Ordering { } } -fn reduce_min_max( +fn reduce_min_max( pool: &TensorPool, input: TensorView, axes: Option<&[i32]>, @@ -548,7 +544,7 @@ fn reduce_min_max( struct MinMaxReducer { max: bool, } - impl Reducer for MinMaxReducer { + impl Reducer for MinMaxReducer { fn reduce>(&self, iter: I) -> T { let reduced = if self.max { iter.max_by(|a, b| cmp_nan_greater(*a, *b)) @@ -576,7 +572,7 @@ fn get_axes<'a>( Ok(axes) } -pub fn reduce_min( +pub fn reduce_min( pool: &TensorPool, input: TensorView, axes: Option<&[i32]>, @@ -603,7 +599,7 @@ impl Operator for ReduceMin { } } -pub fn reduce_max( +pub fn reduce_max( pool: &TensorPool, input: TensorView, axes: Option<&[i32]>, @@ -729,7 +725,7 @@ impl Operator for ReduceSumSquare { } } -pub fn topk( +pub fn topk( pool: &TensorPool, values: TensorView, k: usize, diff --git a/src/ops/variadic_elementwise.rs b/src/ops/variadic_elementwise.rs index dfb8e6e6..87af73ff 100644 --- a/src/ops/variadic_elementwise.rs +++ b/src/ops/variadic_elementwise.rs @@ -3,6 +3,7 @@ use std::cmp::Ordering; use rten_tensor::prelude::*; use rten_tensor::{Tensor, TensorView}; +use crate::number::IsNaN; use crate::ops::binary_elementwise::binary_op; use crate::ops::reduce::{cmp_nan_greater, cmp_nan_less}; use crate::ops::{Input, InputList, IntoOpResult, OpError, Operator, OutputList}; @@ -42,7 +43,7 @@ where }) } -pub fn max( +pub fn max( pool: &TensorPool, inputs: &[TensorView], ) -> Result, OpError> { @@ -102,7 +103,7 @@ impl Operator for Mean { } } -pub fn min( +pub fn min( pool: &TensorPool, inputs: &[TensorView], ) -> Result, OpError> {