Skip to content

Commit

Permalink
FEAT: In sgemm sse2 and fallback, use a 8x4 kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
bluss committed Dec 5, 2018
1 parent d7a74e3 commit aa5cb3f
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/sgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
return selector.select(KernelFallback);
}

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
const MR: usize = 8;
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
const NR: usize = 8;

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; }
#[cfg(test)]
macro_rules! loop_n { ($j:ident, $e:expr) => { loop8!($j, $e) }; }

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
Expand Down Expand Up @@ -89,8 +93,8 @@ impl GemmKernel for KernelAvx {
impl GemmKernel for KernelSse2 {
type Elem = T;

const MR: usize = MR;
const NR: usize = NR;
const MR: usize = KernelFallback::MR;
const NR: usize = KernelFallback::NR;
#[inline(always)]
fn align_to() -> usize { 16 }

Expand All @@ -100,7 +104,7 @@ impl GemmKernel for KernelSse2 {
fn nr() -> usize { Self::NR }

#[inline(always)]
fn always_masked() -> bool { true }
fn always_masked() -> bool { KernelFallback::always_masked() }

#[inline(always)]
fn nc() -> usize { archparam::S_NC }
Expand All @@ -124,8 +128,8 @@ impl GemmKernel for KernelSse2 {
impl GemmKernel for KernelFallback {
type Elem = T;

const MR: usize = MR;
const NR: usize = NR;
const MR: usize = 8;
const NR: usize = 4;
#[inline(always)]
fn align_to() -> usize { 0 }

Expand Down Expand Up @@ -433,7 +437,7 @@ unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,

// Compute A B into ab[i][j]
unroll_by!(4 => k, {
loop_m!(i, loop_n!(j, ab[i][j] += at(a, i) * at(b, j)));
loop8!(i, loop4!(j, ab[i][j] += at(a, i) * at(b, j)));

a = a.offset(MR as isize);
b = b.offset(NR as isize);
Expand All @@ -444,7 +448,7 @@ unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
}

// set C = A B
loop_n!(j, loop_m!(i, *c![i, j] = ab[i][j]));
loop4!(j, loop8!(i, *c![i, j] = ab[i][j]));
}

#[inline(always)]
Expand Down

0 comments on commit aa5cb3f

Please sign in to comment.