Skip to content

Commit

Permalink
[API] Add invokeMLPLLaMA FP16 API. (#302)
Browse files Browse the repository at this point in the history
* [API] Add invokeMLPLLaMA FP16 API.

* Update.

* Update.
  • Loading branch information
changqi1 authored Apr 15, 2024
1 parent 03a3d49 commit 280a915
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 24 deletions.
32 changes: 32 additions & 0 deletions src/layers/mlp_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,38 @@ void invokeMLPLLaMA(DataType dt, int numTokens, int hiddenSize, int intermediate
llama_mlp = it_created->second;
}

ctx->resize(1, numTokens, 0);
llama_mlp->forward(ctx, (float *)const_cast<void *>(input), (float *)output, inputStride, outputStride, false);
} else if (dt == DataType::fp16) {
static std::unordered_map<std::string, LlamaMLP<float16_t> *> llama_mlp_hub;

static DecoderContext *ctx;
if (ctx == nullptr
|| (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->intermediateSize != intermediateSize))) {
if (ctx != nullptr) delete ctx;
printf(">> create context: %d %d\n", hiddenSize, intermediateSize);
ctx = new DecoderContext(1, hiddenSize, 1, 1, intermediateSize, "silu", 1e-6, 0, 0, 0, 0, 0, 0, 1);
ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex());
}

// create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed.
std::stringstream weights_addr;
weights_addr << gateWeight << "_" << upWeight << "_" << downWeight;
std::string llama_mlp_key = weights_addr.str();
LlamaMLP<float16_t> *llama_mlp;

auto it_created = llama_mlp_hub.find(llama_mlp_key);
if (it_created == llama_mlp_hub.end()) {
// LlamaMLP<float16_t> &llama_mlp = LlamaMLP<float16_t>::getInstance();
llama_mlp = new LlamaMLP<float16_t>;
llama_mlp->setWeights(ctx, (float *)gateWeight, nullptr, nullptr, nullptr, (float *)upWeight, nullptr,
nullptr, nullptr, nullptr, nullptr, (float *)downWeight, nullptr, nullptr, false);
llama_mlp_hub[llama_mlp_key] = llama_mlp;
printf(">> create llama_mlp_key: %s\n", llama_mlp_key.c_str());
} else {
llama_mlp = it_created->second;
}

ctx->resize(1, numTokens, 0);
llama_mlp->forward(ctx, (float *)const_cast<void *>(input), (float *)output, inputStride, outputStride, false);
}
Expand Down
66 changes: 42 additions & 24 deletions tests/ut/layers_mlp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,27 @@ static void compareMLPLLaMA(
input[i] = static_cast<float>(1.0f * rand() / RAND_MAX);
}

xft::DataType dt = xft::DataType::unknown;
if constexpr (std::is_same<T, bfloat16_t>::value) {
auto t0 = std::chrono::high_resolution_clock::now();
invokeMLPLLaMA(xft::DataType::bf16, numTokens, hiddenSize, intermediateSize, (void *)ourOutput, hiddenSize,
(const void *)input, hiddenSize, (const void *)gateW, (const void *)upW, (const void *)downW);
auto t1 = std::chrono::high_resolution_clock::now();
float during_time = std::chrono::duration<float>(t1 - t0).count();
printf("[ RUNTIME ] XFT::invokeMLPLLaMA %.6f sec\n", during_time);

refMLPLLaMA<bfloat16_t>(numTokens, hiddenSize, intermediateSize, (float *)refOutput, hiddenSize,
(const float *)input, hiddenSize, (const float *)gateW, (const float *)upW, (const float *)downW);
dt = xft::DataType::bf16;
} else if constexpr (std::is_same<T, float16_t>::value) {
dt = xft::DataType::fp16;
} else {
printf("Unsupported data type\n");
GTEST_FAIL();
return;
}

auto start = std::chrono::high_resolution_clock::now();
invokeMLPLLaMA(dt, numTokens, hiddenSize, intermediateSize, (void *)ourOutput, hiddenSize,
(const void *)input, hiddenSize, (const void *)gateW, (const void *)upW, (const void *)downW);
auto end = std::chrono::high_resolution_clock::now();
float during_time = std::chrono::duration<float>(end - start).count();
printf("[ RUNTIME ] XFT::invokeMLPLLaMA %.6f sec\n", during_time);

refMLPLLaMA<T>(numTokens, hiddenSize, intermediateSize, (float *)refOutput, hiddenSize,
(const float *)input, hiddenSize, (const float *)gateW, (const float *)upW, (const float *)downW);

for (int i = 0; i < numTokens * hiddenSize; ++i) {
EXPECT_EQ(std::abs(refOutput[i] - ourOutput[i]) > 0.01
&& std::abs((refOutput[i] - ourOutput[i]) / refOutput[i]) > 0.01,
Expand All @@ -112,7 +121,8 @@ static void compareMLPLLaMA(
free(refOutput);
}

TEST(MLPLLaMA, bfloat16_t) {
template <typename T>
void test_MLPLLaMA(void) {
int hiddenSize = 4096;
int intermediateSize = 11008;

Expand All @@ -126,26 +136,34 @@ TEST(MLPLLaMA, bfloat16_t) {
downW[i] = static_cast<float>(0.5f * rand() / RAND_MAX);
}

compareMLPLLaMA<bfloat16_t>(18, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(10, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(4, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(2, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(1, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(2, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(4, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(6, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<bfloat16_t>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(18, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(10, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(4, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(2, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(1, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(2, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(4, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(6, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(16, hiddenSize, intermediateSize, gateW, upW, downW);
compareMLPLLaMA<T>(16, hiddenSize, intermediateSize, gateW, upW, downW);

free(gateW);
free(upW);
free(downW);
}

TEST(MLPLLaMA, bfloat16_t) {
test_MLPLLaMA<bfloat16_t>();
}

TEST(MLPLLaMA, float16_t) {
test_MLPLLaMA<float16_t>();
}

int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down

0 comments on commit 280a915

Please sign in to comment.