Skip to content

Commit

Permalink
Merge pull request #3 from LaurentMazare/f16-vec-plus-wasm-simd
Browse files Browse the repository at this point in the history
F16 vec plus wasm simd
  • Loading branch information
LaurentMazare authored Jul 26, 2023
2 parents f577d01 + e91056e commit c03b453
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 10 deletions.
15 changes: 14 additions & 1 deletion gemm-common/src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,17 @@ macro_rules! gemm_def {
}
}

#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
simd128::gemm_basic
}

#[cfg(all(target_arch = "wasm32", not(target_feature = "simd128")))]
{
scalar::gemm_basic
}

#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
{
scalar::gemm_basic
}
Expand All @@ -805,6 +815,9 @@ macro_rules! gemm_def {

#[cfg(target_arch = "aarch64")]
$crate::__inject_mod!(neon, $ty, 2 * $multiplier, Scalar);

#[cfg(target_arch = "wasm32")]
$crate::__inject_mod!(simd128, $ty, 2 * $multiplier, Simd128);
};
}

Expand Down
18 changes: 18 additions & 0 deletions gemm-common/src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,21 @@ mod x86 {

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub use x86::*;

#[cfg(target_arch = "wasm32")]
mod wasm32 {
use super::*;

#[derive(Copy, Clone)]
pub struct Simd128;

impl Simd for Simd128 {
#[inline]
#[target_feature(enable = "simd128")]
unsafe fn vectorize(f: impl FnOnce()) {
f()
}
}
}
#[cfg(target_arch = "wasm32")]
pub use wasm32::*;
77 changes: 68 additions & 9 deletions gemm-f16/src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use gemm_common::{
pack_operands::quick_zero,
Parallelism, Ptr,
};
use half::slice::HalfFloatSliceExt;
type T = half::f16;

#[inline(always)]
Expand All @@ -17,16 +18,74 @@ unsafe fn pack_generic_inner_loop<const N: usize, const DST_WIDTH: usize>(
src_width: usize,
k: usize,
) {
for _ in 0..k {
for j in 0..src_width {
*dst.add(j) = (*src.offset(j as isize * src_rs)).into();
if src_width == DST_WIDTH {
if src_rs == 1 {
for _ in 0..k {
let val = (src as *const [T; DST_WIDTH]).read();
val.convert_to_f32_slice(core::slice::from_raw_parts_mut(dst, DST_WIDTH));

src = src.wrapping_offset(src_cs);
dst = dst.add(DST_WIDTH);
}
} else {
for _ in 0..k {
for j in 0..DST_WIDTH {
*dst.add(j) = (*src.offset(j as isize * src_rs)).into();
}
src = src.wrapping_offset(src_cs);
dst = dst.add(DST_WIDTH);
}
}
} else if src_width == N {
if src_rs == 1 {
for _ in 0..k {
let val = (src as *const [T; N]).read();
val.convert_to_f32_slice(core::slice::from_raw_parts_mut(dst, N));

src = src.wrapping_offset(src_cs);
dst = dst.add(DST_WIDTH);
}
} else {
for _ in 0..k {
for j in 0..N {
*dst.add(j) = (*src.offset(j as isize * src_rs)).into();
}
src = src.wrapping_offset(src_cs);
dst = dst.add(DST_WIDTH);
}
}
} else if src_width == 2 * N {
if src_rs == 1 {
for _ in 0..k {
let val0 = (src as *const [T; N]).read();
let val1 = (src.add(N) as *const [T; N]).read();
val0.convert_to_f32_slice(core::slice::from_raw_parts_mut(dst, N));
val1.convert_to_f32_slice(core::slice::from_raw_parts_mut(dst.add(N), N));

src = src.wrapping_offset(src_cs);
dst = dst.add(DST_WIDTH);
}
} else {
for _ in 0..k {
for j in 0..2 * N {
*dst.add(j) = (*src.offset(j as isize * src_rs)).into();
}
src = src.wrapping_offset(src_cs);
dst = dst.add(DST_WIDTH);
}
}
} else {
for _ in 0..k {
for j in 0..src_width {
*dst.add(j) = (*src.offset(j as isize * src_rs)).into();
}
quick_zero(core::slice::from_raw_parts_mut(
dst.add(src_width),
DST_WIDTH - src_width,
));
src = src.wrapping_offset(src_cs);
dst = dst.add(DST_WIDTH);
}
quick_zero(core::slice::from_raw_parts_mut(
dst.add(src_width),
DST_WIDTH - src_width,
));
src = src.wrapping_offset(src_cs);
dst = dst.add(DST_WIDTH);
}
}

Expand Down
47 changes: 47 additions & 0 deletions gemm-f32/src/microkernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,53 @@ mod v128_common {
}
}

#[cfg(target_arch = "wasm32")]
pub mod simd128 {
pub mod f32 {
use core::arch::wasm32::*;
use core::mem::transmute;

type T = f32;
const N: usize = 4;
type Pack = [T; N];

#[inline(always)]
unsafe fn splat(value: T) -> Pack {
transmute(f32x4_splat(value))
}

#[inline(always)]
unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
transmute(f32x4_mul(transmute(lhs), transmute(rhs)))
}

#[inline(always)]
unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
transmute(f32x4_add(transmute(lhs), transmute(rhs)))
}

#[inline(always)]
unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
add(mul(a, b), c)
}

microkernel!(["simd128"], 2, x1x1, 1, 1);
microkernel!(["simd128"], 2, x1x2, 1, 2);
microkernel!(["simd128"], 2, x1x3, 1, 3);
microkernel!(["simd128"], 2, x1x4, 1, 4);

microkernel!(["simd128"], 2, x2x1, 2, 1);
microkernel!(["simd128"], 2, x2x2, 2, 2);
microkernel!(["simd128"], 2, x2x3, 2, 3);
microkernel!(["simd128"], 2, x2x4, 2, 4);

microkernel_fn_array! {
[x1x1, x1x2, x1x3, x1x4,],
[x2x1, x2x2, x2x3, x2x4,],
}
}
}

#[cfg(target_arch = "aarch64")]
pub mod neon {
pub mod f32 {
Expand Down
48 changes: 48 additions & 0 deletions gemm-f64/src/microkernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,54 @@ mod v128_common {
}
}

#[cfg(target_arch = "wasm32")]
pub mod simd128 {
pub mod f64 {
use core::arch::wasm32::*;
use core::mem::transmute;

type T = f64;
const N: usize = 2;
type Pack = [T; N];

#[inline(always)]
unsafe fn splat(value: T) -> Pack {
transmute(f64x2_splat(value))
}

#[inline(always)]
unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack {
transmute(f64x2_mul(transmute(lhs), transmute(rhs)))
}

#[inline(always)]
unsafe fn add(lhs: Pack, rhs: Pack) -> Pack {
transmute(f64x2_add(transmute(lhs), transmute(rhs)))
}

#[inline(always)]
unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack {
add(mul(a, b), c)
}

microkernel!(["simd128"], 2, x1x1, 1, 1);
microkernel!(["simd128"], 2, x1x2, 1, 2);
microkernel!(["simd128"], 2, x1x3, 1, 3);
microkernel!(["simd128"], 2, x1x4, 1, 4);

microkernel!(["simd128"], 2, x2x1, 2, 1);
microkernel!(["simd128"], 2, x2x2, 2, 2);
microkernel!(["simd128"], 2, x2x3, 2, 3);
microkernel!(["simd128"], 2, x2x4, 2, 4);

microkernel_fn_array! {
[x1x1, x1x2, x1x3, x1x4,],
[x2x1, x2x2, x2x3, x2x4,],
}
}
}


#[cfg(target_arch = "aarch64")]
pub mod neon {
pub mod f64 {
Expand Down

0 comments on commit c03b453

Please sign in to comment.