Skip to content

Commit

Permalink
Add sve implementation for float matrix transpose (pytorch#3421)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#509

Adding sve-based function for transposing float matrixes

Reviewed By: psaab

Differential Revision: D66528598
  • Loading branch information
Nicoshev authored and facebook-github-bot committed Nov 27, 2024
1 parent 6eb379a commit f7ce173
Show file tree
Hide file tree
Showing 6 changed files with 874 additions and 2 deletions.
4 changes: 2 additions & 2 deletions defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ def get_fbgemm_inline_avx512_srcs(msvc = False, buck = False):
return asm_srcs if not msvc else intrinsics_srcs

def get_fbgemm_inline_sve_srcs(msvc = False, buck = False):
intrinsics_srcs = ["src/FbgemmFP16UKernelsSve128.cc"]
intrinsics_srcs = ["src/FbgemmFP16UKernelsSve128.cc", "src/UtilsSve.cc"]

#FP16 kernels contain inline assembly and inline assembly syntax for MSVC is different.
asm_srcs = ["src/FbgemmFP16UKernelsSve128.cc"]
asm_srcs = ["src/FbgemmFP16UKernelsSve128.cc", "src/UtilsSve.cc"]
if buck:
return select({
"DEFAULT": asm_srcs,
Expand Down
9 changes: 9 additions & 0 deletions include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
#include <string>
#include <type_traits>

#ifndef HAVE_SVE
#if defined(__aarch64__) && (__GNUC__ >= 8 || __clang_major__ >= 5) && \
__ARM_FEATURE_SVE
#define HAVE_SVE 1
#else
#define HAVE_SVE 0
#endif
#endif

namespace fbgemm {

/**
Expand Down
10 changes: 10 additions & 0 deletions src/TransposeUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ void transpose_simd(
}
return;
}

#if HAVE_SVE
if constexpr (std::is_same<T, float>::value) {
internal::transpose_sve<T>(M, N, src, ld_src, dst, ld_dst);
} else {
transpose_ref<T>(M, N, src, ld_src, dst, ld_dst);
}
#else
static const auto iset = fbgemmInstructionSet();
// Run time CPU detection
if (isZmm(iset)) {
Expand All @@ -55,6 +63,8 @@ void transpose_simd(
} else {
transpose_ref<T>(M, N, src, ld_src, dst, ld_dst);
}

#endif
}

template void transpose_ref<float>(
Expand Down
16 changes: 16 additions & 0 deletions src/TransposeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ void transpose_avx512(
T* dst,
int64_t ld_dst);

#ifdef __aarch64__
/**
* @brief Transpose a matrix using Intel AVX2.
*
* This is called if the code is running on a CPU with Intel AVX2 support.
*/
template <typename T>
void transpose_sve(
int64_t M,
int64_t N,
const T* src,
int64_t ld_src,
T* dst,
int64_t ld_dst);
#endif // __aarch64__

} // namespace internal

} // namespace fbgemm
Loading

0 comments on commit f7ce173

Please sign in to comment.