Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* avx: _mm256_testnzc_si256

* avx: _mm256_shuffle_ps

8 levels of macro expansion takes too long to compile.

* avx: remove useless 0 in tests

* avx: _mm256_shuffle_ps

Macro expansion can be reduced to four levels

* avx: _mm256_blend_ps

Copy/paste from avx2::_mm256_blend_epi32
  • Loading branch information
gwenn authored and alexcrichton committed Nov 1, 2017
1 parent 5a4a1f4 commit 46d64f0
Showing 1 changed file with 166 additions and 19 deletions.
185 changes: 166 additions & 19 deletions src/x86/avx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,56 @@ pub unsafe fn _mm256_shuffle_pd(a: f64x4, b: f64x4, imm8: i32) -> f64x4 {
}
}

/// Shuffle single-precision (32-bit) floating-point elements in `a` within
/// 128-bit lanes using the control in `imm8`.
#[inline(always)]
#[target_feature = "+avx"]
#[cfg_attr(test, assert_instr(vshufps, imm8 = 0x0))]
pub unsafe fn _mm256_shuffle_ps(a: f32x8, b: f32x8, imm8: i32) -> f32x8 {
let imm8 = (imm8 & 0xFF) as u8;
macro_rules! shuffle4 {
($a:expr, $b:expr, $c:expr, $d:expr, $e:expr, $f:expr, $g: expr, $h: expr) => {
simd_shuffle8(a, b, [$a, $b, $c, $d, $e, $f, $g, $h]);
}
}
macro_rules! shuffle3 {
($a:expr, $b: expr, $c: expr, $e:expr, $f:expr, $g:expr) => {
match (imm8 >> 6) & 0x3 {
0 => shuffle4!($a, $b, $c, 8, $e, $f, $g, 12),
1 => shuffle4!($a, $b, $c, 9, $e, $f, $g, 13),
2 => shuffle4!($a, $b, $c, 10, $e, $f, $g, 14),
_ => shuffle4!($a, $b, $c, 11, $e, $f, $g, 15),
}
}
}
macro_rules! shuffle2 {
($a:expr, $b:expr, $e:expr, $f:expr) => {
match (imm8 >> 4) & 0x3 {
0 => shuffle3!($a, $b, 8, $e, $f, 12),
1 => shuffle3!($a, $b, 9, $e, $f, 13),
2 => shuffle3!($a, $b, 10, $e, $f, 14),
_ => shuffle3!($a, $b, 11, $e, $f, 15),
}
}
}
macro_rules! shuffle1 {
($a:expr, $e:expr) => {
match (imm8 >> 2) & 0x3 {
0 => shuffle2!($a, 0, $e, 4),
1 => shuffle2!($a, 1, $e, 5),
2 => shuffle2!($a, 2, $e, 6),
_ => shuffle2!($a, 3, $e, 7),
}
}
}
match (imm8 >> 0) & 0x3 {
0 => shuffle1!(0, 4),
1 => shuffle1!(1, 5),
2 => shuffle1!(2, 6),
_ => shuffle1!(3, 7),
}
}

/// Compute the bitwise NOT of packed double-precision (64-bit) floating-point
/// elements in `a`
/// and then AND with `b`.
Expand Down Expand Up @@ -393,6 +443,56 @@ pub unsafe fn _mm256_blend_pd(a: f64x4, b: f64x4, imm8: i32) -> f64x4 {
}
}

/// Blend packed single-precision (32-bit) floating-point elements from
/// `a` and `b` using control mask `imm8`.
#[inline(always)]
#[target_feature = "+avx"]
#[cfg_attr(test, assert_instr(vblendps, imm8 = 9))]
pub unsafe fn _mm256_blend_ps(a: f32x8, b: f32x8, imm8: i32) -> f32x8 {
let imm8 = (imm8 & 0xFF) as u8;
macro_rules! blend4 {
($a:expr, $b:expr, $c:expr, $d:expr, $e:expr, $f:expr, $g:expr, $h:expr) => {
simd_shuffle8(a, b, [$a, $b, $c, $d, $e, $f, $g, $h]);
}
}
macro_rules! blend3 {
($a:expr, $b:expr, $c:expr, $d:expr, $e:expr, $f:expr) => {
match (imm8 >> 6) & 0b11 {
0b00 => blend4!($a, $b, $c, $d, $e, $f, 6, 7),
0b01 => blend4!($a, $b, $c, $d, $e, $f, 14, 7),
0b10 => blend4!($a, $b, $c, $d, $e, $f, 6, 15),
_ => blend4!($a, $b, $c, $d, $e, $f, 14, 15),
}
}
}
macro_rules! blend2 {
($a:expr, $b:expr, $c:expr, $d:expr) => {
match (imm8 >> 4) & 0b11 {
0b00 => blend3!($a, $b, $c, $d, 4, 5),
0b01 => blend3!($a, $b, $c, $d, 12, 5),
0b10 => blend3!($a, $b, $c, $d, 4, 13),
_ => blend3!($a, $b, $c, $d, 12, 13),
}
}
}
macro_rules! blend1 {
($a:expr, $b:expr) => {
match (imm8 >> 2) & 0b11 {
0b00 => blend2!($a, $b, 2, 3),
0b01 => blend2!($a, $b, 10, 3),
0b10 => blend2!($a, $b, 2, 11),
_ => blend2!($a, $b, 10, 11),
}
}
}
match imm8 & 0b11 {
0b00 => blend1!(0, 1),
0b01 => blend1!(8, 1),
0b10 => blend1!(0, 9),
_ => blend1!(8, 9),
}
}

/// Blend packed double-precision (64-bit) floating-point elements from
/// `a` and `b` using `c` as a mask.
#[inline(always)]
Expand Down Expand Up @@ -1437,6 +1537,18 @@ pub unsafe fn _mm256_testc_si256(a: i64x4, b: i64x4) -> i32 {
ptestc256(a, b)
}

/// Compute the bitwise AND of 256 bits (representing integer data) in `a` and
/// `b`, and set `ZF` to 1 if the result is zero, otherwise set `ZF` to 0.
/// Compute the bitwise NOT of `a` and then AND with `b`, and set `CF` to 1 if
/// the result is zero, otherwise set `CF` to 0. Return 1 if both the `ZF` and
/// `CF` values are zero, otherwise return 0.
#[inline(always)]
#[target_feature = "+avx"]
#[cfg_attr(test, assert_instr(vptest))]
pub unsafe fn _mm256_testnzc_si256(a: i64x4, b: i64x4) -> i32 {
ptestnzc256(a, b)
}

/// Compute the bitwise AND of 256 bits (representing double-precision (64-bit)
/// floating-point elements) in `a` and `b`, producing an intermediate 256-bit
/// value, and set `ZF` to 1 if the sign bit of each 64-bit element in the
Expand Down Expand Up @@ -2272,6 +2384,8 @@ extern "C" {
fn ptestz256(a: i64x4, b: i64x4) -> i32;
#[link_name = "llvm.x86.avx.ptestc.256"]
fn ptestc256(a: i64x4, b: i64x4) -> i32;
#[link_name = "llvm.x86.avx.ptestnzc.256"]
fn ptestnzc256(a: i64x4, b: i64x4) -> i32;
#[link_name = "llvm.x86.avx.vtestz.pd.256"]
fn vtestzpd256(a: f64x4, b: f64x4) -> i32;
#[link_name = "llvm.x86.avx.vtestc.pd.256"]
Expand Down Expand Up @@ -2375,6 +2489,15 @@ mod tests {
assert_eq!(r, e);
}

#[simd_test = "avx"]
unsafe fn _mm256_shuffle_ps() {
let a = f32x8::new(1., 4., 5., 8., 9., 12., 13., 16.);
let b = f32x8::new(2., 3., 6., 7., 10., 11., 14., 15.);
let r = avx::_mm256_shuffle_ps(a, b, 0x0F);
let e = f32x8::new(8., 8., 2., 2., 16., 16., 10., 10.);
assert_eq!(r, e);
}

#[simd_test = "avx"]
unsafe fn _mm256_andnot_pd() {
let a = f64x4::splat(0.);
Expand Down Expand Up @@ -2421,7 +2544,7 @@ mod tests {
#[simd_test = "avx"]
unsafe fn _mm256_min_ps() {
let a = f32x8::new(1., 4., 5., 8., 9., 12., 13., 16.);
let b = f32x8::new(2., 3., 6., 7., 10.0, 11., 14., 15.);
let b = f32x8::new(2., 3., 6., 7., 10., 11., 14., 15.);
let r = avx::_mm256_min_ps(a, b);
let e = f32x8::new(1., 3., 5., 7., 9., 11., 13., 15.);
assert_eq!(r, e);
Expand All @@ -2439,9 +2562,9 @@ mod tests {
#[simd_test = "avx"]
unsafe fn _mm256_mul_ps() {
let a = f32x8::new(1., 2., 3., 4., 5., 6., 7., 8.);
let b = f32x8::new(9., 10.0, 11., 12., 13., 14., 15., 16.);
let b = f32x8::new(9., 10., 11., 12., 13., 14., 15., 16.);
let r = avx::_mm256_mul_ps(a, b);
let e = f32x8::new(9., 20.0, 33., 48., 65., 84., 105., 128.);
let e = f32x8::new(9., 20., 33., 48., 65., 84., 105., 128.);
assert_eq!(r, e);
}

Expand Down Expand Up @@ -2560,7 +2683,7 @@ mod tests {
#[simd_test = "avx"]
unsafe fn _mm256_div_ps() {
let a = f32x8::new(4., 9., 16., 25., 4., 9., 16., 25.);
let b = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.0);
let b = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.);
let r = avx::_mm256_div_ps(a, b);
let e = f32x8::new(1., 3., 8., 5., 0.5, 1., 0.25, 0.5);
assert_eq!(r, e);
Expand All @@ -2587,6 +2710,18 @@ mod tests {
assert_eq!(r, f64x4::new(4., 3., 2., 5.));
}

#[simd_test = "avx"]
unsafe fn _mm256_blend_ps() {
let a = f32x8::new(1., 4., 5., 8., 9., 12., 13., 16.);
let b = f32x8::new(2., 3., 6., 7., 10., 11., 14., 15.);
let r = avx::_mm256_blend_ps(a, b, 0x0);
assert_eq!(r, f32x8::new(1., 4., 5., 8., 9., 12., 13., 16.));
let r = avx::_mm256_blend_ps(a, b, 0x3);
assert_eq!(r, f32x8::new(2., 3., 5., 8., 9., 12., 13., 16.));
let r = avx::_mm256_blend_ps(a, b, 0xF);
assert_eq!(r, f32x8::new(2., 3., 6., 7., 9., 12., 13., 16.));
}

#[simd_test = "avx"]
unsafe fn _mm256_blendv_pd() {
let a = f64x4::new(4., 9., 16., 25.);
Expand All @@ -2600,23 +2735,23 @@ mod tests {
#[simd_test = "avx"]
unsafe fn _mm256_blendv_ps() {
let a = f32x8::new(4., 9., 16., 25., 4., 9., 16., 25.);
let b = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.0);
let b = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.);
#[cfg_attr(rustfmt, rustfmt_skip)]
let c = f32x8::new(
0., 0., 0., 0., !0 as f32, !0 as f32, !0 as f32, !0 as f32,
);
let r = avx::_mm256_blendv_ps(a, b, c);
let e = f32x8::new(4., 9., 16., 25., 8., 9., 64., 50.0);
let e = f32x8::new(4., 9., 16., 25., 8., 9., 64., 50.);
assert_eq!(r, e);
}

#[simd_test = "avx"]
unsafe fn _mm256_dp_ps() {
let a = f32x8::new(4., 9., 16., 25., 4., 9., 16., 25.);
let b = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.0);
let b = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.);
let r = avx::_mm256_dp_ps(a, b, 0xFF);
let e =
f32x8::new(200.0, 200.0, 200.0, 200.0, 2387., 2387., 2387., 2387.);
f32x8::new(200., 200., 200., 200., 2387., 2387., 2387., 2387.);
assert_eq!(r, e);
}

Expand All @@ -2638,7 +2773,7 @@ mod tests {
#[simd_test = "avx"]
unsafe fn _mm256_hadd_ps() {
let a = f32x8::new(4., 9., 16., 25., 4., 9., 16., 25.);
let b = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.0);
let b = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.);
let r = avx::_mm256_hadd_ps(a, b);
let e = f32x8::new(13., 41., 7., 7., 13., 41., 17., 114.);
assert_eq!(r, e);
Expand Down Expand Up @@ -2668,7 +2803,7 @@ mod tests {
#[simd_test = "avx"]
unsafe fn _mm256_hsub_ps() {
let a = f32x8::new(4., 9., 16., 25., 4., 9., 16., 25.);
let b = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.0);
let b = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.);
let r = avx::_mm256_hsub_ps(a, b);
let e = f32x8::new(-5., -9., 1., -3., -5., -9., -1., 14.);
assert_eq!(r, e);
Expand Down Expand Up @@ -2821,7 +2956,7 @@ mod tests {

#[simd_test = "avx"]
unsafe fn _mm256_extractf128_ps() {
let a = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.0);
let a = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.);
let r = avx::_mm256_extractf128_ps(a, 0);
let e = f32x4::new(4., 3., 2., 5.);
assert_eq!(r, e);
Expand Down Expand Up @@ -2890,10 +3025,10 @@ mod tests {

#[simd_test = "avx"]
unsafe fn _mm256_permutevar_ps() {
let a = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.0);
let a = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.);
let b = i32x8::new(1, 2, 3, 4, 5, 6, 7, 8);
let r = avx::_mm256_permutevar_ps(a, b);
let e = f32x8::new(3., 2., 5., 4., 9., 64., 50.0, 8.);
let e = f32x8::new(3., 2., 5., 4., 9., 64., 50., 8.);
assert_eq!(r, e);
}

Expand All @@ -2908,9 +3043,9 @@ mod tests {

#[simd_test = "avx"]
unsafe fn _mm256_permute_ps() {
let a = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.0);
let a = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.);
let r = avx::_mm256_permute_ps(a, 0x1b);
let e = f32x8::new(5., 2., 3., 4., 50.0, 64., 9., 8.);
let e = f32x8::new(5., 2., 3., 4., 50., 64., 9., 8.);
assert_eq!(r, e);
}

Expand Down Expand Up @@ -3022,10 +3157,10 @@ mod tests {

#[simd_test = "avx"]
unsafe fn _mm256_insertf128_ps() {
let a = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.0);
let a = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.);
let b = f32x4::new(4., 9., 16., 25.);
let r = avx::_mm256_insertf128_ps(a, b, 0);
let e = f32x8::new(4., 9., 16., 25., 8., 9., 64., 50.0);
let e = f32x8::new(4., 9., 16., 25., 8., 9., 64., 50.);
assert_eq!(r, e);
}

Expand Down Expand Up @@ -3112,10 +3247,10 @@ mod tests {

#[simd_test = "avx"]
unsafe fn _mm256_loadu_ps() {
let a = &[4., 3., 2., 5., 8., 9., 64., 50.0];
let a = &[4., 3., 2., 5., 8., 9., 64., 50.];
let p = a.as_ptr();
let r = avx::_mm256_loadu_ps(black_box(p));
let e = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.0);
let e = f32x8::new(4., 3., 2., 5., 8., 9., 64., 50.);
assert_eq!(r, e);
}

Expand Down Expand Up @@ -3357,6 +3492,18 @@ mod tests {
assert_eq!(r, 1);
}

#[simd_test = "avx"]
unsafe fn _mm256_testnzc_si256() {
let a = i64x4::new(1, 2, 3, 4);
let b = i64x4::new(5, 6, 7, 8);
let r = avx::_mm256_testnzc_si256(a, b);
assert_eq!(r, 1);
let a = i64x4::new(0, 0, 0, 0);
let b = i64x4::new(0, 0, 0, 0);
let r = avx::_mm256_testnzc_si256(a, b);
assert_eq!(r, 0);
}

#[simd_test = "avx"]
unsafe fn _mm256_testz_pd() {
let a = f64x4::new(1., 2., 3., 4.);
Expand Down

0 comments on commit 46d64f0

Please sign in to comment.