From 782828c38edf05bfb0732ead2f6423d7818caa8e Mon Sep 17 00:00:00 2001 From: Huang Wenhuan Date: Mon, 25 Nov 2024 07:16:49 +0000 Subject: [PATCH] + Add layernorm FP16 support; + Add FP16 UTs for layernorm and rmsnorm. Signed-off-by: Huang Wenhuan --- src/kernels/layernorm_kernels.cpp | 37 ++++++++++++++++++++++--------- tests/ut/layers_norm_test.cpp | 7 ++++++ tests/ut/rms_norm_test.cpp | 7 ++++++ 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/kernels/layernorm_kernels.cpp b/src/kernels/layernorm_kernels.cpp index 623c887f..b8cad2f8 100644 --- a/src/kernels/layernorm_kernels.cpp +++ b/src/kernels/layernorm_kernels.cpp @@ -18,13 +18,15 @@ #include "dtype.h" #include "float16.h" #include "intrinsic_ext.h" +#include "intrinsics_util.h" #include "layernorm_kernels.h" #include "my_types.h" namespace xft { -void invokeLayerNorm(float *output, const float *input, const float *gamma, const float *beta, int rows, int cols, - int iStride, int oStride, float epsilon) { +template +void invokeLayerNorm(T *output, const T *input, const T *gamma, const T *beta, int rows, int cols, int iStride, + int oStride, float epsilon) { int size = cols; if (iStride == -1) iStride = size; @@ -32,8 +34,8 @@ void invokeLayerNorm(float *output, const float *input, const float *gamma, cons #pragma omp parallel for for (int r = 0; r < rows; ++r) { - const float *px = input + r * iStride; - float *py = output + r * oStride; + const T *px = input + r * iStride; + T *py = output + r * oStride; float sum = 0; float squareSum = 0; @@ -46,7 +48,7 @@ void invokeLayerNorm(float *output, const float *input, const float *gamma, cons __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); // SUM(x) - __m512 vx = _mm512_maskz_loadu_ps(mask, px + col); + __m512 vx = xft::load_avx512(mask, px + col); vsum = _mm512_add_ps(vsum, vx); // SUM(x*x) @@ -69,11 +71,11 @@ void invokeLayerNorm(float *output, const float *input, const float *gamma, cons int remain = size - col; __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); - __m512 vx = _mm512_maskz_loadu_ps(mask, px + col); - __m512 vgamma = _mm512_maskz_loadu_ps(mask, gamma + col); - __m512 vbeta = _mm512_maskz_loadu_ps(mask, beta + col); + __m512 vx = xft::load_avx512(mask, px + col); + __m512 vgamma = xft::load_avx512(mask, gamma + col); + __m512 vbeta = xft::load_avx512(mask, beta + col); __m512 vy = (vx - vmean) * vgamma * vvar + vbeta; - _mm512_mask_storeu_ps(py + col, mask, vy); + xft::store_avx512(py + col, mask, vy); } } } @@ -133,14 +135,27 @@ void invokeLayerNorm(bfloat16_t *output, const bfloat16_t *input, const bfloat16 } } +void invokeLayerNorm(float *output, const float *input, const float *gamma, const float *beta, int rows, int cols, + int iStride, int oStride, float epsilon) { + invokeLayerNorm(output, input, gamma, beta, rows, cols, iStride, oStride, epsilon); +} + +void invokeLayerNorm(float16_t *output, const float16_t *input, const float16_t *gamma, const float16_t *beta, int rows, + int cols, int iStride, int oStride, float epsilon) { + invokeLayerNorm(output, input, gamma, beta, rows, cols, iStride, oStride, epsilon); +} + void invokeLayerNorm(DataType dt, void *output, const void *input, const void *gamma, const void *beta, int rows, int cols, int iStride, int oStride, float epsilon) { if (dt == DataType::fp32) { - invokeLayerNorm((float *)output, (const float *)input, (const float *)gamma, - (const float *)beta, rows, cols, iStride, oStride, epsilon); + invokeLayerNorm((float *)output, (const float *)input, (const float *)gamma, (const float *)beta, rows, cols, + iStride, oStride, epsilon); } else if (dt == DataType::bf16) { invokeLayerNorm((bfloat16_t *)output, (const bfloat16_t *)input, (const bfloat16_t *)gamma, (const bfloat16_t *)beta, rows, cols, iStride, oStride, epsilon); + } else if (dt == DataType::fp16) { + invokeLayerNorm((float16_t *)output, (const float16_t *)input, (const float16_t *)gamma, + (const float16_t *)beta, rows, cols, iStride, oStride, epsilon); } } diff --git a/tests/ut/layers_norm_test.cpp b/tests/ut/layers_norm_test.cpp index 1044150e..262b68ec 100644 --- a/tests/ut/layers_norm_test.cpp +++ b/tests/ut/layers_norm_test.cpp @@ -111,6 +111,13 @@ TEST(LayerNorm, bfloat16_t) { compareLayerNorm(rand() % 100 + 100, rand() % 100 + 100); } +TEST(LayerNorm, float16_t) { + compareLayerNorm(128, 128); + compareLayerNorm(5120, 5120); + compareLayerNorm(5120, 5120 * 3); + compareLayerNorm(rand() % 100 + 100, rand() % 100 + 100); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/tests/ut/rms_norm_test.cpp b/tests/ut/rms_norm_test.cpp index 90c5ffce..8d078460 100644 --- a/tests/ut/rms_norm_test.cpp +++ b/tests/ut/rms_norm_test.cpp @@ -98,6 +98,13 @@ TEST(RMSNorm, bfloat16_t) { compareRMSNorm(rand() % 100 + 100, rand() % 100 + 100); } +TEST(RMSNorm, float16_t) { + compareRMSNorm(128, 128); + compareRMSNorm(5120, 5120); + compareRMSNorm(5120, 5120 * 3); + compareRMSNorm(rand() % 100 + 100, rand() % 100 + 100); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS();