Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate first token gen with BF16-gemm MHA and concat-Silu MLP #106

Merged
merged 1 commit into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dist/
# 3drparty
/3rdparty/ig
/3rdparty/mklml
/3rdparty/mkl
/3rdparty/oneCCL
/3rdparty/oneccl
/3rdparty/jsoncpp
Expand Down
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ endif()
include("cmake/mklml.cmake")
include("cmake/onednn.cmake")
include("cmake/xdnn.cmake")
include("cmake/mkl.cmake")
include(GNUInstallDirs)

set(DEPEND_LIST "onednn" "xdnn_lib")
Expand All @@ -63,6 +64,7 @@ include_directories(${CMAKE_SOURCE_DIR}/3rdparty/)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/include)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/build/include)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/xdnn)
include_directories(${CMAKE_SOURCE_DIR}/3rdparty/mkl/include)
include_directories(${CMAKE_SOURCE_DIR}/include)
include_directories(${CMAKE_SOURCE_DIR}/src/kernels)
include_directories(${CMAKE_SOURCE_DIR}/src/layers)
Expand All @@ -74,6 +76,7 @@ include_directories(${CMAKE_SOURCE_DIR}/src/common)
link_directories(${CMAKE_SOURCE_DIR}/src/kernels)
link_directories(${CMAKE_SOURCE_DIR}/3rdparty/onednn/build/src)
link_directories(${CMAKE_SOURCE_DIR}/3rdparty/xdnn)
link_directories(${CMAKE_SOURCE_DIR}/3rdparty/mkl/lib)

find_package(oneCCL REQUIRED)

Expand All @@ -93,7 +96,7 @@ else()
link_directories(${CMAKE_SOURCE_DIR}/3rdparty/oneccl/build/_install/lib/prov)
endif()

set(3RDPART_LIB_LIST ${MPI_LIBS} "ccl" "dnnl" "numa")
set(3RDPART_LIB_LIST ${MPI_LIBS} "ccl" "dnnl" "numa" "mkl_rt")
option(BUILD_WITH_SHARED_LIBS "Build with shared libraries" OFF)
if(BUILD_WITH_SHARED_LIBS)
message(STATUS "Notice: Building with shared libraries.")
Expand Down
29 changes: 29 additions & 0 deletions cmake/mkl.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

cmake_minimum_required(VERSION 3.18)

# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
if(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif()

find_package (Python COMPONENTS Interpreter Development)
execute_process(COMMAND ${Python_EXECUTABLE} -m pip install --prefix=${CMAKE_SOURCE_DIR}/3rdparty/mkl mkl mkl-include
RESULT_VARIABLE EXIT_CODE
OUTPUT_QUIET)
execute_process(COMMAND ln -sf ${CMAKE_SOURCE_DIR}/3rdparty/mkl/lib/libmkl_rt.so.2 ${CMAKE_SOURCE_DIR}/3rdparty/mkl/lib/libmkl_rt.so
RESULT_VARIABLE EXIT_CODE
OUTPUT_QUIET)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ sentencepiece==0.1.99
tokenizers==0.13.3
torch==2.0.1+cpu
transformers==4.30.0
accelerate==0.23.0
accelerate==0.23.0
3 changes: 2 additions & 1 deletion src/common/transformer_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,14 @@ struct DecoderContext {
int vCols = kCols;
int qkvCols = qCols + kCols + vCols;
int qkvStride = (qkvCols % 512 == 0 ? qkvCols + pad : qkvCols); // stride for the concated QKV
int mlpFactor = (this->actType == SILU || this->actType == SWIGLU) ? 2 : 1;
int imCols = splitIdx < (intermediateSize % numSplit) ? (intermediateSize / numSplit + 1)
: (intermediateSize / numSplit);
int imStride = (imCols % 512 == 0 ? imCols + pad : imCols); // stride for intermediate output

int normSize = batchSize * inputSeqLen * hiddenStride;
int qkvSize = batchSize * inputSeqLen * qkvStride;
int imOutSize = batchSize * inputSeqLen * imStride;
int imOutSize = batchSize * inputSeqLen * imStride * mlpFactor;

int presentSeqLen = preSeqLen + 1;
int paddedSize = (presentSeqLen + 15) / 16 * 16;
Expand Down
105 changes: 66 additions & 39 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class Attention {
inputBuffer.Assign(presult, rows, cols, stride);
}
// TODO: support large inputSeqLen when pastSeqLen > 0
if (ctx->inputSeqLen > 256 && pastSeqLen == 0)
if (ctx->inputSeqLen >= 1024 && pastSeqLen == 0)
flashAttention(
ctx, qkvGroupMatMul, resultBuffer2, resultBuffer1, presentKey, presentValue, attnMask, pastSeqLen);
else
Expand Down Expand Up @@ -702,34 +702,50 @@ class Attention {
} // end for b
}

template <typename KVCacheT>
void flashAttention(DecoderContext *ctx, hpj::Matrix<float> &qkvMatMul, hpj::Matrix<float> &tmpRes,
template <typename KVCacheT, typename AttnT = bfloat16_t>
void flashAttention(DecoderContext *ctx, hpj::Matrix<float> &qkvMatMul, hpj::Matrix<float> &tmpBuf,
hpj::Matrix<float> &result, KVCacheTensor<KVCacheT> &presentKey, KVCacheTensor<KVCacheT> &presentValue,
const float *attnMask, int pastSeqLen) {

// How many heads this task should do
int batchSize = ctx->batchSize;
int respQHeads = this->endQHead - this->startQHead;
int respKVHeads = this->endKVHead - this->startKVHead;
int qkvCols = respQHeads + respKVHeads * 2;
int headSize = ctx->attHeadSize;
int qCols = respQHeads * headSize;
int kvCols = respKVHeads * headSize;
int qkvCols = qCols + kvCols * 2;
float scale = ctx->attFactor;
int srcLen = ctx->inputSeqLen;
int tgtLen = pastSeqLen + srcLen;

float *transQKV = (float *)malloc(sizeof(float) * batchSize * qkvCols * srcLen * headSize);

DecoderUtil::transposeQKV(qkvMatMul.Data(), transQKV, batchSize, srcLen, respQHeads, respKVHeads, headSize);
// TODO: kv dtype conversion for prefixSharing
AttnT *k, *v;
if constexpr (std::is_same_v<AttnT, bfloat16_t>) {
#pragma omp parallel for collapse(3)
for (int b = 0; b < batchSize; ++b)
for (int seq = 0; seq < srcLen; ++seq)
for (int i = qCols; i < qkvCols; i += headSize) {
const float *srcPtr = qkvMatMul.Data() + b * srcLen * qkvCols + seq * qkvCols + i;
bfloat16_t *dstPtr
= (bfloat16_t *)tmpBuf.Data() + b * srcLen * kvCols * 2 + seq * kvCols * 2 + i - qCols;
bfloat16_t::cvt_float_to_bfloat16(srcPtr, dstPtr, headSize);
}

float *query = transQKV;
float *key = transQKV + batchSize * respQHeads * srcLen * headSize;
float *value = transQKV + batchSize * (respQHeads + respKVHeads) * srcLen * headSize;
k = (AttnT *)tmpBuf.Data();
v = (AttnT *)tmpBuf.Data() + kvCols;
} else {
k = qkvMatMul.Data() + respQHeads * headSize;
v = qkvMatMul.Data() + (respQHeads + respKVHeads) * headSize;
}

scaledDpAttention(query, key, value, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads, respKVHeads,
headSize, tmpRes.Data());
DecoderUtil::transposeAttnResult(
tmpRes.Data(), result.Data(), batchSize, srcLen, respQHeads, headSize, result.Stride());
float *query = qkvMatMul.Data();
// [batch, src, head, headsize]
scaledDpAttention<AttnT>(query, k, v, attnMask, scale, batchSize, srcLen, tgtLen, respQHeads, respKVHeads,
headSize, result.Data(), qkvCols, kvCols * 2, ctx->hiddenSize);

float *key = qkvMatMul.Data() + respQHeads * headSize;
float *value = qkvMatMul.Data() + (respQHeads + respKVHeads) * headSize;
// For group attention, as #kvHeads != #qHeads, need to copy current key/values to cache seperately
// When M dimension is split, also multiple tasks per copy, so do copy seperately
#pragma omp parallel for collapse(3)
Expand All @@ -739,10 +755,10 @@ class Attention {
// Re-layout is needed: (bs, seq=1, hidden_size) -> (seq=1, bs, hidden_size)
// Be noted: for group attention, the key/value is less than query
for (int seq = 0; seq < tgtLen; ++seq) {
auto srcK = key + b * respKVHeads * tgtLen * headSize + i * tgtLen * headSize + seq * headSize;
auto srcK = key + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
auto dstK = presentKey.getSequence(pastSeqLen + seq, b, i);

auto srcV = value + b * respKVHeads * tgtLen * headSize + i * tgtLen * headSize + seq * headSize;
auto srcV = value + b * tgtLen * qkvCols + seq * qkvCols + i * headSize;
auto dstV = presentValue.getSequence(pastSeqLen + seq, b, i);

if constexpr (std::is_same_v<KVCacheT, float>) {
Expand All @@ -755,60 +771,67 @@ class Attention {
}
}
}
free(transQKV);
}

// scaled dot-product attention: bmm1 + softmax + bmm2
void scaledDpAttention(const float *query, const float *key, const float *value, const float *attnMask, float scale,
int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize, float *output) {
template <typename AttnT>
void scaledDpAttention(const float *query, const AttnT *key, const AttnT *value, const float *attnMask, float scale,
int batchSize, int srcLen, int tgtLen, int numQHead, int numKVHead, int headSize, float *output,
int qStride, int kvStride, int stride) {
// output = trans(softmax(query * trans(key)) * value)
int nth = omp_get_max_threads();
int minBlk = (nth >= batchSize * numQHead ? 256 : 512);
int srcBlk = std::min(minBlk, srcLen);
int tgtBlk = std::min(minBlk, tgtLen);
// closest value of power of 2
int minBlk = (int)std::pow(2, int(std::log2(srcLen / 2)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any design principle here? if any, would you make some comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current block size is just derived from practical experience. Added comments.

// Split sequence to make sure a moderate sync frequency and the intermediate
// result [srcSeq * tgtSeq] in cache. The current block size is derived from practical experience.
int srcBlk = std::min(256, minBlk);
int tgtBlk = std::min(512, tgtLen);
float refac = scale;
int numGroup = numQHead / numKVHead;

int numArr = 6;
int arrStride = (4 + tgtBlk + headSize) * srcBlk;
float *thrBuf = (float *)malloc(sizeof(float) * nth * arrStride);
float **thrPtrBuf = (float **)malloc(sizeof(float *) * nth * numArr);
int numArr = 7;
int arrStride = (4 + tgtBlk + 2 * headSize) * srcBlk;
float *thrBuf = (float *)SimpleMemPool::instance().getBuffer("threadBuffers", nth * arrStride * sizeof(float));
float **thrPtrBuf
= (float **)SimpleMemPool::instance().getBuffer("threadPtrBuffers", nth * numArr * sizeof(float *));

float **preSum = thrPtrBuf;
float **sum = thrPtrBuf + nth;
float **preMax = thrPtrBuf + nth * 2;
float **max = thrPtrBuf + nth * 3;
float **qkArr = thrPtrBuf + nth * 4;
float **expQkvArr = thrPtrBuf + nth * 5;
float **qArr = thrPtrBuf + nth * 6;

for (int i = 0; i < nth; ++i) {
preSum[i] = thrBuf + srcBlk * i;
sum[i] = thrBuf + srcBlk * nth + srcBlk * i;
preMax[i] = thrBuf + srcBlk * nth * 2 + srcBlk * i;
max[i] = thrBuf + srcBlk * nth * 3 + srcBlk * i;
qkArr[i] = thrBuf + srcBlk * nth * 4 + srcBlk * tgtBlk * i;
expQkvArr[i] = thrBuf + srcBlk * nth * (4 + tgtBlk) + srcBlk * headSize * i;
qArr[i] = thrBuf + srcBlk * nth * (4 + tgtBlk + headSize) + srcBlk * headSize * i;
}

#pragma omp parallel for collapse(3)
for (int i = 0; i < batchSize; ++i) {
for (int j = 0; j < numQHead; ++j) {
for (int m = 0; m < srcLen; m += srcBlk) {
int tid = omp_get_thread_num();
int tgtOff = i * numKVHead * tgtLen * headSize + (j / numGroup) * tgtLen * headSize;
const float *k = key + tgtOff;
const float *v = value + tgtOff;
const float *attnMsk = getMask(attnMask, i, j, srcLen, tgtLen) + m * tgtLen;

int qRealBlk = std::min(srcBlk, srcLen - m);
int srcOff = i * numQHead * tgtLen * headSize + j * tgtLen * headSize;
const float *q = query + srcOff + m * headSize;
float *out = output + srcOff + m * headSize;
int srcOff = i * srcLen * qStride + j * headSize;
int outOff = i * srcLen * stride + j * headSize;
const float *qbuf = query + srcOff + m * qStride;
AttnT *q = (AttnT *)qArr[tid];
float *out = output + outOff + m * stride;

// reset out
for (int ii = 0; ii < qRealBlk; ++ii) {
#pragma omp simd
for (int jj = 0; jj < headSize; ++jj) {
out[ii * headSize + jj] = 0; // reset output
out[ii * stride + jj] = 0; // reset output
q[ii * headSize + jj] = (AttnT)(qbuf[ii * qStride + jj]); // reset output
}
}
// reset sum
Expand All @@ -819,22 +842,25 @@ class Attention {
preMax[tid][ii] = std::numeric_limits<float>::lowest();
max[tid][ii] = std::numeric_limits<float>::lowest();
}

int tgtOff = i * tgtLen * kvStride + (j / numGroup) * headSize;
const float *attnMsk = getMask(attnMask, i, j, srcLen, tgtLen) + m * tgtLen;
const AttnT *k = key + tgtOff;
const AttnT *v = value + tgtOff;
// split the target len dimension
for (int b = 0; b < tgtLen; b += tgtBlk) {
int kvRealBlk = std::min(tgtBlk, tgtLen - b);
// TODO: mask out
const float *kBlk = k + b * headSize;
const float *vBlk = v + b * headSize;
const AttnT *kBlk = k + b * kvStride;
const AttnT *vBlk = v + b * kvStride;

DecoderUtil::incrementalTileAttention(q, kBlk, vBlk, attnMsk + b, qRealBlk, headSize, kvRealBlk,
tgtLen, preSum[tid], sum[tid], preMax[tid], max[tid], refac, qkArr[tid], expQkvArr[tid],
out);
out, headSize, kvStride, kvStride, stride);
}
}
}
}
free(thrPtrBuf);
free(thrBuf);
return;
}

Expand Down Expand Up @@ -878,6 +904,7 @@ class Attention {
}

virtual const float *getMask(const float *attnMask, int bId, int hId, int srcLen, int tgtLen) {
// Would mask be different for each sample in one batch?
return attnMask + bId * srcLen * tgtLen;
}

Expand Down
Loading