Skip to content

Commit

Permalink
support alibi embedding (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
aurora327 authored Nov 10, 2023
1 parent c294e7c commit bcf9a3a
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/layers/alibi_embedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) 2023 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#include "alibi_embedding.h"
#include <cmath>
#include "compile_util.h"

AlibiEmbedding::AlibiEmbedding(const int headNum, const int seqLen) {
maxLen = seqLen;
maxHeadNums = headNum;
alibiGetRelativePos(maxLen);
alibiGetSlope(maxHeadNums);
}

void AlibiEmbedding::alibiGetBias(const int headIdx, const int seqLen, float *biasMatrx) {
REQUIRES(headIdx < maxHeadNums, "Alibi Embedding ERROR, headIdx is exceeds max head nums.");
if (seqLen > maxLen) {
maxLen = seqLen;
alibiGetRelativePos(maxLen);
}
for (size_t i = 0; i < seqLen; i++) {
for (size_t j = 0; j < seqLen; j++) {
int index = i * seqLen + j;
biasMatrx[index] = posMatrix[index] * slopeM[headIdx];
}
}
}

void AlibiEmbedding::alibiGetRelativePos(const int seqLen) {
posMatrix = (int *)aligned_alloc(64, seqLen * seqLen * sizeof(int));
for (int i = 0; i < seqLen; i++) {
for (int j = 0; j < seqLen; j++) {
posMatrix[i * seqLen + j] = j - i;
}
}
}

void AlibiEmbedding::alibiGetSlope(const int headNum) {
slopeM = (float *)aligned_alloc(64, headNum * sizeof(float));
float x = std::pow(2, 8);
x = std::pow(x, 1.0 / headNum);
for (int i = 0; i < headNum; i++) {
slopeM[i] = 1 / std::pow(x, i + 1);
}
}
41 changes: 41 additions & 0 deletions src/layers/alibi_embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2023 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#pragma once
#include <iostream>

class AlibiEmbedding {
public:
AlibiEmbedding(const int headNum, const int seqLen);

~AlibiEmbedding() {
maxLen = 0;
maxHeadNums = 0;
if (posMatrix != nullptr) free(posMatrix);
if (slopeM != nullptr) free(slopeM);
}

void alibiGetRelativePos(const int seqLen);

void alibiGetSlope(const int headNum);

// headIdx is [0,n]
void alibiGetBias(const int headIdx, const int seqLen, float *bias_matrx);

private:
int maxLen = 0;
int maxHeadNums = 0;
int *posMatrix;
float *slopeM;
};
2 changes: 2 additions & 0 deletions tests/ut/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ foreach(src ${sources})
${SRC_DIR}/utils/shm_reduction.cpp
${SRC_DIR}/kernels/gemm_kernel_ext.cpp
${SRC_DIR}/kernels/gemm_kernel_ext_fp16.cpp)
elseif(${executable} STREQUAL "alibi_embedding_test")
add_executable(alibi_embedding_test ${src} ${SRC_DIR}/layers/alibi_embedding.cpp)
elseif(${executable} STREQUAL "rotary_embedding_test")
add_executable(rotary_embedding_test ${src} ${SRC_DIR}/layers/rotary_embedding.cpp)
elseif(${executable} STREQUAL "gemm_kernel_ext_test")
Expand Down
47 changes: 47 additions & 0 deletions tests/ut/alibi_embedding_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2023 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#include <ctime>
#include <iostream>

#include "alibi_embedding.h"
#include "gtest/gtest.h"

static bool compare(const float *result, const float *groundTruth, const int size) {
const float diff = 0.001;
for (int i = 0; i < size; ++i) {
if (abs(groundTruth[i] - result[i]) > diff) { return false; }
}
return true;
}

TEST(AlibiEmbedding, AlibiEmbeddingTest) {
int seqLen = 6, headNum = 6, headIdx = 4;
AlibiEmbedding alibi(headNum, seqLen);
float *biasMatrx = (float *)malloc(seqLen * seqLen * sizeof(float));
alibi.alibiGetBias(headIdx, seqLen, biasMatrx);

float groundTruth[36] = {0.0000, 0.0098, 0.0197, 0.0295, 0.0394, 0.0492, -0.0098, 0.0000, 0.0098, 0.0197, 0.0295,
0.0394, -0.0197, -0.0098, 0.0000, 0.0098, 0.0197, 0.0295, -0.0295, -0.0197, -0.0098, 0.0000, 0.0098, 0.0197,
-0.0394, -0.0295, -0.0197, -0.0098, 0.0000, 0.0098, -0.0492, -0.0394, -0.0295, -0.0197, -0.0098, 0.0000};
int size = 36;
EXPECT_TRUE(compare(biasMatrx, groundTruth, size));

free(biasMatrx);
}

int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

0 comments on commit bcf9a3a

Please sign in to comment.