diff --git a/src/i16x16_.rs b/src/i16x16_.rs index 54126bc7..f582ce4a 100644 --- a/src/i16x16_.rs +++ b/src/i16x16_.rs @@ -485,6 +485,21 @@ impl i16x16 { } } + /// Calculates partial dot product. + /// Multiplies packed signed 16-bit integers, producing intermediate signed 32-bit integers. Horizontally add adjacent pairs of intermediate 32-bit integers. + pub fn dot(self, rhs: Self) -> i32x8 { + pick! { + if #[cfg(target_feature="avx2")] { + i32x8 { avx2: mul_i16_horizontal_add_m256i(self.avx2, rhs.avx2) } + } else { + i32x8 { + a : self.a.dot(rhs.a), + b : self.b.dot(rhs.b), + } + } + } + } + /// Multiply and scale equivilent to ((self * rhs) + 0x4000) >> 15 on each /// lane, effectively multiplying by a 16 bit fixed point number between -1 /// and 1. This corresponds to the following instructions: diff --git a/src/i16x8_.rs b/src/i16x8_.rs index 816db1e3..9fc278ff 100644 --- a/src/i16x8_.rs +++ b/src/i16x8_.rs @@ -780,6 +780,31 @@ impl i16x8 { } } + /// Calculates partial dot product. + /// Multiplies packed signed 16-bit integers, producing intermediate signed 32-bit integers. Horizontally add adjacent pairs of intermediate 32-bit integers. + pub fn dot(self, rhs: Self) -> i32x4 { + pick! { + if #[cfg(target_feature="sse2")] { + i32x4 { sse: mul_i16_horizontal_add_m128i(self.sse, rhs.sse) } + } else if #[cfg(target_feature="simd128")] { + i32x4 { simd: i32x4_dot_i16x8(self.simd, rhs.simd) } + } else if #[cfg(all(target_feature="neon",target_arch="aarch64"))]{ + unsafe { + let pl = vmull_s16(vget_low_s16(self.neon), vget_low_s16(rhs.neon)); + let ph = vmull_high_s16(self.neon, rhs.neon); + i32x4 { neon: vpaddq_s32(pl, ph) } + } + } else { + i32x4 { arr: [ + (i32::from(self.arr[0]) * i32::from(rhs.arr[0])) + (i32::from(self.arr[1]) * i32::from(rhs.arr[1])), + (i32::from(self.arr[2]) * i32::from(rhs.arr[2])) + (i32::from(self.arr[3]) * i32::from(rhs.arr[3])), + (i32::from(self.arr[4]) * i32::from(rhs.arr[4])) + (i32::from(self.arr[5]) * i32::from(rhs.arr[5])), + (i32::from(self.arr[6]) * i32::from(rhs.arr[6])) + (i32::from(self.arr[7]) * i32::from(rhs.arr[7])), + ] } + } + } + } + /// Multiply and scale equivilent to ((self * rhs) + 0x4000) >> 15 on each /// lane, effectively multiplying by a 16 bit fixed point number between -1 /// and 1. This corresponds to the following instructions: diff --git a/tests/all_tests/t_i16x16.rs b/tests/all_tests/t_i16x16.rs index 27ad6c71..186e4963 100644 --- a/tests/all_tests/t_i16x16.rs +++ b/tests/all_tests/t_i16x16.rs @@ -648,6 +648,35 @@ fn impl_i16x16_reduce_add() { assert_eq!(p.reduce_add(), 407); } +#[test] +fn impl_dot_for_i16x16() { + let a = i16x16::from([ + 1, + 2, + 3, + 4, + 5, + 6, + i16::MIN + 1, + i16::MIN, + 10, + 20, + 30, + 40, + 50, + 60, + i16::MAX - 1, + i16::MAX, + ]); + let b = i16x16::from([ + 17, -18, 190, -20, 21, -22, 3, 2, 170, -180, 1900, -200, 210, -220, 30, 20, + ]); + let expected = + i32x8::from([-19, 490, -27, -163837, -1900, 49000, -2700, 1638320]); + let actual = a.dot(b); + assert_eq!(expected, actual); +} + #[test] fn impl_i16x16_reduce_min() { for i in 0..8 { diff --git a/tests/all_tests/t_i16x8.rs b/tests/all_tests/t_i16x8.rs index c4222300..132cff11 100644 --- a/tests/all_tests/t_i16x8.rs +++ b/tests/all_tests/t_i16x8.rs @@ -317,6 +317,15 @@ fn impl_i16x8_reduce_add() { assert_eq!(p.reduce_add(), 37); } +#[test] +fn impl_dot_for_i16x8() { + let a = i16x8::from([1, 2, 3, 4, 5, 6, i16::MIN + 1, i16::MIN]); + let b = i16x8::from([17, -18, 190, -20, 21, -22, 3, 2]); + let expected = i32x4::from([-19, 490, -27, -163837]); + let actual = a.dot(b); + assert_eq!(expected, actual); +} + #[test] fn impl_i16x8_reduce_min() { for i in 0..8 {