Skip to content

Commit

Permalink
Add dot product for i16. (#136)
Browse files Browse the repository at this point in the history
* add reduce_add

* fix unit test

* added comments to reduce_add

* add i16 dot product instructions

* fix comment

* fix comment

* fix other cpus

* fix avx

* temporarily disable nightly MIPS tests and record issue

* fix test from merge
  • Loading branch information
mcroomp authored Oct 1, 2023
1 parent 1fd7657 commit aba67bf
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/i16x16_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,21 @@ impl i16x16 {
}
}

/// Calculates partial dot product.
/// Multiplies packed signed 16-bit integers, producing intermediate signed 32-bit integers. Horizontally add adjacent pairs of intermediate 32-bit integers.
pub fn dot(self, rhs: Self) -> i32x8 {
pick! {
if #[cfg(target_feature="avx2")] {
i32x8 { avx2: mul_i16_horizontal_add_m256i(self.avx2, rhs.avx2) }
} else {
i32x8 {
a : self.a.dot(rhs.a),
b : self.b.dot(rhs.b),
}
}
}
}

/// Multiply and scale equivilent to ((self * rhs) + 0x4000) >> 15 on each
/// lane, effectively multiplying by a 16 bit fixed point number between -1
/// and 1. This corresponds to the following instructions:
Expand Down
25 changes: 25 additions & 0 deletions src/i16x8_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,31 @@ impl i16x8 {
}
}

/// Calculates partial dot product.
/// Multiplies packed signed 16-bit integers, producing intermediate signed 32-bit integers. Horizontally add adjacent pairs of intermediate 32-bit integers.
pub fn dot(self, rhs: Self) -> i32x4 {
pick! {
if #[cfg(target_feature="sse2")] {
i32x4 { sse: mul_i16_horizontal_add_m128i(self.sse, rhs.sse) }
} else if #[cfg(target_feature="simd128")] {
i32x4 { simd: i32x4_dot_i16x8(self.simd, rhs.simd) }
} else if #[cfg(all(target_feature="neon",target_arch="aarch64"))]{
unsafe {
let pl = vmull_s16(vget_low_s16(self.neon), vget_low_s16(rhs.neon));
let ph = vmull_high_s16(self.neon, rhs.neon);
i32x4 { neon: vpaddq_s32(pl, ph) }
}
} else {
i32x4 { arr: [
(i32::from(self.arr[0]) * i32::from(rhs.arr[0])) + (i32::from(self.arr[1]) * i32::from(rhs.arr[1])),
(i32::from(self.arr[2]) * i32::from(rhs.arr[2])) + (i32::from(self.arr[3]) * i32::from(rhs.arr[3])),
(i32::from(self.arr[4]) * i32::from(rhs.arr[4])) + (i32::from(self.arr[5]) * i32::from(rhs.arr[5])),
(i32::from(self.arr[6]) * i32::from(rhs.arr[6])) + (i32::from(self.arr[7]) * i32::from(rhs.arr[7])),
] }
}
}
}

/// Multiply and scale equivilent to ((self * rhs) + 0x4000) >> 15 on each
/// lane, effectively multiplying by a 16 bit fixed point number between -1
/// and 1. This corresponds to the following instructions:
Expand Down
29 changes: 29 additions & 0 deletions tests/all_tests/t_i16x16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,35 @@ fn impl_i16x16_reduce_add() {
assert_eq!(p.reduce_add(), 407);
}

#[test]
fn impl_dot_for_i16x16() {
let a = i16x16::from([
1,
2,
3,
4,
5,
6,
i16::MIN + 1,
i16::MIN,
10,
20,
30,
40,
50,
60,
i16::MAX - 1,
i16::MAX,
]);
let b = i16x16::from([
17, -18, 190, -20, 21, -22, 3, 2, 170, -180, 1900, -200, 210, -220, 30, 20,
]);
let expected =
i32x8::from([-19, 490, -27, -163837, -1900, 49000, -2700, 1638320]);
let actual = a.dot(b);
assert_eq!(expected, actual);
}

#[test]
fn impl_i16x16_reduce_min() {
for i in 0..8 {
Expand Down
9 changes: 9 additions & 0 deletions tests/all_tests/t_i16x8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,15 @@ fn impl_i16x8_reduce_add() {
assert_eq!(p.reduce_add(), 37);
}

#[test]
fn impl_dot_for_i16x8() {
let a = i16x8::from([1, 2, 3, 4, 5, 6, i16::MIN + 1, i16::MIN]);
let b = i16x8::from([17, -18, 190, -20, 21, -22, 3, 2]);
let expected = i32x4::from([-19, 490, -27, -163837]);
let actual = a.dot(b);
assert_eq!(expected, actual);
}

#[test]
fn impl_i16x8_reduce_min() {
for i in 0..8 {
Expand Down

0 comments on commit aba67bf

Please sign in to comment.