Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing slice length checks for vec_shift_scale_in_place, better docs for other vectorized functions #467

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions rten-vecmath/src/shift_scale.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ impl<'a> SimdOp for SimdShiftScale<'a> {
const_scale,
} = self;

assert_eq!(scale.len(), data.len());
if let Some(bias) = bias {
assert_eq!(bias.len(), data.len());
}

let mut out_ptr = data.as_mut_ptr();
let mut scale_ptr = scale.as_ptr();
let mut bias_ptr = bias.map(|b| b.as_ptr());
Expand Down Expand Up @@ -54,16 +59,20 @@ impl<'a> SimdOp for SimdShiftScale<'a> {

/// Shift and scale each element in the input.
///
/// This scales and shifts each element using `y[i] = y[i] * const_scale *
/// scale[i] + bias[i]`.
/// This updates each element in `xs` according to the formula
/// `xs[i] = xs[i] * const_scale * scale[i] + bias[i]`.
///
/// # Panics
///
/// Panics if the length of `scale` or `bias` does not match `xs`.
pub fn vec_shift_scale_in_place(
data: &mut [f32],
xs: &mut [f32],
const_scale: f32,
scale: &[f32],
bias: Option<&[f32]>,
) {
let simd_op = SimdShiftScale {
data,
data: xs,
bias,
scale,
const_scale,
Expand Down
14 changes: 12 additions & 2 deletions rten-vecmath/src/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ impl SimdOp for SimdSum<'_> {
}
}

/// Return the sum of a slice of floats.
/// Compute the sum of a slice of floats.
///
/// This is more efficient than `xs.iter().sum()` as it computes multiple
/// partial sums in parallel using SIMD and then sums across the SIMD lanes at
/// the end. This will produce very slightly different results because the
/// additions are happening in a different order.
pub fn vec_sum(xs: &[f32]) -> f32 {
let op = SimdSum { input: xs };
dispatch(op)
Expand All @@ -48,7 +53,12 @@ impl SimdOp for SimdSumSquare<'_> {
}
}

/// Return the sum of the squares of elements in `xs`.
/// Compute the sum of the squares of elements in `xs`.
///
/// Conceptually this is like `xs.iter().map(|&x| x * x).sum()` but more
/// efficient as it computes multiple partial sums in parallel and then sums
/// across SIMD lanes at the end. The results will also be slightly different
/// because the additions are happening in a different order.
pub fn vec_sum_square(xs: &[f32]) -> f32 {
let op = SimdSumSquare { input: xs };
dispatch(op)
Expand Down
2 changes: 2 additions & 0 deletions rten-vecmath/src/tanh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use rten_simd::SimdFloat;

use crate::exp::simd_exp;

/// Compute `x.tanh()` using the same algorithm as [`vec_tanh`].
pub fn tanh(x: f32) -> f32 {
unsafe { simd_tanh(x) }
}
Expand Down Expand Up @@ -81,6 +82,7 @@ pub fn vec_tanh(xs: &[f32], out: &mut [MaybeUninit<f32>]) {
dispatch_map_op(xs, out, SimdTanh {});
}

/// Variant of [`vec_tanh`] which modifies elements in-place.
pub fn vec_tanh_in_place(xs: &mut [f32]) {
dispatch_map_op_in_place(xs, SimdTanh {});
}
Expand Down
Loading