Skip to content

Commit

Permalink
Make NaN testing more efficient in reduction ops
Browse files Browse the repository at this point in the history
Add an `IsNaN` trait which calls `f32::is_nan` to test for NaN-ness instead of
testing whether `self.partial_cmp(self)` is None.
  • Loading branch information
robertknight committed Nov 16, 2024
1 parent 4b97e9d commit 90a8604
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 34 deletions.
33 changes: 33 additions & 0 deletions src/number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,39 @@ impl Identities for i32 {
}
}

/// Test if a number is a float NaN ("Not a number") value.
pub trait IsNaN {
/// Return true if the current value is a NaN. See [`f32::is_nan`].
///
/// This is always false for integer types.
#[allow(clippy::wrong_self_convention)] // Match `f32::is_nan` etc.
fn is_nan(self) -> bool;
}

macro_rules! impl_isnan_float {
($type:ty) => {
impl IsNaN for $type {
fn is_nan(self) -> bool {
<$type>::is_nan(self)
}
}
};
}
macro_rules! impl_isnan_int {
($type:ty) => {
impl IsNaN for $type {
fn is_nan(self) -> bool {
false
}
}
};
}

impl_isnan_float!(f32);
impl_isnan_int!(i32);
impl_isnan_int!(i8);
impl_isnan_int!(u8);

/// Convert between a primitive type and an array of bytes in little-endian
/// order.
pub trait LeBytes {
Expand Down
9 changes: 6 additions & 3 deletions src/ops/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use rten_tensor::prelude::*;
use rten_tensor::{to_slice_items, NdTensorView, SliceItem, Tensor, TensorView, TensorViewMut};
use smallvec::SmallVec;

use crate::number::IsNaN;
use crate::ops::reduce::{cmp_nan_greater, cmp_nan_less};
use crate::ops::{
resolve_axis, resolve_index, Input, InputList, IntoOpResult, OpError, Operator, OutputList,
Expand Down Expand Up @@ -392,7 +393,9 @@ pub enum ScatterReduction {
Max,
}

fn scatter_reduce<T: Copy + PartialOrd + std::ops::Add<Output = T> + std::ops::Mul<Output = T>>(
fn scatter_reduce<
T: Copy + PartialOrd + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + IsNaN,
>(
current: T,
update: T,
reduction: Option<ScatterReduction>,
Expand All @@ -416,7 +419,7 @@ fn scatter_reduce<T: Copy + PartialOrd + std::ops::Add<Output = T> + std::ops::M
}

pub fn scatter_elements<
T: Copy + Default + PartialOrd + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
T: Copy + Default + PartialOrd + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + IsNaN,
>(
pool: &TensorPool,
data: TensorView<T>,
Expand Down Expand Up @@ -499,7 +502,7 @@ impl Operator for ScatterElements {
}

pub fn scatter_nd<
T: Copy + Default + PartialOrd + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
T: Copy + Default + PartialOrd + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + IsNaN,
>(
pool: &TensorPool,
data: TensorView<T>,
Expand Down
18 changes: 9 additions & 9 deletions src/ops/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::Debug;
use rten_tensor::prelude::*;
use rten_tensor::{MutLayout, NdTensorView, Storage, Tensor, TensorBase, TensorView};

use crate::number::{Identities, IsInt};
use crate::number::{Identities, IsInt, IsNaN};
use crate::ops::OpError;
use crate::ops::{
arg_max, div, matmul, mul, pad, reduce_l2, reduce_max, reduce_mean, reduce_min, reduce_sum,
Expand All @@ -22,7 +22,7 @@ pub trait Operators {

fn arg_max(&self, axis: isize, keep_dims: bool) -> Result<Tensor<i32>, OpError>
where
Self::Elem: Copy + PartialOrd;
Self::Elem: Copy + PartialOrd + IsNaN;

fn div(&self, other: TensorView<Self::Elem>) -> Result<Tensor<Self::Elem>, OpError>
where
Expand All @@ -44,15 +44,15 @@ pub trait Operators {
keep_dims: bool,
) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy + PartialOrd;
Self::Elem: Copy + PartialOrd + IsNaN;

fn reduce_min(
&self,
axes: Option<&[i32]>,
keep_dims: bool,
) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy + PartialOrd;
Self::Elem: Copy + PartialOrd + IsNaN;

fn reduce_sum(
&self,
Expand All @@ -78,7 +78,7 @@ pub trait Operators {
sorted: bool,
) -> Result<(Tensor<Self::Elem>, Tensor<i32>), OpError>
where
Self::Elem: Copy + Default + PartialOrd;
Self::Elem: Copy + Default + PartialOrd + IsNaN;
}

/// Trait which exposes ONNX operators as methods of tensors.
Expand Down Expand Up @@ -112,7 +112,7 @@ impl<T: Send, S: Storage<Elem = T>, L: MutLayout> Operators for TensorBase<S, L>

fn arg_max(&self, axis: isize, keep_dims: bool) -> Result<Tensor<i32>, OpError>
where
T: Copy + PartialOrd,
T: Copy + PartialOrd + IsNaN,
{
let view = self.as_dyn();
use_thread_pool(|| arg_max(&TensorPool::new(), view, axis, keep_dims))
Expand Down Expand Up @@ -142,15 +142,15 @@ impl<T: Send, S: Storage<Elem = T>, L: MutLayout> Operators for TensorBase<S, L>

fn reduce_max(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor<T>, OpError>
where
T: Copy + PartialOrd,
T: Copy + PartialOrd + IsNaN,
{
let view = self.as_dyn();
use_thread_pool(|| reduce_max(&TensorPool::new(), view, axes, keep_dims))
}

fn reduce_min(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor<T>, OpError>
where
T: Copy + PartialOrd,
T: Copy + PartialOrd + IsNaN,
{
let view = self.as_dyn();
use_thread_pool(|| reduce_min(&TensorPool::new(), view, axes, keep_dims))
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<T: Send, S: Storage<Elem = T>, L: MutLayout> Operators for TensorBase<S, L>
sorted: bool,
) -> Result<(Tensor<Self::Elem>, Tensor<i32>), OpError>
where
T: Copy + Default + PartialOrd,
T: Copy + Default + PartialOrd + IsNaN,
{
let view = self.as_dyn();
use_thread_pool(|| topk(&TensorPool::new(), view, k, axis, largest, sorted))
Expand Down
36 changes: 16 additions & 20 deletions src/ops/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use rten_tensor;
use rten_tensor::prelude::*;
use rten_tensor::{DynIndices, NdTensor, NdTensorView, SliceItem, Tensor, TensorView};

use crate::number::Identities;
use crate::number::{Identities, IsNaN};
use crate::ops::layout::squeeze_in_place;
use crate::ops::{
resolve_axes, resolve_axis, Input, InputList, IntoOpResult, OpError, Operator, OutputList,
Expand Down Expand Up @@ -46,15 +46,15 @@ fn select_max_index<T, Cmp: Fn(&T, &T) -> std::cmp::Ordering>(
}

if !input.is_empty() {
for lane in input.lanes(resolved_axis) {
reduced_data.extend(input.lanes(resolved_axis).map(|lane| {
let index = if let Some(slice) = lane.as_slice() {
// Fast path for contiguous lanes.
max_position_by(slice.iter(), &compare)
} else {
max_position_by(lane, &compare)
};
reduced_data.push(index as i32);
}
index as i32
}));
}

let mut reduced = Tensor::<i32>::from_data(&reduced_shape, reduced_data);
Expand Down Expand Up @@ -109,7 +109,7 @@ macro_rules! dispatch_single_axis_reduce_op {
/// Return the index of the maximum value along a given axis.
///
/// NaN values are propagated by treating NaNs as greater than other values.
pub fn arg_max<T: Copy + PartialOrd>(
pub fn arg_max<T: Copy + PartialOrd + IsNaN>(
pool: &TensorPool,
input: TensorView<T>,
axis: isize,
Expand Down Expand Up @@ -138,7 +138,7 @@ impl Operator for ArgMax {
/// Return the index of the minimum value along a given axis.
///
/// NaN values are propagated by treating NaNs as smaller than other values.
pub fn arg_min<T: Copy + PartialOrd>(
pub fn arg_min<T: Copy + PartialOrd + IsNaN>(
pool: &TensorPool,
input: TensorView<T>,
axis: isize,
Expand All @@ -147,7 +147,7 @@ pub fn arg_min<T: Copy + PartialOrd>(
select_max_index(pool, input, axis, keep_dims, |a, b| {
match a.partial_cmp(b) {
Some(ordering) => ordering.reverse(),
None => cmp_nan_greater(a, b),
None => cmp_nan_greater(*a, *b),
}
})
}
Expand Down Expand Up @@ -506,16 +506,12 @@ impl Operator for ReduceL2 {
}
}

fn is_nan<T: PartialOrd>(a: &T) -> bool {
a.partial_cmp(a).is_none()
}

/// Compare `a` and `b`, treating all NaN values as greater than non-NaN values.
pub fn cmp_nan_greater<T: PartialOrd>(a: T, b: T) -> std::cmp::Ordering {
pub fn cmp_nan_greater<T: PartialOrd + IsNaN>(a: T, b: T) -> std::cmp::Ordering {
match a.partial_cmp(&b) {
Some(ordering) => ordering,
None => {
if is_nan(&a) {
if a.is_nan() {
std::cmp::Ordering::Greater
} else {
std::cmp::Ordering::Less
Expand All @@ -525,11 +521,11 @@ pub fn cmp_nan_greater<T: PartialOrd>(a: T, b: T) -> std::cmp::Ordering {
}

/// Compare `a` and `b`, treating all NaN values as less than non-NaN values.
pub fn cmp_nan_less<T: PartialOrd>(a: T, b: T) -> std::cmp::Ordering {
pub fn cmp_nan_less<T: PartialOrd + IsNaN>(a: T, b: T) -> std::cmp::Ordering {
match a.partial_cmp(&b) {
Some(ordering) => ordering,
None => {
if is_nan(&a) {
if a.is_nan() {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
Expand All @@ -538,7 +534,7 @@ pub fn cmp_nan_less<T: PartialOrd>(a: T, b: T) -> std::cmp::Ordering {
}
}

fn reduce_min_max<T: Copy + PartialOrd>(
fn reduce_min_max<T: Copy + PartialOrd + IsNaN>(
pool: &TensorPool,
input: TensorView<T>,
axes: Option<&[i32]>,
Expand All @@ -548,7 +544,7 @@ fn reduce_min_max<T: Copy + PartialOrd>(
struct MinMaxReducer {
max: bool,
}
impl<T: Copy + PartialOrd> Reducer<T> for MinMaxReducer {
impl<T: Copy + PartialOrd + IsNaN> Reducer<T> for MinMaxReducer {
fn reduce<I: ExactSizeIterator<Item = T>>(&self, iter: I) -> T {
let reduced = if self.max {
iter.max_by(|a, b| cmp_nan_greater(*a, *b))
Expand Down Expand Up @@ -576,7 +572,7 @@ fn get_axes<'a>(
Ok(axes)
}

pub fn reduce_min<T: Copy + PartialOrd>(
pub fn reduce_min<T: Copy + PartialOrd + IsNaN>(
pool: &TensorPool,
input: TensorView<T>,
axes: Option<&[i32]>,
Expand All @@ -603,7 +599,7 @@ impl Operator for ReduceMin {
}
}

pub fn reduce_max<T: Copy + PartialOrd>(
pub fn reduce_max<T: Copy + PartialOrd + IsNaN>(
pool: &TensorPool,
input: TensorView<T>,
axes: Option<&[i32]>,
Expand Down Expand Up @@ -729,7 +725,7 @@ impl Operator for ReduceSumSquare {
}
}

pub fn topk<T: Copy + Default + PartialOrd>(
pub fn topk<T: Copy + Default + PartialOrd + IsNaN>(
pool: &TensorPool,
values: TensorView<T>,
k: usize,
Expand Down
5 changes: 3 additions & 2 deletions src/ops/variadic_elementwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::cmp::Ordering;
use rten_tensor::prelude::*;
use rten_tensor::{Tensor, TensorView};

use crate::number::IsNaN;
use crate::ops::binary_elementwise::binary_op;
use crate::ops::reduce::{cmp_nan_greater, cmp_nan_less};
use crate::ops::{Input, InputList, IntoOpResult, OpError, Operator, OutputList};
Expand Down Expand Up @@ -42,7 +43,7 @@ where
})
}

pub fn max<T: Copy + PartialOrd>(
pub fn max<T: Copy + PartialOrd + IsNaN>(
pool: &TensorPool,
inputs: &[TensorView<T>],
) -> Result<Tensor<T>, OpError> {
Expand Down Expand Up @@ -102,7 +103,7 @@ impl Operator for Mean {
}
}

pub fn min<T: Copy + PartialOrd>(
pub fn min<T: Copy + PartialOrd + IsNaN>(
pool: &TensorPool,
inputs: &[TensorView<T>],
) -> Result<Tensor<T>, OpError> {
Expand Down

0 comments on commit 90a8604

Please sign in to comment.