Skip to content

Commit

Permalink
bf16->f32 avx512bf16 GEMM microkernels
Browse files Browse the repository at this point in the history
rsp is now always valid

PiperOrigin-RevId: 720524634
  • Loading branch information
alankelly authored and xnnpack-bot committed Jan 28, 2025
1 parent 0c3e255 commit d7f398e
Show file tree
Hide file tree
Showing 114 changed files with 11,321 additions and 2,344 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ jobs:
env:
CC: gcc-9
CXX: g++-9
BAZEL_DEFINES: --define=xnn_enable_avxvnni=false --define=xnn_enable_avx256vnni=false --define=xnn_enable_avxvnniint8=false --define=xnn_enable_avx512amx=false --define=xnn_enable_avx512fp16=false
BAZEL_DEFINES: --define=xnn_enable_avxvnni=false --define=xnn_enable_avx256vnni=false --define=xnn_enable_avxvnniint8=false --define=xnn_enable_avx512amx=false --define=xnn_enable_avx512fp16=false --define=xnn_enable_avx512bf16=false
steps:
- uses: actions/checkout@v4
- name: Update apt
Expand Down
28 changes: 28 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,18 @@ config_setting(
define_values = {"xnn_enable_avx512fp16": "false"},
)

# Enables usage of Intel AVX512-BF16 (bf16 arithmetic) kernels.
config_setting(
name = "xnn_enable_avx512bf16_explicit_true",
define_values = {"xnn_enable_avx512bf16": "true"},
)

# Disables usage of Intel AVX512-BF16 (bf16 arithmetic) kernels.
config_setting(
name = "xnn_enable_avx512bf16_explicit_false",
define_values = {"xnn_enable_avx512bf16": "false"},
)

# Enables usage of Intel AVX-VNNI (integer dot product) kernels.
config_setting(
name = "xnn_enable_avxvnni_explicit_true",
Expand Down Expand Up @@ -1664,6 +1676,22 @@ selects.config_setting_group(
],
)

selects.config_setting_group(
name = "avx512bf16_enabled_by_default",
match_any = [
"//build_config:x86_64",
],
)

alias(
name = "avx512bf16_enabled",
actual = select({
":xnn_enable_avx512bf16_explicit_true": ":xnn_enable_avx512bf16_explicit_true",
":xnn_enable_avx512bf16_explicit_false": ":xnn_enable_avx512bf16_explicit_true",
"//conditions:default": ":avx512bf16_enabled_by_default",
}),
)

selects.config_setting_group(
name = "arm_bf16_enabled_by_default",
match_any = [
Expand Down
19 changes: 19 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,18 @@ ELSEIF(CMAKE_C_COMPILER_ID STREQUAL "Clang")
ELSEIF(CMAKE_C_COMPILER_ID STREQUAL "MSVC")
SET(XNNPACK_ENABLE_AVX512FP16 OFF)
ENDIF()
OPTION(XNNPACK_ENABLE_AVX512BF16 "Build XNNPACK with AVX512-BF16 micro-kernels" ON)
IF(CMAKE_C_COMPILER_ID STREQUAL "GNU")
IF(CMAKE_C_COMPILER_VERSION VERSION_LESS "13")
SET(XNNPACK_ENABLE_AVX512BF16 OFF)
ENDIF()
ELSEIF(CMAKE_C_COMPILER_ID STREQUAL "Clang")
IF(CMAKE_C_COMPILER_VERSION VERSION_LESS "15")
SET(XNNPACK_ENABLE_AVX512BF16 OFF)
ENDIF()
ELSEIF(CMAKE_C_COMPILER_ID STREQUAL "MSVC")
SET(XNNPACK_ENABLE_AVX512BF16 OFF)
ENDIF()
OPTION(XNNPACK_ENABLE_HVX "Build XNNPACK with Hexagon HVX micro-kernels" ON)
OPTION(XNNPACK_ENABLE_KLEIDIAI "Use KleidiAI GEMM microkernels for Arm" ON)
IF(XNNPACK_TARGET_PROCESSOR STREQUAL "arm64" AND XNNPACK_ENABLE_ARM_I8MM AND NOT CMAKE_C_COMPILER_ID STREQUAL "MSVC")
Expand Down Expand Up @@ -307,6 +319,7 @@ ADD_COMPILE_DEFINITIONS("XNN_ENABLE_AVX512VNNI=$<BOOL:${XNNPACK_ENABLE_AVX512VNN
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_AVX512VNNIGFNI=$<BOOL:${XNNPACK_ENABLE_AVX512VNNIGFNI}>")
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_AVX512AMX=$<BOOL:${XNNPACK_ENABLE_AVX512AMX}>")
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_AVX512FP16=$<BOOL:${XNNPACK_ENABLE_AVX512FP16}>")
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_AVX512BF16=$<BOOL:${XNNPACK_ENABLE_AVX512BF16}>")
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_VSX=$<BOOL:${XNNPACK_ENABLE_VSX}>")
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_ASSEMBLY=$<BOOL:${XNNPACK_ENABLE_ASSEMBLY}>")
ADD_COMPILE_DEFINITIONS("XNN_ENABLE_MEMOPT=$<BOOL:${XNNPACK_ENABLE_MEMOPT}>")
Expand Down Expand Up @@ -677,6 +690,9 @@ IF(XNNPACK_TARGET_PROCESSOR MATCHES "^x86(_64)?$")
IF(XNNPACK_ENABLE_AVX512FP16)
LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_AVX512FP16_MICROKERNEL_SRCS})
ENDIF()
IF(XNNPACK_ENABLE_AVX512BF16)
LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_AVX512BF16_MICROKERNEL_SRCS})
ENDIF()
IF(XNNPACK_ENABLE_AVXVNNI)
LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_AVXVNNI_MICROKERNEL_SRCS})
ENDIF()
Expand Down Expand Up @@ -727,6 +743,9 @@ IF(XNNPACK_TARGET_PROCESSOR MATCHES "^x86(_64)?$")
IF(XNNPACK_ENABLE_AVX512FP16)
LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_AVX512FP16_MICROKERNEL_SRCS})
ENDIF()
IF(XNNPACK_ENABLE_AVX512BF16)
LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_AVX512BF16_MICROKERNEL_SRCS})
ENDIF()
IF(XNNPACK_ENABLE_AVXVNNI)
LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_AVXVNNI_MICROKERNEL_SRCS})
ENDIF()
Expand Down
23 changes: 23 additions & 0 deletions build_params.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def xnnpack_configurable_defines():
":avx512fp16_enabled",
["XNN_ENABLE_AVX512FP16=1"],
["XNN_ENABLE_AVX512FP16=0"],
) + xnnpack_select_if(
":avx512bf16_enabled",
["XNN_ENABLE_AVX512BF16=1"],
["XNN_ENABLE_AVX512BF16=0"],
) + xnnpack_select_if(
":avxvnni_enabled",
["XNN_ENABLE_AVXVNNI=1"],
Expand Down Expand Up @@ -811,6 +815,25 @@ XNNPACK_PARAMS_FOR_ARCH = {
mingw_copts = ["-fno-asynchronous-unwind-tables"],
msys_copts = ["-fno-asynchronous-unwind-tables"],
),
"avx512bf16": _create_params(
cond = "//:avx512bf16_enabled",
gcc_x86_copts = [
"-mf16c",
"-mfma",
"-mavx512f",
"-mavx512cd",
"-mavx512bw",
"-mavx512dq",
"-mavx512vl",
"-mavx512vnni",
"-mgfni",
"-mavx512bf16",
],
msvc_x86_32_copts = ["/arch:AVX512"],
msvc_x86_64_copts = ["/arch:AVX512"],
mingw_copts = ["-fno-asynchronous-unwind-tables"],
msys_copts = ["-fno-asynchronous-unwind-tables"],
),

# RISC-V.
"rvv": _create_params(
Expand Down
27 changes: 27 additions & 0 deletions cmake/gen/amd64_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,33 @@
SET(PROD_AMD64_ASM_MICROKERNEL_SRCS)

SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS
src/bf16-f32-gemm/gen/bf16-f32-gemm-1x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-1x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-1x64c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-2x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-2x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-2x64c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-3x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-3x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-3x64c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-4x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-4x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-4x64c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-5x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-5x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-5x64c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-6x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-6x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-7x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-7x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-8x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-8x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-9x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-9x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-10x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-10x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-11x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S
Expand Down
1 change: 1 addition & 0 deletions cmake/gen/scalar_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS
src/xx-transposev/xx-transposev-1x1-scalar-memcpy.c)

SET(NON_PROD_SCALAR_MICROKERNEL_SRCS
src/bf16-f32-gemm/bf16-f32-gemm-1x4c2-minmax-scalar.c
src/f16-f32-vcvt/gen/f16-f32-vcvt-scalar-u2.c
src/f16-f32-vcvt/gen/f16-f32-vcvt-scalar-u3.c
src/f16-qs8-vcvt/gen/f16-qs8-vcvt-scalar-fmagic-u1.c
Expand Down
2 changes: 2 additions & 0 deletions gemm_compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ py_library(
name = "generate_gemm_microkernels",
srcs = [
"generate.py",
"generate_bf16_f32_gemm_microkernels.py",
"generate_f32_gemm_microkernels.py",
"generate_qd8_f32_qc4w_gemm_microkernels.py",
"generate_qd8_f32_qc8w_gemm_microkernels.py",
Expand All @@ -39,6 +40,7 @@ py_library(
py_library(
name = "x64_isa_templates",
srcs = [
"avx512bf16_template.py",
"avx512f_template.py",
"avx512vnni_template.py",
"fma3_template.py",
Expand Down
77 changes: 77 additions & 0 deletions gemm_compiler/avx512bf16_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python3
# Copyright 2025 Google LLC
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from gemm_compiler import avx512f_template as isa

"""All SIMD features for avx512f."""


class Avx512Bf16(isa.Avx512F):

def __init__(self):
pass # Empty constructor

def isa(self):
return 'avx512bf16'

def n_step(self):
return 16

def compute_asm(self):
c_asm = {
'loop': ['vdpbf16ps z{ACC}, {A}, {W}\n'],
'loop_tail': ["""vpslld {A}, {A}, 16
vpsrad {A}, {A}, 16
vdpbf16ps z{ACC}, {A}, {W}\n
"""],
}
return c_asm

def function_name(self, M, N, isa):
return f'xnn_bf16_f32_gemm_minmax_ukernel_{M}x{N}c2__asm_amd64_{isa}_broadcast\n'

def params_offset(self):
return 72

def outer_loop_prepare(self, M, N):
k_register = self.k_register()
kc_register = self.kc_register()
offset = M * 16
asm_string = f"""
# Copy k and flip bit.
mov {k_register}, rdx
and {k_register}, 0x2
and {kc_register}, 0xFFFFFFFFFFFFFFFD
mov [rsp + {offset}], {k_register}\n"""
return asm_string

def init_accumulators(self, M, N):
asm_string = super().init_accumulators(M, N)
asm_string += """
# Are there at least 4 bytes?
cmp rdx, 4
js inner_loop_tail\n"""

return asm_string

def inner_loop_tail(self, M, N):
k_register = self.k_register()
nc_register = self.nc_register()
offset = M * 16
nc_offset = offset + 8
asm_string = f"""
mov [rsp + {nc_offset}], {nc_register}
mov {nc_register}, [rsp + {offset}]
test {nc_register}, {nc_register}
mov {nc_register}, [rsp + {nc_offset}]
jz inner_loop_end
inner_loop_tail:\n"""
if M > self.max_M_before_spilling():
asm_string += self.inner_loop_spill_gp(M=M, N=N, tail=True)
else:
asm_string += self.inner_loop_small_M_N(M=M, N=N, tail=True)
return asm_string
Loading

0 comments on commit d7f398e

Please sign in to comment.