From 6652c16395ce46d99de153f85b64e430ed042d6e Mon Sep 17 00:00:00 2001 From: Chen Meng Date: Fri, 1 Dec 2023 16:22:12 +0800 Subject: [PATCH] Accelerate first token gen with BF16-gemm MHA and concat-Silu MLP --- .gitignore | 1 + CMakeLists.txt | 5 +- cmake/mkl.cmake | 29 +++++ requirements.txt | 2 +- src/common/transformer_ctx.h | 3 +- src/layers/attention.h | 115 +++++++++++++------- src/layers/mlp_chatglm2.h | 81 +++++++++----- src/layers/mlp_llama.h | 159 +++++++++++++++++++++++---- src/utils/decoder_util.h | 205 ++++++++++++++++++++--------------- 9 files changed, 424 insertions(+), 176 deletions(-) create mode 100644 cmake/mkl.cmake diff --git a/.gitignore b/.gitignore index be868c74..a5656a27 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ dist/ # 3drparty /3rdparty/ig /3rdparty/mklml +/3rdparty/mkl /3rdparty/oneCCL /3rdparty/oneccl /3rdparty/jsoncpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e7b274ab..008b150a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,6 +55,7 @@ endif() include("cmake/mklml.cmake") include("cmake/onednn.cmake") include("cmake/xdnn.cmake") +include("cmake/mkl.cmake") include(GNUInstallDirs) set(DEPEND_LIST "onednn" "xdnn_lib") @@ -63,6 +64,7 @@ include_directories(${CMAKE_SOURCE_DIR}/3rdparty/) include_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/include) include_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/build/include) include_directories(${CMAKE_SOURCE_DIR}/3rdparty/xdnn) +include_directories(${CMAKE_SOURCE_DIR}/3rdparty/mkl/include) include_directories(${CMAKE_SOURCE_DIR}/include) include_directories(${CMAKE_SOURCE_DIR}/src/kernels) include_directories(${CMAKE_SOURCE_DIR}/src/layers) @@ -74,6 +76,7 @@ include_directories(${CMAKE_SOURCE_DIR}/src/common) link_directories(${CMAKE_SOURCE_DIR}/src/kernels) link_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/build/src) link_directories(${CMAKE_SOURCE_DIR}/3rdparty/xdnn) +link_directories(${CMAKE_SOURCE_DIR}/3rdparty/mkl/lib) find_package(oneCCL REQUIRED) @@ -93,7 +96,7 @@ else() link_directories(${CMAKE_SOURCE_DIR}/3rdparty/oneccl/build/_install/lib/prov) endif() -set(3RDPART_LIB_LIST ${MPI_LIBS} "ccl" "dnnl" "numa") +set(3RDPART_LIB_LIST ${MPI_LIBS} "ccl" "dnnl" "numa" "mkl_rt") option(BUILD_WITH_SHARED_LIBS "Build with shared libraries" OFF) if(BUILD_WITH_SHARED_LIBS) message(STATUS "Notice: Building with shared libraries.") diff --git a/cmake/mkl.cmake b/cmake/mkl.cmake new file mode 100644 index 00000000..778920ec --- /dev/null +++ b/cmake/mkl.cmake @@ -0,0 +1,29 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +cmake_minimum_required(VERSION 3.18) + +# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24: +if(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.24.0") + cmake_policy(SET CMP0135 NEW) +endif() + +find_package (Python COMPONENTS Interpreter Development) +execute_process(COMMAND ${Python_EXECUTABLE} -m pip install --prefix=${CMAKE_SOURCE_DIR}/3rdparty/mkl mkl mkl-include + RESULT_VARIABLE EXIT_CODE + OUTPUT_QUIET) +execute_process(COMMAND ln -sf ${CMAKE_SOURCE_DIR}/3rdparty/mkl/lib/libmkl_rt.so.2 ${CMAKE_SOURCE_DIR}/3rdparty/mkl/lib/libmkl_rt.so + RESULT_VARIABLE EXIT_CODE + OUTPUT_QUIET) diff --git a/requirements.txt b/requirements.txt index 3f00acfd..2dd72dd9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ sentencepiece==0.1.99 tokenizers==0.13.3 torch==2.0.1+cpu transformers==4.30.0 -accelerate==0.23.0 \ No newline at end of file +accelerate==0.23.0 diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index d9742c4a..18974082 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -158,13 +158,14 @@ struct DecoderContext { int vCols = kCols; int qkvCols = qCols + kCols + vCols; int qkvStride = (qkvCols % 512 == 0 ? qkvCols + pad : qkvCols); // stride for the concated QKV + int mlpFactor = (this->actType == SILU || this->actType == SWIGLU) ? 2 : 1; int imCols = splitIdx < (intermediateSize % numSplit) ? (intermediateSize / numSplit + 1) : (intermediateSize / numSplit); int imStride = (imCols % 512 == 0 ? imCols + pad : imCols); // stride for intermediate output int normSize = batchSize * inputSeqLen * hiddenStride; int qkvSize = batchSize * inputSeqLen * qkvStride; - int imOutSize = batchSize * inputSeqLen * imStride; + int imOutSize = batchSize * inputSeqLen * imStride * mlpFactor; int presentSeqLen = preSeqLen + 1; int paddedSize = (presentSeqLen + 15) / 16 * 16; diff --git a/src/layers/attention.h b/src/layers/attention.h index 3fa6de5b..b746f177 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -274,8 +274,9 @@ class Attention { resultBuffer1.Assign(inputBuffer.Data(), inputBuffer.Rows(), inputBuffer.Cols(), inputBuffer.Stride()); inputBuffer.Assign(presult, rows, cols, stride); } + int enable = (getenv("ENABLE_FLASH_ATTN") ? atoi(getenv("ENABLE_FLASH_ATTN")) : 1); // TODO: support large inputSeqLen when pastSeqLen > 0 - if (ctx->inputSeqLen > 256 && pastSeqLen == 0) + if (enable && ctx->inputSeqLen >= 1024 && pastSeqLen == 0) flashAttention( ctx, qkvGroupMatMul, resultBuffer2, resultBuffer1, presentKey, presentValue, attnMask, pastSeqLen); else @@ -702,8 +703,8 @@ class Attention { } // end for b } - template - void flashAttention(DecoderContext *ctx, hpj::Matrix &qkvMatMul, hpj::Matrix &tmpRes, + template + void flashAttention(DecoderContext *ctx, hpj::Matrix &qkvMatMul, hpj::Matrix &tmpBuf, hpj::Matrix &result, KVCacheTensor &presentKey, KVCacheTensor &presentValue, const float *attnMask, int pastSeqLen) { @@ -711,25 +712,40 @@ class Attention { int batchSize = ctx->batchSize; int respQHeads = this->endQHead - this->startQHead; int respKVHeads = this->endKVHead - this->startKVHead; - int qkvCols = respQHeads + respKVHeads * 2; int headSize = ctx->attHeadSize; + int qCols = respQHeads * headSize; + int kvCols = respKVHeads * headSize; + int qkvCols = qCols + kvCols * 2; float scale = ctx->attFactor; int srcLen = ctx->inputSeqLen; int tgtLen = pastSeqLen + srcLen; - float *transQKV = (float *)malloc(sizeof(float) * batchSize * qkvCols * srcLen * headSize); - - DecoderUtil::transposeQKV(qkvMatMul.Data(), transQKV, batchSize, srcLen, respQHeads, respKVHeads, headSize); + // TODO: kv dtype conversion for prefixSharing + AttnT *k, *v; + if constexpr (std::is_same_v) { +#pragma omp parallel for collapse(3) + for (int b = 0; b < batchSize; ++b) + for (int seq = 0; seq < srcLen; ++seq) + for (int i = qCols; i < qkvCols; i+=headSize) { + const float *srcPtr = qkvMatMul.Data() + b * srcLen * qkvCols + seq * qkvCols + i; + bfloat16_t *dstPtr = (bfloat16_t *)tmpBuf.Data() + b * srcLen * kvCols * 2 + seq * kvCols * 2 + i - qCols; + bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize); + } - float *query = transQKV; - float *key = transQKV + batchSize * respQHeads * srcLen * headSize; - float *value = transQKV + batchSize * (respQHeads + respKVHeads) * srcLen * headSize; + k = (AttnT*)tmpBuf.Data(); + v = (AttnT*)tmpBuf.Data() + kvCols; + } else { + k = qkvMatMul.Data() + respQHeads * headSize; + v = qkvMatMul.Data() + (respQHeads + respKVHeads) * headSize; + } - scaledDpAttention(query, key, value, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads, respKVHeads, - headSize, tmpRes.Data()); - DecoderUtil::transposeAttnResult( - tmpRes.Data(), result.Data(), batchSize, srcLen, respQHeads, headSize, result.Stride()); + float *query = qkvMatMul.Data(); + // [batch, src, head, headsize] + scaledDpAttention(query, k, v, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads, + respKVHeads, headSize, result.Data(), qkvCols, kvCols * 2, ctx->hiddenSize); + float *key = qkvMatMul.Data() + respQHeads * headSize; + float *value = qkvMatMul.Data() + (respQHeads + respKVHeads) * headSize; // For group attention, as #kvHeads != #qHeads, need to copy current key/values to cache seperately // When M dimension is split, also multiple tasks per copy, so do copy seperately #pragma omp parallel for collapse(3) @@ -739,10 +755,10 @@ class Attention { // Re-layout is needed: (bs, seq=1, hidden_size) -> (seq=1, bs, hidden_size) // Be noted: for group attention, the key/value is less than query for (int seq = 0; seq < tgtLen; ++seq) { - auto srcK = key + b * respKVHeads * tgtLen * headSize + i * tgtLen * headSize + seq * headSize; + auto srcK = key + b * tgtLen * qkvCols + seq * qkvCols + i * headSize; auto dstK = presentKey.getSequence(pastSeqLen + seq, b, i); - auto srcV = value + b * respKVHeads * tgtLen * headSize + i * tgtLen * headSize + seq * headSize; + auto srcV = value + b * tgtLen * qkvCols + seq * qkvCols + i * headSize; auto dstV = presentValue.getSequence(pastSeqLen + seq, b, i); if constexpr (std::is_same_v) { @@ -755,31 +771,43 @@ class Attention { } } } - free(transQKV); } // scaled dot-product attention: bmm1 + softmax + bmm2 - void scaledDpAttention(const float *query, const float *key, const float *value, const float *attnMask, float scale, - int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize, float *output) { + template + void scaledDpAttention(const float *query, const AttnT *key, const AttnT *value, const float *attnMask, + float scale, int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize, + float *output, int qStride, int kvStride, int stride) { // output = trans(softmax(query * trans(key)) * value) int nth = omp_get_max_threads(); - int minBlk = (nth >= batchSize * numQHead ? 256 : 512); - int srcBlk = std::min(minBlk, srcLen); - int tgtBlk = std::min(minBlk, tgtLen); + int minBlk = (int)std::pow(2, int(std::log2(srcLen / 2))); + int srcBlk = std::min(256, minBlk); + int tgtBlk = std::min(512, tgtLen); float refac = scale; int numGroup = numQHead / numKVHead; - int numArr = 6; - int arrStride = (4 + tgtBlk + headSize) * srcBlk; - float *thrBuf = (float *)malloc(sizeof(float) * nth * arrStride); - float **thrPtrBuf = (float **)malloc(sizeof(float *) * nth * numArr); - + int numArr = 7; + int arrStride = (4 + tgtBlk + 2 * headSize) * srcBlk; + float *thrBuf; + float **thrPtrBuf; float **preSum = thrPtrBuf; float **sum = thrPtrBuf + nth; float **preMax = thrPtrBuf + nth * 2; float **max = thrPtrBuf + nth * 3; float **qkArr = thrPtrBuf + nth * 4; float **expQkvArr = thrPtrBuf + nth * 5; + float **qArr = thrPtrBuf + nth * 6; + + thrBuf = (float *)malloc(sizeof(float) * nth * arrStride); + thrPtrBuf = (float **)malloc(sizeof(float *) * nth * numArr); + + preSum = thrPtrBuf; + sum = thrPtrBuf + nth; + preMax = thrPtrBuf + nth * 2; + max = thrPtrBuf + nth * 3; + qkArr = thrPtrBuf + nth * 4; + expQkvArr = thrPtrBuf + nth * 5; + qArr = thrPtrBuf + nth * 6; for (int i = 0; i < nth; ++i) { preSum[i] = thrBuf + srcBlk * i; sum[i] = thrBuf + srcBlk * nth + srcBlk * i; @@ -787,6 +815,7 @@ class Attention { max[i] = thrBuf + srcBlk * nth * 3 + srcBlk * i; qkArr[i] = thrBuf + srcBlk * nth * 4 + srcBlk * tgtBlk * i; expQkvArr[i] = thrBuf + srcBlk * nth * (4 + tgtBlk) + srcBlk * headSize * i; + qArr[i] = thrBuf + srcBlk * nth * (4 + tgtBlk + headSize) + srcBlk * headSize * i; } #pragma omp parallel for collapse(3) @@ -794,21 +823,22 @@ class Attention { for (int j = 0; j < numQHead; ++j) { for (int m = 0; m < srcLen; m += srcBlk) { int tid = omp_get_thread_num(); - int tgtOff = i * numKVHead * tgtLen * headSize + (j / numGroup) * tgtLen * headSize; - const float *k = key + tgtOff; - const float *v = value + tgtOff; - const float *attnMsk = getMask(attnMask, i, j, srcLen, tgtLen) + m * tgtLen; int qRealBlk = std::min(srcBlk, srcLen - m); - int srcOff = i * numQHead * tgtLen * headSize + j * tgtLen * headSize; - const float *q = query + srcOff + m * headSize; - float *out = output + srcOff + m * headSize; + int srcOff = + i * srcLen * qStride + j * headSize; + int outOff = + i * srcLen * stride + j * headSize; + const float *qbuf = query + srcOff + m * qStride; + AttnT* q = (AttnT*)qArr[tid]; + float *out = output + outOff + m * stride; // reset out for (int ii = 0; ii < qRealBlk; ++ii) { #pragma omp simd for (int jj = 0; jj < headSize; ++jj) { - out[ii * headSize + jj] = 0; // reset output + out[ii * stride + jj] = 0; // reset output + q[ii * headSize + jj] = (AttnT)(qbuf[ii * qStride + jj]); // reset output } } // reset sum @@ -819,16 +849,22 @@ class Attention { preMax[tid][ii] = std::numeric_limits::lowest(); max[tid][ii] = std::numeric_limits::lowest(); } + + int tgtOff = + i * tgtLen * kvStride + (j / numGroup) * headSize; + const float *attnMsk = getMask(attnMask, i, j, srcLen, tgtLen) + m * tgtLen; + const AttnT* k = key + tgtOff; + const AttnT* v = value + tgtOff; // split the target len dimension for (int b = 0; b < tgtLen; b += tgtBlk) { int kvRealBlk = std::min(tgtBlk, tgtLen - b); // TODO: mask out - const float *kBlk = k + b * headSize; - const float *vBlk = v + b * headSize; + const AttnT* kBlk = k + b * kvStride; + const AttnT* vBlk = v + b * kvStride; DecoderUtil::incrementalTileAttention(q, kBlk, vBlk, attnMsk + b, qRealBlk, headSize, kvRealBlk, tgtLen, preSum[tid], sum[tid], preMax[tid], max[tid], refac, qkArr[tid], expQkvArr[tid], - out); + out, headSize, kvStride, kvStride, stride); } } } @@ -878,7 +914,8 @@ class Attention { } virtual const float *getMask(const float *attnMask, int bId, int hId, int srcLen, int tgtLen) { - return attnMask + bId * srcLen * tgtLen; + return attnMask; + //return attnMask + bId * srcLen * tgtLen; } // query, key, value weighs diff --git a/src/layers/mlp_chatglm2.h b/src/layers/mlp_chatglm2.h index e384a28f..ec0bea9a 100644 --- a/src/layers/mlp_chatglm2.h +++ b/src/layers/mlp_chatglm2.h @@ -37,37 +37,68 @@ class ChatGLM2MLP : public LlamaMLP { auto range = SplitUtil::getTaskRange(intermediateSize, ctx->numSplit, ctx->splitIdx); int colSplit = range.second - range.first; - float *gateW = (float *)malloc(hiddenSize * colSplit * sizeof(float)); - float *upW = (float *)malloc(hiddenSize * colSplit * sizeof(float)); - if (trans) { - int blockSize = colSplit * hiddenSize; - memcpy(gateW, gate_upW + ctx->splitIdx * blockSize, blockSize * sizeof(float)); - memcpy(upW, gate_upW + intermediateSize * hiddenSize + ctx->splitIdx * blockSize, - blockSize * sizeof(float)); - } else { - const float *weightPTR = gate_upW; - for (int i = 0; i < hiddenSize; i++) { - memcpy(gateW + i * colSplit, weightPTR + ctx->splitIdx * colSplit, colSplit * sizeof(float)); - weightPTR += intermediateSize; - memcpy(upW + i * colSplit, weightPTR + ctx->splitIdx * colSplit, colSplit * sizeof(float)); - weightPTR += intermediateSize; + + int enable = (getenv("ENABLE_CAT_MLP") ? atoi(getenv("ENABLE_CAT_MLP")) : 1); + if (enable == 0) { + float *gateW = (float *)malloc(hiddenSize * colSplit * sizeof(float)); + float *upW = (float *)malloc(hiddenSize * colSplit * sizeof(float)); + if (trans) { + int blockSize = colSplit * hiddenSize; + memcpy(gateW, gate_upW + ctx->splitIdx * blockSize, blockSize * sizeof(float)); + memcpy(upW, gate_upW + intermediateSize * hiddenSize + ctx->splitIdx * blockSize, + blockSize * sizeof(float)); + } else { + const float *weightPTR = gate_upW; + for (int i = 0; i < hiddenSize; i++) { + memcpy(gateW + i * colSplit, weightPTR + ctx->splitIdx * colSplit, colSplit * sizeof(float)); + weightPTR += intermediateSize; + memcpy(upW + i * colSplit, weightPTR + ctx->splitIdx * colSplit, colSplit * sizeof(float)); + weightPTR += intermediateSize; + } } - } - MMHelper::convertWeight( - trans, hiddenSize, colSplit, gateW, convertedGateWeight, this->gateWeightScale, this->gateWeightZero); - MMHelper::packWeight(trans, convertedGateWeight, this->gateWeight); + MMHelper::convertWeight( + trans, hiddenSize, colSplit, gateW, convertedGateWeight, this->gateWeightScale, this->gateWeightZero); + MMHelper::packWeight(trans, convertedGateWeight, this->gateWeight); - MMHelper::convertWeight(trans, hiddenSize, colSplit, upW, convertedUpWeight, this->upWeightScale, this->upWeightZero); - MMHelper::packWeight(trans, convertedUpWeight, this->upWeight); + MMHelper::convertWeight(trans, hiddenSize, colSplit, upW, convertedUpWeight, this->upWeightScale, this->upWeightZero); + MMHelper::packWeight(trans, convertedUpWeight, this->upWeight); - free(gateW); - free(upW); + free(gateW); + free(upW); + } else { + if (trans) { + printf("Trans GateUpW Not supported yet.\n"); + exit(-1); + } else { + int colSplitStride = colSplit * 2; + float *gateUpW = (float *)malloc(hiddenSize * colSplitStride * sizeof(float)); + const float *weightPTR = gate_upW; + for (int i = 0; i < hiddenSize; i++) { + memcpy(gateUpW + i * colSplitStride, weightPTR + ctx->splitIdx * colSplit, colSplit * sizeof(float)); + weightPTR += intermediateSize; + memcpy(gateUpW + colSplit + i * colSplitStride, weightPTR + ctx->splitIdx * colSplit, colSplit * sizeof(float)); + weightPTR += intermediateSize; + } + hpj::Matrix quantizedCatWeights; + MMHelper::convertWeight( + trans, hiddenSize, colSplitStride, gateUpW, quantizedCatWeights, this->catWeightsScale, this->catWeightsZero); + this->catWeights.Resize(quantizedCatWeights.Rows(), quantizedCatWeights.Cols()); + MMHelper::packWeight(trans, quantizedCatWeights, this->catWeights); + free(gateUpW); + } + } // Horizontally split the down weight - MMHelper::convertWeight(ctx, trans, intermediateSize, hiddenSize, downW, false, convertedDownWeight, - this->downWeightScale, this->downWeightZero); - MMHelper::packWeight(trans, convertedDownWeight, this->downWeight); + enable = (getenv("ENABLE_CBLAS_MLP") ? atoi(getenv("ENABLE_CBLAS_MLP")) : 0); + if (enable && std::is_same_v) { + MMHelper::convertWeight(ctx, trans, intermediateSize, hiddenSize, downW, false, this->downWeight, + this->downWeightScale, this->downWeightZero); + } else { + MMHelper::convertWeight(ctx, trans, intermediateSize, hiddenSize, downW, false, convertedDownWeight, + this->downWeightScale, this->downWeightZero); + MMHelper::packWeight(trans, convertedDownWeight, this->downWeight); + } #ifdef DEBUG this->dbg.debugPrint("convertedGateWeight [%d, %d](%d):\n", convertedGateWeight.Rows(), convertedGateWeight.Cols(), convertedGateWeight.Stride()); diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index b0cc5055..3d77aca5 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -54,22 +54,39 @@ class LlamaMLP : public SingletonBase> { hpj::Matrix quantizedGateWeight, quantizedUpWeight, quantizedDownWeight; auto it = SplitUtil::getTaskRange(imSize, ctx->numSplit, ctx->splitIdx); - gateWeight.Resize(hiddenSize, it.second - it.first); - upWeight.Resize(hiddenSize, it.second - it.first); downWeight.Resize(it.second - it.first, hiddenSize); MMHelper::convertWeight( ctx, trans, hiddenSize, imSize, gateW, true, quantizedGateWeight, gateWeightScale, gateWeightZero); - MMHelper::packWeight(trans, quantizedGateWeight, gateWeight); - MMHelper::convertWeight( ctx, trans, hiddenSize, imSize, upW, true, quantizedUpWeight, upWeightScale, upWeightZero); - MMHelper::packWeight(trans, quantizedUpWeight, upWeight); + + int enable = (getenv("ENABLE_CAT_MLP") ? atoi(getenv("ENABLE_CAT_MLP")) : 1); + if (enable == 0) { + gateWeight.Resize(hiddenSize, it.second - it.first); + upWeight.Resize(hiddenSize, it.second - it.first); + MMHelper::packWeight(trans, quantizedGateWeight, gateWeight); + MMHelper::packWeight(trans, quantizedUpWeight, upWeight); + } else { + hpj::Matrix quantizedCatWeights; + catGateUpWeights(quantizedGateWeight, quantizedUpWeight, gateWeightScale, gateWeightZero, + upWeightScale, upWeightZero, quantizedCatWeights, catWeightsScale, catWeightsZero); + quantizedGateWeight.Release(); + quantizedUpWeight.Release(); + catWeights.Resize(quantizedCatWeights.Rows(), quantizedCatWeights.Cols()); + MMHelper::packWeight(trans, quantizedCatWeights, catWeights); + } // Horizontally split the down weight - MMHelper::convertWeight( + enable = (getenv("ENABLE_CBLAS_MLP") ? atoi(getenv("ENABLE_CBLAS_MLP")) : 0); + if (enable && std::is_same_v) { + MMHelper::convertWeight( + ctx, trans, imSize, hiddenSize, downW, false, downWeight, downWeightScale, downWeightZero); + } else { + MMHelper::convertWeight( ctx, trans, imSize, hiddenSize, downW, false, quantizedDownWeight, downWeightScale, downWeightZero); - MMHelper::packWeight(trans, quantizedDownWeight, downWeight); + MMHelper::packWeight(trans, quantizedDownWeight, downWeight); + } #ifdef DEBUG dbg.debugPrint("quantizedGateWeight:\n"); @@ -103,7 +120,6 @@ class LlamaMLP : public SingletonBase> { hpj::Matrix inBuffer(input, M, hiddenSize, iStride); hpj::Matrix outBuffer(output, M, hiddenSize, oStride); auto &normBuffer = ctx->normBuf; - auto &imBuffer = ctx->imOut; if (doLnBefore == true) { DecoderUtil::rmsNorm(inBuffer, normBuffer, normWeight, 1e-6); @@ -114,25 +130,41 @@ class LlamaMLP : public SingletonBase> { dbg.dumpMatrix(normBuffer); #endif - gateProj(doLnBefore ? normBuffer : inBuffer, imBuffer); + int enable = (getenv("ENABLE_CAT_MLP") ? atoi(getenv("ENABLE_CAT_MLP")) : 1); + if (enable == 0) { + auto &imBuffer = ctx->imOut; + gateProj(doLnBefore ? normBuffer : inBuffer, imBuffer); #ifdef DEBUG - dbg.debugPrint("gateWeight:\n"); - dbg.dumpMatrix(gateWeight); - dbg.debugPrint("gate output:\n"); - dbg.dumpMatrix(imBuffer); + dbg.debugPrint("gateWeight:\n"); + dbg.dumpMatrix(gateWeight); + dbg.debugPrint("gate output:\n"); + dbg.dumpMatrix(imBuffer); #endif - upProj(doLnBefore ? normBuffer : inBuffer, imBuffer); + upProj(doLnBefore ? normBuffer : inBuffer, imBuffer); #ifdef DEBUG - dbg.debugPrint("upWeight:\n"); - dbg.dumpMatrix(upWeight); - dbg.debugPrint("up output:\n"); - dbg.dumpMatrix(imBuffer); + dbg.debugPrint("upWeight:\n"); + dbg.dumpMatrix(upWeight); + dbg.debugPrint("up output:\n"); + dbg.dumpMatrix(imBuffer); #endif + downProj(imBuffer, outBuffer, inBuffer, ctx->splitIdx == 0); + + } else { + hpj::Matrix imBuffer(ctx->imOut.Data(), normBuffer.Rows(), catWeights.Cols(), + catWeights.Cols()); + catGateUpProj(doLnBefore ? normBuffer : inBuffer, imBuffer); - downProj(imBuffer, outBuffer, inBuffer, ctx->splitIdx == 0); +#ifdef DEBUG + dbg.debugPrint("catWeights:\n"); + dbg.dumpMatrix(catWeights); + dbg.debugPrint("gateUp output:\n"); + dbg.dumpMatrix(imBuffer); +#endif + downProj(imBuffer, outBuffer, inBuffer, ctx->splitIdx == 0); + } #ifdef DEBUG dbg.debugPrint("downWeight:\n"); @@ -191,7 +223,7 @@ class LlamaMLP : public SingletonBase> { assert(input.Cols() == downWeight.Rows()); assert(downWeight.Cols() == output.Cols()); - int M = input.Rows(), N = output.Cols(), K = input.Cols(); + int M = input.Rows(), N = output.Cols(), K = downWeight.Rows(); int lda = input.Stride(), ldc = output.Stride(), ldr = residential.Stride(); const float *A = input.Data(); @@ -202,10 +234,90 @@ class LlamaMLP : public SingletonBase> { const float *R = residential.Data(); if (isMaster) { - MMHelper::compute_residential(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, 0.0f, C, ldc, NULL, R, ldr); + int enable = (getenv("ENABLE_CBLAS_MLP") ? atoi(getenv("ENABLE_CBLAS_MLP")) : 0); + if (enable && std::is_same_v) { + compute_proj_bf16(A, B, C, M, N, K, lda, ldc, ldc, R, ldr); + } else { + MMHelper::compute_residential(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, 0.0f, C, ldc, NULL, R, ldr); + } } else { - MMHelper::compute(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, 0.0f, C, ldc); + int enable = (getenv("ENABLE_CBLAS_MLP") ? atoi(getenv("ENABLE_CBLAS_MLP")) : 0); + if (enable && std::is_same_v) { + compute_proj_bf16(A, B, C, M, N, K, lda, ldc, ldc, nullptr, 0); + } else { + MMHelper::compute(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, 0.0f, C, ldc); + } + } + } + + void compute_proj_bf16(const float *A, const WeiT *B, float *C, int M, int N, int K, int lda, int ldb, int ldc, + const float *R, int ldr) { + int alpha = 1.0; + int beta = 0.0; + if (R != nullptr) { +#pragma omp parallel for + for (int i = 0; i < M; ++i) { + memcpy(C + i * ldc, R + i * ldr, N * sizeof(float)); + } + beta = 1.0; + } + int ldaH = lda * 2; +#pragma omp parallel for + for (int i = 0; i < M; ++i) { + bfloat16_t::cvt_float_to_bfloat16(A + i * lda, (bfloat16_t *)A + i * ldaH, K); } + cblas_gemm_bf16bf16f32( + CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha, (const MKL_BF16 *)(A), ldaH, + (const MKL_BF16 *)(B), ldb, beta, C, ldc); + } + + void catGateUpProj(hpj::Matrix &input, hpj::Matrix &output) { + TimeLine t("catGateUpProj"); + + assert(input.Rows() == output.Rows()); + assert(input.Cols() == catWeights.Rows()); + assert(catWeights.Cols() == output.Cols()); + + int M = input.Rows(), N = output.Cols(), K = input.Cols(); + int lda = input.Stride(), ldc = output.Stride(); + + const float *A = input.Data(); + const WeiT *B = catWeights.Data(); + const float *scaleB = catWeightsScale.Data(); + const float *zeroB = catWeightsZero.Data(); + float *C = output.Data(); + + MMHelper::compute(false, M, N, K, 1.0f, A, lda, B, scaleB, zeroB, 0.0f, C, ldc); + // compute silu on the left half and then add it with the right half + DecoderUtil::siluSum(output); + } + + void catGateUpWeights(hpj::Matrix &gateWeight, hpj::Matrix &upWeight, + hpj::Vector &gateWeightScale, hpj::Vector &gateWeightZero, + hpj::Vector &upWeightScale, hpj::Vector &upWeightZero, + hpj::Matrix &catWeights, + hpj::Vector &catWeightsScale, hpj::Vector &catWeightsZero) { + catWeights.Resize(gateWeight.Rows(), gateWeight.Cols() + upWeight.Cols()); + catWeightsScale.Resize(gateWeightScale.Size() + upWeightScale.Size()); + catWeightsZero.Resize(gateWeightZero.Size() + upWeightZero.Size()); + + int M = catWeights.Rows(); + int Stride = catWeights.Cols(); + int N = gateWeight.Cols(); +#pragma omp parallel for + for (int i = 0; i < M; ++i) { + memcpy(catWeights.Data() + i * Stride, gateWeight.Data() + i * N, + N * sizeof(WeiT)); + memcpy(catWeights.Data() + i * Stride + N, upWeight.Data() + i * N, + N * sizeof(WeiT)); + } + + M = gateWeightScale.Size(); + N = upWeightScale.Size(); + memcpy(catWeightsScale.Data(), gateWeightScale.Data(), M * sizeof(float)); + memcpy(catWeightsScale.Data() + M, upWeightScale.Data(), N * sizeof(float)); + memcpy(catWeightsZero.Data(), gateWeightZero.Data(), M * sizeof(float)); + memcpy(catWeightsZero.Data() + M, upWeightZero.Data(), N * sizeof(float)); } protected: @@ -215,6 +327,9 @@ class LlamaMLP : public SingletonBase> { hpj::Matrix upWeight; hpj::Vector upWeightScale; // For int8_t weight hpj::Vector upWeightZero; // For int8_t weight + hpj::Matrix catWeights; + hpj::Vector catWeightsScale; // For int8_t weight + hpj::Vector catWeightsZero; // For int8_t weight hpj::Matrix downWeight; hpj::Vector downWeightScale; // For int8_t weight hpj::Vector downWeightZero; // For int8_t weight diff --git a/src/utils/decoder_util.h b/src/utils/decoder_util.h index 98300796..764b2f8c 100644 --- a/src/utils/decoder_util.h +++ b/src/utils/decoder_util.h @@ -24,6 +24,8 @@ #include "my_types.h" #include "timeline.h" #include "transformer_ctx.h" +#include "xdnn.h" +#include class DecoderUtil { public: @@ -500,79 +502,90 @@ class DecoderUtil { } } - // batchs x seqlen x 3 x head x heads -> 3 x batchs x head x seqlen x heads (2 - // 0 3 1 4) - template - static void transposeQKV(const T *qkvBuffer, Tt *qkvTransBuffer, int batchSize, int seqLen, int headQNum, - int headKVNum, int headSize) { - int hiddenQSize = headQNum * headSize; - int hiddenKVSize = headKVNum * headSize; - int hiddenQKVSize = hiddenQSize + hiddenKVSize * 2; - - int blockSize = hiddenQKVSize * seqLen; - - const T *qBuffer = qkvBuffer; - const T *kBuffer = qkvBuffer + hiddenQSize; - const T *vBuffer = qkvBuffer + hiddenQSize + hiddenKVSize; - - Tt *qTransBuffer = qkvTransBuffer; - Tt *kTransBuffer = qkvTransBuffer + batchSize * hiddenQSize * seqLen; - Tt *vTransBuffer = qkvTransBuffer + batchSize * (hiddenQSize + hiddenKVSize) * seqLen; - -#pragma omp parallel for collapse(3) - for (int i = 0; i < batchSize; i++) { - for (int k = 0; k < headQNum; k++) { // assume headQNum >= headKVNum - for (int j = 0; j < seqLen; j++) { - const float *qSrcEachBatch = reinterpret_cast(qBuffer) + blockSize * i; - const float *kSrcEachBatch = reinterpret_cast(kBuffer) + blockSize * i; - const float *vSrcEachBatch = reinterpret_cast(vBuffer) + blockSize * i; - - int dstOffEachHead = k * seqLen * headSize; - int srcOffEachLine = k * headSize; - - int dstOffEachLine = j * headSize; - int srcOffEachHead = j * hiddenQKVSize; - - Tt *qDstEachLine = qTransBuffer + i * hiddenQSize * seqLen + dstOffEachHead + dstOffEachLine; - const T *qSrcEachLine = qSrcEachBatch + srcOffEachHead + srcOffEachLine; - - Tt *kDstEachLine = kTransBuffer + i * hiddenKVSize * seqLen + dstOffEachHead + dstOffEachLine; - const T *kSrcEachLine = kSrcEachBatch + srcOffEachHead + srcOffEachLine; - - Tt *vDstEachLine = vTransBuffer + i * hiddenKVSize * seqLen + dstOffEachHead + dstOffEachLine; - const T *vSrcEachLine = vSrcEachBatch + srcOffEachHead + srcOffEachLine; - arrayCpy(qDstEachLine, qSrcEachLine, headSize); - if (k < headKVNum) { - arrayCpy(kDstEachLine, kSrcEachLine, headSize); - arrayCpy(vDstEachLine, vSrcEachLine, headSize); - } - } - } - } + template + static void single_thread_cvt2bf16_inplace(T *buf, int m, int n, int stride) { + if (!std::is_same_v) + for (int i = 0; i < m; ++i) + bfloat16_t::cvt_float_to_bfloat16(buf + i * stride, (bfloat16_t *)buf + i * stride, n); } - // batchs x head x seqlen x heads -> batchs x seqlen x head x heads (0 2 1 3) - template - static void transposeAttnResult( - T *Buffer, Tt *TransBuffer, int batchSize, int seqLen, int headNum, int headSize, int dstStride) { - int hiddenSize = headNum * headSize; - int blockSize = seqLen * hiddenSize; // dst buffer stride in each batch - -#pragma omp parallel for collapse(2) - for (int i = 0; i < batchSize; i++) { - for (int k = 0; k < seqLen; k++) { - int srcOffEachHead = k * headSize; - int dstOffEachLine = k * dstStride; - - for (int j = 0; j < headNum; j++) { - int srcOffEachLine = j * seqLen * headSize; - int dstOffEachHead = j * headSize; + static inline __m512 dilExpKernel(__m512 vecSrc) { + static __m512 vecFac1 = _mm512_set1_ps(0.999999701f); // 1/factorial(1) + static __m512 vecFac2 = _mm512_set1_ps(0.499991506f); // 1/factorial(2) + static __m512 vecFac3 = _mm512_set1_ps(0.166676521f); // 1/factorial(3) + static __m512 vecFac4 = _mm512_set1_ps(0.0418978221f); // 1/factorial(4) + static __m512 vecFac5 = _mm512_set1_ps(0.00828929059f); // 1/factorial(5) + static __m512 vecExpLog2ef = (__m512)_mm512_set1_epi32(0x3fb8aa3b); // log2(e) + static __m512 vecHalf = _mm512_set1_ps(0.5f); + static __m512 vecOne = _mm512_set1_ps(1.f); + static __m512 vecZero = _mm512_set1_ps(0.f); + static __m512 vecTwo = _mm512_set1_ps(2.f); + static __m512 vecLn2f = (__m512)_mm512_set1_epi32(0x3f317218); // ln(2) + static __m512 vecLnFltMin = (__m512)_mm512_set1_epi32(0xc2aeac50); + static __m512 vecLnFltMax = (__m512)_mm512_set1_epi32(0x42b17218); + static __m512i vec127 = _mm512_set1_epi32(0x0000007f); + static int nMantissaBits = 23; + + // exp(x) = + // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem + // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression + + auto lessLnFltMinMask = + _mm512_cmp_ps_mask(vecSrc, vecLnFltMin, 1 /*_CMP_LT_OS*/); + vecSrc = _mm512_min_ps(vecSrc, vecLnFltMax); + vecSrc = _mm512_max_ps(vecSrc, vecLnFltMin); + + // fx = floorf(x * log2ef + 0.5) + auto vecFx = _mm512_fmadd_ps(vecSrc, vecExpLog2ef, vecHalf); + auto vecFx_i = _mm512_cvt_roundps_epi32(vecFx, _MM_FROUND_TO_NEG_INF | + _MM_FROUND_NO_EXC); + vecFx = _mm512_cvtepi32_ps(vecFx_i); + + // x = x - fx * ln2 + auto vecExpPoly = _mm512_fnmadd_ps(vecFx, vecLn2f, vecSrc); + + // compute polynomial + auto vecRes = + _mm512_fmadd_ps(vecExpPoly, vecFac5, vecFac4); + vecRes = _mm512_fmadd_ps(vecExpPoly, vecRes, vecFac3); + vecRes = _mm512_fmadd_ps(vecExpPoly, vecRes, vecFac2); + vecRes = _mm512_fmadd_ps(vecExpPoly, vecRes, vecFac1); + vecRes = _mm512_fmadd_ps(vecExpPoly, vecRes, vecOne); + // compute 2^(n-1) + auto vecExpNumber = _mm512_sub_ps(vecFx, vecOne); + auto vecExpNumber_i = _mm512_cvtps_epi32(vecExpNumber); + auto vecTwoPowN_i = _mm512_add_epi32(vecExpNumber_i, vec127); + vecTwoPowN_i = _mm512_slli_epi32(vecTwoPowN_i, nMantissaBits); + auto vecTwoPowN = (__m512)vecTwoPowN_i; + vecTwoPowN = + _mm512_mask_blend_ps(lessLnFltMinMask, vecTwoPowN, vecZero); + + // y = y * 2^n + vecRes = _mm512_mul_ps(vecRes, vecTwoPowN); + vecRes = _mm512_mul_ps(vecRes, vecTwo); + return vecRes; + } - Tt *qDstEachLine = TransBuffer + dstOffEachHead + dstOffEachLine + i * seqLen * dstStride; - const T *qSrcEachLine = Buffer + srcOffEachLine + srcOffEachHead + i * blockSize; + // compute silu on the left half and then add it with the right half + static void siluSum(hpj::Matrix &src) { + __m512 one = _mm512_set1_ps(1.f); + __m512 negOne = _mm512_set1_ps(-1.f); + int M = src.Rows(); + int stride = src.Cols(); + int N = stride / 2; - arrayCpy(qDstEachLine, qSrcEachLine, headSize); - } +#pragma omp parallel for collapse(2) + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; j += 16) { + int remain = N - j; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + auto left = _mm512_maskz_loadu_ps(mask, src.Data() + i * stride + j); + auto right = _mm512_maskz_loadu_ps(mask, src.Data() + i * stride + j + N); + auto x0 = dilExpKernel(_mm512_mul_ps(left, negOne)); + auto x1 = _mm512_add_ps(one, x0); + auto x2 = _mm512_div_ps(left, x1); + auto res = _mm512_mul_ps(right, x2); + _mm512_mask_storeu_ps(src.Data() + i * stride + j, mask, res); } } } @@ -580,18 +593,32 @@ class DecoderUtil { // C = A * B // bTranspose: B need to be transposed or not // xdnn_sgemm_single_thread(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); - static void sgemm(const float *A, const float *B, float *C, int m, int n, int k, bool transa, bool transb) { - int lda = (transa ? m : k); - int ldb = (transb ? k : n); - int ldc = n; + template + static void sgemm(const T *A, const T *B, float *C, int m, int n, int k, + int lda, int ldb, int ldc, bool transa, bool transb) { float alpha = 1; float beta = 0; - char ta[] = "N"; - char tb[] = "N"; - if (transa) ta[0] = 'T'; - if (transb) tb[0] = 'T'; - dnnl_sgemm(ta[0], tb[0], m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + if constexpr (std::is_same_v) { + char ta[] = "N"; + char tb[] = "N"; + if (transa) ta[0] = 'T'; + if (transb) tb[0] = 'T'; + + dnnl_sgemm(ta[0], tb[0], m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + + } else if (std::is_same_v) { + CBLAS_TRANSPOSE ta, tb; + ta = transa? CblasTrans : CblasNoTrans; + tb = transb? CblasTrans : CblasNoTrans; + + cblas_gemm_bf16bf16f32( + CblasRowMajor, ta, tb, m, n, k, alpha, (const MKL_BF16 *)(A), lda, + (const MKL_BF16 *)(B), ldb, beta, C, ldc); + } else { + printf("Datatype Not supported yet\n"); + exit(-1); + } } // need to do for res. @@ -653,11 +680,11 @@ class DecoderUtil { } } - static void updateOutTile( - float *output, const float *expABC, float *preSum, float *sum, float *preMax, float *max, int m, int n) { + static void updateOutTile(float *output, const float *expABC, float *preSum, float *sum, float *preMax, + float *max, int m, int n, int stride) { for (int i = 0; i < m; ++i) { const float *buf = expABC + i * n; - float *outbuf = output + i * n; + float *outbuf = output + i * stride; __m512 merr = _mm512_set1_ps(preMax[i] - max[i]); merr = BertUtil::vexp(merr); __m512 vfac = _mm512_set1_ps(preSum[i] / sum[i]); @@ -678,12 +705,16 @@ class DecoderUtil { // sum += sum(exp(A[i])) // output = output * preSum / sum + (exp(A) / sum) x B // preSum = sum - static void incrementalTileAttention(const float *A, const float *B, const float *C, const float *attnMask, int m, - int n, int k, int attnMskStride, float *preSum, float *sum, float *preMax, float *max, float refac, - float *AB, float *expABC, float *output) { - sgemm(A, B, AB, m, k, n, false, true); + template + static void incrementalTileAttention(const T *A, const T *B, const T *C, const float *attnMask, + int m, int n, int k, int attnMskStride, float *preSum, float *sum, float *preMax, float *max, + float refac, float *AB, float *expABC, float *output, int qStride, int kStride, int vStride, int stride) { + sgemm(A, B, AB, m, k, n, qStride, kStride, k, false, true); softmaxTile(AB, sum, max, preSum, preMax, refac, attnMask, m, k, attnMskStride); - sgemm(AB, C, expABC, m, n, k, false, false); - updateOutTile(output, expABC, preSum, sum, preMax, max, m, n); + + single_thread_cvt2bf16_inplace(AB, m, k, k); + sgemm((T*)AB, C, expABC, m, n, k, k, vStride, n, false, false); + updateOutTile(output, expABC, preSum, sum, preMax, max, m, n, stride); } + };