Skip to content

Commit

Permalink
Accelerate first token gen with BF16-gemm MHA and concat-Silu MLP (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
abenmao authored Dec 6, 2023
1 parent 6bf1e1f commit 605e62e
Show file tree
Hide file tree
Showing 10 changed files with 376 additions and 191 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dist/
# 3drparty
/3rdparty/ig
/3rdparty/mklml
/3rdparty/mkl
/3rdparty/oneCCL
/3rdparty/oneccl
/3rdparty/jsoncpp
Expand Down
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ endif()
include("cmake/mklml.cmake")
include("cmake/onednn.cmake")
include("cmake/xdnn.cmake")
include("cmake/mkl.cmake")
include(GNUInstallDirs)

set(DEPEND_LIST "onednn" "xdnn_lib")
Expand All @@ -63,6 +64,7 @@ include_directories(${CMAKE_SOURCE_DIR}/3rdparty/)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/include)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/build/include)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/xdnn)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/mkl/include)
include_directories(${CMAKE_SOURCE_DIR}/include)
include_directories(${CMAKE_SOURCE_DIR}/src/kernels)
include_directories(${CMAKE_SOURCE_DIR}/src/layers)
Expand All @@ -74,6 +76,7 @@ include_directories(${CMAKE_SOURCE_DIR}/src/common)
link_directories(${CMAKE_SOURCE_DIR}/src/kernels)
link_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/build/src)
link_directories(${CMAKE_SOURCE_DIR}/3rdparty/xdnn)
link_directories(${CMAKE_SOURCE_DIR}/3rdparty/mkl/lib)

find_package(oneCCL REQUIRED)

Expand All @@ -93,7 +96,7 @@ else()
link_directories(${CMAKE_SOURCE_DIR}/3rdparty/oneccl/build/_install/lib/prov)
endif()

set(3RDPART_LIB_LIST ${MPI_LIBS} "ccl" "dnnl" "numa")
set(3RDPART_LIB_LIST ${MPI_LIBS} "ccl" "dnnl" "numa" "mkl_rt")
option(BUILD_WITH_SHARED_LIBS "Build with shared libraries" OFF)
if(BUILD_WITH_SHARED_LIBS)
message(STATUS "Notice: Building with shared libraries.")
Expand Down
29 changes: 29 additions & 0 deletions cmake/mkl.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

cmake_minimum_required(VERSION 3.18)

# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
if(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif()

find_package (Python COMPONENTS Interpreter Development)
execute_process(COMMAND ${Python_EXECUTABLE} -m pip install --prefix=${CMAKE_SOURCE_DIR}/3rdparty/mkl mkl mkl-include
RESULT_VARIABLE EXIT_CODE
OUTPUT_QUIET)
execute_process(COMMAND ln -sf ${CMAKE_SOURCE_DIR}/3rdparty/mkl/lib/libmkl_rt.so.2 ${CMAKE_SOURCE_DIR}/3rdparty/mkl/lib/libmkl_rt.so
RESULT_VARIABLE EXIT_CODE
OUTPUT_QUIET)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ sentencepiece==0.1.99
tokenizers==0.13.3
torch==2.0.1+cpu
transformers==4.30.0
accelerate==0.23.0
accelerate==0.23.0
3 changes: 2 additions & 1 deletion src/common/transformer_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,14 @@ struct DecoderContext {
int vCols = kCols;
int qkvCols = qCols + kCols + vCols;
int qkvStride = (qkvCols % 512 == 0 ? qkvCols + pad : qkvCols); // stride for the concated QKV
int mlpFactor = (this->actType == SILU || this->actType == SWIGLU) ? 2 : 1;
int imCols = splitIdx < (intermediateSize % numSplit) ? (intermediateSize / numSplit + 1)
: (intermediateSize / numSplit);
int imStride = (imCols % 512 == 0 ? imCols + pad : imCols); // stride for intermediate output

int normSize = batchSize * inputSeqLen * hiddenStride;
int qkvSize = batchSize * inputSeqLen * qkvStride;
int imOutSize = batchSize * inputSeqLen * imStride;
int imOutSize = batchSize * inputSeqLen * imStride * mlpFactor;

int presentSeqLen = preSeqLen + 1;
int paddedSize = (presentSeqLen + 15) / 16 * 16;
Expand Down
105 changes: 66 additions & 39 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class Attention {
inputBuffer.Assign(presult, rows, cols, stride);
}
// TODO: support large inputSeqLen when pastSeqLen > 0
if (ctx->inputSeqLen > 256 && pastSeqLen == 0)
if (ctx->inputSeqLen >= 1024 && pastSeqLen == 0)
flashAttention(
ctx, qkvGroupMatMul, resultBuffer2, resultBuffer1, presentKey, presentValue, attnMask, pastSeqLen);
else
Expand Down Expand Up @@ -702,34 +702,50 @@ class Attention {
} // end for b
}

template <typename KVCacheT>
void flashAttention(DecoderContext *ctx, hpj::Matrix<float> &qkvMatMul, hpj::Matrix<float> &tmpRes,
template <typename KVCacheT, typename AttnT = bfloat16_t>
void flashAttention(DecoderContext *ctx, hpj::Matrix<float> &qkvMatMul, hpj::Matrix<float> &tmpBuf,
hpj::Matrix<float> &result, KVCacheTensor<KVCacheT> &presentKey, KVCacheTensor<KVCacheT> &presentValue,
const float *attnMask, int pastSeqLen) {

// How many heads this task should do
int batchSize = ctx->batchSize;
int respQHeads = this->endQHead - this->startQHead;
int respKVHeads = this->endKVHead - this->startKVHead;
int qkvCols = respQHeads + respKVHeads * 2;
int headSize = ctx->attHeadSize;
int qCols = respQHeads * headSize;
int kvCols = respKVHeads * headSize;
int qkvCols = qCols + kvCols * 2;
float scale = ctx->attFactor;
int srcLen = ctx->inputSeqLen;
int tgtLen = pastSeqLen + srcLen;

float *transQKV = (float *)malloc(sizeof(float) * batchSize * qkvCols * srcLen * headSize);

DecoderUtil::transposeQKV(qkvMatMul.Data(), transQKV, batchSize, srcLen, respQHeads, respKVHeads, headSize);
// TODO: kv dtype conversion for prefixSharing
AttnT *k, *v;
if constexpr (std::is_same_v<AttnT, bfloat16_t>) {
#pragma omp parallel for collapse(3)
for (int b = 0; b < batchSize; ++b)
for (int seq = 0; seq < srcLen; ++seq)
for (int i = qCols; i < qkvCols; i += headSize) {
const float *srcPtr = qkvMatMul.Data() + b * srcLen * qkvCols + seq * qkvCols + i;
bfloat16_t *dstPtr
= (bfloat16_t *)tmpBuf.Data() + b * srcLen * kvCols * 2 + seq * kvCols * 2 + i - qCols;
bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize);
}

float *query = transQKV;
float *key = transQKV + batchSize * respQHeads * srcLen * headSize;
float *value = transQKV + batchSize * (respQHeads + respKVHeads) * srcLen * headSize;
k = (AttnT *)tmpBuf.Data();
v = (AttnT *)tmpBuf.Data() + kvCols;
} else {
k = qkvMatMul.Data() + respQHeads * headSize;
v = qkvMatMul.Data() + (respQHeads + respKVHeads) * headSize;
}

scaledDpAttention(query, key, value, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads, respKVHeads,
headSize, tmpRes.Data());
DecoderUtil::transposeAttnResult(
tmpRes.Data(), result.Data(), batchSize, srcLen, respQHeads, headSize, result.Stride());
float *query = qkvMatMul.Data();
// [batch, src, head, headsize]
scaledDpAttention<AttnT>(query, k, v, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads, respKVHeads,
headSize, result.Data(), qkvCols, kvCols * 2, ctx->hiddenSize);

float *key = qkvMatMul.Data() + respQHeads * headSize;
float *value = qkvMatMul.Data() + (respQHeads + respKVHeads) * headSize;
// For group attention, as #kvHeads != #qHeads, need to copy current key/values to cache seperately
// When M dimension is split, also multiple tasks per copy, so do copy seperately
#pragma omp parallel for collapse(3)
Expand All @@ -739,10 +755,10 @@ class Attention {
// Re-layout is needed: (bs, seq=1, hidden_size) -> (seq=1, bs, hidden_size)
// Be noted: for group attention, the key/value is less than query
for (int seq = 0; seq < tgtLen; ++seq) {
auto srcK = key + b * respKVHeads * tgtLen * headSize + i * tgtLen * headSize + seq * headSize;
auto srcK = key + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
auto dstK = presentKey.getSequence(pastSeqLen + seq, b, i);

auto srcV = value + b * respKVHeads * tgtLen * headSize + i * tgtLen * headSize + seq * headSize;
auto srcV = value + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
auto dstV = presentValue.getSequence(pastSeqLen + seq, b, i);

if constexpr (std::is_same_v<KVCacheT, float>) {
Expand All @@ -755,60 +771,67 @@ class Attention {
}
}
}
free(transQKV);
}

// scaled dot-product attention: bmm1 + softmax + bmm2
void scaledDpAttention(const float *query, const float *key, const float *value, const float *attnMask, float scale,
int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize, float *output) {
template <typename AttnT>
void scaledDpAttention(const float *query, const AttnT *key, const AttnT *value, const float *attnMask, float scale,
int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize, float *output,
int qStride, int kvStride, int stride) {
// output = trans(softmax(query * trans(key)) * value)
int nth = omp_get_max_threads();
int minBlk = (nth >= batchSize * numQHead ? 256 : 512);
int srcBlk = std::min(minBlk, srcLen);
int tgtBlk = std::min(minBlk, tgtLen);
// closest value of power of 2
int minBlk = (int)std::pow(2, int(std::log2(srcLen / 2)));
// Split sequence to make sure a moderate sync frequency and the intermediate
// result [srcSeq * tgtSeq] in cache. The current block size is derived from practical experience.
int srcBlk = std::min(256, minBlk);
int tgtBlk = std::min(512, tgtLen);
float refac = scale;
int numGroup = numQHead / numKVHead;

int numArr = 6;
int arrStride = (4 + tgtBlk + headSize) * srcBlk;
float *thrBuf = (float *)malloc(sizeof(float) * nth * arrStride);
float **thrPtrBuf = (float **)malloc(sizeof(float *) * nth * numArr);
int numArr = 7;
int arrStride = (4 + tgtBlk + 2 * headSize) * srcBlk;
float *thrBuf = (float *)SimpleMemPool::instance().getBuffer("threadBuffers", nth * arrStride * sizeof(float));
float **thrPtrBuf
= (float **)SimpleMemPool::instance().getBuffer("threadPtrBuffers", nth * numArr * sizeof(float *));

float **preSum = thrPtrBuf;
float **sum = thrPtrBuf + nth;
float **preMax = thrPtrBuf + nth * 2;
float **max = thrPtrBuf + nth * 3;
float **qkArr = thrPtrBuf + nth * 4;
float **expQkvArr = thrPtrBuf + nth * 5;
float **qArr = thrPtrBuf + nth * 6;

for (int i = 0; i < nth; ++i) {
preSum[i] = thrBuf + srcBlk * i;
sum[i] = thrBuf + srcBlk * nth + srcBlk * i;
preMax[i] = thrBuf + srcBlk * nth * 2 + srcBlk * i;
max[i] = thrBuf + srcBlk * nth * 3 + srcBlk * i;
qkArr[i] = thrBuf + srcBlk * nth * 4 + srcBlk * tgtBlk * i;
expQkvArr[i] = thrBuf + srcBlk * nth * (4 + tgtBlk) + srcBlk * headSize * i;
qArr[i] = thrBuf + srcBlk * nth * (4 + tgtBlk + headSize) + srcBlk * headSize * i;
}

#pragma omp parallel for collapse(3)
for (int i = 0; i < batchSize; ++i) {
for (int j = 0; j < numQHead; ++j) {
for (int m = 0; m < srcLen; m += srcBlk) {
int tid = omp_get_thread_num();
int tgtOff = i * numKVHead * tgtLen * headSize + (j / numGroup) * tgtLen * headSize;
const float *k = key + tgtOff;
const float *v = value + tgtOff;
const float *attnMsk = getMask(attnMask, i, j, srcLen, tgtLen) + m * tgtLen;

int qRealBlk = std::min(srcBlk, srcLen - m);
int srcOff = i * numQHead * tgtLen * headSize + j * tgtLen * headSize;
const float *q = query + srcOff + m * headSize;
float *out = output + srcOff + m * headSize;
int srcOff = i * srcLen * qStride + j * headSize;
int outOff = i * srcLen * stride + j * headSize;
const float *qbuf = query + srcOff + m * qStride;
AttnT *q = (AttnT *)qArr[tid];
float *out = output + outOff + m * stride;

// reset out
for (int ii = 0; ii < qRealBlk; ++ii) {
#pragma omp simd
for (int jj = 0; jj < headSize; ++jj) {
out[ii * headSize + jj] = 0; // reset output
out[ii * stride + jj] = 0; // reset output
q[ii * headSize + jj] = (AttnT)(qbuf[ii * qStride + jj]); // reset output
}
}
// reset sum
Expand All @@ -819,22 +842,25 @@ class Attention {
preMax[tid][ii] = std::numeric_limits<float>::lowest();
max[tid][ii] = std::numeric_limits<float>::lowest();
}

int tgtOff = i * tgtLen * kvStride + (j / numGroup) * headSize;
const float *attnMsk = getMask(attnMask, i, j, srcLen, tgtLen) + m * tgtLen;
const AttnT *k = key + tgtOff;
const AttnT *v = value + tgtOff;
// split the target len dimension
for (int b = 0; b < tgtLen; b += tgtBlk) {
int kvRealBlk = std::min(tgtBlk, tgtLen - b);
// TODO: mask out
const float *kBlk = k + b * headSize;
const float *vBlk = v + b * headSize;
const AttnT *kBlk = k + b * kvStride;
const AttnT *vBlk = v + b * kvStride;

DecoderUtil::incrementalTileAttention(q, kBlk, vBlk, attnMsk + b, qRealBlk, headSize, kvRealBlk,
tgtLen, preSum[tid], sum[tid], preMax[tid], max[tid], refac, qkArr[tid], expQkvArr[tid],
out);
out, headSize, kvStride, kvStride, stride);
}
}
}
}
free(thrPtrBuf);
free(thrBuf);
return;
}

Expand Down Expand Up @@ -878,6 +904,7 @@ class Attention {
}

virtual const float *getMask(const float *attnMask, int bId, int hId, int srcLen, int tgtLen) {
// Would mask be different for each sample in one batch?
return attnMask + bId * srcLen * tgtLen;
}

Expand Down
Loading

0 comments on commit 605e62e

Please sign in to comment.