Skip to content

Commit

Permalink
+ Add layernorm FP16 support;
Browse files Browse the repository at this point in the history
+ Add FP16 UTs for layernorm and rmsnorm.

Signed-off-by: Huang Wenhuan <[email protected]>
  • Loading branch information
wenhuanh committed Nov 25, 2024
1 parent 12296de commit 782828c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 11 deletions.
37 changes: 26 additions & 11 deletions src/kernels/layernorm_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,24 @@
#include "dtype.h"
#include "float16.h"
#include "intrinsic_ext.h"
#include "intrinsics_util.h"
#include "layernorm_kernels.h"
#include "my_types.h"

namespace xft {

void invokeLayerNorm(float *output, const float *input, const float *gamma, const float *beta, int rows, int cols,
int iStride, int oStride, float epsilon) {
template <typename T>
void invokeLayerNorm(T *output, const T *input, const T *gamma, const T *beta, int rows, int cols, int iStride,
int oStride, float epsilon) {

int size = cols;
if (iStride == -1) iStride = size;
if (oStride == -1) oStride = size;

#pragma omp parallel for
for (int r = 0; r < rows; ++r) {
const float *px = input + r * iStride;
float *py = output + r * oStride;
const T *px = input + r * iStride;
T *py = output + r * oStride;

float sum = 0;
float squareSum = 0;
Expand All @@ -46,7 +48,7 @@ void invokeLayerNorm(float *output, const float *input, const float *gamma, cons
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);

// SUM(x)
__m512 vx = _mm512_maskz_loadu_ps(mask, px + col);
__m512 vx = xft::load_avx512(mask, px + col);
vsum = _mm512_add_ps(vsum, vx);

// SUM(x*x)
Expand All @@ -69,11 +71,11 @@ void invokeLayerNorm(float *output, const float *input, const float *gamma, cons
int remain = size - col;
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);

__m512 vx = _mm512_maskz_loadu_ps(mask, px + col);
__m512 vgamma = _mm512_maskz_loadu_ps(mask, gamma + col);
__m512 vbeta = _mm512_maskz_loadu_ps(mask, beta + col);
__m512 vx = xft::load_avx512(mask, px + col);
__m512 vgamma = xft::load_avx512(mask, gamma + col);
__m512 vbeta = xft::load_avx512(mask, beta + col);
__m512 vy = (vx - vmean) * vgamma * vvar + vbeta;
_mm512_mask_storeu_ps(py + col, mask, vy);
xft::store_avx512(py + col, mask, vy);
}
}
}
Expand Down Expand Up @@ -133,14 +135,27 @@ void invokeLayerNorm(bfloat16_t *output, const bfloat16_t *input, const bfloat16
}
}

void invokeLayerNorm(float *output, const float *input, const float *gamma, const float *beta, int rows, int cols,
int iStride, int oStride, float epsilon) {
invokeLayerNorm<float>(output, input, gamma, beta, rows, cols, iStride, oStride, epsilon);
}

void invokeLayerNorm(float16_t *output, const float16_t *input, const float16_t *gamma, const float16_t *beta, int rows,
int cols, int iStride, int oStride, float epsilon) {
invokeLayerNorm<float16_t>(output, input, gamma, beta, rows, cols, iStride, oStride, epsilon);
}

void invokeLayerNorm(DataType dt, void *output, const void *input, const void *gamma, const void *beta, int rows,
int cols, int iStride, int oStride, float epsilon) {
if (dt == DataType::fp32) {
invokeLayerNorm((float *)output, (const float *)input, (const float *)gamma,
(const float *)beta, rows, cols, iStride, oStride, epsilon);
invokeLayerNorm((float *)output, (const float *)input, (const float *)gamma, (const float *)beta, rows, cols,
iStride, oStride, epsilon);
} else if (dt == DataType::bf16) {
invokeLayerNorm((bfloat16_t *)output, (const bfloat16_t *)input, (const bfloat16_t *)gamma,
(const bfloat16_t *)beta, rows, cols, iStride, oStride, epsilon);
} else if (dt == DataType::fp16) {
invokeLayerNorm((float16_t *)output, (const float16_t *)input, (const float16_t *)gamma,
(const float16_t *)beta, rows, cols, iStride, oStride, epsilon);
}
}

Expand Down
7 changes: 7 additions & 0 deletions tests/ut/layers_norm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ TEST(LayerNorm, bfloat16_t) {
compareLayerNorm<bfloat16_t>(rand() % 100 + 100, rand() % 100 + 100);
}

TEST(LayerNorm, float16_t) {
compareLayerNorm<float16_t>(128, 128);
compareLayerNorm<float16_t>(5120, 5120);
compareLayerNorm<float16_t>(5120, 5120 * 3);
compareLayerNorm<float16_t>(rand() % 100 + 100, rand() % 100 + 100);
}

int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down
7 changes: 7 additions & 0 deletions tests/ut/rms_norm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ TEST(RMSNorm, bfloat16_t) {
compareRMSNorm<bfloat16_t>(rand() % 100 + 100, rand() % 100 + 100);
}

TEST(RMSNorm, float16_t) {
compareRMSNorm<float16_t>(128, 128);
compareRMSNorm<float16_t>(5120, 5120);
compareRMSNorm<float16_t>(5120, 5120 * 3);
compareRMSNorm<float16_t>(rand() % 100 + 100, rand() % 100 + 100);
}

int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down

0 comments on commit 782828c

Please sign in to comment.