From 812e5501d070c42ea3166a97a149ba33af0685a5 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Thu, 24 Aug 2023 12:44:21 -0700 Subject: [PATCH 01/10] commit --- .vscode/settings.json | 28 +- .../layers/attention_layers/CMakeLists.txt | 20 + .../LlamaContextAttentionLayer.cc | 809 +++++++++++ .../LlamaContextAttentionLayer.h | 139 ++ .../LlamaDecoderSelfAttentionLayer.cc | 731 ++++++++++ .../LlamaDecoderSelfAttentionLayer.h | 202 +++ ...ensorParallelLlamaContextAttentionLayer.cc | 162 +++ ...TensorParallelLlamaContextAttentionLayer.h | 76 + ...rParallelLlamaDecoderSelfAttentionLayer.cc | 241 +++ ...orParallelLlamaDecoderSelfAttentionLayer.h | 109 ++ src/fastertransformer/models/CMakeLists.txt | 1 + .../models/llama/CMakeLists.txt | 69 + src/fastertransformer/models/llama/Llama.cc | 1287 +++++++++++++++++ src/fastertransformer/models/llama/Llama.h | 237 +++ .../models/llama/LlamaContextDecoder.cc | 646 +++++++++ .../models/llama/LlamaContextDecoder.h | 128 ++ .../models/llama/LlamaDecoder.cc | 400 +++++ .../models/llama/LlamaDecoder.h | 109 ++ .../models/llama/LlamaDecoderLayerWeight.cc | 398 +++++ .../models/llama/LlamaDecoderLayerWeight.h | 74 + .../models/llama/LlamaWeight.cc | 321 ++++ .../models/llama/LlamaWeight.h | 114 ++ .../triton_backend/CMakeLists.txt | 1 + .../triton_backend/llama/CMakeLists.txt | 25 + .../triton_backend/llama/LlamaTritonModel.cc | 273 ++++ .../triton_backend/llama/LlamaTritonModel.h | 87 ++ .../llama/LlamaTritonModelInstance.cc | 264 ++++ .../llama/LlamaTritonModelInstance.h | 82 ++ .../transformer_triton_backend.hpp | 1 + 29 files changed, 7032 insertions(+), 2 deletions(-) create mode 100644 src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc create mode 100644 src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.h create mode 100644 src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc create mode 100644 src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.h create mode 100644 src/fastertransformer/layers/attention_layers/TensorParallelLlamaContextAttentionLayer.cc create mode 100644 src/fastertransformer/layers/attention_layers/TensorParallelLlamaContextAttentionLayer.h create mode 100644 src/fastertransformer/layers/attention_layers/TensorParallelLlamaDecoderSelfAttentionLayer.cc create mode 100644 src/fastertransformer/layers/attention_layers/TensorParallelLlamaDecoderSelfAttentionLayer.h create mode 100644 src/fastertransformer/models/llama/CMakeLists.txt create mode 100644 src/fastertransformer/models/llama/Llama.cc create mode 100644 src/fastertransformer/models/llama/Llama.h create mode 100644 src/fastertransformer/models/llama/LlamaContextDecoder.cc create mode 100644 src/fastertransformer/models/llama/LlamaContextDecoder.h create mode 100644 src/fastertransformer/models/llama/LlamaDecoder.cc create mode 100644 src/fastertransformer/models/llama/LlamaDecoder.h create mode 100644 src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc create mode 100644 src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h create mode 100644 src/fastertransformer/models/llama/LlamaWeight.cc create mode 100644 src/fastertransformer/models/llama/LlamaWeight.h create mode 100644 src/fastertransformer/triton_backend/llama/CMakeLists.txt create mode 100644 src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc create mode 100644 src/fastertransformer/triton_backend/llama/LlamaTritonModel.h create mode 100644 src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.cc create mode 100644 src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h diff --git a/.vscode/settings.json b/.vscode/settings.json index 6f535da99..6df8277d0 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -67,6 +67,30 @@ "unordered_set": "cpp", "future": "cpp", "cfenv": "cpp", - "typeindex": "cpp" + "typeindex": "cpp", + "__bit_reference": "cpp", + "__bits": "cpp", + "__config": "cpp", + "__debug": "cpp", + "__errc": "cpp", + "__hash_table": "cpp", + "__locale": "cpp", + "__mutex_base": "cpp", + "__node_handle": "cpp", + "__split_buffer": "cpp", + "__threading_support": "cpp", + "__tree": "cpp", + "__tuple": "cpp", + "__verbose_abort": "cpp", + "bit": "cpp", + "ios": "cpp", + "locale": "cpp", + "queue": "cpp", + "stack": "cpp", + "variant": "cpp", + "__nullptr": "cpp", + "__string": "cpp", + "compare": "cpp", + "concepts": "cpp" } -} \ No newline at end of file +} diff --git a/src/fastertransformer/layers/attention_layers/CMakeLists.txt b/src/fastertransformer/layers/attention_layers/CMakeLists.txt index 1f0e93b1b..628b3083a 100644 --- a/src/fastertransformer/layers/attention_layers/CMakeLists.txt +++ b/src/fastertransformer/layers/attention_layers/CMakeLists.txt @@ -39,6 +39,16 @@ set_property(TARGET DecoderSelfAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE set_property(TARGET DecoderSelfAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(DecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils decoder_masked_multihead_attention fpA_intB_gemm int8_gemm tensor nvtx_utils) +add_library(LlamaDecoderSelfAttentionLayer STATIC LlamaDecoderSelfAttentionLayer.cc) +set_property(TARGET LlamaDecoderSelfAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET LlamaDecoderSelfAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(LlamaDecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils decoder_masked_multihead_attention fpA_intB_gemm int8_gemm tensor nvtx_utils) + +add_library(LlamaContextAttentionLayer STATIC LlamaContextAttentionLayer.cc) +set_property(TARGET LlamaContextAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET LlamaContextAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(LlamaContextAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils unfused_attention_kernels trt_fused_multi_head_attention fpA_intB_gemm int8_gemm nvtx_utils) + add_library(GptContextAttentionLayer STATIC GptContextAttentionLayer.cc) set_property(TARGET GptContextAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET GptContextAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) @@ -54,6 +64,11 @@ set_property(TARGET TensorParallelDecoderSelfAttentionLayer PROPERTY POSITION_IN set_property(TARGET TensorParallelDecoderSelfAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(TensorParallelDecoderSelfAttentionLayer PUBLIC -lcudart DecoderSelfAttentionLayer nccl_utils custom_ar_kernels nvtx_utils) +add_library(TensorParallelLlamaDecoderSelfAttentionLayer STATIC TensorParallelLlamaDecoderSelfAttentionLayer.cc) +set_property(TARGET TensorParallelLlamaDecoderSelfAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET TensorParallelLlamaDecoderSelfAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(TensorParallelLlamaDecoderSelfAttentionLayer PUBLIC -lcudart LlamaDecoderSelfAttentionLayer nccl_utils custom_ar_kernels repeat_kv_kernels nvtx_utils) + add_library(TensorParallelDecoderCrossAttentionLayer STATIC TensorParallelDecoderCrossAttentionLayer.cc) set_property(TARGET TensorParallelDecoderCrossAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET TensorParallelDecoderCrossAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) @@ -64,6 +79,11 @@ set_property(TARGET TensorParallelGptContextAttentionLayer PROPERTY POSITION_IND set_property(TARGET TensorParallelGptContextAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(TensorParallelGptContextAttentionLayer PUBLIC -lcudart GptContextAttentionLayer nccl_utils custom_ar_kernels nvtx_utils) +add_library(TensorParallelLlamaContextAttentionLayer STATIC TensorParallelLlamaContextAttentionLayer.cc) +set_property(TARGET TensorParallelLlamaContextAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET TensorParallelLlamaContextAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(TensorParallelLlamaContextAttentionLayer PUBLIC -lcudart LlamaContextAttentionLayer nccl_utils custom_ar_kernels repeat_kv_kernels nvtx_utils) + add_library(TensorParallelUnfusedAttentionLayer STATIC TensorParallelUnfusedAttentionLayer.cc) set_property(TARGET TensorParallelUnfusedAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET TensorParallelUnfusedAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc new file mode 100644 index 000000000..91de7d46d --- /dev/null +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -0,0 +1,809 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.h" +#include "src/fastertransformer/kernels/unfused_attention_kernels.h" +#include "src/fastertransformer/utils/nvtx_utils.h" +#include "src/fastertransformer/kernels/repeat_kv_kernels.h" +#include + +namespace fastertransformer { + +template +void LlamaContextAttentionLayer::forward(TensorMap* output_tensors, + TensorMap* input_tensors, + const AttentionWeight* attention_weights) +{ + // input_tensors: + // input_query [token_num, hidden_dimension] + // attention_mask [batch_size, 1, seq_len, seq_len + max_prompt_length] + // attention_type [1] + // is_final_layer [1], bool on cpu + // layer_id [1], int on cpu + // padding_offset, int, [token_num] (optional) + // cu_seqlens, int, [batch_size] (optional) + // d_prefix_prompt_batch [global_batch_size], (optional) + // each element contains ptr with buffer shape[2, local_head_num_, prompt_length, size_per_head] + // d_prefix_prompt_lengths [batch_size], int (optional) + // linear_bias_slopes [head_num] (optional) + + // output_tensors: + // hidden_features [token_num, hidden_dimension] + // key_cache [batch, local_head_num, size_per_head // x, max_seq_len, x] + // value_cache [batch, local_head_num, max_seq_len, size_per_head] + printf("LlamaContextAttentionLayer::forward\n"); + FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + FT_CHECK(output_tensors->at("key_cache").shape.size() == 5); + FT_CHECK(output_tensors->at("value_cache").shape.size() == 4 + || output_tensors->at("value_cache").shape.size() == 3); + const int request_batch_size = input_tensors->at("attention_mask").shape[0]; + const int request_seq_len = input_tensors->at("attention_mask").shape[2]; + const int max_prompt_length = + input_tensors->at("attention_mask").shape[3] - input_tensors->at("attention_mask").shape[2]; + const int layer_id = input_tensors->getVal("layer_id"); + const T** d_prefix_prompt_batch = input_tensors->getPtr("d_prefix_prompt_batch", nullptr); + const int* d_prefix_prompt_lengths = input_tensors->getPtr("d_prefix_prompt_lengths", nullptr); + const int* padding_offset = input_tensors->getPtr("padding_offset", nullptr); + int* cu_seqlens = input_tensors->getPtr("cu_seqlens", nullptr); + T* linear_bias_slopes = input_tensors->getPtr("linear_bias_slopes", nullptr); + /* float* attention_query_dynamic_scale = input_tensors->getPtr("attention_query_dynamic_scale", + * nullptr); */ + + T* attention_out = output_tensors->at("hidden_features").getPtr(); + T* attention_input = input_tensors->at("input_query").getPtr(); + T* attention_mask = input_tensors->at("attention_mask").getPtr(); + + const AttentionType attention_type = input_tensors->getVal("attention_type"); + FT_CHECK_WITH_INFO(attention_type != AttentionType::FUSED_PADDED_MHA, + "Llama Context FUSED_PADDED_MHA is not supported !"); + + printf("attention buffer alloc %d %d\n", request_batch_size, request_seq_len + max_prompt_length); + PUSH_RANGE("attention buffer alloc"); + allocateBuffer(request_batch_size, request_seq_len + max_prompt_length, attention_type != AttentionType::FUSED_MHA); + POP_RANGE; + sync_check_cuda_error(); + printf("attention buffer alloc done\n"); + const bool is_final = input_tensors->at("is_final_layer").getVal(); + + const int m = input_tensors->at("input_query").shape[0]; + + PUSH_RANGE("qkv_gemm"); + +#ifdef SPARSITY_ENABLED + const int m_padded = 8 * div_up(m, 8); + bool use_sparse = sparse_ && cublas_wrapper_->isUseSparse(1, 3 * local_hidden_units_, m_padded, hidden_units_); +#else + constexpr bool use_sparse = false; +#endif + + if (use_sparse) { +#ifdef SPARSITY_ENABLED + cublas_wrapper_->SpGemm(CUBLAS_OP_N, + CUBLAS_OP_N, + 3 * local_hidden_units_, + m_padded, + hidden_units_, + attention_weights->query_weight.sp_kernel, + attention_input, + qkv_buf_); +#endif + } + else if (int8_mode_ == 1) { + FT_CHECK(weight_only_int8_fc_runner_.get() != NULL && attention_weights->query_weight.int8_kernel != NULL + && attention_weights->query_weight.weight_only_quant_scale != NULL); + + weight_only_int8_fc_runner_->gemm(attention_input, + reinterpret_cast(attention_weights->query_weight.int8_kernel), + attention_weights->query_weight.weight_only_quant_scale, + qkv_buf_, + m, + 3 * local_hidden_units_, + hidden_units_, + mixed_gemm_workspace_, + mixed_gemm_ws_bytes_, + stream_); + } + else if (int8_mode_ == 2) { + cublas_wrapper_->Int8Gemm(3 * local_hidden_units_, + m, + hidden_units_, + attention_weights->query_weight.int8_kernel, + hidden_units_, + input_tensors->at("input_query").getPtr(), + hidden_units_, + reinterpret_cast(qkv_buf_), + 3 * local_hidden_units_, + attention_weights->query_weight.scale_inter, + true); + } + else { + size_t local_qkv_size = local_hidden_units_ + 2 * local_kv_head_num_ * size_per_head_; + cublas_wrapper_->Gemm(CUBLAS_OP_N, + CUBLAS_OP_N, + local_qkv_size, // n + m, + hidden_units_, // k + attention_weights->query_weight.kernel, + local_qkv_size, // n + attention_input, + hidden_units_, // k + qkv_buf_tmp_, + local_qkv_size /* n */); + invokeRepeatKv(qkv_buf_, + qkv_buf_tmp_, + local_head_num_, + local_kv_head_num_, + size_per_head_, + m, + stream_); + + // { + // const int head_num = 6; + // const int kv_head_num = 2; + // const int size_per_head = 3; + // const int token_num = 3; + // const int qkv_size = (head_num+2*kv_head_num) * size_per_head; + // const int dst_size = 3 * head_num * size_per_head * token_num; + // const int src_size = qkv_size * token_num; + // float * dst = new float[dst_size]; + // float * src = new float[src_size]; + // for (int t=0; t < token_num; t++) { + // for (int i=0; i < qkv_size; i++) { + // if (i < head_num * size_per_head) { + // src[t*qkv_size+i] = i; + // } + // else if (i - head_num * size_per_head < kv_head_num * size_per_head) { + // src[t*qkv_size+i] = i/10.f; + // } else { + // src[t*qkv_size+i] = i*10.f; + // } + // } + // } + + // float* dst_buf = nullptr; + // dst_buf = (float*)allocator_->reMalloc(dst_buf, sizeof(float)*dst_size, true); + // float* src_buf = nullptr; + // src_buf = (float*)allocator_->reMalloc(src_buf, sizeof(float)*src_size, true); + + // cudaMemcpy(src_buf, src, sizeof(float)*src_size, cudaMemcpyHostToDevice); + // for (int t=0; t < token_num; t++) { + // for (int i=0; i < qkv_size; i++) { + // printf("%f ", src[t*qkv_size+i]); + // } + // printf("\n"); + // } + // invokeRepeatKv(dst_buf, + // src_buf, + // head_num, + // kv_head_num, + // size_per_head, + // token_num, + // stream_); + // sync_check_cuda_error(); + // cudaMemcpy(dst, dst_buf, sizeof(float)*dst_size, cudaMemcpyDeviceToHost); + // printf("after: \n"); + // int j = 0; + // for (int t=0; t < token_num; t++) { + // for (int i=0; i < 3 * head_num * size_per_head; i++) { + // printf("%f ", dst[j++]); + // } + // printf("\n"); + // } + + // } + // { + // printf("test\n"); + // int m = 3; + // int n = m; + // int k = 2; + // int st = m*k; + // half* A = new half[st]; + // half* B = new half[st]; + // half* C = new half[m*n]; + // int t = 0; + // for (int i=0; i(float(i)); + // t = i+st; + // B[i] = static_cast(float(i)); + // printf("%f %f\n", (float)A[i], (float)B[i]); + // } + // half* a_buf = nullptr; + // a_buf = (half*)allocator_->reMalloc(a_buf, sizeof(half)*st, true); + // half* b_buf = nullptr; + // b_buf = (half*)allocator_->reMalloc(b_buf, sizeof(half)*st, true); + // half* c_buf = nullptr; + // c_buf = (half*)allocator_->reMalloc(c_buf, sizeof(half)*m*n, true); + + // cudaMemcpy(a_buf, A, sizeof(half)*st, cudaMemcpyHostToDevice); + // cudaMemcpy(b_buf, B, sizeof(half)*st, cudaMemcpyHostToDevice); + // sync_check_cuda_error(); + // cublas_wrapper_->Gemm(CUBLAS_OP_N, + // CUBLAS_OP_N, + // m, // n + // n, + // k, // k + // a_buf, + // m, // n + // b_buf, + // k, // k + // c_buf, + // m /* n */); + // sync_check_cuda_error(); + // printf("cudaMemcpy\n"); + // cudaMemcpy(C, a_buf, sizeof(half) * m * k, cudaMemcpyDeviceToHost); + // sync_check_cuda_error(); + // for (int i=0; iGemm(CUBLAS_OP_N, + // CUBLAS_OP_N, + // 3 * local_hidden_units_, // n + // m, + // hidden_units_, // k + // attention_weights->query_weight.kernel, + // 3 * local_hidden_units_, // n + // attention_input, + // hidden_units_, // k + // qkv_buf_, + // 3 * local_hidden_units_ /* n */); + } + + { + // printf("qkv_buf_\n"); + // int sz = 100; + // T *qkv_buf = new T[sz]; + // cudaMemcpy(qkv_buf, qkv_buf_, sizeof(T)*sz, cudaMemcpyDeviceToHost); + // sync_check_cuda_error(); + // for (int i=0; iGemm: done\n"); + // IDEA: append prefix prompt key value here + PrefixPromptBatchWeightsParam param{d_prefix_prompt_batch, + d_prefix_prompt_lengths, + max_prompt_length, + (size_t)layer_id * 2 * local_head_num_ * size_per_head_}; + + if (padding_offset != nullptr) { + // q_buf_2_, k_buf_2_ and v_buf_2_ are continuous + cudaMemsetAsync( + q_buf_2_, 0, request_batch_size * request_seq_len * 3 * local_hidden_units_ * sizeof(T), stream_); + } + invokeAddFusedQKVBiasTranspose(q_buf_2_, + k_buf_2_, + v_buf_2_, + param, // prefix prompt + qkv_buf_, + attention_weights->query_weight.bias, + padding_offset, + request_batch_size, + request_seq_len, + m, + local_head_num_, + size_per_head_, + rotary_embedding_dim_, + neox_rotary_style_, + attention_weights->query_weight.scale_out, + int8_mode_, + stream_); + sync_check_cuda_error(); + + const int max_seq_len = (int)(output_tensors->at("key_cache").shape[3]); // max output seq length + // Use batch major + // put k/v_buf from shape [B, H, PL + L, Dh] + // to cache [B, H, Dh/x, PL + L, x] and [B, H, PL + L, Dh/x, x], PL denotes prompt length + invokeTranspose4dBatchMajor(output_tensors->getPtr("key_cache"), + output_tensors->getPtr("value_cache"), + k_buf_2_, + v_buf_2_, + request_batch_size, + max_prompt_length + request_seq_len, // max input length + prefix prompt length + max_seq_len, + size_per_head_, + local_head_num_, + stream_); + // IDEA : after this, k_cache = (batch_size, num_heads, Dh/x, prefix_prompt_len + L, x) + // k_cache = (batch_size, num_heads, prefix_prompt_len + L, Dh) + sync_check_cuda_error(); + + // TODO: fmha kernels doesn't support different seq lengths of q and kv + if (attention_type == AttentionType::FUSED_MHA) { + dispatcher_fp16->setup_causal_masked_fmha(request_seq_len, request_batch_size); + dispatcher_fp16->run_causal_masked_fmha(qkv_buf_, cu_seqlens, qkv_buf_3_, true, stream_); + } + // NOTE: qkv buffer shape (batch_size, num_heads,L or prompt_len + L, Dh) + + POP_RANGE; + if (is_final == false) { + const cudaDataType_t gemm_data_type = getCudaDataType(); + const int attention_seq_len_1 = request_seq_len; // q length + const int attention_seq_len_2 = max_prompt_length + request_seq_len; // kv length + const T qk_scale = static_cast(1.0f / sqrtf(size_per_head_ * 1.0f)); + if (attention_type != AttentionType::FUSED_MHA) { + if (is_qk_buf_float_ == true && gemm_data_type != CUDA_R_32F) { + PUSH_RANGE("Q*K batch gemm"); + cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, + CUBLAS_OP_N, + attention_seq_len_2, // n + attention_seq_len_1, // m + size_per_head_, // k + 1.0f, + k_buf_2_, + gemm_data_type, + size_per_head_, // k + attention_seq_len_2 * size_per_head_, // n * k + q_buf_2_, + gemm_data_type, + size_per_head_, // k + attention_seq_len_1 * size_per_head_, // m * k + 0.0f, + qk_buf_float_, + CUDA_R_32F, + attention_seq_len_2, // n + attention_seq_len_2 * attention_seq_len_1, + request_batch_size * local_head_num_, // global batch size + CUDA_R_32F); + + sync_check_cuda_error(); + POP_RANGE; + + PUSH_RANGE("softmax"); + MaskedSoftmaxParam param; + param.attention_score = qk_buf_; // (batch_size, head_num, q_length, k_length) + param.qk = qk_buf_float_; // (batch_size, head_num, q_length, k_length) + param.attention_mask = attention_mask; // (batch_size, q_length, k_length) + param.batch_size = request_batch_size; + param.q_length = attention_seq_len_1; + param.k_length = attention_seq_len_2; + param.num_heads = local_head_num_; + param.qk_scale = qk_scale; + param.linear_bias_slopes = const_cast(linear_bias_slopes); // (head_num,), optional + invokeMaskedSoftmax(param, stream_); + sync_check_cuda_error(); + POP_RANGE; + } + else { + PUSH_RANGE("Q*K batch gemm"); + cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, + CUBLAS_OP_N, + attention_seq_len_2, + attention_seq_len_1, + size_per_head_, + k_buf_2_, + size_per_head_, + attention_seq_len_2 * size_per_head_, + q_buf_2_, + size_per_head_, + attention_seq_len_1 * size_per_head_, + qk_buf_, + attention_seq_len_2, + attention_seq_len_2 * attention_seq_len_1, + request_batch_size * local_head_num_); + + POP_RANGE; + PUSH_RANGE("softmax"); + MaskedSoftmaxParam param; + param.attention_score = qk_buf_; // (batch_size, head_num, q_length, k_length) + param.qk = qk_buf_; // (batch_size, head_num, q_length, k_length) + param.attention_mask = attention_mask; // (batch_size, q_length, k_length) + param.batch_size = request_batch_size; + param.q_length = attention_seq_len_1; + param.k_length = attention_seq_len_2; + param.num_heads = local_head_num_; + param.qk_scale = qk_scale; + param.linear_bias_slopes = const_cast(linear_bias_slopes); // (head_num,), optional + invokeMaskedSoftmax(param, stream_); + sync_check_cuda_error(); + POP_RANGE; + } + + PUSH_RANGE("QK*V batch gemm"); + + cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_N, + CUBLAS_OP_N, + size_per_head_, + attention_seq_len_1, + attention_seq_len_2, + v_buf_2_, + size_per_head_, + attention_seq_len_2 * size_per_head_, + qk_buf_, + attention_seq_len_2, + attention_seq_len_1 * attention_seq_len_2, + qkv_buf_2_, + size_per_head_, + attention_seq_len_1 * size_per_head_, + request_batch_size * local_head_num_); + + // transpose (batch_size, num_heads, L, Dh) to (batch_size, L, num_heads * Dh) + if (padding_offset == nullptr) { + invokeTransposeQKV(qkv_buf_3_, + qkv_buf_2_, + request_batch_size, + attention_seq_len_1, + local_head_num_, + size_per_head_, + attention_weights->attention_output_weight.scale, + int8_mode_, + stream_); + sync_check_cuda_error(); + } + else { + invokeTransposeAttentionOutRemovePadding(qkv_buf_2_, + qkv_buf_3_, + m, + request_batch_size, + attention_seq_len_1, + local_head_num_, + size_per_head_, + padding_offset, + attention_weights->attention_output_weight.scale, + int8_mode_, + stream_); + } + POP_RANGE; + } + sync_check_cuda_error(); + + PUSH_RANGE("proj gemm"); +#ifdef SPARSITY_ENABLED + bool use_sparse = sparse_ && cublas_wrapper_->isUseSparse(1, hidden_units_, m_padded, local_hidden_units_); +#endif + + if (use_sparse) { +#ifdef SPARSITY_ENABLED + cublas_wrapper_->SpGemm(CUBLAS_OP_N, + CUBLAS_OP_N, + hidden_units_, + m_padded, + local_hidden_units_, + attention_weights->attention_output_weight.sp_kernel, + qkv_buf_3_, + attention_out); +#endif + } + else { + if (int8_mode_ == 1) { + FT_CHECK(weight_only_int8_fc_runner_.get() != NULL + && attention_weights->attention_output_weight.int8_kernel != NULL + && attention_weights->attention_output_weight.weight_only_quant_scale != NULL); + + weight_only_int8_fc_runner_->gemm( + qkv_buf_3_, + reinterpret_cast(attention_weights->attention_output_weight.int8_kernel), + attention_weights->attention_output_weight.weight_only_quant_scale, + attention_out, + m, + hidden_units_, + local_hidden_units_, + mixed_gemm_workspace_, + mixed_gemm_ws_bytes_, + stream_); + } + else if (int8_mode_ == 2) { + int8_fc_runner_->gemm(reinterpret_cast(qkv_buf_3_), + attention_weights->attention_output_weight.int8_kernel, + QuantMode::PerTensorQuant, + attention_weights->attention_output_weight.scale_inter, + attention_weights->attention_output_weight.scale_out, + output_tensors->at("hidden_features").getPtr(), + m, + hidden_units_, + local_hidden_units_, + nullptr, + 0, + stream_); + } + else { + cublas_wrapper_->Gemm(CUBLAS_OP_N, + CUBLAS_OP_N, + hidden_units_, + m, + local_hidden_units_, + attention_weights->attention_output_weight.kernel, + hidden_units_, + qkv_buf_3_, + local_hidden_units_, + attention_out, + hidden_units_); + } + } + POP_RANGE; + } + + if (is_free_buffer_after_forward_ == true) { + freeBuffer(); + } + sync_check_cuda_error(); + FT_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); +} + +template +LlamaContextAttentionLayer::LlamaContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse, + int int8_mode): + BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), + max_batch_size_(max_batch_size), + max_seq_len_(max_seq_len), + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + hidden_units_(head_num * size_per_head), + local_head_num_(head_num), + local_kv_head_num_(kv_head_num), + local_hidden_units_(local_head_num_ * size_per_head), + rotary_embedding_dim_(0), + neox_rotary_style_(false), + is_qk_buf_float_(is_qk_buf_float || int8_mode == 2), + weight_only_int8_fc_runner_(int8_mode == 1 ? std::make_shared>() : nullptr), + int8_fc_runner_(int8_mode == 2 ? std::make_shared>() : nullptr), + int8_mode_(int8_mode) +{ +} + +template +LlamaContextAttentionLayer::LlamaContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t local_head_num, + size_t local_kv_head_num, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse, + int int8_mode): + BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), + max_batch_size_(max_batch_size), + max_seq_len_(max_seq_len), + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + hidden_units_(head_num * size_per_head), + local_head_num_(local_head_num), + local_kv_head_num_(local_kv_head_num), + local_hidden_units_(local_head_num_ * size_per_head), + rotary_embedding_dim_(0), + neox_rotary_style_(false), + is_qk_buf_float_(is_qk_buf_float || int8_mode == 2), + weight_only_int8_fc_runner_(int8_mode == 1 ? std::make_shared>() : nullptr), + int8_fc_runner_(int8_mode == 2 ? std::make_shared>() : nullptr), + int8_mode_(int8_mode) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + dispatcher_fp16.reset(new FusedMHARunnerFP16v2(local_head_num_, size_per_head_, sm_, 1.0f)); +} + +template +LlamaContextAttentionLayer::LlamaContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t local_head_num, + size_t local_kv_head_num, + size_t rotary_embedding_dim, + bool neox_rotary_style, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse, + int int8_mode): + BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), + max_batch_size_(max_batch_size), + max_seq_len_(max_seq_len), + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + hidden_units_(head_num * size_per_head), + local_head_num_(local_head_num), + local_kv_head_num_(local_kv_head_num), + local_hidden_units_(local_head_num_ * size_per_head), + rotary_embedding_dim_(rotary_embedding_dim), + neox_rotary_style_(neox_rotary_style), + is_qk_buf_float_(is_qk_buf_float), + weight_only_int8_fc_runner_(int8_mode == 1 ? std::make_shared>() : nullptr), + int8_fc_runner_(int8_mode == 2 ? std::make_shared>() : nullptr), + int8_mode_(int8_mode) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + dispatcher_fp16.reset(new FusedMHARunnerFP16v2(local_head_num_, size_per_head_, sm_, 1.0f)); +} + +template +LlamaContextAttentionLayer::LlamaContextAttentionLayer(LlamaContextAttentionLayer const& attention_layer): + BaseAttentionLayer(attention_layer.stream_, + attention_layer.cublas_wrapper_, + attention_layer.allocator_, + attention_layer.is_free_buffer_after_forward_, + attention_layer.sparse_), + max_batch_size_(attention_layer.max_batch_size_), + max_seq_len_(attention_layer.max_seq_len_), + head_num_(attention_layer.head_num_), + kv_head_num_(attention_layer.kv_head_num_), + size_per_head_(attention_layer.size_per_head_), + hidden_units_(attention_layer.hidden_units_), + local_head_num_(attention_layer.local_head_num_), + local_kv_head_num_(attention_layer.local_kv_head_num_), + local_hidden_units_(attention_layer.local_hidden_units_), + rotary_embedding_dim_(attention_layer.rotary_embedding_dim_), + neox_rotary_style_(attention_layer.neox_rotary_style_), + is_qk_buf_float_(attention_layer.is_qk_buf_float_), + weight_only_int8_fc_runner_(attention_layer.weight_only_int8_fc_runner_), + int8_fc_runner_(attention_layer.int8_fc_runner_), + int8_mode_(attention_layer.int8_mode_) +{ +} + +template +LlamaContextAttentionLayer::~LlamaContextAttentionLayer() +{ + cublas_wrapper_ = nullptr; + freeBuffer(); +} + +template +void LlamaContextAttentionLayer::allocateBuffer() +{ + FT_CHECK(false); +} + +template +void LlamaContextAttentionLayer::allocateBuffer(size_t batch_size, size_t seq_len, bool allocate_qk_buf) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + // const auto type_size = int8_mode_ == 2 ? sizeof(int8_t) : sizeof(T); + // NOTE (perkzz): use sizeof(T) here for cutlass int8 kernels. + const auto type_size = sizeof(T); + qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, type_size * 3 * batch_size * seq_len * local_hidden_units_, true); + size_t local_qkv_size = local_hidden_units_ + 2 * local_kv_head_num_ * size_per_head_; + qkv_buf_tmp_ = (T*)allocator_->reMalloc(qkv_buf_tmp_, type_size * batch_size * seq_len * local_qkv_size, true); + q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * batch_size * seq_len * 3 * local_hidden_units_, true); + k_buf_2_ = q_buf_2_ + batch_size * seq_len * local_hidden_units_; + v_buf_2_ = k_buf_2_ + batch_size * seq_len * local_hidden_units_; + + // save memory usage when using fmha + if (allocate_qk_buf) { + qk_buf_ = (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * seq_len * seq_len, true); + } + else { + allocator_->free((void**)(&qk_buf_)); + } + qkv_buf_2_ = (T*)allocator_->reMalloc(qkv_buf_2_, sizeof(T) * batch_size * seq_len * local_hidden_units_, true); + qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, type_size * batch_size * seq_len * local_hidden_units_, true); + + if (is_qk_buf_float_ == true) { + if (allocate_qk_buf) { + qk_buf_float_ = (float*)allocator_->reMalloc( + qk_buf_float_, sizeof(float) * batch_size * local_head_num_ * seq_len * seq_len, true); + } + else { + allocator_->free((void**)(&qk_buf_float_)); + } + } + + if (int8_mode_ == 1) { + // We use max_size for n and k since we reuse buffers for both FCs and want to allocate the max + // possible memory that would be required by any of the individual gemms. + const int max_size = std::max(hidden_units_, 3 * local_hidden_units_); + mixed_gemm_ws_bytes_ = weight_only_int8_fc_runner_->getWorkspaceSize(batch_size * seq_len, max_size, max_size); + mixed_gemm_workspace_ = (char*)allocator_->reMalloc(mixed_gemm_workspace_, mixed_gemm_ws_bytes_, false); + } + + if (int8_mode_ == 1) { + // We use max_size for n and k since we reuse buffers for both FCs and want to allocate the max + // possible memory that would be required by any of the individual gemms. + const int max_size = std::max(hidden_units_, 3 * local_hidden_units_); + mixed_gemm_ws_bytes_ = weight_only_int8_fc_runner_->getWorkspaceSize(batch_size * seq_len, max_size, max_size); + mixed_gemm_workspace_ = (char*)allocator_->reMalloc(mixed_gemm_workspace_, mixed_gemm_ws_bytes_, false); + } + else if (int8_mode_ == 2) { + const int max_size = std::max(hidden_units_, 3 * local_hidden_units_); + int8_gemm_ws_bytes_ = int8_fc_runner_->getWorkspaceSize(batch_size * seq_len, max_size, max_size); + int8_gemm_workspace_ = (char*)allocator_->reMalloc(int8_gemm_workspace_, int8_gemm_ws_bytes_, false); + } + is_allocate_buffer_ = true; +} + +template +void LlamaContextAttentionLayer::freeBuffer() +{ + if (is_allocate_buffer_) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + allocator_->free((void**)(&qkv_buf_)); + allocator_->free((void**)(&qkv_buf_tmp_)); + allocator_->free((void**)(&q_buf_2_)); + allocator_->free((void**)(&qk_buf_)); + allocator_->free((void**)(&qkv_buf_2_)); + allocator_->free((void**)(&qkv_buf_3_)); + + if (is_qk_buf_float_ == true) { + allocator_->free((void**)(&qk_buf_float_)); + } + + allocator_->free((void**)(&mixed_gemm_workspace_)); + mixed_gemm_ws_bytes_ = 0; + + allocator_->free((void**)(&int8_gemm_workspace_)); + int8_gemm_ws_bytes_ = 0; + + is_allocate_buffer_ = false; + } +} + +template class LlamaContextAttentionLayer; +template class LlamaContextAttentionLayer; +#ifdef ENABLE_BF16 +template class LlamaContextAttentionLayer<__nv_bfloat16>; +#endif + +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.h b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.h new file mode 100644 index 000000000..50837db23 --- /dev/null +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "3rdparty/trt_fused_multihead_attention/qkvToContext.h" +#include "src/fastertransformer/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" +#include "src/fastertransformer/kernels/cutlass_kernels/int8_gemm/int8_gemm.h" +#include "src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h" + +namespace fastertransformer { + +template +class LlamaContextAttentionLayer: public BaseAttentionLayer { +private: + // buffer handling + size_t max_batch_size_ = 0; + size_t max_seq_len_ = 0; + + // metadata + const size_t head_num_; + const size_t kv_head_num_; + const size_t size_per_head_; + const size_t hidden_units_; + const size_t local_head_num_; + const size_t local_kv_head_num_; + const size_t local_hidden_units_; + const size_t rotary_embedding_dim_; + const bool neox_rotary_style_; + + // fmha runner + int sm_ = getSMVersion(); + std::unique_ptr dispatcher_fp16; + + void allocateBuffer() override; + void allocateBuffer(size_t batch_size, size_t seq_len, bool allocate_qk_buf); + void freeBuffer() override; + + using BaseAttentionLayer::is_free_buffer_after_forward_; + using BaseAttentionLayer::is_allocate_buffer_; + using BaseAttentionLayer::cublas_wrapper_; + + bool is_qk_buf_float_; + + std::shared_ptr> weight_only_int8_fc_runner_; + std::shared_ptr> int8_fc_runner_; + +protected: + using BaseAttentionLayer::allocator_; + using BaseAttentionLayer::stream_; + using BaseAttentionLayer::sparse_; + T* qkv_buf_ = nullptr; + T* qkv_buf_tmp_ = nullptr; + T* q_buf_2_ = nullptr; + T* k_buf_2_ = nullptr; + T* v_buf_2_ = nullptr; + T* qk_buf_ = nullptr; + float* qk_buf_float_ = nullptr; + T* qkv_buf_2_ = nullptr; + T* qkv_buf_3_ = nullptr; + char* mixed_gemm_workspace_ = nullptr; + size_t mixed_gemm_ws_bytes_ = 0; + char* int8_gemm_workspace_ = nullptr; + size_t int8_gemm_ws_bytes_ = 0; + + // int8_mode_ == 0 means we don't use any mechanism related to INT8. + // int8_mode_ == 1 for weight quantized only gemm for GPT + // int8_mode_ == 2 for SmoothQuant O3 (per tensor scales) + const int int8_mode_ = 0; + +public: + LlamaContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse = false, + int int8_mode = 0); + + LlamaContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t local_head_num, + size_t local_kv_head_num, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse = false, + int int8_mode = 0); + + LlamaContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t local_head_num, + size_t local_kv_head_num, + size_t rotary_embedding_dim, + bool neox_rotary_style_, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse = false, + int int8_mode = 0); + + LlamaContextAttentionLayer(LlamaContextAttentionLayer const& attention_layer); + + virtual ~LlamaContextAttentionLayer(); + + void + forward(TensorMap* output_tensors, TensorMap* input_tensors, const AttentionWeight* attention_weights) override; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc new file mode 100644 index 000000000..5d12ff9a4 --- /dev/null +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -0,0 +1,731 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/utils/logger.h" +#include "src/fastertransformer/utils/memory_utils.h" +#include "src/fastertransformer/kernels/repeat_kv_kernels.h" +#include "src/fastertransformer/utils/nvtx_utils.h" + +namespace fastertransformer { + +template +struct SATypeConverter { + using Type = T; +}; + +template<> +struct SATypeConverter { + using Type = uint16_t; +}; + +template +void fusedQKV_masked_attention_dispatch(const T* qkv_buf, + const T* qkv_bias, + const T* relative_attention_bias, + T* key_cache, + T* value_cache, + const int* cache_indir, + T* context_buf, + const bool* finished, + const int* sequence_lengths, + const int max_batch_size, + const int inference_batch_size, + const int beam_width, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + const bool neox_rotary_style, + const int memory_max_len, + const int* prefix_prompt_lengths, + const int max_prefix_prompt_length, + const int max_input_len, + const int* total_padding_tokens, + const int step, + const float q_scaling, + const int relative_attention_bias_stride, + const T* linear_bias_slopes, + const bool* masked_tokens, + const int* ia3_tasks, + const T* ia3_key_weights, + const T* ia3_value_weights, + const float* qkv_scale_out, + const float* attention_out_scale, + const int int8_mode, + cudaStream_t stream) +{ + using DataType = typename SATypeConverter::Type; + // Prepare the parameters. + Masked_multihead_attention_params params; + memset(¶ms, 0, sizeof(params)); + int hidden_units = head_num * size_per_head; + if (qkv_bias != nullptr) { + params.q_bias = reinterpret_cast(qkv_bias); + params.k_bias = reinterpret_cast(qkv_bias) + hidden_units; + params.v_bias = reinterpret_cast(qkv_bias) + 2 * hidden_units; + } + else { + params.q_bias = nullptr; + params.k_bias = nullptr; + params.v_bias = nullptr; + } + + // Set the output buffer. + params.out = reinterpret_cast(context_buf); + + // Set the input buffers. + params.q = reinterpret_cast(qkv_buf); + if (int8_mode != 2) { + params.k = reinterpret_cast(qkv_buf) + hidden_units; + params.v = reinterpret_cast(qkv_buf) + 2 * hidden_units; + } + else { + params.k = reinterpret_cast(reinterpret_cast(qkv_buf) + hidden_units); + params.v = reinterpret_cast(reinterpret_cast(qkv_buf) + 2 * hidden_units); + } + params.stride = 3 * hidden_units; + params.finished = const_cast(finished); + + params.k_cache = reinterpret_cast(key_cache); + params.v_cache = reinterpret_cast(value_cache); + params.cache_indir = cache_indir; + params.batch_size = inference_batch_size; + params.beam_width = beam_width; + params.memory_max_len = memory_max_len; + params.prefix_prompt_lengths = prefix_prompt_lengths; + params.max_prefix_prompt_length = max_prefix_prompt_length; + params.length_per_sample = sequence_lengths; // max_input_length + current output length + // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation + params.timestep = step + max_prefix_prompt_length - 1; + params.num_heads = head_num; + params.hidden_size_per_head = size_per_head; + params.rotary_embedding_dim = rotary_embedding_dim; + params.neox_rotary_style = neox_rotary_style; + // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust) + params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling); + + params.total_padding_tokens = total_padding_tokens; + if (relative_attention_bias != nullptr) { + params.relative_attention_bias = reinterpret_cast(relative_attention_bias); + } + params.relative_attention_bias_stride = relative_attention_bias_stride; + params.masked_tokens = masked_tokens; + + // The slope of linear position bias per head, e.g., ALiBi. + if (linear_bias_slopes != nullptr) { + params.linear_bias_slopes = reinterpret_cast(linear_bias_slopes); + } + params.max_input_length = max_input_len; + + params.ia3_tasks = ia3_tasks; + params.ia3_key_weights = reinterpret_cast(ia3_key_weights); + params.ia3_value_weights = reinterpret_cast(ia3_value_weights); + + params.int8_mode = int8_mode; + if (int8_mode == 2) { + params.qkv_scale_out = qkv_scale_out; + params.attention_out_scale = attention_out_scale; + } + + PUSH_RANGE("scaled dot-product fusion"); + masked_multihead_attention(params, stream); + POP_RANGE; +} + +#define INSTANTIATE_FUSEDQKV_MASKED_ATTENTION_DISPATCH(T) \ + template void fusedQKV_masked_attention_dispatch(const T* qkv_buf, \ + const T* qkv_bias, \ + const T* relative_attention_bias, \ + T* key_cache, \ + T* value_cache, \ + const int* cache_indir, \ + T* context_buf, \ + const bool* finished, \ + const int* sequence_lengths, \ + const int max_batch_size, \ + const int inference_batch_size, \ + const int beam_width, \ + const int head_num, \ + const int size_per_head, \ + const int rotary_embedding_dim, \ + const bool neox_rotary_style, \ + const int memory_max_len, \ + const int* prefix_prompt_lengths, \ + const int max_prefix_prompt_length, \ + const int max_input_len, \ + const int* total_padding_tokens, \ + const int step, \ + const float q_scaling, \ + const int relative_attention_bias_stride, \ + const T* linear_bias_slopes, \ + const bool* masked_tokens, \ + const int* ia3_tasks, \ + const T* ia3_key_weights, \ + const T* ia3_value_weights, \ + const float* qkv_scale_out, \ + const float* attention_out_scale, \ + const int int8_mode, \ + cudaStream_t stream) + +INSTANTIATE_FUSEDQKV_MASKED_ATTENTION_DISPATCH(float); +INSTANTIATE_FUSEDQKV_MASKED_ATTENTION_DISPATCH(half); +#ifdef ENABLE_BF16 +INSTANTIATE_FUSEDQKV_MASKED_ATTENTION_DISPATCH(__nv_bfloat16); +#endif + +#undef INSTANTIATE_FUSEDQKV_MASKED_ATTENTION_DISPATCH + +template +void LlamaDecoderSelfAttentionLayer::allocateBuffer() +{ + FT_CHECK_WITH_INFO(false, "Deprecated. Use `allocateBuffer(size_t batch_size)` instead"); +} + +template +void LlamaDecoderSelfAttentionLayer::allocateBuffer(size_t batch_size) +{ + const size_t type_size = int8_mode_ == 2 ? sizeof(int8_t) : sizeof(T); + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + const size_t local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; + qkv_buf_tmp_ = + reinterpret_cast(allocator_->reMalloc(qkv_buf_tmp_, type_size * batch_size * size_per_head_ * local_q_kv_head_num, false)); + qkv_buf_ = + reinterpret_cast(allocator_->reMalloc(qkv_buf_, type_size * batch_size * 3 * local_hidden_units_, false)); + context_buf_ = + reinterpret_cast(allocator_->reMalloc(context_buf_, type_size * batch_size * local_hidden_units_, false)); + + if (int8_mode_ == 1) { + // We use max_size for n and k since we reuse buffers for both FCs and want to allocate the max + // possible memory that would be required by any of the individual gemms. + const int max_size = std::max(d_model_, 3 * local_hidden_units_); + mixed_gemm_ws_bytes_ = weight_only_int8_fc_runner_->getWorkspaceSize(batch_size, max_size, max_size); + mixed_gemm_workspace_ = (char*)allocator_->reMalloc(mixed_gemm_workspace_, mixed_gemm_ws_bytes_, false); + } + else if (int8_mode_ == 2) { + const int max_size = std::max(d_model_, 3 * local_hidden_units_); + int8_gemm_ws_bytes_ = int8_fc_runner_->getWorkspaceSize(batch_size, max_size, max_size); + int8_gemm_workspace_ = (char*)allocator_->reMalloc(int8_gemm_workspace_, int8_gemm_ws_bytes_, false); + } + + is_allocate_buffer_ = true; +} + +template +void LlamaDecoderSelfAttentionLayer::freeBuffer() +{ + if (is_allocate_buffer_) { + allocator_->free((void**)(&qkv_buf_)); + allocator_->free((void**)(&qkv_buf_tmp_)); + allocator_->free((void**)(&context_buf_)); + is_allocate_buffer_ = false; + + if (mixed_gemm_workspace_) { + allocator_->free((void**)(&mixed_gemm_workspace_)); + mixed_gemm_ws_bytes_ = 0; + } + } +} + +template +bool LlamaDecoderSelfAttentionLayer::isValidBatchSize(size_t batch_size) +{ + if (batch_size <= max_batch_size_) { + return true; + } + else { + freeBuffer(); + max_batch_size_ = batch_size * 1.2; + return true; + } +} + +template +LlamaDecoderSelfAttentionLayer::LlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t local_head_num, + size_t local_kv_head_num, + size_t rotary_embedding_dim, + bool neox_rotary_style, + size_t d_model, + const float q_scaling, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode): + BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), + max_batch_size_(max_batch_size), + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + hidden_units_(head_num_ * size_per_head_), + local_head_num_(local_head_num), + local_kv_head_num_(local_kv_head_num), + local_hidden_units_(local_head_num_ * size_per_head_), + rotary_embedding_dim_(rotary_embedding_dim), + neox_rotary_style_(neox_rotary_style), + d_model_(d_model), + q_scaling_(q_scaling), + int8_fc_runner_(int8_mode == 2 ? std::make_shared>() : nullptr), + int8_mode_(int8_mode) +{ + FT_CHECK(size_per_head_ == 32 || size_per_head_ == 48 || size_per_head_ == 64 || size_per_head_ == 80 + || size_per_head_ == 96 || size_per_head_ == 128 || size_per_head_ == 144 || size_per_head_ == 160 + || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256); + if (int8_mode_ == 1) { + FT_CHECK_WITH_INFO(!(std::is_same::value), "Weight only quant not supported for fp32."); + weight_only_int8_fc_runner_ = std::make_shared>(); + } +} + +template +LlamaDecoderSelfAttentionLayer::LlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode): + LlamaDecoderSelfAttentionLayer(max_batch_size, + head_num, + kv_head_num, + size_per_head, + head_num, + kv_head_num, + 0, + false, + head_num * size_per_head, + 1.0f, + stream, + cublas_wrapper, + allocator, + is_free_buffer_after_forward, + sparse, + int8_mode) +{ +} + +template +LlamaDecoderSelfAttentionLayer::LlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + const float q_scaling, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode): + LlamaDecoderSelfAttentionLayer(max_batch_size, + head_num, + kv_head_num, + size_per_head, + head_num, + kv_head_num, + 0, + false, + head_num * size_per_head, + q_scaling, + stream, + cublas_wrapper, + allocator, + is_free_buffer_after_forward, + sparse, + int8_mode) +{ +} + +template +LlamaDecoderSelfAttentionLayer::LlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t local_head_num, + size_t local_kv_head_num, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode): + LlamaDecoderSelfAttentionLayer(max_batch_size, + head_num, + kv_head_num, + size_per_head, + local_head_num, + local_kv_head_num, + 0, + false, + head_num * size_per_head, + 1.0f, + stream, + cublas_wrapper, + allocator, + is_free_buffer_after_forward, + sparse, + int8_mode) +{ +} + +template +LlamaDecoderSelfAttentionLayer::LlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t local_head_num, + size_t local_kv_head_num, + size_t d_model, + const float q_scaling, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode): + LlamaDecoderSelfAttentionLayer(max_batch_size, + head_num, + kv_head_num, + size_per_head, + local_head_num, + local_kv_head_num, + 0, + false, + d_model, + q_scaling, + stream, + cublas_wrapper, + allocator, + is_free_buffer_after_forward, + sparse, + int8_mode) +{ +} + +template +LlamaDecoderSelfAttentionLayer::LlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t local_head_num, + size_t local_kv_head_num, + size_t rotary_embedding_dim, + bool neox_rotary_style, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse, + int int8_mode): + LlamaDecoderSelfAttentionLayer(max_batch_size, + head_num, + kv_head_num, + size_per_head, + local_head_num, + local_kv_head_num, + rotary_embedding_dim, + neox_rotary_style, + head_num * size_per_head, + 1.0f, + stream, + cublas_wrapper, + allocator, + is_free_buffer_after_forward, + sparse, + int8_mode) +{ +} + +template +LlamaDecoderSelfAttentionLayer::LlamaDecoderSelfAttentionLayer(LlamaDecoderSelfAttentionLayer const& attention_layer): + LlamaDecoderSelfAttentionLayer(attention_layer.max_batch_size_, + attention_layer.head_num_, + attention_layer.kv_head_num_, + attention_layer.size_per_head_, + attention_layer.local_head_num_, + attention_layer.local_kv_head_num_, + attention_layer.rotary_embedding_dim_, + attention_layer.neox_rotary_style_, + attention_layer.d_model_, + attention_layer.q_scaling_, + attention_layer.stream_, + attention_layer.cublas_wrapper_, + attention_layer.allocator_, + attention_layer.is_free_buffer_after_forward_, + attention_layer.sparse_, + attention_layer.int8_mode_) +{ +} + +template +LlamaDecoderSelfAttentionLayer::~LlamaDecoderSelfAttentionLayer() +{ + cublas_wrapper_ = nullptr; + freeBuffer(); +} + +template +void LlamaDecoderSelfAttentionLayer::forward(TensorMap* output_tensors, + TensorMap* input_tensors, + const AttentionWeight* attention_weights) +{ + // input tensors: + // input_query [batch_size, d_model_], + // sequence_lengths [batch_size] + // step [1] on cpu + // finished [batch_size] (optional) + // total_padding_tokens [batch_size] (optional) + // max_input_length [1] on cpu (optional) + // masked_tokens [batch_size, memory_len], (optional) + // cache_indirection [batch_size / beam_width, beam_width, memory_max_len] (optional) + // d_prefix_prompt_lengths [batch_size] (optional) + // max_prefix_prompt_length [1] on cpu (optional) + // relative_attention_bias [1, head_num, step, step] or [1, head_num, max_seq_len, max_seq_len] (optional) + // linear_bias_slopes [head_num] (optional) + // ia3_tasks [batch_size] (optional) + + // output tensors: + // attention_output [batch_size, d_model_], + // key_cache [batch, local_head_num, size_per_head // x, memory_max_len, x] + // value_cache [batch, local_head_num, memory_max_len, size_per_head] + + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + FT_CHECK(output_tensors->at("key_cache").shape.size() == 5 || output_tensors->at("key_cache").shape.size() == 3); + FT_CHECK(output_tensors->at("value_cache").shape.size() == 4 + || output_tensors->at("value_cache").shape.size() == 3); + allocateBuffer(input_tensors->at("input_query").shape[0]); + + const T* attention_input = input_tensors->getPtr("input_query"); + const int* sequence_lengths = input_tensors->getPtr("sequence_lengths"); + const bool* finished = input_tensors->getPtr("finished", nullptr); + const bool* masked_tokens = input_tensors->getPtr("masked_tokens", nullptr); + const int* cache_indir = input_tensors->getPtr("cache_indirection", nullptr); + const T* relative_attention_bias = input_tensors->getPtr("relative_attention_bias", nullptr); + const int relative_attention_bias_stride = + input_tensors->isExist("relative_attention_bias") ? input_tensors->at("relative_attention_bias").shape[3] : 0; + const T* linear_bias_slopes = input_tensors->getPtr("linear_bias_slopes", nullptr); + const bool has_ia3 = input_tensors->isExist("ia3_tasks"); + + T* attention_out = output_tensors->getPtr("hidden_features"); + T* key_cache = output_tensors->getPtr("key_cache"); + T* value_cache = output_tensors->getPtr("value_cache"); + + const int batch_size = input_tensors->at("input_query").shape[0]; + const int beam_width = cache_indir != nullptr ? input_tensors->at("cache_indirection").shape[1] : 1; + const int memory_max_len = output_tensors->at("key_cache").shape[3]; + + const int* d_prefix_prompt_lengths = input_tensors->getPtr("d_prefix_prompt_lengths", nullptr); + const int max_prefix_prompt_length = input_tensors->getVal("max_prefix_prompt_length", 0); + + const int m_padded = 8 * div_up(batch_size, 8); +#ifdef SPARSITY_ENABLED + bool use_sparse_gemm = sparse_ && cublas_wrapper_->isUseSparse(1, 3 * local_hidden_units_, m_padded, d_model_); +#else + constexpr bool use_sparse_gemm = false; +#endif + + PUSH_RANGE("qkv_gemm"); + if (use_sparse_gemm) { +#ifdef SPARSITY_ENABLED + cublas_wrapper_->SpGemm(CUBLAS_OP_N, + CUBLAS_OP_N, + 3 * local_hidden_units_, + m_padded, + d_model_, + attention_weights->query_weight.sp_kernel, + attention_input, + qkv_buf_); +#endif + } + else { + if (int8_mode_ == 1) { + FT_CHECK(weight_only_int8_fc_runner_.get() != NULL && attention_weights->query_weight.int8_kernel != NULL + && attention_weights->query_weight.weight_only_quant_scale != NULL); + + weight_only_int8_fc_runner_->gemm( + attention_input, + reinterpret_cast(attention_weights->query_weight.int8_kernel), + attention_weights->query_weight.weight_only_quant_scale, + qkv_buf_, + batch_size, + 3 * local_hidden_units_, + d_model_, + mixed_gemm_workspace_, + mixed_gemm_ws_bytes_, + stream_); + } + else if (int8_mode_ == 2) { + // Here, we set per_column_scaling to be true because q, k, v may + // use different scales. So, we pass a pointer with shape [3, local_hidden_units_] like + // [s_q, s_q, ..., s_q, s_k, s_k, ..., s_k, s_v, s_v, ..., s_v], + // where s_q are scales of q, s_k are scales of k and s_v are scales of v. + cublas_wrapper_->Int8Gemm(3 * local_hidden_units_, + batch_size, + d_model_, + attention_weights->query_weight.int8_kernel, + d_model_, + input_tensors->getPtr("input_query"), + d_model_, + reinterpret_cast(qkv_buf_), + 3 * local_hidden_units_, + attention_weights->query_weight.scale_inter, + true); + } + else { + size_t local_qkv_size = local_hidden_units_ + 2 * local_kv_head_num_ * size_per_head_; + cublas_wrapper_->Gemm(CUBLAS_OP_N, + CUBLAS_OP_N, + local_qkv_size, // n + batch_size, + d_model_, // k + attention_weights->query_weight.kernel, + local_qkv_size, // n + attention_input, + d_model_, // k + qkv_buf_tmp_, + local_qkv_size /* n */); + invokeRepeatKv(qkv_buf_, + qkv_buf_tmp_, + local_head_num_, + local_kv_head_num_, + size_per_head_, + batch_size, + stream_); + } + } + sync_check_cuda_error(); + POP_RANGE; + fusedQKV_masked_attention_dispatch( + qkv_buf_, + attention_weights->query_weight.bias, + relative_attention_bias, + key_cache, + value_cache, + cache_indir, + context_buf_, + finished, + sequence_lengths, // NOTE: current seq len including padding (fixed after meeting the finished id) + batch_size, + batch_size, + beam_width, + local_head_num_, + size_per_head_, + rotary_embedding_dim_, + neox_rotary_style_, + memory_max_len, + d_prefix_prompt_lengths, + max_prefix_prompt_length, + input_tensors->getVal("max_input_length", 0), + input_tensors->getPtr("total_padding_tokens", nullptr), + input_tensors->getVal("step"), + q_scaling_, + relative_attention_bias_stride, + linear_bias_slopes, + masked_tokens, + input_tensors->getPtr("ia3_tasks", nullptr), + has_ia3 ? attention_weights->ia3_key_weight.kernel : nullptr, + has_ia3 ? attention_weights->ia3_value_weight.kernel : nullptr, + int8_mode_ == 2 ? attention_weights->query_weight.scale_out : nullptr, + int8_mode_ == 2 ? attention_weights->attention_output_weight.scale : nullptr, + int8_mode_, + stream_); + sync_check_cuda_error(); + + PUSH_RANGE("proj gemm"); +#ifdef SPARSITY_ENABLED + use_sparse_gemm = sparse_ && cublas_wrapper_->isUseSparse(1, d_model_, m_padded, local_hidden_units_); +#endif + + if (use_sparse_gemm) { +#ifdef SPARSITY_ENABLED + cublas_wrapper_->SpGemm(CUBLAS_OP_N, + CUBLAS_OP_N, + d_model_, + m_padded, + local_hidden_units_, + attention_weights->attention_output_weight.sp_kernel, + context_buf_, + attention_out); +#endif + } + else { + if (int8_mode_ == 1) { + FT_CHECK(weight_only_int8_fc_runner_.get() != NULL + && attention_weights->attention_output_weight.int8_kernel != NULL + && attention_weights->attention_output_weight.weight_only_quant_scale != NULL); + + weight_only_int8_fc_runner_->gemm( + context_buf_, + reinterpret_cast(attention_weights->attention_output_weight.int8_kernel), + attention_weights->attention_output_weight.weight_only_quant_scale, + attention_out, + batch_size, + d_model_, + local_hidden_units_, + mixed_gemm_workspace_, + mixed_gemm_ws_bytes_, + stream_); + } + else if (int8_mode_ == 2) { + int8_fc_runner_->gemm(reinterpret_cast(context_buf_), + attention_weights->attention_output_weight.int8_kernel, + QuantMode::PerTensorQuant, + attention_weights->attention_output_weight.scale_inter, + attention_weights->attention_output_weight.scale_out, + output_tensors->getPtr("hidden_features"), + batch_size, + d_model_, + local_hidden_units_, + nullptr, + 0, + stream_); + } + else { + cublas_wrapper_->Gemm(CUBLAS_OP_N, + CUBLAS_OP_N, + d_model_, // n + batch_size, + local_hidden_units_, // k + attention_weights->attention_output_weight.kernel, + d_model_, // n + context_buf_, + local_hidden_units_, // k + attention_out, + d_model_ /* n */); + } + sync_check_cuda_error(); + } + POP_RANGE; + + if (is_free_buffer_after_forward_) { + freeBuffer(); + } +} + +template class LlamaDecoderSelfAttentionLayer; +template class LlamaDecoderSelfAttentionLayer; +#ifdef ENABLE_BF16 +template class LlamaDecoderSelfAttentionLayer<__nv_bfloat16>; +#endif + +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.h b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.h new file mode 100644 index 000000000..0a3c6e052 --- /dev/null +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.h @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" +#include "src/fastertransformer/kernels/cutlass_kernels/int8_gemm/int8_gemm.h" +#include "src/fastertransformer/kernels/matrix_vector_multiplication.h" +#include "src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h" + +namespace fastertransformer { + +template +class LlamaDecoderSelfAttentionLayer: public BaseAttentionLayer { +private: + // buffer handling + size_t max_batch_size_; + + // metadata + const size_t head_num_; + const size_t kv_head_num_; + const size_t size_per_head_; + const size_t hidden_units_; + const size_t local_head_num_; + const size_t local_kv_head_num_; + const size_t local_hidden_units_; + const size_t d_model_; + const float q_scaling_; + const size_t rotary_embedding_dim_; + const bool neox_rotary_style_; + + std::shared_ptr> weight_only_int8_fc_runner_; + std::shared_ptr> int8_fc_runner_; + + void allocateBuffer() override; + void freeBuffer() override; + bool isValidBatchSize(size_t batch_size); + void allocateBuffer(size_t batch_size); + + using BaseAttentionLayer::is_free_buffer_after_forward_; + using BaseAttentionLayer::is_allocate_buffer_; + using BaseAttentionLayer::cublas_wrapper_; + +protected: + T* qkv_buf_ = nullptr; + T* qkv_buf_tmp_ = nullptr; + T* context_buf_ = nullptr; + char* mixed_gemm_workspace_ = nullptr; + size_t mixed_gemm_ws_bytes_ = 0; + char* int8_gemm_workspace_ = nullptr; + size_t int8_gemm_ws_bytes_ = 0; + using BaseAttentionLayer::stream_; + using BaseAttentionLayer::sparse_; + using BaseAttentionLayer::allocator_; + + // int8_mode_ == 0 means we don't use any mechanism related to INT8. + // int8_mode_ == 1 for weight quantized only gemm for GPT + // int8_mode_ == 2 for SmoothQuant O3 (per tensor scales) + const int int8_mode_ = 0; + +public: + LlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t local_head_num, + size_t local_kv_head_num, + size_t rotary_embedding_dim, + bool neox_rotary_style, + size_t d_model, + const float q_scaling, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0); + + LlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0); + + LlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + const float q_scaling, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0); + + LlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t local_head_num, + size_t local_kv_head_num, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0); + + LlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t local_head_num, + size_t local_kv_head_num, + size_t d_model, + const float q_scaling, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0); + + LlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t local_head_num, + size_t local_kv_head_num, + size_t rotary_embedding_dim, + bool neox_rotary_style, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0); + + LlamaDecoderSelfAttentionLayer(LlamaDecoderSelfAttentionLayer const& attention_layer); + + ~LlamaDecoderSelfAttentionLayer(); + + void + forward(TensorMap* output_tensors, TensorMap* input_tensors, const AttentionWeight* attention_weights) override; +}; + +template +void fusedQKV_masked_attention_dispatch(const T* qkv_buf, + const T* qkv_bias, + const T* relative_attention_bias, + T* key_cache, + T* value_cache, + const int* cache_indir, + T* context_buf, + const bool* finished, + const int* sequence_lengths, + const int max_batch_size, + const int inference_batch_size, + const int beam_width, + const int head_num, + const int kv_head_num, + const int size_per_head, + const int rotary_embedding_dim, + const bool neox_rotary_style, + const int memory_max_len, + const int* prefix_prompt_lengths, + const int max_prefix_prompt_length, + const int max_input_len, + const int* total_padding_tokens, + const int step, + const float q_scaling, + const int relative_attention_bias_stride, + const T* linear_bias_slopes, + const bool* masked_tokens, + const int* ia3_tasks, + const T* ia3_key_weights, + const T* ia3_value_weights, + const float* qkv_scale_out, + const float* attention_out_scale, + const int int8_mode, + cudaStream_t stream); + +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/TensorParallelLlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/TensorParallelLlamaContextAttentionLayer.cc new file mode 100644 index 000000000..45252cfae --- /dev/null +++ b/src/fastertransformer/layers/attention_layers/TensorParallelLlamaContextAttentionLayer.cc @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/layers/attention_layers/TensorParallelLlamaContextAttentionLayer.h" +#include "src/fastertransformer/utils/nvtx_utils.h" + +namespace fastertransformer { + +template +void TensorParallelLlamaContextAttentionLayer::forward(TensorMap* output_tensors, + TensorMap* input_tensors, + const AttentionWeight* attention_weights) +{ + // input_tensors: + // input_query [batch_size * seq_len, hidden_dimension] + // attention_mask [batch_size, 1, seq_len, seq_len] + // is_final_layer [1], bool on cpu + + // output_tensors: + // hidden_features [batch_size * seq_len, hidden_dimension] + // key_cache [batch, local_head_num, size_per_head // x, max_seq_len, x] + // value_cache [batch, local_head_num, max_seq_len, size_per_head] + + const size_t size = output_tensors->at("hidden_features").size(); + + bool use_custom_all_reduce_kernel = false; + if (do_all_reduce_ && enable_custom_all_reduce_ && custom_all_reduce_comm_ != nullptr) { + std::vector reduce_tensor{output_tensors->at("hidden_features")}; + use_custom_all_reduce_kernel = custom_all_reduce_comm_->swapInternalBuffer(&reduce_tensor, size); + } + + LlamaContextAttentionLayer::forward(output_tensors, input_tensors, attention_weights); + + PUSH_RANGE("all reduce sum"); + T* attention_out = output_tensors->getPtr("hidden_features"); + if (do_all_reduce_ && tensor_para_.world_size_ > 1) { + if (!use_custom_all_reduce_kernel) { + ftNcclAllReduceSum(attention_out, attention_out, size, tensor_para_, LlamaContextAttentionLayer::stream_); + } + else { + custom_all_reduce_comm_->customAllReduce(size, LlamaContextAttentionLayer::stream_); + } + sync_check_cuda_error(); + } + POP_RANGE; +} + +template +TensorParallelLlamaContextAttentionLayer::TensorParallelLlamaContextAttentionLayer( + size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse, + int int8_mode, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + LlamaContextAttentionLayer(max_batch_size, + max_seq_len, + head_num, + kv_head_num, + size_per_head, + head_num / tensor_para.world_size_, + kv_head_num / tensor_para.world_size_, + stream, + cublas_wrapper, + allocator, + is_free_buffer_after_forward, + is_qk_buf_float, + sparse, + int8_mode), + tensor_para_(tensor_para), + custom_all_reduce_comm_(custom_all_reduce_comm), + enable_custom_all_reduce_(enable_custom_all_reduce), + do_all_reduce_(do_all_reduce) +{ + FT_CHECK(head_num % tensor_para_.world_size_ == 0); +} + +template +TensorParallelLlamaContextAttentionLayer::TensorParallelLlamaContextAttentionLayer( + size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse, + int int8_mode, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + LlamaContextAttentionLayer(max_batch_size, + max_seq_len, + head_num, + kv_head_num, + size_per_head, + head_num / tensor_para.world_size_, + kv_head_num / tensor_para.world_size_, + rotary_embedding_dim, + neox_rotary_style, + stream, + cublas_wrapper, + allocator, + is_free_buffer_after_forward, + is_qk_buf_float, + sparse, + int8_mode), + tensor_para_(tensor_para), + custom_all_reduce_comm_(custom_all_reduce_comm), + enable_custom_all_reduce_(enable_custom_all_reduce), + do_all_reduce_(do_all_reduce) +{ + FT_CHECK(head_num % tensor_para_.world_size_ == 0); +} + +template +TensorParallelLlamaContextAttentionLayer::TensorParallelLlamaContextAttentionLayer( + TensorParallelLlamaContextAttentionLayer const& attention_layer): + LlamaContextAttentionLayer(attention_layer), + tensor_para_(attention_layer.tensor_para_), + custom_all_reduce_comm_(attention_layer.custom_all_reduce_comm_), + enable_custom_all_reduce_(attention_layer.enable_custom_all_reduce_), + do_all_reduce_(attention_layer.do_all_reduce_) +{ +} + +template class TensorParallelLlamaContextAttentionLayer; +template class TensorParallelLlamaContextAttentionLayer; +#ifdef ENABLE_BF16 +template class TensorParallelLlamaContextAttentionLayer<__nv_bfloat16>; +#endif + +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/TensorParallelLlamaContextAttentionLayer.h b/src/fastertransformer/layers/attention_layers/TensorParallelLlamaContextAttentionLayer.h new file mode 100644 index 000000000..27ff18384 --- /dev/null +++ b/src/fastertransformer/layers/attention_layers/TensorParallelLlamaContextAttentionLayer.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" + +namespace fastertransformer { + +template +class TensorParallelLlamaContextAttentionLayer: public LlamaContextAttentionLayer { +private: + NcclParam tensor_para_; + std::shared_ptr custom_all_reduce_comm_; + int enable_custom_all_reduce_; + bool do_all_reduce_; + +public: + TensorParallelLlamaContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse = false, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); + + TensorParallelLlamaContextAttentionLayer(size_t max_batch_size, + size_t max_seq_len, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + bool sparse = false, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); + + TensorParallelLlamaContextAttentionLayer(TensorParallelLlamaContextAttentionLayer const& attention_layer); + + void + forward(TensorMap* output_tensors, TensorMap* input_tensors, const AttentionWeight* attention_weights) override; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/TensorParallelLlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/TensorParallelLlamaDecoderSelfAttentionLayer.cc new file mode 100644 index 000000000..a48d8a18d --- /dev/null +++ b/src/fastertransformer/layers/attention_layers/TensorParallelLlamaDecoderSelfAttentionLayer.cc @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/layers/attention_layers/TensorParallelLlamaDecoderSelfAttentionLayer.h" +#include "src/fastertransformer/utils/nvtx_utils.h" + +namespace fastertransformer { + +template +TensorParallelLlamaDecoderSelfAttentionLayer::TensorParallelLlamaDecoderSelfAttentionLayer( + size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + size_t d_model, + float q_scaling, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + int int8_mode, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + LlamaDecoderSelfAttentionLayer(max_batch_size, + head_num, + kv_head_num, + size_per_head, + head_num / tensor_para.world_size_, + kv_head_num / tensor_para.world_size_, + rotary_embedding_dim, + neox_rotary_style, + d_model, + q_scaling, // NOTE + stream, + cublas_wrapper, + allocator, + is_free_buffer_after_forward, + is_sparse, + int8_mode), + do_all_reduce_(do_all_reduce), + tensor_para_(tensor_para), + custom_all_reduce_comm_(custom_all_reduce_comm), + enable_custom_all_reduce_(enable_custom_all_reduce) +{ + FT_CHECK(head_num % tensor_para_.world_size_ == 0); +} + +template +TensorParallelLlamaDecoderSelfAttentionLayer::TensorParallelLlamaDecoderSelfAttentionLayer( + size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + int int8_mode, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + TensorParallelLlamaDecoderSelfAttentionLayer(max_batch_size, + head_num, + kv_head_num, + size_per_head, + 0, + false, + head_num * size_per_head, + 1.0f, + tensor_para, + stream, + cublas_wrapper, + allocator, + do_all_reduce, + is_free_buffer_after_forward, + is_sparse, + int8_mode, + custom_all_reduce_comm, + enable_custom_all_reduce) + +{ +} + +template +TensorParallelLlamaDecoderSelfAttentionLayer::TensorParallelLlamaDecoderSelfAttentionLayer( + size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t d_model, + float q_scaling, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + int int8_mode, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + TensorParallelLlamaDecoderSelfAttentionLayer(max_batch_size, + head_num, + kv_head_num, + size_per_head, + 0, + false, + d_model, + q_scaling, + tensor_para, + stream, + cublas_wrapper, + allocator, + do_all_reduce, + is_free_buffer_after_forward, + is_sparse, + int8_mode, + custom_all_reduce_comm, + enable_custom_all_reduce) +{ +} + +template +TensorParallelLlamaDecoderSelfAttentionLayer::TensorParallelLlamaDecoderSelfAttentionLayer( + size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse, + int int8_mode, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + TensorParallelLlamaDecoderSelfAttentionLayer(max_batch_size, + head_num, + kv_head_num, + size_per_head, + rotary_embedding_dim, + neox_rotary_style, + head_num * size_per_head, + 1.0f, + tensor_para, + stream, + cublas_wrapper, + allocator, + do_all_reduce, + is_free_buffer_after_forward, + is_sparse, + int8_mode, + custom_all_reduce_comm, + enable_custom_all_reduce) +{ +} + +template +TensorParallelLlamaDecoderSelfAttentionLayer::TensorParallelLlamaDecoderSelfAttentionLayer( + TensorParallelLlamaDecoderSelfAttentionLayer const& attention_layer): + LlamaDecoderSelfAttentionLayer(attention_layer), + do_all_reduce_(attention_layer.do_all_reduce_), + tensor_para_(attention_layer.tensor_para_), + custom_all_reduce_comm_(attention_layer.custom_all_reduce_comm_), + enable_custom_all_reduce_(attention_layer.enable_custom_all_reduce_) +{ +} + +template +void TensorParallelLlamaDecoderSelfAttentionLayer::forward(TensorMap* output_tensors, + TensorMap* input_tensors, + const AttentionWeight* attention_weights) +{ + // input tensors: + // attention_input [batch_size, hidden_dimension], + // finished [batch_size], + // sequence_lengths [batch_size] + // input_lengths [batch_size] + // max_input_length [1] on cpu + // step [1] on cpu + + // output tensors: + // attention_output [batch_size, hidden_dimension], + // key_cache [batch, head_num, size_per_head // x, max_seq_len, x] + // value_cache [batch, head_num, max_seq_len, size_per_head] + + const size_t size = output_tensors->at("hidden_features").size(); + + bool use_custom_all_reduce_kernel = false; + if (enable_custom_all_reduce_ && custom_all_reduce_comm_ != nullptr && do_all_reduce_) { + std::vector reduce_tensor{output_tensors->at("hidden_features")}; + use_custom_all_reduce_kernel = custom_all_reduce_comm_->swapInternalBuffer(&reduce_tensor, size); + } + + LlamaDecoderSelfAttentionLayer::forward(output_tensors, input_tensors, attention_weights); + + PUSH_RANGE("all reduce sum"); + T* attention_out = output_tensors->getPtr("hidden_features"); + if (tensor_para_.world_size_ > 1 && do_all_reduce_) { + if (!use_custom_all_reduce_kernel) { + ftNcclAllReduceSum(attention_out, attention_out, size, tensor_para_, LlamaDecoderSelfAttentionLayer::stream_); + } + else { + custom_all_reduce_comm_->customAllReduce(size, LlamaDecoderSelfAttentionLayer::stream_); + } + sync_check_cuda_error(); + } + POP_RANGE; +} + +template class TensorParallelLlamaDecoderSelfAttentionLayer; +template class TensorParallelLlamaDecoderSelfAttentionLayer; +#ifdef ENABLE_BF16 +template class TensorParallelLlamaDecoderSelfAttentionLayer<__nv_bfloat16>; +#endif + +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/TensorParallelLlamaDecoderSelfAttentionLayer.h b/src/fastertransformer/layers/attention_layers/TensorParallelLlamaDecoderSelfAttentionLayer.h new file mode 100644 index 000000000..3df664b22 --- /dev/null +++ b/src/fastertransformer/layers/attention_layers/TensorParallelLlamaDecoderSelfAttentionLayer.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" + +namespace fastertransformer { + +template +class TensorParallelLlamaDecoderSelfAttentionLayer: public LlamaDecoderSelfAttentionLayer { +private: + NcclParam tensor_para_; + std::shared_ptr custom_all_reduce_comm_; + int enable_custom_all_reduce_; + bool do_all_reduce_; + +protected: +public: + TensorParallelLlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + size_t d_model, + float q_scaling, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse = false, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); + + TensorParallelLlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool is_sparse = false, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); + + TensorParallelLlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t d_model, + float q_scaling, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); + + TensorParallelLlamaDecoderSelfAttentionLayer(size_t max_batch_size, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool do_all_reduce, + bool is_free_buffer_after_forward, + bool sparse = false, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0); + + TensorParallelLlamaDecoderSelfAttentionLayer(TensorParallelLlamaDecoderSelfAttentionLayer const& attention_layer); + + void + forward(TensorMap* output_tensors, TensorMap* input_tensors, const AttentionWeight* attention_weights) override; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/CMakeLists.txt b/src/fastertransformer/models/CMakeLists.txt index 248b4af3d..c19f79d70 100644 --- a/src/fastertransformer/models/CMakeLists.txt +++ b/src/fastertransformer/models/CMakeLists.txt @@ -37,3 +37,4 @@ add_subdirectory(vit) add_subdirectory(vit_int8) add_subdirectory(wenet) +add_subdirectory(llama) diff --git a/src/fastertransformer/models/llama/CMakeLists.txt b/src/fastertransformer/models/llama/CMakeLists.txt new file mode 100644 index 000000000..50834131a --- /dev/null +++ b/src/fastertransformer/models/llama/CMakeLists.txt @@ -0,0 +1,69 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.8) + +add_library(LlamaDecoderLayerWeight STATIC LlamaDecoderLayerWeight.cc) +set_property(TARGET LlamaDecoderLayerWeight PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET LlamaDecoderLayerWeight PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(LlamaDecoderLayerWeight PUBLIC memory_utils cuda_utils logger) + +add_library(LlamaDecoder STATIC LlamaDecoder.cc) +set_property(TARGET LlamaDecoder PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET LlamaDecoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(LlamaDecoder PUBLIC -lcudart cublasMMWrapper + TensorParallelLlamaDecoderSelfAttentionLayer + TensorParallelSiluFfnLayer + layernorm_kernels + add_residual_kernels + LlamaDecoderLayerWeight + tensor + nccl_utils + cuda_utils + logger) + +add_library(LlamaContextDecoder STATIC LlamaContextDecoder.cc) +set_property(TARGET LlamaContextDecoder PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET LlamaContextDecoder PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(LlamaContextDecoder PUBLIC -lcudart cublasMMWrapper + TensorParallelLlamaContextAttentionLayer + TensorParallelSiluFfnLayer + layernorm_kernels + add_residual_kernels + gpt_kernels + tensor + nccl_utils + cuda_utils + logger) + +add_library(LlamaWeight STATIC LlamaWeight.cc) +set_property(TARGET LlamaWeight PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET LlamaWeight PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(LlamaWeight PUBLIC LlamaDecoderLayerWeight cuda_utils logger) + +add_library(Llama STATIC Llama.cc) +set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(Llama PUBLIC -lcudart + LlamaDecoder + LlamaContextDecoder + decoding_kernels + gpt_kernels + DynamicDecodeLayer + BaseBeamSearchLayer + bert_preprocess_kernels + tensor + LlamaWeight + cuda_utils + logger) diff --git a/src/fastertransformer/models/llama/Llama.cc b/src/fastertransformer/models/llama/Llama.cc new file mode 100644 index 000000000..9fcdd9169 --- /dev/null +++ b/src/fastertransformer/models/llama/Llama.cc @@ -0,0 +1,1287 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/models/llama/Llama.h" +#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" +#include "src/fastertransformer/kernels/decoding_kernels.h" +#include "src/fastertransformer/kernels/gpt_kernels.h" +#include "src/fastertransformer/layers/beam_search_layers/BaseBeamSearchLayer.h" +#include + +namespace fastertransformer { + +template +void Llama::initialize() +{ + gpt_context_decoder_ = new LlamaContextDecoder(head_num_, + kv_head_num_, + size_per_head_, + inter_size_, + num_layer_, + rotary_embedding_dim_, + neox_rotary_style_, + use_gptj_residual_, + layernorm_eps_, + tensor_para_, + pipeline_para_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + is_context_qk_buf_float_, + attention_type_, + int8_mode_, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + + gpt_decoder_ = new LlamaDecoder(head_num_, + kv_head_num_, + size_per_head_, + inter_size_, + num_layer_, + rotary_embedding_dim_, + neox_rotary_style_, + use_gptj_residual_, + layernorm_eps_, + tensor_para_, + pipeline_para_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + int8_mode_, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + + dynamic_decode_layer_ = new DynamicDecodeLayer(vocab_size_, + vocab_size_padded_, + 0, // end_id, deprecated + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + cuda_device_prop_); + + // parse env overrides + if (std::getenv("LLAMA_STREAM_CB_STEP") != nullptr) { + try { + int callback_step_from_env = stoi( + std::string(std::getenv("LLAMA_STREAM_CB_STEP")) + ); + token_generated_cb_step_ = callback_step_from_env; + FT_LOG_INFO("Override stream callback step to %d from LLAMA_STREAM_CB_STEP", + token_generated_cb_step_); + } catch (...) { + FT_LOG_WARNING("convert LLAMA_STREAM_CB_STEP err, use default value %d", + token_generated_cb_step_); + } + } +} + +template +void Llama::allocateBuffer() +{ + FT_CHECK(false); +} + +template +void Llama::allocateBuffer( + size_t batch_size, size_t beam_width, size_t max_seq_len, size_t max_cache_seq_len, size_t max_input_len) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + const size_t batchxbeam = batch_size * beam_width; + const size_t self_cache_size = (num_layer_ / pipeline_para_.world_size_) * batchxbeam * max_cache_seq_len + * hidden_units_ / tensor_para_.world_size_; + + if (vocab_size_ != vocab_size_padded_) { + padded_embedding_kernel_ = + (T*)(allocator_->reMalloc(padded_embedding_kernel_, sizeof(T) * hidden_units_ * vocab_size_padded_, true)); + padded_embedding_kernel_ptr_ = padded_embedding_kernel_; + } + + input_attention_mask_ = (T*)(allocator_->reMalloc( + input_attention_mask_, sizeof(T) * batchxbeam * max_seq_len * max_cache_seq_len, false)); + decoder_input_buf_ = (T*)(allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); + decoder_output_buf_ = + (T*)(allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); + normed_decoder_output_buf_ = + (T*)(allocator_->reMalloc(normed_decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units_, false)); + logits_buf_ = (float*)(allocator_->reMalloc(logits_buf_, sizeof(float) * batchxbeam * vocab_size_padded_, false)); + nccl_logits_buf_ = + (float*)(allocator_->reMalloc(nccl_logits_buf_, sizeof(float) * batchxbeam * vocab_size_padded_, false)); + cum_log_probs_ = (float*)(allocator_->reMalloc(cum_log_probs_, sizeof(float) * batchxbeam, false)); + finished_buf_ = (bool*)(allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false)); + h_finished_buf_ = new bool[batchxbeam]; + sequence_lengths_ = (int*)(allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false)); + + key_cache_ = (T*)(allocator_->reMalloc(key_cache_, sizeof(T) * self_cache_size * 2, true)); + value_cache_ = key_cache_ + self_cache_size; + if (beam_width > 1) { + cache_indirections_[0] = + (int*)(allocator_->reMalloc(cache_indirections_[0], sizeof(int) * batchxbeam * max_seq_len * 2, true)); + cache_indirections_[1] = cache_indirections_[0] + batchxbeam * max_seq_len; + } + + // prompt_learning weight batch ptrs + prompt_learning_weight_batch_ = + (const T**)(allocator_->reMalloc(prompt_learning_weight_batch_, sizeof(T*) * batchxbeam, false)); + tiled_prompt_lengths_buf_ = + (int*)(allocator_->reMalloc(tiled_prompt_lengths_buf_, sizeof(int) * batchxbeam, false)); + + tiled_input_ids_buf_ = + (int*)(allocator_->reMalloc(tiled_input_ids_buf_, sizeof(int) * batchxbeam * max_input_len, true)); + tiled_input_lengths_buf_ = (int*)(allocator_->reMalloc(tiled_input_lengths_buf_, sizeof(int) * batchxbeam, true)); + tiled_total_padding_count_ = + (int*)allocator_->reMalloc(tiled_total_padding_count_, batchxbeam * sizeof(int), false); + + transposed_output_ids_buf_ = + (int*)(allocator_->reMalloc(transposed_output_ids_buf_, sizeof(int) * batchxbeam * max_seq_len, true)); + output_ids_buf_ = (int*)(allocator_->reMalloc(output_ids_buf_, sizeof(int) * batchxbeam * max_seq_len, true)); + parent_ids_buf_ = (int*)(allocator_->reMalloc(parent_ids_buf_, sizeof(int) * batchxbeam * max_seq_len, true)); + seq_limit_len_ = (uint32_t*)(allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false)); + masked_tokens_ = (bool*)(allocator_->reMalloc(masked_tokens_, sizeof(bool) * batchxbeam * max_cache_seq_len, true)); + + start_ids_buf_ = (int*)(allocator_->reMalloc(start_ids_buf_, sizeof(int) * batch_size, false)); + end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false)); + + context_decoder_input_buf_ = (T*)(allocator_->reMalloc( + context_decoder_input_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false)); + context_decoder_output_buf_ = (T*)(allocator_->reMalloc( + context_decoder_output_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false)); + output_log_probs_buf_ = + (float*)(allocator_->reMalloc(output_log_probs_buf_, sizeof(float) * batchxbeam * max_seq_len, false)); + + generation_should_stop_ = (bool*)allocator_->reMalloc(generation_should_stop_, sizeof(bool), true, true); + + if (shared_contexts_ratio_ > 0.0f) { + shared_contexts_idx_ = (int*)allocator_->reMalloc(shared_contexts_idx_, batch_size * sizeof(int), false); + batch_to_compact_idx_ = (int*)allocator_->reMalloc(batch_to_compact_idx_, batchxbeam * sizeof(int), false); + compact_idx_ = (int*)allocator_->reMalloc(compact_idx_, batch_size * sizeof(int), false); + compact_size_ = (int*)allocator_->reMalloc(compact_size_, sizeof(int), false); + } + + is_allocate_buffer_ = true; +} + +template +void Llama::freeBuffer() +{ + if (is_allocate_buffer_) { + if (vocab_size_ != vocab_size_padded_) { + padded_embedding_kernel_ptr_ = nullptr; + allocator_->free((void**)(&padded_embedding_kernel_)); + } + + allocator_->free((void**)(&input_attention_mask_)); + allocator_->free((void**)(&decoder_input_buf_)); + allocator_->free((void**)(&decoder_output_buf_)); + allocator_->free((void**)(&normed_decoder_output_buf_)); + allocator_->free((void**)(&logits_buf_)); + allocator_->free((void**)(&nccl_logits_buf_)); + allocator_->free((void**)(&cum_log_probs_)); + allocator_->free((void**)(&finished_buf_)); + delete[] h_finished_buf_; + allocator_->free((void**)(&sequence_lengths_)); + + allocator_->free((void**)(&key_cache_)); + if (cache_indirections_[0] != nullptr) { + allocator_->free((void**)(&cache_indirections_)[0]); + } + + allocator_->free((void**)(&prompt_learning_weight_batch_)); + allocator_->free((void**)(&tiled_prompt_lengths_buf_)); + + allocator_->free((void**)(&tiled_input_ids_buf_)); + allocator_->free((void**)(&tiled_input_lengths_buf_)); + allocator_->free((void**)(&tiled_total_padding_count_)); + + allocator_->free((void**)(&transposed_output_ids_buf_)); + allocator_->free((void**)(&output_ids_buf_)); + allocator_->free((void**)(&parent_ids_buf_)); + allocator_->free((void**)(&seq_limit_len_)); + allocator_->free((void**)(&masked_tokens_)); + + allocator_->free((void**)(&start_ids_buf_)); + allocator_->free((void**)(&end_ids_buf_)); + + allocator_->free((void**)(&context_decoder_input_buf_)); + allocator_->free((void**)(&context_decoder_output_buf_)); + allocator_->free((void**)(&output_log_probs_buf_)); + + allocator_->free((void**)(&generation_should_stop_), true); + + if (shared_contexts_ratio_ > 0.0f) { + allocator_->free((void**)(&shared_contexts_idx_)); + allocator_->free((void**)(&compact_size_)); + } + + is_allocate_buffer_ = false; + } +} + +template +Llama::Llama(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + float layernorm_eps, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + bool use_gptj_residual, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop, + AttentionType attention_type, + int int8_mode, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce, + float shared_contexts_ratio): + BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop), + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + inter_size_(inter_size), + num_layer_(num_layer), + vocab_size_(vocab_size), + rotary_embedding_dim_(rotary_embedding_dim), + layernorm_eps_(layernorm_eps), + start_id_(start_id), + end_id_(end_id), + prompt_learning_start_id_(prompt_learning_start_id), + prompt_learning_type_(prompt_learning_type), + use_gptj_residual_(use_gptj_residual), + hidden_units_(head_num * size_per_head), + local_head_num_(head_num / 1), + local_kv_head_num_(kv_head_num / 1), + attention_type_(attention_type), + int8_mode_(int8_mode), + shared_contexts_ratio_(shared_contexts_ratio) +{ + tensor_para_.world_size_ = 1; + tensor_para_.rank_ = 0; + pipeline_para_.world_size_ = 1; + pipeline_para_.rank_ = 0; + + int local_vacab_size = ceil(vocab_size_ / 1.f / tensor_para_.world_size_); + if (std::is_same::value +#ifdef ENABLE_BF16 + || std::is_same<__nv_bfloat16, T>::value +#endif + ) { + local_vacab_size = ceil(local_vacab_size / 8.f) * 8; + } + vocab_size_padded_ = (size_t)local_vacab_size * tensor_para_.world_size_; + initialize(); +} + +template +Llama::Llama(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + float layernorm_eps, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + bool use_gptj_residual, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop, + AttentionType attention_type, + int int8_mode, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce, + float shared_contexts_ratio): + BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop), + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + inter_size_(inter_size), + num_layer_(num_layer), + vocab_size_(vocab_size), + rotary_embedding_dim_(rotary_embedding_dim), + layernorm_eps_(layernorm_eps), + start_id_(start_id), + end_id_(end_id), + prompt_learning_start_id_(prompt_learning_start_id), + prompt_learning_type_(prompt_learning_type), + use_gptj_residual_(use_gptj_residual), + hidden_units_(head_num * size_per_head), + tensor_para_(tensor_para), + pipeline_para_(pipeline_para), + local_head_num_(head_num / tensor_para.world_size_), + local_kv_head_num_(kv_head_num / tensor_para.world_size_), + custom_all_reduce_comm_(custom_all_reduce_comm), + enable_custom_all_reduce_(enable_custom_all_reduce), + attention_type_(attention_type), + int8_mode_(int8_mode), + shared_contexts_ratio_(shared_contexts_ratio) +{ + int local_vacab_size = ceil(vocab_size_ / 1.f / tensor_para_.world_size_); + if (std::is_same::value) { + local_vacab_size = ceil(local_vacab_size / 8.f) * 8; + } + vocab_size_padded_ = (size_t)local_vacab_size * tensor_para_.world_size_; + initialize(); +} + +template +Llama::Llama(Llama const& gpt): + BaseLayer(gpt), + head_num_(gpt.head_num_), + kv_head_num_(gpt.kv_head_num_), + size_per_head_(gpt.size_per_head_), + inter_size_(gpt.inter_size_), + num_layer_(gpt.num_layer_), + vocab_size_(gpt.vocab_size_), + rotary_embedding_dim_(gpt.rotary_embedding_dim_), + layernorm_eps_(gpt.layernorm_eps_), + start_id_(gpt.start_id_), + end_id_(gpt.end_id_), + prompt_learning_start_id_(gpt.prompt_learning_start_id_), + prompt_learning_type_(gpt.prompt_learning_type_), + use_gptj_residual_(gpt.use_gptj_residual_), + hidden_units_(gpt.hidden_units_), + tensor_para_(gpt.tensor_para_), + pipeline_para_(gpt.pipeline_para_), + local_head_num_(gpt.local_head_num_), + local_kv_head_num_(gpt.local_kv_head_num_), + vocab_size_padded_(gpt.vocab_size_padded_), + custom_all_reduce_comm_(gpt.custom_all_reduce_comm_), + enable_custom_all_reduce_(gpt.enable_custom_all_reduce_), + attention_type_(gpt.attention_type_), + int8_mode_(gpt.int8_mode_), + shared_contexts_ratio_(gpt.shared_contexts_ratio_) +{ + initialize(); +} + +template +Llama::~Llama() +{ + delete gpt_decoder_; + delete dynamic_decode_layer_; + delete gpt_context_decoder_; + freeBuffer(); +} + +template +void Llama::registerCallback(callback_sig* fn, void* ctx) +{ + token_generated_cb_ = fn; + token_generated_ctx_ = ctx; +} + +template +void Llama::unRegisterCallback() +{ + token_generated_cb_ = nullptr; + token_generated_ctx_ = nullptr; +} + +template +void Llama::forward(std::vector* output_tensors, + const std::vector* input_tensors, + const LlamaWeight* gpt_weights) +{ + FT_CHECK(false); +} + +template +void Llama::forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const LlamaWeight* gpt_weights) +{ + // input_tensors: + // input_ids [batch_size, max_input_length] + // input_lengths [batch_size] + // prompt_learning_task_name_ids [batch_size] on cpu, optional + // output_seq_len [batch_size] on cpu + // start_id [batch_size] on cpu, optional + // end_id [batch_size] on cpu, optional + // stop_words_list [batch_size, 2, stop_words_length], optional + // bad_words_list [2, bad_words_length] or [batch_size, 2, bad_words_length], optional + // runtime_top_k [1] or [batch_size] on cpu, optional, uint. + // runtime_top_p [1] or [batch_size] on cpu, optional, float. + // beam_search_diversity_rate [1] or [batch_size] on cpu, optional, float. + // temperature [1] or [batch_size] on cpu, optional, float. + // len_penalty [1] or [batch_size] on cpu, optional, float. + // repetition_penalty [1] or [batch_size] on cpu, optional, float. + // min_length [1] or [batch_size] on cpu, optional, int + // random_seed [1] or [batch_size] on cpu, optional, unsigned long long int. + // request_prompt_lengths [batch_size], optional + // request_prompt_embedding [batch_size, max_prompt_length, hidden_units], float, optional + // requst_prompt_type [batch_size], int, optional + // top_p_decay [batch_size] on gpu, float, optional + // top_p_min [batch_size] on gpu, float, optional + // top_p_reset_ids [batch_size] on gpu, uint32, optional + + // output_tensors: + // output_ids [batch_size, beam_width, max_output_seq_len] + // sequence_length [batch_size, beam_width] + // output_log_probs [batch_size, beam_width, request_output_seq_len], must be float*. + // optional. It leads to additional computing cost. If we don't need this result, don't put it. + // cum_log_probs [batch_size, beam], optional, must be float*. + // optional. It leads to additional computing cost. If we don't need this result, don't put it. + + // Step is from max_input_length ~ max_output_seq_len, + // When step = k, we put output ids and caches at step k, and the sequence_length would be k - 1 before + // complete this step. + // When there is no input_ids, put the start token at step 0 of output_ids_buf_. After forward, only copy + // the step 1 ~ max_output_seq_len of output_ids_buf_ to output_tensors->at(0).data + + FT_CHECK_WITH_INFO(input_tensors->size() >= 3, "input_tensors->size() >= 3"); + FT_CHECK_WITH_INFO(output_tensors->size() >= 2, "output_tensors->size() >= 2"); + FT_CHECK(input_tensors->at("input_ids").shape.size() == 2); + FT_CHECK(input_tensors->at("input_lengths").shape.size() == 1); + FT_CHECK(input_tensors->find("output_seq_len") != input_tensors->end() + && input_tensors->at("output_seq_len").shape.size() == 1); + FT_CHECK(output_tensors->at("output_ids").shape.size() == 3); + FT_CHECK(output_tensors->at("sequence_length").shape.size() == 2); + FT_CHECK_WITH_INFO(input_tensors->at("input_ids").shape[0] == output_tensors->at("output_ids").shape[0], + "input_tensors->at(\"input_ids\").shape[0] == output_tensors->at(\"output_ids\").shape[0]"); + + const size_t batch_size = output_tensors->at("output_ids").shape[0]; + const size_t beam_width = output_tensors->at("output_ids").shape[1]; + + PromptLearningType request_prompt_type = PromptLearningType::no_prompt; + int valid_prompt_inputs = input_tensors->count("request_prompt_type") + + input_tensors->count("request_prompt_lengths") + + input_tensors->count("request_prompt_embedding"); + + if (valid_prompt_inputs == 3) { + request_prompt_type = static_cast(input_tensors->at("request_prompt_type").getVal()); + FT_LOG_INFO("Apply prompt embedding from input, will ignore task name ids"); + } + else if (valid_prompt_inputs > 0) { + FT_LOG_WARNING( + "Prompts not applied: request_prompt_embedding, request_prompt_lengths, request_prompt_type are all needed!"); + } + if (request_prompt_type == PromptLearningType::prefix_prompt) { + FT_LOG_WARNING("Request prompt doesn't support prefix prompt currently!"); + } + + // Prefix Prompt Inputs + // Padding works as follows: p p x x i i i x x --> p p i i i x x x x (p denotes prompt, i denotes input, x denotes + // pad) + // TODO (perkzz): move unnecessary paddings + const int* prompt_learning_task_name_ids = + input_tensors->count("prompt_learning_task_name_ids") ? + input_tensors->at("prompt_learning_task_name_ids").getPtr() : + nullptr; + has_prefix_prompt_ = + (prompt_learning_task_name_ids != nullptr) && (prompt_learning_type_ == PromptLearningType::prefix_prompt); + int max_prefix_prompt_length = 0; + + FT_CHECK_WITH_INFO( + !(prompt_learning_task_name_ids != nullptr + && (prompt_learning_type_ == PromptLearningType::no_prompt + || prompt_learning_type_ == PromptLearningType::soft_prompt)), + "prompt_learning_type is prefix_prompt either p_prompt_tuning when prompt_learning_task_name_ids are provided."); + + // NOTE: Prefix Prompt PreProcessing + // get prefix_prompt_weight for each batch --> shape [batch, beam_width] + // --> ptrs with shape [num_layers, 2, num_heads, perfix_seq_len, size_per_head] + std::vector prefix_prompt_weight_batch_ptrs; + std::vector prefix_prompt_lengths; + if (has_prefix_prompt_) { + for (int bs_id = 0; bs_id < batch_size; ++bs_id) { + int task_id = prompt_learning_task_name_ids[bs_id]; + // throw errors when prompt task_name_ids are not found + std::pair prefix_prompt_weight_length_pair; + try { + prefix_prompt_weight_length_pair = gpt_weights->prompt_learning_table.at(task_id); + } + catch (const std::out_of_range& oor) { + FT_LOG_ERROR("prefix_prompt_weights_lengths not found for prompt task id: " + task_id); + throw oor; + } + for (int bw_id = 0; bw_id < beam_width; ++bw_id) { + prefix_prompt_weight_batch_ptrs.push_back(prefix_prompt_weight_length_pair.first); + prefix_prompt_lengths.push_back(prefix_prompt_weight_length_pair.second); + } + } + + max_prefix_prompt_length = *max_element(prefix_prompt_lengths.begin(), prefix_prompt_lengths.end()); + + FT_LOG_DEBUG("max_prefix_prompt_length: %d", max_prefix_prompt_length); + + if (max_prefix_prompt_length == 0) { + has_prefix_prompt_ = false; + FT_LOG_DEBUG("prompts are not applied !"); + } + } + + int max_input_length = input_tensors->at("input_ids").shape[1]; + FT_CHECK_WITH_INFO(!(max_input_length == 0 && max_prefix_prompt_length > 0), + "Prefix Prompt should come with inputs!"); + + // Prefix Soft Prompt + has_prefix_soft_prompt_ = request_prompt_type == PromptLearningType::soft_prompt; + const size_t max_prefix_soft_prompt_length = + has_prefix_soft_prompt_ ? input_tensors->at("request_prompt_embedding").shape[1] : 0; + const size_t limit_len_offset = max_prefix_soft_prompt_length + (max_input_length == 0 ? 1 : 0); + const size_t max_output_seq_len = input_tensors->at("output_seq_len").max() + limit_len_offset; + const size_t max_seq_len = max_output_seq_len; + // max cache seq len should include max prefix prompt length as it has k/v states + const size_t max_cache_seq_len = max_output_seq_len + max_prefix_prompt_length; + if (max_cache_seq_len < max_seq_len) { + FT_LOG_WARNING("max_cache_seq_len (%d) is less than max_seq_len (%d). " + "Note that this reduces the memory cost of k/v cache, but may hurt the accuracy.", + max_cache_seq_len, + max_seq_len); + } + else if (max_cache_seq_len > max_seq_len) { + FT_LOG_WARNING("max_cache_seq_len (%d) is larger than max_seq_len (%d). " + "This may lead to additional memory cost. Suggest to use smaller max_cache_seq_len.", + max_cache_seq_len, + max_seq_len); + } + const cudaDataType_t gemm_data_type = getCudaDataType(); + allocateBuffer( + batch_size, beam_width, max_seq_len, max_cache_seq_len, max_input_length + max_prefix_soft_prompt_length); + setSeqLimitLen(seq_limit_len_, input_tensors->at("output_seq_len"), limit_len_offset, batch_size); + + sync_check_cuda_error(); + { + TensorMap input_map(*input_tensors); + dynamic_decode_layer_->setup(batch_size, beam_width, &input_map); + handleOptArg(&input_map, "start_id", start_ids_buf_, start_id_, batch_size); + handleOptArg(&input_map, "end_id", end_ids_buf_, end_id_, batch_size); + } + + const DataType data_type = getTensorType(); + + const std::vector self_k_cache_shape = {num_layer_ / pipeline_para_.world_size_, + batch_size * beam_width, + local_head_num_, + size_per_head_ / (16 / sizeof(T)), + max_cache_seq_len, + 16 / sizeof(T)}; + const std::vector self_v_cache_shape = {num_layer_ / pipeline_para_.world_size_, + batch_size * beam_width, + local_head_num_, + max_cache_seq_len, + size_per_head_}; + + // initialize the output ids and parent ids + cudaMemsetAsync(output_ids_buf_, 0, sizeof(int) * batch_size * beam_width * max_seq_len, stream_); + cudaMemsetAsync(parent_ids_buf_, 0, sizeof(int) * batch_size * beam_width * max_seq_len, stream_); + cudaMemsetAsync(masked_tokens_, false, sizeof(bool) * batch_size * beam_width * max_cache_seq_len, stream_); + cudaMemsetAsync(tiled_total_padding_count_, 0, sizeof(int) * batch_size * beam_width, stream_); + if (beam_width > 1) { + cudaMemsetAsync(cache_indirections_[0], 0, 2 * sizeof(int) * batch_size * beam_width * max_seq_len, stream_); + } + + int compact_size; + bool use_shared_contexts = (shared_contexts_ratio_ > 0.0f) && (max_input_length >= 1) && (batch_size > 1); + if (use_shared_contexts) { + invokeFindContextDups(shared_contexts_idx_, + batch_to_compact_idx_, + compact_idx_, + compact_size_, + input_tensors->at("input_ids").getPtr(), + batch_size, + beam_width, + max_input_length, + stream_); + cudaD2Hcpy(&compact_size, compact_size_, 1); + use_shared_contexts = compact_size <= shared_contexts_ratio_ * batch_size; + sync_check_cuda_error(); + } + + // Prefix prompts + if (has_prefix_prompt_) { + cudaMemcpyAsync(prompt_learning_weight_batch_, + prefix_prompt_weight_batch_ptrs.data(), + sizeof(T*) * batch_size * beam_width, + cudaMemcpyDefault, + stream_); + cudaMemcpyAsync(tiled_prompt_lengths_buf_, + prefix_prompt_lengths.data(), + sizeof(int) * batch_size * beam_width, + cudaMemcpyDefault, + stream_); + } + + sync_check_cuda_error(); + + // handle first step + if (has_prefix_prompt_ || has_prefix_soft_prompt_ || max_input_length > 1) { + invokeTileGptInputs(tiled_input_ids_buf_, + tiled_input_lengths_buf_, + input_tensors->at("input_ids").getPtr(), + input_tensors->at("input_lengths").getPtr(), + batch_size, + beam_width, + max_input_length, + stream_); + sync_check_cuda_error(); + + if (has_prefix_soft_prompt_) { + inputIdsEmbeddingLookupPosEncodingSoftPromptParam param; + param.from_tensor = context_decoder_input_buf_; + param.output_ids = output_ids_buf_; + param.input_lengths = tiled_input_lengths_buf_; + param.embedding_table = gpt_weights->pre_decoder_embedding_table; + param.pos_table = gpt_weights->position_encoding_table; + param.prefix_soft_prompt_embedding = input_tensors->at("request_prompt_embedding").getPtr(); + param.prefix_soft_prompt_lengths = input_tensors->at("request_prompt_lengths").getPtr(); + param.input_ids = tiled_input_ids_buf_; + param.start_step = 1; + param.max_input_length = max_input_length; + param.max_prefix_soft_prompt_length = max_prefix_soft_prompt_length; + param.batch_size = batch_size; + param.beam_width = beam_width; + param.hidden_units = hidden_units_; + param.stream = stream_; + + invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(param); + sync_check_cuda_error(); + max_input_length += max_prefix_soft_prompt_length; // view soft_prompt as input + } + else { + invokeInputIdsEmbeddingLookupPosEncoding(context_decoder_input_buf_, + output_ids_buf_, + gpt_weights->pre_decoder_embedding_table, + gpt_weights->position_encoding_table, + pPromptTuningParam{}, // no p/prompt tuning + tiled_input_ids_buf_, + 1, + max_input_length, + max_input_length, + batch_size * beam_width, + hidden_units_, + stream_); + sync_check_cuda_error(); + } + + invokeBuildDecoderAttentionMask(input_attention_mask_, + tiled_input_lengths_buf_, + tiled_prompt_lengths_buf_, + batch_size * beam_width, + max_input_length, + max_prefix_prompt_length, + stream_); + sync_check_cuda_error(); + + std::unordered_map decoder_input_tensors{ + {"decoder_input", + Tensor{MEMORY_GPU, + data_type, + {batch_size * beam_width, (size_t)max_input_length, hidden_units_}, + context_decoder_input_buf_}}, + {"attention_mask", + Tensor{MEMORY_GPU, + data_type, + {batch_size * beam_width, + 1, + (size_t)max_input_length, + (size_t)(max_input_length + max_prefix_prompt_length)}, + input_attention_mask_}}, + {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, tiled_input_lengths_buf_}}, + {"d_prefix_prompt_batch", + Tensor{MEMORY_GPU, + data_type, + {batch_size * beam_width}, + has_prefix_prompt_ ? prompt_learning_weight_batch_ : nullptr}}, + {"d_prefix_prompt_lengths", + Tensor{MEMORY_GPU, + TYPE_INT32, + {batch_size * beam_width}, + has_prefix_prompt_ ? tiled_prompt_lengths_buf_ : nullptr}}}; + + if (use_shared_contexts) { + decoder_input_tensors.insert( + {"compact_idx", Tensor(MEMORY_GPU, TYPE_INT32, {(size_t)compact_size}, compact_idx_)}); + decoder_input_tensors.insert( + {"batch_to_compact_idx", + Tensor(MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, batch_to_compact_idx_)}); + } + + std::unordered_map decoder_output_tensors{ + {"decoder_output", + Tensor{MEMORY_GPU, + data_type, + {batch_size * beam_width, (size_t)max_input_length, hidden_units_}, + context_decoder_output_buf_}}, + {"key_cache", Tensor{MEMORY_GPU, data_type, self_k_cache_shape, key_cache_}}, + {"value_cache", Tensor{MEMORY_GPU, data_type, self_v_cache_shape, value_cache_}}, + {"last_token_hidden_units", + Tensor{MEMORY_GPU, data_type, {batch_size * beam_width, hidden_units_}, decoder_output_buf_}}}; + + gpt_context_decoder_->forward( + &decoder_output_tensors, &decoder_input_tensors, &gpt_weights->decoder_layer_weights); + sync_check_cuda_error(); + printf("gpt_context_decoder_->forward done\n"); + invokeDecodingInitialize(finished_buf_, + sequence_lengths_, + nullptr, + cum_log_probs_, + start_ids_buf_, + batch_size, + beam_width, + max_input_length - 1, + stream_); + sync_check_cuda_error(); + } + else if (max_input_length == 0) { + FT_CHECK(prompt_learning_type_ == PromptLearningType::no_prompt + && request_prompt_type == PromptLearningType::no_prompt); // Not support prompts in this case + max_input_length++; + invokeDecodingInitialize(finished_buf_, + sequence_lengths_, + output_ids_buf_, + cum_log_probs_, + start_ids_buf_, + batch_size, + beam_width, + max_input_length - 1, + stream_); + std::vector h_input_lengths(batch_size * beam_width, 1); + cudaMemcpyAsync(tiled_input_lengths_buf_, + h_input_lengths.data(), + sizeof(int) * batch_size * beam_width, + cudaMemcpyHostToDevice, + stream_); + sync_check_cuda_error(); + } + else if (max_input_length == 1) { + FT_CHECK(prompt_learning_type_ == PromptLearningType::no_prompt + && request_prompt_type == PromptLearningType::no_prompt); // Not support prompts in this case + invokeDecodingInitialize(finished_buf_, + sequence_lengths_, + nullptr, + cum_log_probs_, + start_ids_buf_, + batch_size, + beam_width, + max_input_length - 1, + stream_); + sync_check_cuda_error(); + invokeTileGptInputs(tiled_input_ids_buf_, + tiled_input_lengths_buf_, + input_tensors->at("input_ids").getPtr(), + input_tensors->at("input_lengths").getPtr(), + batch_size, + beam_width, + max_input_length, + stream_); + sync_check_cuda_error(); + + cudaMemcpyAsync(output_ids_buf_, + tiled_input_ids_buf_, + sizeof(int) * batch_size * beam_width, + cudaMemcpyDeviceToDevice, + stream_); + } + + if (vocab_size_ == vocab_size_padded_) { + padded_embedding_kernel_ptr_ = gpt_weights->post_decoder_embedding.kernel; + } + else { + cudaMemcpyAsync(padded_embedding_kernel_, + gpt_weights->post_decoder_embedding.kernel, + sizeof(T) * vocab_size_ * hidden_units_, + cudaMemcpyDeviceToDevice, + stream_); + sync_check_cuda_error(); + } + + invokeMaskPaddingTokens(masked_tokens_, + input_tensors->at("input_lengths").getPtr(), // not_tiled + tiled_prompt_lengths_buf_, + max_cache_seq_len, + max_input_length + max_prefix_prompt_length, + 0, + batch_size, + beam_width, + stream_); + + for (int step = max_input_length; step < (int)max_output_seq_len; step++) { + const int src_indir_idx = (step - max_input_length) % 2; + const int tgt_indir_idx = 1 - src_indir_idx; + + const size_t local_batch_size = getLocalBatchSize(batch_size, 1, pipeline_para_.world_size_); + FT_CHECK(batch_size % local_batch_size == 0); + const size_t iteration_num = batch_size / local_batch_size; + *generation_should_stop_ = true; + + for (uint ite = 0; ite < iteration_num; ++ite) { + const int id_offset = ite * local_batch_size * beam_width; + const int hidden_units_offset = id_offset * hidden_units_; + const int vocab_size_units_offset = id_offset * vocab_size_padded_; + + if (!(max_input_length > 1 && step == max_input_length)) { + if (pipeline_para_.rank_ == 0) { + invokeEmbeddingLookupPosEncodingPadCount(decoder_input_buf_ + hidden_units_offset, + gpt_weights->pre_decoder_embedding_table, + gpt_weights->position_encoding_table, + output_ids_buf_ + id_offset, + tiled_total_padding_count_ + id_offset, + local_batch_size * beam_width, + hidden_units_, + (T)(1.0f), + step - 1, + batch_size * beam_width, + 0, + stream_); + sync_check_cuda_error(); + } + std::unordered_map decoder_input_tensors{ + {"decoder_input", + Tensor{MEMORY_GPU, + data_type, + {local_batch_size * beam_width, hidden_units_}, + decoder_input_buf_ + hidden_units_offset}}, + {"finished", + Tensor{MEMORY_GPU, TYPE_BOOL, {local_batch_size * beam_width}, finished_buf_ + id_offset}}, + {"sequence_lengths", + Tensor{MEMORY_GPU, TYPE_INT32, {local_batch_size * beam_width}, sequence_lengths_ + id_offset}}, + {"total_padding_tokens", + Tensor{MEMORY_GPU, + TYPE_INT32, + {local_batch_size * beam_width}, + tiled_total_padding_count_ + id_offset}}, + {"d_prefix_prompt_lengths", + Tensor{MEMORY_GPU, + TYPE_INT32, + {local_batch_size}, + has_prefix_prompt_ ? (tiled_prompt_lengths_buf_ + id_offset) : nullptr}}, + {"max_prefix_prompt_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_prefix_prompt_length}}, + {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_length}}, + {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}}, + {"ite", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &ite}}, + {"cache_indirection", + Tensor{MEMORY_GPU, + TYPE_INT32, + {local_batch_size, beam_width, max_output_seq_len}, + beam_width > 1 ? cache_indirections_[src_indir_idx] + id_offset * max_output_seq_len : + nullptr}}, + {"masked_tokens", + Tensor{MEMORY_GPU, + TYPE_BOOL, + {local_batch_size * beam_width, max_cache_seq_len}, + masked_tokens_ + id_offset * max_cache_seq_len}}}; + std::unordered_map decoder_output_tensors{ + {"decoder_output", + Tensor{MEMORY_GPU, + data_type, + {local_batch_size * beam_width, hidden_units_}, + decoder_output_buf_ + hidden_units_offset}}, + {"key_cache", Tensor{MEMORY_GPU, data_type, self_k_cache_shape, key_cache_}}, + {"value_cache", Tensor{MEMORY_GPU, data_type, self_v_cache_shape, value_cache_}}}; + gpt_decoder_->forward( + &decoder_output_tensors, &decoder_input_tensors, &gpt_weights->decoder_layer_weights); + } + + if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { + invokeGeneralT5LayerNorm(normed_decoder_output_buf_ + hidden_units_offset, + decoder_output_buf_ + hidden_units_offset, + gpt_weights->post_decoder_layernorm.gamma, + (const T*)nullptr, + layernorm_eps_, + local_batch_size * beam_width, + hidden_units_, + stream_); + sync_check_cuda_error(); + + + if (tensor_para_.world_size_ == 1) { + float alpha = 1.0f; + float beta = 0.0f; + cublas_wrapper_->Gemm(CUBLAS_OP_T, + CUBLAS_OP_N, + vocab_size_padded_, // n + local_batch_size * beam_width, + hidden_units_, // k + &alpha, + padded_embedding_kernel_ptr_, + gemm_data_type, + hidden_units_, // k + normed_decoder_output_buf_ + hidden_units_offset, + gemm_data_type, + hidden_units_, // k + &beta, + logits_buf_ + vocab_size_units_offset, + CUDA_R_32F, + vocab_size_padded_, /* n */ + CUDA_R_32F, + cublasGemmAlgo_t(-1)); + } + else { + FT_CHECK(vocab_size_padded_ % tensor_para_.world_size_ == 0); + const int local_vocab_size = vocab_size_padded_ / tensor_para_.world_size_; + float alpha = 1.0f; + float beta = 0.0f; + cublas_wrapper_->Gemm(CUBLAS_OP_T, + CUBLAS_OP_N, + local_vocab_size, // n + local_batch_size * beam_width, + hidden_units_, // k + &alpha, + padded_embedding_kernel_ptr_ + + tensor_para_.rank_ * local_vocab_size * hidden_units_, + gemm_data_type, + hidden_units_, // k + normed_decoder_output_buf_ + hidden_units_offset, + gemm_data_type, + hidden_units_, // k + &beta, + nccl_logits_buf_ + vocab_size_units_offset + + tensor_para_.rank_ * local_batch_size * beam_width * local_vocab_size, + CUDA_R_32F, + local_vocab_size, /* n */ + CUDA_R_32F, + cublasGemmAlgo_t(-1)); + + + ftNcclAllGather(nccl_logits_buf_ + vocab_size_units_offset, + nccl_logits_buf_ + vocab_size_units_offset, + local_batch_size * beam_width * local_vocab_size, + tensor_para_.rank_, + tensor_para_, + stream_); + invokeTransposeAxis01(logits_buf_ + vocab_size_units_offset, + nccl_logits_buf_ + vocab_size_units_offset, + tensor_para_.world_size_, + local_batch_size * beam_width, + local_vocab_size, + stream_); + } + + + int tmp_local_batch_size = local_batch_size; + bool is_initialize_random_table = step == max_input_length; + std::unordered_map dynamic_decode_input_tensors{ + {"logits", + Tensor{MEMORY_GPU, TYPE_FP32, {batch_size, beam_width, vocab_size_padded_}, logits_buf_}}, + // {"embedding_bias", Tensor{MEMORY_GPU, data_type, {vocab_size_padded_}, nullptr}}, + {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}}, + {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_length}}, + {"input_lengths", + Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, tiled_input_lengths_buf_}}, + {"sequence_limit_length", Tensor{MEMORY_GPU, TYPE_UINT32, {batch_size}, seq_limit_len_}}, + {"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}}, + {"src_cache_indirection", + Tensor{MEMORY_GPU, + TYPE_INT32, + {local_batch_size, beam_width, max_output_seq_len}, + cache_indirections_[src_indir_idx] + id_offset * max_output_seq_len}}, + {"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &tmp_local_batch_size}}, + {"end_id", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids_buf_}}, + {"is_initialize_random_table", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_initialize_random_table}}}; + + for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { + if (dynamic_decode_input_tensors.find(t->first) == dynamic_decode_input_tensors.end()) { + dynamic_decode_input_tensors.insert(*t); + } + } + + // common outputs + bool subbatch_should_stop = false; + std::unordered_map dynamic_decode_output_tensors{ + {"output_ids", + Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, output_ids_buf_}}, + {"finished", Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width}, finished_buf_}}, + // cum_log_probs is necessary for beam search, while it is optional for sampling. + {"cum_log_probs", + Tensor{MEMORY_GPU, + TYPE_FP32, + {batch_size * beam_width}, + ((beam_width > 1) || (output_tensors->count("cum_log_probs") > 0)) ? cum_log_probs_ : + nullptr}}, + {"output_log_probs", + Tensor{MEMORY_GPU, + TYPE_FP32, + {max_seq_len, batch_size, beam_width}, + output_tensors->count("output_log_probs") > 0 + && output_tensors->at("output_log_probs").data != nullptr ? + output_log_probs_buf_ : + nullptr}}, + {"parent_ids", + Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, parent_ids_buf_}}, + {"sequence_length", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, sequence_lengths_}}, + {"tgt_cache_indirection", + Tensor{MEMORY_GPU, + TYPE_INT32, + {local_batch_size, beam_width, max_output_seq_len}, + cache_indirections_[tgt_indir_idx] + id_offset * max_output_seq_len}}, + {"should_stop", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &subbatch_should_stop}}}; + + for (auto t = output_tensors->begin(); t != output_tensors->end(); ++t) { + // Handle exceptions. + if (t->first == "cum_log_probs" || t->first == "output_log_probs") { + continue; + } + dynamic_decode_output_tensors.insert(*t); + } + + dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); + *generation_should_stop_ &= subbatch_should_stop; + } + } + + if (pipeline_para_.world_size_ > 1) { + ftNcclGroupStart(); + ftNcclBroadCast(output_ids_buf_ + step * batch_size * beam_width, + batch_size * beam_width, + pipeline_para_.world_size_ - 1, + pipeline_para_, + stream_); + + ftNcclBroadCast( + sequence_lengths_, batch_size * beam_width, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); + + ftNcclBroadCast(generation_should_stop_, 1, pipeline_para_.world_size_ - 1, pipeline_para_, stream_); + + if (beam_width > 1) { + ftNcclBroadCast(cache_indirections_[tgt_indir_idx], + batch_size * beam_width * max_output_seq_len, + pipeline_para_.world_size_ - 1, + pipeline_para_, + stream_); + } + ftNcclGroupEnd(); + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); + sync_check_cuda_error(); + } + + if (*generation_should_stop_) { + break; + } + if (token_generated_cb_ && (step + 1) % token_generated_cb_step_ == 0 && step + 1 < (int)max_output_seq_len) { + setOutputTensors(output_tensors, input_tensors, max_input_length, max_output_seq_len); + sendTensorsToFirstPipelineNode(output_tensors, input_tensors); + + if (pipeline_para_.rank_ == 0 && tensor_para_.rank_ == 0) { + token_generated_cb_(output_tensors, token_generated_ctx_); + } + } + if (step == max_input_length) { + /* We have just finished processing input: update the padding count: + * total_padding_count += (max_input_length - input_lengths) + * if has prefix prompts, += (max_prefix_prompt_length - prompt_length) + */ + invokeUpdatePaddingCount(tiled_total_padding_count_, + input_tensors->at("input_lengths").getPtr(), // not_tiled + has_prefix_prompt_ ? tiled_prompt_lengths_buf_ : (const int*)nullptr, + max_input_length, + has_prefix_prompt_ ? max_prefix_prompt_length : 0, + batch_size, + beam_width, + stream_); + } + } + + setOutputTensors(output_tensors, input_tensors, max_input_length, max_output_seq_len); + sendTensorsToFirstPipelineNode(output_tensors, input_tensors); +} + +template +void Llama::sendTensorsToFirstPipelineNode(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (pipeline_para_.world_size_ == 1) { + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); + return; + } + + const auto pp_rank = pipeline_para_.rank_; + + ftNcclGroupStart(); + for (auto const& it : *output_tensors) { + if (it.second.data == nullptr) { + continue; + } + + if (pp_rank == pipeline_para_.world_size_ - 1) { + ftNcclSend(it.second.getPtr(), it.second.sizeBytes(), 0, pipeline_para_, stream_); + } + else if (pp_rank == 0) { + ftNcclRecv(it.second.getPtr(), + it.second.sizeBytes(), + pipeline_para_.world_size_ - 1, + pipeline_para_, + stream_); + } + } + ftNcclGroupEnd(); + // throw errors when detected + ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); +} + +template +void Llama::setOutputTensors(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const size_t max_input_length, + const size_t max_output_seq_len) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (pipeline_para_.rank_ != pipeline_para_.world_size_ - 1) { + return; + } + + const size_t batch_size = output_tensors->at("output_ids").shape[0]; + const size_t beam_width = output_tensors->at("output_ids").shape[1]; + int* sequence_lengths = output_tensors->at("sequence_length").getPtr(); + const size_t max_prefix_soft_prompt_length = + has_prefix_soft_prompt_ ? input_tensors->at("request_prompt_embedding").shape[1] : 0; + + cudaAutoCpy(sequence_lengths, sequence_lengths_, output_tensors->at("sequence_length").size(), stream_); + if (input_tensors->at("input_ids").shape[1] == 0) { + // TODO: D2D sequence_lenghts + if (beam_width > 1) { + // For beam search, do gather_tree + // take output_parent_ids as inter buffer + invokeGatherTree(transposed_output_ids_buf_, + sequence_lengths, + max_output_seq_len, + batch_size, + beam_width, + output_ids_buf_ + batch_size * beam_width, + parent_ids_buf_ + batch_size * beam_width, + end_ids_buf_, + stream_); + + // transpose and take output_parent_ids as inter buffer + invokeTransposeAxis01(output_tensors->at("output_ids").getPtr(), + transposed_output_ids_buf_, + max_output_seq_len - 1, + batch_size * beam_width, + 1, + stream_); + } + else { + // For sampling, only copy the results to output_tensor + invokeTransposeAxis01(output_tensors->at("output_ids").getPtr(), + output_ids_buf_ + batch_size * beam_width, + max_output_seq_len - 1, + batch_size * beam_width, + 1, + stream_); + } + } + else { + + // For sampling, it is equivalent to all parent ids are 0. + gatherTreeParam param; + param.beams = transposed_output_ids_buf_; + param.max_sequence_lengths = sequence_lengths; + // add sequence_length 1 here because the sequence_length of time step t is t - 1 + param.max_sequence_length_final_step = 1; + param.max_time = max_output_seq_len; + param.batch_size = batch_size; + param.beam_width = beam_width; + param.step_ids = output_ids_buf_; + param.parent_ids = beam_width == 1 ? nullptr : parent_ids_buf_; + param.end_tokens = end_ids_buf_; + param.max_input_length = max_input_length; + param.prefix_soft_prompt_lengths = + has_prefix_soft_prompt_ ? input_tensors->at("request_prompt_lengths").getPtr() : nullptr; + param.input_lengths = tiled_input_lengths_buf_; + param.max_prefix_soft_prompt_length = max_prefix_soft_prompt_length; + param.max_input_without_prompt_length = max_input_length; + param.stream = stream_; + param.output_ids = output_tensors->at("output_ids").getPtr(); + invokeGatherTree(param); + sync_check_cuda_error(); + } + if ((output_tensors->count("output_log_probs") > 0 && output_tensors->at("output_log_probs").data != nullptr)) { + invokeTransposeAxis01(output_tensors->at("output_log_probs").getPtr(), + output_log_probs_buf_, + input_tensors->at("output_seq_len").max() - max_input_length, + batch_size * beam_width, + 1, + stream_); + } + // Return the cumulative log probability if requested. + if (output_tensors->count("cum_log_probs") > 0) { + Tensor cum_log_probs = output_tensors->at("cum_log_probs"); + FT_CHECK_WITH_INFO(cum_log_probs.size() == batch_size * beam_width, + "The shape of cum_log_probs does not match with batch_size x beam_width."); + cudaAutoCpy(cum_log_probs.getPtr(), cum_log_probs_, cum_log_probs.size(), stream_); + } +} + +template +size_t Llama::getPipelineParallelRank() +{ + return pipeline_para_.rank_; +} + +template +size_t Llama::getPipelineParallelSize() +{ + return pipeline_para_.world_size_; +} + +template +size_t Llama::getTensorParallelRank() +{ + return tensor_para_.rank_; +} + +template +size_t Llama::getTensorParallelSize() +{ + return tensor_para_.world_size_; +} + +template +bool* Llama::getFinishBuffer() +{ + return finished_buf_; +} + +template class Llama; +template class Llama; +#ifdef ENABLE_BF16 +template class Llama<__nv_bfloat16>; +#endif +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/Llama.h b/src/fastertransformer/models/llama/Llama.h new file mode 100644 index 000000000..abec1659e --- /dev/null +++ b/src/fastertransformer/models/llama/Llama.h @@ -0,0 +1,237 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "src/fastertransformer/layers/DynamicDecodeLayer.h" +#include "src/fastertransformer/models/llama/LlamaContextDecoder.h" +#include "src/fastertransformer/models/llama/LlamaDecoder.h" +#include "src/fastertransformer/models/llama/LlamaWeight.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/prompt_learning.h" + +namespace fastertransformer { + +template +class Llama: public BaseLayer { +private: + // meta data + size_t head_num_; + size_t kv_head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t vocab_size_; + size_t rotary_embedding_dim_; + float layernorm_eps_; + + static constexpr bool neox_rotary_style_ = true; + float shared_contexts_ratio_; + + int start_id_; + int end_id_; + size_t hidden_units_; + + size_t local_head_num_; + size_t local_kv_head_num_; + NcclParam tensor_para_; + NcclParam pipeline_para_; + + std::shared_ptr custom_all_reduce_comm_; + int enable_custom_all_reduce_; + + AttentionType attention_type_; + const int int8_mode_ = 0; + + size_t vocab_size_padded_; + const bool is_context_qk_buf_float_ = + (std::getenv("CONTEXT_ATTENTION_BMM1_HALF_ACCUM") == nullptr || + std::string(std::getenv("CONTEXT_ATTENTION_BMM1_HALF_ACCUM")) != "ON"); + + // Residual Type + const bool use_gptj_residual_ = false; + + // Prompt Learning Parameters + PromptLearningType prompt_learning_type_; + int prompt_learning_start_id_; // start_id for prompt_learning (only needed by prefix prompts) + bool has_prefix_prompt_; + bool has_prefix_soft_prompt_; + + LlamaDecoder* gpt_decoder_; + LlamaContextDecoder* gpt_context_decoder_; + DynamicDecodeLayer* dynamic_decode_layer_; + + void allocateBuffer() override; + void allocateBuffer( + size_t batch_size, size_t beam_width, size_t max_seq_len, size_t max_cache_seq_len, size_t max_input_len); + void freeBuffer() override; + + void initialize(); + +protected: + T* padded_embedding_kernel_; + const T* padded_embedding_kernel_ptr_; + + T* input_attention_mask_; + + T* decoder_input_buf_; + T* decoder_output_buf_; + T* normed_decoder_output_buf_; + + float* logits_buf_; + float* nccl_logits_buf_; + float* cum_log_probs_; + + bool* finished_buf_; + bool* h_finished_buf_; + int* sequence_lengths_ = nullptr; + int* tiled_total_padding_count_ = nullptr; + uint32_t* seq_limit_len_ = nullptr; + + T* key_cache_; + T* value_cache_; + int* cache_indirections_[2] = {nullptr, nullptr}; + + // prompt_learning weight_batch ptrs + const T** prompt_learning_weight_batch_; + int* tiled_prompt_lengths_buf_; // only needed by prefix prompts + + int* tiled_input_ids_buf_; + int* tiled_input_lengths_buf_; + int* transposed_output_ids_buf_; + int* output_ids_buf_; + int* parent_ids_buf_; + int* start_ids_buf_; + int* end_ids_buf_; + bool* masked_tokens_ = nullptr; + + bool* generation_should_stop_ = nullptr; + + int* shared_contexts_idx_ = nullptr; + int* compact_idx_ = nullptr; + int* batch_to_compact_idx_ = nullptr; + int* compact_size_ = nullptr; + + T* context_decoder_input_buf_; + T* context_decoder_output_buf_; + float* output_log_probs_buf_; + + // function pointer callback + using callback_sig = void(std::unordered_map*, void*); + callback_sig* token_generated_cb_ = nullptr; + void* token_generated_ctx_ = nullptr; + + // callback step + size_t token_generated_cb_step_ = 5; // default 5, override by env LLAMA_STREAM_CB_STEP + + void setOutputTensors(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const size_t max_input_length, + const size_t max_seq_len); + void sendTensorsToFirstPipelineNode(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors); + +public: + Llama(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + float layernorm_eps, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + bool use_gptj_residual, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop = nullptr, + AttentionType attention_type = AttentionType::UNFUSED_MHA, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0, + float shared_contexts_ratio = 1.0f); + + Llama(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + float layernorm_eps, + int start_id, + int end_id, + int prompt_learning_start_id, // only needed by p/prompt-tuning + PromptLearningType prompt_learning_type, + bool use_gptj_residual, + float beam_search_diversity_rate, + size_t top_k, + float top_p, + unsigned long long random_seed, + float temperature, + float len_penalty, + float repetition_penalty, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop = nullptr, + AttentionType attention_type = AttentionType::UNFUSED_MHA, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce = 0, + float shared_contexts_ratio = 1.0f); + + Llama(Llama const& Llama); + + ~Llama(); + + void forward(std::vector* output_tensors, + const std::vector* input_tensors, + const LlamaWeight* gpt_weights); + + void forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const LlamaWeight* gpt_weights); + + size_t getPipelineParallelRank(); + size_t getPipelineParallelSize(); + size_t getTensorParallelRank(); + size_t getTensorParallelSize(); + bool* getFinishBuffer(); + + void registerCallback(callback_sig* fn, void* ctx); + void unRegisterCallback(); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.cc b/src/fastertransformer/models/llama/LlamaContextDecoder.cc new file mode 100644 index 000000000..8082a2f13 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.cc @@ -0,0 +1,646 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/models/llama/LlamaContextDecoder.h" +#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" +#include "src/fastertransformer/kernels/gpt_kernels.h" + +#include "src/fastertransformer/layers/TensorParallelSiluFfnLayer.h" +#include "src/fastertransformer/layers/attention_layers/TensorParallelLlamaContextAttentionLayer.h" + +namespace fastertransformer { + +template +void LlamaContextDecoder::initialize() +{ + self_attention_layer_ = new TensorParallelLlamaContextAttentionLayer(0, // max_batch_size + 0, // max_seq_len + head_num_, + kv_head_num_, + size_per_head_, + rotary_embedding_dim_, + neox_rotary_style_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + !use_gptj_residual_, + is_free_buffer_after_forward_, + is_qk_buf_float_, + false, + int8_mode_, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + + + ffn_layer_ = new TensorParallelSiluFfnLayer(0, // max_batch_size + 0, // max_seq_len + head_num_, + size_per_head_, + 0, // expert_num + inter_size_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + !use_gptj_residual_, + is_free_buffer_after_forward_, + false, + true, // use_gated_activation = true; + custom_all_reduce_comm_, + enable_custom_all_reduce_, + int8_mode_); +} + +template +void LlamaContextDecoder::allocateBuffer() +{ + FT_CHECK(false); +} + +template +void LlamaContextDecoder::allocateBuffer(size_t batch_size, size_t seq_len, bool use_shared_contexts) +{ + decoder_normed_input_ = reinterpret_cast( + allocator_->reMalloc(decoder_normed_input_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + self_attn_output_ = reinterpret_cast( + allocator_->reMalloc(self_attn_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + ffn_output_ = reinterpret_cast( + allocator_->reMalloc(ffn_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + decoder_layer_output_ = reinterpret_cast( + allocator_->reMalloc(decoder_layer_output_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true); + padding_offset_ = + reinterpret_cast(allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * seq_len, false)); + cu_seqlens_ = reinterpret_cast(allocator_->reMalloc(cu_seqlens_, sizeof(int) * (batch_size + 1), false)); + + if (use_shared_contexts) { + compact_decoder_features_ = reinterpret_cast( + allocator_->reMalloc(compact_decoder_features_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + compact_attention_mask_ = reinterpret_cast( + allocator_->reMalloc(compact_attention_mask_, sizeof(T) * batch_size * seq_len * seq_len, false)); + compact_input_lengths_ = + reinterpret_cast(allocator_->reMalloc(compact_input_lengths_, sizeof(int) * batch_size, false)); + k_cache_layer_ = reinterpret_cast( + allocator_->reMalloc(k_cache_layer_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + v_cache_layer_ = reinterpret_cast( + allocator_->reMalloc(v_cache_layer_, sizeof(T) * batch_size * seq_len * hidden_units_, false)); + } + + is_allocate_buffer_ = true; +} + +template +void LlamaContextDecoder::freeBuffer() +{ + if (is_allocate_buffer_ == true) { + allocator_->free((void**)(&decoder_normed_input_)); + allocator_->free((void**)(&self_attn_output_)); + allocator_->free((void**)(&ffn_output_)); + allocator_->free((void**)(&decoder_layer_output_)); + allocator_->free((void**)(&h_pinned_token_num_ptr_), true); + allocator_->free((void**)(&padding_offset_)); + allocator_->free((void**)(&cu_seqlens_)); + if (compact_decoder_features_ != nullptr) { + allocator_->free((void**)(&compact_decoder_features_)); + allocator_->free((void**)(&compact_attention_mask_)); + allocator_->free((void**)(&compact_input_lengths_)); + allocator_->free((void**)(&k_cache_layer_)); + allocator_->free((void**)(&v_cache_layer_)); + } + is_allocate_buffer_ = false; + } +} + +template +bool LlamaContextDecoder::isValidLayerParallelId(uint l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l >= local_num_layer * pipeline_para_.rank_) + && (l < local_num_layer * (pipeline_para_.rank_ + 1)); +} + +template +bool LlamaContextDecoder::isFirstLayerParallelId(uint l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l == local_num_layer * pipeline_para_.rank_); +} + +template +bool LlamaContextDecoder::isLastLayerParallelId(uint l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l == local_num_layer * (pipeline_para_.rank_ + 1) - 1); +} + +template +int LlamaContextDecoder::getFirstLayerParallelId() +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return local_num_layer * pipeline_para_.rank_; +} + +template +LlamaContextDecoder::LlamaContextDecoder(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + bool neox_rotary_style, + bool use_gptj_residual, + float layernorm_eps, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + AttentionType attention_type, + int int8_mode, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + inter_size_(inter_size), + num_layer_(num_layer), + rotary_embedding_dim_(rotary_embedding_dim), + neox_rotary_style_(neox_rotary_style), + use_gptj_residual_(use_gptj_residual), + layernorm_eps_(layernorm_eps), + hidden_units_(head_num * size_per_head), + tensor_para_(tensor_para), + pipeline_para_(pipeline_para), + is_qk_buf_float_(is_qk_buf_float), + attention_type_(attention_type), + int8_mode_(int8_mode), + custom_all_reduce_comm_(custom_all_reduce_comm), + enable_custom_all_reduce_(enable_custom_all_reduce) +{ + initialize(); +} + +template +LlamaContextDecoder::LlamaContextDecoder(LlamaContextDecoder const& decoder): + BaseLayer(decoder.stream_, decoder.cublas_wrapper_, decoder.allocator_, decoder.is_free_buffer_after_forward_), + head_num_(decoder.head_num_), + kv_head_num_(decoder.kv_head_num_), + size_per_head_(decoder.size_per_head_), + inter_size_(decoder.inter_size_), + num_layer_(decoder.num_layer_), + rotary_embedding_dim_(decoder.rotary_embedding_dim_), + neox_rotary_style_(decoder.neox_rotary_style_), + use_gptj_residual_(decoder.use_gptj_residual_), + layernorm_eps_(decoder.layernorm_eps_), + hidden_units_(decoder.hidden_units_), + tensor_para_(decoder.tensor_para_), + pipeline_para_(decoder.pipeline_para_), + is_qk_buf_float_(decoder.is_qk_buf_float_), + attention_type_(decoder.attention_type_), + int8_mode_(decoder.int8_mode_), + custom_all_reduce_comm_(decoder.custom_all_reduce_comm_), + enable_custom_all_reduce_(decoder.enable_custom_all_reduce_) +{ + initialize(); +} + +template +LlamaContextDecoder::~LlamaContextDecoder() +{ + delete self_attention_layer_; + delete ffn_layer_; + freeBuffer(); +} + +template +void LlamaContextDecoder::forward(std::vector* output_tensors, + const std::vector* input_tensors, + const std::vector*>* gpt_decoder_layer_weight) +{ + std::unordered_map input_tensors_map{{"decoder_input", input_tensors->at(0)}, + {"attention_mask", input_tensors->at(1)}, + {"input_lengths", input_tensors->at(2)}}; + std::unordered_map output_tensors_map{{"decoder_output", output_tensors->at(0)}, + {"key_cache", output_tensors->at(1)}, + {"value_cache", output_tensors->at(2)}, + {"last_token_hidden_units", output_tensors->at(3)}}; + + forward(&output_tensors_map, &input_tensors_map, gpt_decoder_layer_weight); +} + +template +void LlamaContextDecoder::forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const std::vector*>* gpt_decoder_layer_weight) +{ + // input tensors: + // decoder_input [batch_size, seq_len, hidden_dimension], + // attention_mask [batch_size, 1, seq_len, seq_len + max_prompt_length] + // input_lengths [batch_size] + // d_prefix_prompt_batch [batch_size], + // each element contains ptr with buffer shape[2, local_head_num_, prompt_length, size_per_head] + // prefix_prompt_lengths [batch size] + + // output tensors: + // decoder_output [batch_size, seq_len, hidden_dimension], + // key_cache [num_layer, batch, local_head_num, size_per_head // x, max_seq_len, x] + // value_cache [num_layer, batch, local_head_num, max_seq_len, size_per_head] + // last_token_hidden_units [batch_size, hidden_dimension] + + // To use layer/pipeline parallelism, we view the shape of 'batch_size' to 'ite * local_batch_size'. + // For example, the shape of decoder_input becomes [ite, batch_size, seq_len, hidden_dimension] during + // computing. + + FT_CHECK(input_tensors->size() >= 5); + FT_CHECK(output_tensors->size() == 4); + + const bool use_shared_contexts = input_tensors->find("compact_idx") != input_tensors->end(); + FT_CHECK(!use_shared_contexts || (input_tensors->find("batch_to_compact_idx") != input_tensors->end())); + const size_t request_batch_size = input_tensors->at("decoder_input").shape[0]; + // compacted batch size. + const size_t batch_size = + use_shared_contexts ? input_tensors->at("compact_idx").shape[0] : input_tensors->at("decoder_input").shape[0]; + const int seq_len = input_tensors->at("decoder_input").shape[1]; // max_input_len + // The maximum length of generation. + const size_t max_seq_len = output_tensors->at("value_cache").shape[3]; + + const int max_prompt_length = + input_tensors->at("attention_mask").shape[3] - input_tensors->at("attention_mask").shape[2]; + const DataType data_type = getTensorType(); + allocateBuffer(batch_size, seq_len, use_shared_contexts); + + T* decoder_input = input_tensors->at("decoder_input").getPtr(); + T* decoder_output = output_tensors->at("decoder_output").getPtr(); + const T* attention_mask = input_tensors->at("attention_mask").getPtr(); + const T** d_prefix_prompt_batch = input_tensors->at("d_prefix_prompt_batch").getPtr(); + const int* d_prefix_prompt_lengths = input_tensors->at("d_prefix_prompt_lengths").getPtr(); + + if (use_shared_contexts) { + invokeCompactInputs(compact_decoder_features_, + compact_attention_mask_, + compact_input_lengths_, + decoder_input, + attention_mask, + input_tensors->at("input_lengths").getPtr(), + input_tensors->at("compact_idx").getPtr(), + batch_size, + seq_len, + hidden_units_, + stream_); + } + + const int local_batch_size = getLocalBatchSize(batch_size, seq_len, pipeline_para_.world_size_); + FT_CHECK(batch_size % local_batch_size == 0); + const int iteration_num = batch_size / local_batch_size; + + Tensor& k_cache = output_tensors->at("key_cache"); + Tensor& v_cache = output_tensors->at("value_cache"); + std::vector self_k_cache_size; + self_k_cache_size.push_back(local_batch_size); + for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { + self_k_cache_size.push_back(*t); + } + std::vector self_v_cache_size; + self_v_cache_size.push_back(local_batch_size); + for (auto t = v_cache.shape.begin() + 2; t != v_cache.shape.end(); ++t) { + self_v_cache_size.push_back(*t); + } + + if (use_shared_contexts) { + // we use k_cache_layer_ and v_cache_layer_ + self_k_cache_size[3] = seq_len; + self_v_cache_size[2] = seq_len; + } + + AttentionType attention_type = (d_prefix_prompt_lengths != nullptr) ? + getUnfusedAttentionType(attention_type_) : + attention_type_; + printf("attention_type: %d\n", attention_type); + const bool is_unpadded_mha = isUnPaddedMHA(attention_type); + + for (int ite = 0; ite < iteration_num; ite++) { + size_t h_token_num = local_batch_size * seq_len; + if (is_unpadded_mha) { + const int* base_input_lengths = + use_shared_contexts ? compact_input_lengths_ : input_tensors->at("input_lengths").getPtr(); + invokeGetPaddingOffsetAndCuSeqLens(h_pinned_token_num_ptr_, + &h_token_num, + padding_offset_, + cu_seqlens_, + base_input_lengths + ite * local_batch_size, + local_batch_size, + seq_len, + stream_); + } + for (int l = 0; l < num_layer_; l++) { + if (isValidLayerParallelId(l) == false) { + continue; + } + + if (l == 0 && is_unpadded_mha) { + const T* base_input = (use_shared_contexts ? compact_decoder_features_ : decoder_input); + invokeRemovePadding(decoder_layer_output_, + base_input + ite * local_batch_size * seq_len * hidden_units_, + padding_offset_, + h_token_num, + hidden_units_, + stream_); + } + + const bool is_final = false; // TODO(bhsueh) remove this flag + T* layer_input = decoder_layer_output_; + T* layer_output = decoder_layer_output_; + if (!is_unpadded_mha) { + if (l == 0) { + layer_input = use_shared_contexts ? compact_decoder_features_ : decoder_input; + layer_input += ite * local_batch_size * seq_len * hidden_units_; + } + if (l == num_layer_ - 1) { + layer_output = use_shared_contexts ? compact_decoder_features_ : decoder_output; + layer_output += ite * local_batch_size * seq_len * hidden_units_; + } + } + + if (isFirstLayerParallelId(l) && pipeline_para_.rank_ != 0 && pipeline_para_.world_size_ > 1) { + int data_size = h_token_num * hidden_units_ / tensor_para_.world_size_; + ftNcclRecv(layer_input + data_size * tensor_para_.rank_, + data_size, + pipeline_para_.rank_ - 1, + pipeline_para_, + stream_); + if (tensor_para_.world_size_ > 1) { + ftNcclAllGather(layer_input, layer_input, data_size, tensor_para_.rank_, tensor_para_, stream_); + } + } + // TODO: 这里用的LN跟neox不一样,不太清楚这里需不需要改成int8的LN + invokeGeneralT5LayerNorm(decoder_normed_input_, + layer_input, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.gamma, + (const T*)nullptr, + layernorm_eps_, + h_token_num, + hidden_units_, + stream_); + + sync_check_cuda_error(); + + const T* attention_ptr = use_shared_contexts ? compact_attention_mask_ : attention_mask; + + TensorMap self_attention_input_tensors{ + {"input_query", + Tensor{MEMORY_GPU, data_type, {h_token_num, (size_t)hidden_units_}, decoder_normed_input_}}, + {"attention_mask", + Tensor{MEMORY_GPU, + data_type, + {(size_t)local_batch_size, (size_t)1, (size_t)seq_len, (size_t)(seq_len + max_prompt_length)}, + attention_ptr + local_batch_size * ite * seq_len * (seq_len + max_prompt_length)}}, + {"attention_type", Tensor{MEMORY_CPU, TYPE_VOID, {1}, &attention_type}}, + {"is_final_layer", Tensor{MEMORY_CPU, TYPE_BOOL, {(size_t)1}, &is_final}}, + {"layer_id", Tensor{MEMORY_CPU, TYPE_INT32, {(size_t)1}, &l}}}; + self_attention_input_tensors.insertIfValid( + "d_prefix_prompt_batch", + Tensor{MEMORY_GPU, + data_type, + {(size_t)local_batch_size}, + d_prefix_prompt_batch != nullptr ? d_prefix_prompt_batch + ite * local_batch_size : nullptr}); + self_attention_input_tensors.insertIfValid("d_prefix_prompt_lengths", + Tensor{MEMORY_GPU, + TYPE_INT32, + {(size_t)local_batch_size}, + d_prefix_prompt_lengths != nullptr ? + d_prefix_prompt_lengths + ite * local_batch_size : + nullptr}); + + if (is_unpadded_mha) { + self_attention_input_tensors.insert("padding_offset", + Tensor{MEMORY_GPU, TYPE_INT32, {h_token_num}, padding_offset_}); + self_attention_input_tensors.insert( + "cu_seqlens", Tensor{MEMORY_GPU, TYPE_INT32, {size_t(local_batch_size + 1)}, cu_seqlens_}); + } + + size_t cache_offset = l - getFirstLayerParallelId(); + for (auto t = k_cache.shape.begin() + 1; t != k_cache.shape.end(); ++t) { + cache_offset *= *t; + }; + size_t ite_cache_offset = ite * local_batch_size; + for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { + ite_cache_offset *= *t; + } + cache_offset += ite_cache_offset; + + T* k_cache_ptr = use_shared_contexts ? k_cache_layer_ : k_cache.getPtrWithOffset(cache_offset); + T* v_cache_ptr = use_shared_contexts ? v_cache_layer_ : v_cache.getPtrWithOffset(cache_offset); + + TensorMap self_attention_output_tensors{ + {"hidden_features", + Tensor{MEMORY_GPU, data_type, {h_token_num, (size_t)hidden_units_}, self_attn_output_}}, + {"key_cache", Tensor{MEMORY_GPU, data_type, self_k_cache_size, k_cache_ptr}}, + {"value_cache", Tensor{MEMORY_GPU, data_type, self_v_cache_size, v_cache_ptr}}}; + + self_attention_layer_->forward(&self_attention_output_tensors, + &self_attention_input_tensors, + &gpt_decoder_layer_weight->at(l)->self_attention_weights); + + #ifdef ENABLE_FLEX_DEBUG + if (l == 0) { + printf("%d %d: %d %d\n", l, ite, h_token_num, hidden_units_); + T *self_attn_output = new T[h_token_num * hidden_units_]; + cudaMemcpy(self_attn_output, self_attn_output_, sizeof(T)*h_token_num * hidden_units_, cudaMemcpyDeviceToHost); + sync_check_cuda_error(); + int k = 0; + for (int i=0; i(cache_layer_offset), + v_cache.getPtrWithOffset(cache_layer_offset), + k_cache_layer_, + v_cache_layer_, + input_tensors->at("batch_to_compact_idx").getPtr(), + request_batch_size, // batch_size (uncompact) + v_cache.shape[2], // local_head_num + max_seq_len, + seq_len, + size_per_head_, + local_batch_size, + ite, + stream_); + sync_check_cuda_error(); + } + + if (is_final == false) { + if (use_gptj_residual_) { + invokeGeneralLayerNorm(decoder_normed_input_, + layer_input, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.beta, + layernorm_eps_, + h_token_num, + hidden_units_, + (float*)nullptr, + int8_mode_, + stream_); + } + else { + // TODO: modify or not ? + invokeGeneralAddResidualT5PreLayerNorm( + self_attn_output_, + decoder_normed_input_, + layer_input, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.gamma, + layernorm_eps_, + h_token_num, + hidden_units_, + stream_); + } + + TensorMap ffn_input_tensors( + {{"ffn_input", + Tensor{MEMORY_GPU, data_type, {h_token_num, (size_t)hidden_units_}, decoder_normed_input_}}}); + TensorMap ffn_output_tensors({{"ffn_output", + Tensor{MEMORY_GPU, + data_type, + {h_token_num, (size_t)hidden_units_}, + use_gptj_residual_ ? ffn_output_ : layer_output}}}); + ffn_layer_->forward( + &ffn_output_tensors, &ffn_input_tensors, &gpt_decoder_layer_weight->at(l)->ffn_weights); + + if (use_gptj_residual_) { + // Original workflow: + // layer_output = layer_input + reduceSum(ffn_output + self_attn_output + ffn_output_bias) + // Our workflow: + // layer_output = reduceSum(ffn_output + self_attn_output + ffn_output_bias + layer_input / + // TP_size) + // They are equivalent on math, but we can use same buffer for layer_input and layer_output + + invokeAddBiasAttentionFfnResidual(layer_output, + ffn_output_, + self_attn_output_, + layer_input, + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + h_token_num, + hidden_units_, + tensor_para_.world_size_, + stream_); + if (tensor_para_.world_size_ > 1) { + ftNcclAllReduceSum( + layer_output, layer_output, h_token_num * hidden_units_, tensor_para_, stream_); + } + } + else { + invokeAddBiasResidual(layer_output, + self_attn_output_, + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + h_token_num, + hidden_units_, + stream_); + } + + sync_check_cuda_error(); + #ifdef ENABLE_FLEX_DEBUG + if (l == 1) { + printf("%d %d: %d %d\n", l, ite, h_token_num, hidden_units_); + T *self_attn_output = new T[h_token_num * hidden_units_]; + cudaMemcpy(self_attn_output, layer_output, sizeof(T)*h_token_num * hidden_units_, cudaMemcpyDeviceToHost); + sync_check_cuda_error(); + int k = 0; + for (int i=0; i 1) { + int data_size = h_token_num * hidden_units_ / tensor_para_.world_size_; + ftNcclSend(layer_output + data_size * tensor_para_.rank_, + data_size, + pipeline_para_.rank_ + 1, + pipeline_para_, + stream_); + } + + if ((l == num_layer_ - 1) && is_unpadded_mha) { + T* base_ptr = use_shared_contexts ? compact_decoder_features_ : decoder_output; + invokeRebuildPadding(base_ptr + ite * local_batch_size * seq_len * hidden_units_, + decoder_layer_output_, + padding_offset_, + h_token_num, + head_num_ * size_per_head_, + stream_); + } + } + } + } + if (use_shared_contexts) { + invokeUnCompactOutputs(decoder_output, + compact_decoder_features_, + input_tensors->at("batch_to_compact_idx").getPtr(), + request_batch_size, // batch + seq_len * hidden_units_, + stream_); + sync_check_cuda_error(); + } + + // TODO(bhsueh) We could optimize this point by only computing the last token for the last layer + invokeLookupHiddenStateOfLastToken(output_tensors->at("last_token_hidden_units").getPtr(), + output_tensors->at("decoder_output").getPtr(), + input_tensors->at("input_lengths").getPtr(), + seq_len, + request_batch_size, + hidden_units_, + stream_); + sync_check_cuda_error(); + if (is_free_buffer_after_forward_ == true) { + freeBuffer(); + } +} + +template class LlamaContextDecoder; +template class LlamaContextDecoder; +#ifdef ENABLE_BF16 +template class LlamaContextDecoder<__nv_bfloat16>; +#endif + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.h b/src/fastertransformer/models/llama/LlamaContextDecoder.h new file mode 100644 index 000000000..a47b9589b --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.h @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "src/fastertransformer/kernels/add_residual_kernels.h" +#include "src/fastertransformer/kernels/layernorm_kernels.h" +#include "src/fastertransformer/layers/BaseLayer.h" +#include "src/fastertransformer/layers/FfnLayer.h" +#include "src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h" +#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h" +#include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/allocator.h" +#include "src/fastertransformer/utils/cublasMMWrapper.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" + +namespace fastertransformer { + +template +class LlamaContextDecoder: public BaseLayer { +private: + // meta data + size_t head_num_; + size_t kv_head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t rotary_embedding_dim_; + bool neox_rotary_style_; + bool use_gptj_residual_; + float layernorm_eps_; + + // calculated data + size_t hidden_units_; + + NcclParam tensor_para_; + NcclParam pipeline_para_; + + std::shared_ptr custom_all_reduce_comm_; + int enable_custom_all_reduce_; + + AttentionType attention_type_; + + int int8_mode_ = 0; + + bool is_qk_buf_float_; + + BaseAttentionLayer* self_attention_layer_; + FfnLayer* ffn_layer_; + + void allocateBuffer() override; + void allocateBuffer(size_t batch_size, size_t seq_len, bool use_shared_contexts); + void freeBuffer() override; + + bool isValidLayerParallelId(uint l); + bool isFirstLayerParallelId(uint l); + bool isLastLayerParallelId(uint l); + int getFirstLayerParallelId(); + + void initialize(); + +protected: + T* decoder_normed_input_ = nullptr; + T* self_attn_output_ = nullptr; + T* ffn_output_ = nullptr; + T* decoder_layer_output_ = nullptr; + size_t* h_pinned_token_num_ptr_ = nullptr; + int* padding_offset_ = nullptr; + int* cu_seqlens_ = nullptr; + + T* compact_decoder_features_ = nullptr; + T* compact_attention_mask_ = nullptr; + int* compact_input_lengths_ = nullptr; + T* k_cache_layer_ = nullptr; + T* v_cache_layer_ = nullptr; + +public: + LlamaContextDecoder(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + bool neox_rotary_style, + bool use_gptj_residual, + float layernorm_eps, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, + AttentionType attention_type = AttentionType::FUSED_MHA, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce_ = 0); + + LlamaContextDecoder(LlamaContextDecoder const& decoder); + + ~LlamaContextDecoder(); + + void forward(std::vector* output_tensors, + const std::vector* input_tensors, + const std::vector*>* decoder_layer_weights); + + void forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const std::vector*>* gpt_decoder_layer_weight); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaDecoder.cc b/src/fastertransformer/models/llama/LlamaDecoder.cc new file mode 100644 index 000000000..c82de8568 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaDecoder.cc @@ -0,0 +1,400 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/models/llama/LlamaDecoder.h" +#include "src/fastertransformer/layers/TensorParallelSiluFfnLayer.h" +#include "src/fastertransformer/layers/attention_layers/TensorParallelLlamaDecoderSelfAttentionLayer.h" + +namespace fastertransformer { + +template +void LlamaDecoder::initialize() +{ + self_attention_layer_ = new TensorParallelLlamaDecoderSelfAttentionLayer(0, // max_batch_size + head_num_, + kv_head_num_, + size_per_head_, + rotary_embedding_dim_, + neox_rotary_style_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + !use_gptj_residual_, + is_free_buffer_after_forward_, + false, + int8_mode_, + custom_all_reduce_comm_, + enable_custom_all_reduce_); + + // TODO: SiLu ftn layer not support int8 + ffn_layer_ = new TensorParallelSiluFfnLayer(0, // max_batch_size + 1, + head_num_, + size_per_head_, + 0, // expert_num + inter_size_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + !use_gptj_residual_, + is_free_buffer_after_forward_, + false, + true, // use_gated_activation = true; + custom_all_reduce_comm_, + enable_custom_all_reduce_, + int8_mode_); +} + +template +void LlamaDecoder::allocateBuffer() +{ + FT_CHECK(false); +} + +template +void LlamaDecoder::allocateBuffer(size_t batch_size) +{ + decoder_normed_input_ = reinterpret_cast( + allocator_->reMalloc(decoder_normed_input_, sizeof(T) * batch_size * hidden_units_, false)); + self_attn_output_ = + reinterpret_cast(allocator_->reMalloc(self_attn_output_, sizeof(T) * batch_size * hidden_units_, false)); + ffn_output_ = + reinterpret_cast(allocator_->reMalloc(ffn_output_, sizeof(T) * batch_size * hidden_units_, false)); + decoder_layer_output_ = reinterpret_cast( + allocator_->reMalloc(decoder_layer_output_, sizeof(T) * batch_size * hidden_units_, false)); + is_allocate_buffer_ = true; +} + +template +void LlamaDecoder::freeBuffer() +{ + if (is_allocate_buffer_ == true) { + allocator_->free((void**)(&decoder_normed_input_)); + allocator_->free((void**)(&self_attn_output_)); + allocator_->free((void**)(&ffn_output_)); + allocator_->free((void**)(&decoder_layer_output_)); + is_allocate_buffer_ = false; + } +} + +template +bool LlamaDecoder::isValidLayerParallelId(uint l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l >= local_num_layer * pipeline_para_.rank_) + && (l < local_num_layer * (pipeline_para_.rank_ + 1)); +} + +template +bool LlamaDecoder::isFirstLayerParallelId(uint l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l == local_num_layer * pipeline_para_.rank_); +} + +template +bool LlamaDecoder::isLastLayerParallelId(uint l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return l < num_layer_ && (l == local_num_layer * (pipeline_para_.rank_ + 1) - 1); +} + +template +int LlamaDecoder::getFirstLayerParallelId() +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / pipeline_para_.world_size_)); + return local_num_layer * pipeline_para_.rank_; +} + +template +LlamaDecoder::LlamaDecoder(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + bool neox_rotary_style, + bool use_gptj_residual, + float layernorm_eps, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + int int8_mode, + std::shared_ptr custom_all_reduce_comm, + int enable_custom_all_reduce): + BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + inter_size_(inter_size), + num_layer_(num_layer), + rotary_embedding_dim_(rotary_embedding_dim), + neox_rotary_style_(neox_rotary_style), + use_gptj_residual_(use_gptj_residual), + layernorm_eps_(layernorm_eps), + hidden_units_(head_num_ * size_per_head), + tensor_para_(tensor_para), + pipeline_para_(pipeline_para), + int8_mode_(int8_mode), + custom_all_reduce_comm_(custom_all_reduce_comm), + enable_custom_all_reduce_(enable_custom_all_reduce) +{ + initialize(); +} + +template +LlamaDecoder::LlamaDecoder(LlamaDecoder const& decoder): + BaseLayer(decoder.stream_, decoder.cublas_wrapper_, decoder.allocator_, decoder.is_free_buffer_after_forward_), + head_num_(decoder.head_num_), + kv_head_num_(decoder.kv_head_num_), + size_per_head_(decoder.size_per_head_), + inter_size_(decoder.inter_size_), + num_layer_(decoder.num_layer_), + rotary_embedding_dim_(decoder.rotary_embedding_dim_), + neox_rotary_style_(decoder.neox_rotary_style_), + use_gptj_residual_(decoder.use_gptj_residual_), + layernorm_eps_(decoder.layernorm_eps_), + hidden_units_(decoder.hidden_units_), + tensor_para_(decoder.tensor_para_), + pipeline_para_(decoder.pipeline_para_), + int8_mode_(decoder.int8_mode_), + custom_all_reduce_comm_(decoder.custom_all_reduce_comm_), + enable_custom_all_reduce_(decoder.enable_custom_all_reduce_) +{ + initialize(); +} + +template +LlamaDecoder::~LlamaDecoder() +{ + delete self_attention_layer_; + delete ffn_layer_; + freeBuffer(); +} + +template +void LlamaDecoder::forward(std::vector* output_tensors, + const std::vector* input_tensors, + const std::vector*>* gpt_decoder_layer_weight) +{ + FT_CHECK(false); +} + +template +void LlamaDecoder::forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const std::vector*>* gpt_decoder_layer_weight) +{ + // input tensors: + // decoder_input [local_batch_size, hidden_dimension], + // finished [local_batch_size], + // sequence_lengths [local_batch_size] + // total_padding_tokens [local_batch_size], + // max_input_length [1] on cpu + // d_prefix_prompt_lengths [local_batch_size], on GPU + // max_prefix_prompt_length [1] on cpu + // step [1] on cpu + // ite [1] on cpu + // cache_indirection [local_batch_size / beam_width, beam_width, memory_len] + // Here, local_batch_size contains the beam_width, so local_batch_size / beam_width + // is real local_batch_size. + // masked_tokens[local_batch_size, memory_len] + + // output tensors: + // decoder_output [local_batch_size, hidden_dimension], + // key_cache [num_layer, batch_size, head_num, size_per_head // x, memory_len, x] + // value_cache [num_layer, batch_size, head_num, memory_len, size_per_head] + + FT_CHECK(input_tensors->size() == 11); + FT_CHECK(output_tensors->size() == 3); + + const DataType data_type = getTensorType(); + const size_t local_batch_size = input_tensors->at("decoder_input").shape[0]; + allocateBuffer(local_batch_size); + const int ite = input_tensors->at("ite").getVal(); + + T* decoder_input = input_tensors->at("decoder_input").getPtr(); + T* decoder_output = output_tensors->at("decoder_output").getPtr(); + + Tensor& k_cache = output_tensors->at("key_cache"); + Tensor& v_cache = output_tensors->at("value_cache"); + std::vector self_k_cache_size; + self_k_cache_size.push_back(local_batch_size); + for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { + self_k_cache_size.push_back(*t); + } + #ifdef ENABLE_FLEX_DEBUG + printf("self_k_cache_size: "); + for (int i=0; i self_v_cache_size; + self_v_cache_size.push_back(local_batch_size); + for (auto t = v_cache.shape.begin() + 2; t != v_cache.shape.end(); ++t) { + self_v_cache_size.push_back(*t); + } + + for (uint l = 0; l < num_layer_; l++) { + if (isValidLayerParallelId(l) == false) { + continue; + } + T* layer_input = (l == 0) ? decoder_input : decoder_layer_output_; + T* layer_output = (l == num_layer_ - 1) ? decoder_output : decoder_layer_output_; + + if (isFirstLayerParallelId(l) == true && pipeline_para_.rank_ != 0 && pipeline_para_.world_size_ > 1) { + int data_size = local_batch_size * hidden_units_ / tensor_para_.world_size_; + // ftNcclRecv(layer_input, local_batch_size * hidden_units_, pipeline_para_.rank_ - 1, pipeline_para_, + // stream_); + + ftNcclRecv(layer_input + data_size * tensor_para_.rank_, + data_size, + pipeline_para_.rank_ - 1, + pipeline_para_, + stream_); + if (tensor_para_.world_size_ > 1) { + ftNcclAllGather(layer_input, layer_input, data_size, tensor_para_.rank_, tensor_para_, stream_); + } + } + + // TODO(zhwang): NO int8 support here, add later. + invokeGeneralT5LayerNorm(decoder_normed_input_, + layer_input, + gpt_decoder_layer_weight->at(l)->pre_layernorm_weights.gamma, + (const T*)nullptr, + layernorm_eps_, + local_batch_size, + hidden_units_, + stream_); + sync_check_cuda_error(); + + TensorMap self_attention_input_tensors(*input_tensors); + self_attention_input_tensors.insert( + "input_query", Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, decoder_normed_input_}); + + size_t cache_offset = l - getFirstLayerParallelId(); + for (auto t = k_cache.shape.begin() + 1; t != k_cache.shape.end(); ++t) { + cache_offset *= *t; + }; + size_t ite_cache_offset = ite * local_batch_size; + for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { + ite_cache_offset *= *t; + } + cache_offset += ite_cache_offset; + + TensorMap self_attention_output_tensors{ + {"hidden_features", Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, self_attn_output_}}, + {"key_cache", Tensor{MEMORY_GPU, data_type, self_k_cache_size, k_cache.getPtrWithOffset(cache_offset)}}, + {"value_cache", Tensor{MEMORY_GPU, data_type, self_v_cache_size, v_cache.getPtrWithOffset(cache_offset)}}}; + + self_attention_layer_->forward(&self_attention_output_tensors, + &self_attention_input_tensors, + &gpt_decoder_layer_weight->at(l)->self_attention_weights); + if (use_gptj_residual_) { + invokeGeneralLayerNorm(decoder_normed_input_, + layer_input, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.gamma, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.beta, + layernorm_eps_, + local_batch_size, + hidden_units_, + (float*)nullptr, + int8_mode_, + stream_); + } + else { + invokeGeneralAddResidualT5PreLayerNorm( + self_attn_output_, + decoder_normed_input_, + layer_input, + gpt_decoder_layer_weight->at(l)->post_attention_layernorm_weights.gamma, + layernorm_eps_, + local_batch_size, + hidden_units_, + stream_); + } + + TensorMap ffn_input_tensors( + {{"ffn_input", Tensor{MEMORY_GPU, data_type, {local_batch_size, hidden_units_}, decoder_normed_input_}}}); + TensorMap ffn_output_tensors({{"ffn_output", + Tensor{MEMORY_GPU, + data_type, + {local_batch_size, hidden_units_}, + use_gptj_residual_ ? ffn_output_ : layer_output}}}); + ffn_layer_->forward(&ffn_output_tensors, &ffn_input_tensors, &gpt_decoder_layer_weight->at(l)->ffn_weights); + + if (use_gptj_residual_) { + // Original workflow: + // layer_output = layer_input + reduceSum(ffn_output + self_attn_output + ffn_output_bias) + // Our workflow: + // layer_output = reduceSum(ffn_output + self_attn_output + ffn_output_bias + layer_input / TP_size) + // They are equivalent on math, but we can use same buffer for layer_input and layer_output + invokeAddBiasAttentionFfnResidual(layer_output, + ffn_output_, + self_attn_output_, + layer_input, + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + local_batch_size, + hidden_units_, + tensor_para_.world_size_, + stream_); + if (tensor_para_.world_size_ > 1) { + ftNcclAllReduceSum(layer_output, layer_output, local_batch_size * hidden_units_, tensor_para_, stream_); + } + } + else { + invokeAddBiasResidual(layer_output, + self_attn_output_, + gpt_decoder_layer_weight->at(l)->ffn_weights.output_weight.bias, + local_batch_size, + hidden_units_, + stream_); + } + + sync_check_cuda_error(); + + if (isLastLayerParallelId(l) == true && pipeline_para_.rank_ != pipeline_para_.world_size_ - 1 + && pipeline_para_.world_size_ > 1) { + int data_size = local_batch_size * hidden_units_ / tensor_para_.world_size_; + // ftNcclSend(layer_output, local_batch_size * hidden_units_, pipeline_para_.rank_ + 1, pipeline_para_, + // stream_); + + ftNcclSend(layer_output + data_size * tensor_para_.rank_, + data_size, + pipeline_para_.rank_ + 1, + pipeline_para_, + stream_); + } + } + + if (is_free_buffer_after_forward_ == true) { + freeBuffer(); + } +} + +template class LlamaDecoder; +template class LlamaDecoder; +#ifdef ENABLE_BF16 +template class LlamaDecoder<__nv_bfloat16>; +#endif + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaDecoder.h b/src/fastertransformer/models/llama/LlamaDecoder.h new file mode 100644 index 000000000..b6b3db89d --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaDecoder.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "src/fastertransformer/kernels/add_residual_kernels.h" +#include "src/fastertransformer/kernels/layernorm_kernels.h" +#include "src/fastertransformer/layers/BaseLayer.h" +#include "src/fastertransformer/layers/FfnLayer.h" +#include "src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h" +#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h" +#include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/allocator.h" +#include "src/fastertransformer/utils/cublasMMWrapper.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" + +namespace fastertransformer { + +template +class LlamaDecoder: public BaseLayer { +private: +protected: + void allocateBuffer() override; + void allocateBuffer(size_t batch_size); + void freeBuffer() override; + bool isValidLayerParallelId(uint l); + bool isFirstLayerParallelId(uint l); + bool isLastLayerParallelId(uint l); + int getFirstLayerParallelId(); + virtual void initialize(); + + // meta data + size_t head_num_; + size_t kv_head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t rotary_embedding_dim_; + bool neox_rotary_style_; + bool use_gptj_residual_; + size_t hidden_units_; + float layernorm_eps_; + + NcclParam tensor_para_; + NcclParam pipeline_para_; + + std::shared_ptr custom_all_reduce_comm_; + int enable_custom_all_reduce_; + + T* decoder_normed_input_ = nullptr; + T* self_attn_output_ = nullptr; + T* ffn_output_ = nullptr; + T* decoder_layer_output_ = nullptr; + + BaseAttentionLayer* self_attention_layer_; + FfnLayer* ffn_layer_; + + int int8_mode_ = 0; + +public: + LlamaDecoder(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + bool neox_rotary_style, + bool use_gptj_residual, + float layernorm_eps, + NcclParam tensor_para, + NcclParam pipeline_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + int int8_mode = 0, + std::shared_ptr custom_all_reduce_comm = nullptr, + int enable_custom_all_reduce_ = 0); + + LlamaDecoder(LlamaDecoder const& decoder); + + virtual ~LlamaDecoder(); + + virtual void forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const std::vector*>* decoder_layer_weights); + + virtual void forward(std::vector* output_tensors, + const std::vector* input_tensors, + const std::vector*>* decoder_layer_weights); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc new file mode 100644 index 000000000..daf890c87 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc @@ -0,0 +1,398 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h" +#include "src/fastertransformer/utils/memory_utils.h" + +namespace fastertransformer { + +template +LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(const int head_num, + const int kv_head_num, + const int size_per_head, + const int inter_size, + const int tensor_para_size, + const int tensor_para_rank, + const bool use_gptj_residual, + const int int8_mode): + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + hidden_units_(head_num * size_per_head), + inter_size_(inter_size), + tensor_para_size_(tensor_para_size), + tensor_para_rank_(tensor_para_rank), + int8_mode_(int8_mode), + use_gptj_residual_(use_gptj_residual) +{ + mallocWeights(); + setWeightPtr(); + + FT_CHECK_WITH_INFO(int8_mode_ != 2, "Llama doesn't support int8_model == 2"); + FT_CHECK_WITH_INFO(!(std::is_same::value && int8_mode_ == 1), + "Weight only quant does not work with FP32 compute."); +} + +template +LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(const int int8_mode): int8_mode_(int8_mode) +{ +} + +template +LlamaDecoderLayerWeight::~LlamaDecoderLayerWeight() +{ + if (is_maintain_buffer == true) { + for (int i = 0; i < 14; i++) { + if (!use_gptj_residual_ && i != attention_dense_bias_weight_id) { + cudaFree(weights_ptr[i]); + } + } + + pre_layernorm_weights.beta = nullptr; + pre_layernorm_weights.gamma = nullptr; + self_attention_weights.query_weight.kernel = nullptr; + self_attention_weights.query_weight.bias = nullptr; + self_attention_weights.attention_output_weight.kernel = nullptr; + self_attention_weights.attention_output_weight.bias = nullptr; + post_attention_layernorm_weights.beta = nullptr; + post_attention_layernorm_weights.gamma = nullptr; + + ffn_weights.intermediate_weight.kernel = nullptr; + ffn_weights.intermediate_weight.bias = nullptr; + ffn_weights.intermediate_weight2.kernel = nullptr; + ffn_weights.intermediate_weight2.bias = nullptr; + ffn_weights.output_weight.kernel = nullptr; + ffn_weights.output_weight.bias = nullptr; + + if (int8_mode_ != 0) { + for (int i = 0; i < int8_weights_ptr.size(); i++) { + if (int8_weights_ptr[i] != nullptr) { + deviceFree(int8_weights_ptr[i]); + } + } + + if (int8_mode_ == 1) { + for (int i = 0; i < weight_only_scale_ptr.size(); i++) { + if (weight_only_scale_ptr[i] != nullptr) { + deviceFree(weight_only_scale_ptr[i]); + } + } + } + + self_attention_weights.query_weight.int8_kernel = nullptr; + self_attention_weights.query_weight.weight_only_quant_scale = nullptr; + self_attention_weights.attention_output_weight.int8_kernel = nullptr; + self_attention_weights.attention_output_weight.weight_only_quant_scale = nullptr; + + // 作一下标记 intermediate_weight => gate_proj; intermediate_weight2 => up_proj; output_weight => down_proj. + ffn_weights.intermediate_weight.int8_kernel = nullptr; + ffn_weights.intermediate_weight.weight_only_quant_scale = nullptr; + ffn_weights.intermediate_weight2.int8_kernel = nullptr; + ffn_weights.intermediate_weight2.weight_only_quant_scale = nullptr; + ffn_weights.output_weight.int8_kernel = nullptr; + ffn_weights.output_weight.weight_only_quant_scale = nullptr; + } + + is_maintain_buffer = false; + } +} + +template +void LlamaDecoderLayerWeight::copyFrom(const LlamaDecoderLayerWeight& other) +{ + int qkv_size = hidden_units_ + 2 * size_per_head_ * kv_head_num_; + cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], hidden_units_); + cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], hidden_units_); + cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], 3 * hidden_units_ / tensor_para_size_); + if (!use_gptj_residual_) { + cudaD2Dcpy(weights_ptr[5], other.weights_ptr[5], hidden_units_); + } + cudaD2Dcpy(weights_ptr[7], other.weights_ptr[7], inter_size_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[9], other.weights_ptr[9], inter_size_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[11], other.weights_ptr[11], hidden_units_); + cudaD2Dcpy(weights_ptr[12], other.weights_ptr[12], hidden_units_); + cudaD2Dcpy(weights_ptr[13], other.weights_ptr[13], hidden_units_); + if (int8_mode_ == 0) { + cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], qkv_size * hidden_units_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[4], other.weights_ptr[4], hidden_units_ / tensor_para_size_ * hidden_units_); + cudaD2Dcpy(weights_ptr[6], other.weights_ptr[6], hidden_units_ * inter_size_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[8], other.weights_ptr[8], hidden_units_ * inter_size_ / tensor_para_size_); + cudaD2Dcpy(weights_ptr[10], other.weights_ptr[10], inter_size_ / tensor_para_size_ * hidden_units_); + } + else { + cudaD2Dcpy(int8_weights_ptr[0], other.int8_weights_ptr[0], qkv_size * hidden_units_ / tensor_para_size_); + cudaD2Dcpy(int8_weights_ptr[1], other.int8_weights_ptr[1], hidden_units_ / tensor_para_size_ * hidden_units_); + cudaD2Dcpy(int8_weights_ptr[2], other.int8_weights_ptr[2], hidden_units_ * inter_size_ / tensor_para_size_); + cudaD2Dcpy(int8_weights_ptr[3], other.int8_weights_ptr[3], hidden_units_ * inter_size_ / tensor_para_size_); + cudaD2Dcpy(int8_weights_ptr[4], other.int8_weights_ptr[4], inter_size_ / tensor_para_size_ * hidden_units_); + + if (int8_mode_ == 1) { + cudaD2Dcpy(weight_only_scale_ptr[0], other.weight_only_scale_ptr[0], qkv_size / tensor_para_size_); + cudaD2Dcpy(weight_only_scale_ptr[1], other.weight_only_scale_ptr[1], hidden_units_); + cudaD2Dcpy(weight_only_scale_ptr[2], other.weight_only_scale_ptr[2], inter_size_ / tensor_para_size_); + + // TODO: 不太清楚这里存的缩放因子对应的是gate_pro_weight 还是给 up_proj/down_proj用的,后面做一下验证,回来再改一下 + cudaD2Dcpy(weight_only_scale_ptr[3], other.weight_only_scale_ptr[3], inter_size_ / tensor_para_size_); + cudaD2Dcpy(weight_only_scale_ptr[4], other.weight_only_scale_ptr[4], hidden_units_); + } + } +} + +template +LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other): + head_num_(other.head_num_), + kv_head_num_(other.kv_head_num_), + size_per_head_(other.size_per_head_), + hidden_units_(other.hidden_units_), + inter_size_(other.inter_size_), + tensor_para_size_(other.tensor_para_size_), + tensor_para_rank_(other.tensor_para_rank_), + int8_mode_(other.int8_mode_), + use_gptj_residual_(other.use_gptj_residual_) +{ + mallocWeights(); + copyFrom(other); + setWeightPtr(); +} + +template +LlamaDecoderLayerWeight& LlamaDecoderLayerWeight::operator=(const LlamaDecoderLayerWeight& other) +{ + head_num_ = other.head_num_; + kv_head_num_ = other.kv_head_num_; + size_per_head_ = other.size_per_head_; + hidden_units_ = other.hidden_units_; + inter_size_ = other.inter_size_; + tensor_para_size_ = other.tensor_para_size_; + tensor_para_rank_ = other.tensor_para_rank_; + int8_mode_ = other.int8_mode_; + use_gptj_residual_ = other.use_gptj_residual_; + + mallocWeights(); + + copyFrom(other); + setWeightPtr(); + return *this; +} + +template +void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType model_file_type) +{ + FT_CHECK(is_maintain_buffer == true); + const std::string rank_spec = std::to_string(tensor_para_rank_); + + // fill all bias to zeros + deviceFill(weights_ptr[0], (size_t)hidden_units_, (T)0.0); + loadWeightFromBin( + weights_ptr[1], {(size_t)hidden_units_}, dir_path + ".input_layernorm.weight.bin", model_file_type); + + int qkv_size = hidden_units_ + 2 * size_per_head_ * kv_head_num_; + + deviceFill(weights_ptr[3], (size_t)(3 * hidden_units_ / tensor_para_size_), (T)0.0); + + if (!use_gptj_residual_) { + deviceFill(weights_ptr[5], (size_t)hidden_units_, (T)0.0); + } + + // FIXME(sunpeng17): check if the weights are correct + // loadWeightFromBin(weights_ptr[6], + // {(size_t)hidden_units_, (size_t)(inter_size_ / tensor_para_size_)}, + // dir_path + ".mlp.gate_proj.weight." + rank_spec + ".bin", + // model_file_type); + + deviceFill(weights_ptr[7], (size_t)(inter_size_ / tensor_para_size_), (T)0.0); + + deviceFill(weights_ptr[9], (size_t)(inter_size_ / tensor_para_size_), (T)0.0); + + // loadWeightFromBin(weights_ptr[10], + // {(size_t)(inter_size_ / tensor_para_size_), (size_t)hidden_units_}, + // dir_path + ".mlp.down_proj.weight." + rank_spec + ".bin", + // model_file_type); + + + deviceFill(weights_ptr[11], (size_t)(hidden_units_), (T)0.0); + + deviceFill(weights_ptr[12], (size_t)(hidden_units_), (T)0.0); + loadWeightFromBin( + weights_ptr[13], {(size_t)hidden_units_}, dir_path + ".post_attention_layernorm.weight.bin", model_file_type); + + if (int8_mode_ == 0) { + loadWeightFromBin(weights_ptr[2], + {(size_t)hidden_units_, (size_t)(qkv_size / tensor_para_size_)}, + dir_path + ".attention.query_key_value.weight." + rank_spec + ".bin", + model_file_type); + // { + // printf("qkv_size: %d\n", qkv_size); + // printf("w2\n"); + // int sz = 100; + // T *qkv_buf = new T[sz]; + // cudaMemcpy(qkv_buf, weights_ptr[2], sizeof(T)*sz, cudaMemcpyDeviceToHost); + // sync_check_cuda_error(); + // for (int i=0; i(weights_ptr[4], + {(size_t)(hidden_units_ / tensor_para_size_), (size_t)hidden_units_}, + dir_path + ".attention.dense.weight." + rank_spec + ".bin", + model_file_type); + + loadWeightFromBin(weights_ptr[6], + {(size_t)hidden_units_, (size_t)(inter_size_ / tensor_para_size_)}, + dir_path + ".mlp.gate_proj.weight." + rank_spec + ".bin", + model_file_type); + + loadWeightFromBin(weights_ptr[8], + {(size_t)(inter_size_ / tensor_para_size_), (size_t)hidden_units_}, + dir_path + ".mlp.up_proj.weight." + rank_spec + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[10], + {(size_t)(inter_size_ / tensor_para_size_), (size_t)hidden_units_}, + dir_path + ".mlp.down_proj.weight." + rank_spec + ".bin", + model_file_type); + } + else if (int8_mode_ == 1) { + loadWeightFromBinAndQuantizeForWeightOnly(int8_weights_ptr[0], + weight_only_scale_ptr[0], + {(size_t)hidden_units_, (size_t)(qkv_size / tensor_para_size_)}, + dir_path + ".attention.query_key_value.weight." + rank_spec + ".bin", + model_file_type); + + loadWeightFromBinAndQuantizeForWeightOnly(int8_weights_ptr[1], + weight_only_scale_ptr[1], + {(size_t)(hidden_units_ / tensor_para_size_), (size_t)hidden_units_}, + dir_path + ".attention.dense.weight." + rank_spec + ".bin", + model_file_type); + + loadWeightFromBinAndQuantizeForWeightOnly(int8_weights_ptr[2], + weight_only_scale_ptr[2], + {(size_t)hidden_units_, (size_t)(inter_size_ / tensor_para_size_)}, + dir_path + ".mlp.gate_proj.weight." + rank_spec + ".bin", + model_file_type); + + loadWeightFromBinAndQuantizeForWeightOnly(int8_weights_ptr[3], + weight_only_scale_ptr[3], + {(size_t)hidden_units_, (size_t)(inter_size_ / tensor_para_size_)}, + dir_path + ".mlp.up_proj.weight." + rank_spec + ".bin", + model_file_type); + loadWeightFromBinAndQuantizeForWeightOnly(int8_weights_ptr[4], + weight_only_scale_ptr[4], + {(size_t)(inter_size_ / tensor_para_size_), (size_t)hidden_units_}, + dir_path + ".mlp.down_proj.weight." + rank_spec + ".bin", + model_file_type); + + } +} + +template +void LlamaDecoderLayerWeight::setWeightPtr() +{ + pre_layernorm_weights.beta = weights_ptr[0]; + pre_layernorm_weights.gamma = weights_ptr[1]; + self_attention_weights.query_weight.kernel = weights_ptr[2]; + self_attention_weights.query_weight.bias = weights_ptr[3]; + self_attention_weights.attention_output_weight.kernel = weights_ptr[4]; + self_attention_weights.attention_output_weight.bias = use_gptj_residual_ ? nullptr : weights_ptr[5]; + + ffn_weights.intermediate_weight.kernel = weights_ptr[6]; + ffn_weights.intermediate_weight.bias = weights_ptr[7]; + ffn_weights.intermediate_weight2.kernel = weights_ptr[8]; + ffn_weights.intermediate_weight2.bias = weights_ptr[9]; + ffn_weights.output_weight.kernel = weights_ptr[10]; + ffn_weights.output_weight.bias = weights_ptr[11]; + + post_attention_layernorm_weights.beta = weights_ptr[12]; + post_attention_layernorm_weights.gamma = weights_ptr[13]; + + if (int8_mode_ != 0) { + self_attention_weights.query_weight.int8_kernel = int8_weights_ptr[0]; + self_attention_weights.attention_output_weight.int8_kernel = int8_weights_ptr[1]; + ffn_weights.intermediate_weight.int8_kernel = int8_weights_ptr[2]; + ffn_weights.intermediate_weight2.int8_kernel = int8_weights_ptr[3]; + ffn_weights.output_weight.int8_kernel = int8_weights_ptr[4]; + + if (int8_mode_ == 1) { + self_attention_weights.query_weight.weight_only_quant_scale = weight_only_scale_ptr[0]; + self_attention_weights.attention_output_weight.weight_only_quant_scale = weight_only_scale_ptr[1]; + ffn_weights.intermediate_weight.weight_only_quant_scale = weight_only_scale_ptr[2]; + ffn_weights.intermediate_weight2.weight_only_quant_scale = weight_only_scale_ptr[3]; + ffn_weights.output_weight.weight_only_quant_scale = weight_only_scale_ptr[4]; + } + } + + is_maintain_buffer = true; +} + +template +void LlamaDecoderLayerWeight::mallocWeights() +{ + deviceMalloc(&weights_ptr[0], hidden_units_); // pre layernorm beta + deviceMalloc(&weights_ptr[1], hidden_units_); // pre layernorm gamma + // deviceMalloc(&weights_ptr[2], hidden_units_ * 3 * hidden_units_ / tensor_para_size_); // qkv kernel + int qkv_size = hidden_units_ + 2 * size_per_head_ * kv_head_num_; + deviceMalloc(&weights_ptr[3], 3 * hidden_units_ / tensor_para_size_); // qkv bias + // deviceMalloc(&weights_ptr[4], hidden_units_ / tensor_para_size_ * hidden_units_); // attention output weight + if (!use_gptj_residual_) { + deviceMalloc(&weights_ptr[5], hidden_units_); // attention output bias + } + + // deviceMalloc(&weights_ptr[6], hidden_units_ * inter_size_ / tensor_para_size_); // intermediate_weight kernel + deviceMalloc(&weights_ptr[7], inter_size_ / tensor_para_size_); // intermediate_weight bias + // deviceMalloc(&weights_ptr[8], hidden_units_ * inter_size_ / tensor_para_size_); // intermediate_weight2 kernel + deviceMalloc(&weights_ptr[9], inter_size_ / tensor_para_size_); // intermediate_weight2 bias + // deviceMalloc(&weights_ptr[10], inter_size_ / tensor_para_size_ * hidden_units_); // output_weight kernel + deviceMalloc(&weights_ptr[11], hidden_units_); // output_weight bias + deviceMalloc(&weights_ptr[12], hidden_units_); // post attn layernorm beta + deviceMalloc(&weights_ptr[13], hidden_units_); // post attn layernorm gamma + + if (int8_mode_ == 0) { + deviceMalloc(&weights_ptr[2], qkv_size * hidden_units_ / tensor_para_size_); // qkv weight + deviceMalloc(&weights_ptr[4], hidden_units_ / tensor_para_size_ * hidden_units_); // attention output weight + deviceMalloc(&weights_ptr[6], hidden_units_ * inter_size_ / tensor_para_size_); // intermediate_weight kernel + deviceMalloc(&weights_ptr[8], hidden_units_ * inter_size_ / tensor_para_size_); // intermediate_weight2 kernel + deviceMalloc(&weights_ptr[10], inter_size_ / tensor_para_size_ * hidden_units_); // output_weight kernel + } + else { + // Alloc FFN and Attention int8 weights + deviceMalloc(&int8_weights_ptr[0], qkv_size * hidden_units_ / tensor_para_size_); + deviceMalloc(&int8_weights_ptr[1], hidden_units_ / tensor_para_size_ * hidden_units_); + deviceMalloc(&int8_weights_ptr[2], hidden_units_ * inter_size_ / tensor_para_size_); + deviceMalloc(&int8_weights_ptr[3], hidden_units_ * inter_size_ / tensor_para_size_); + deviceMalloc(&int8_weights_ptr[4], inter_size_ / tensor_para_size_ * hidden_units_); + + + if (int8_mode_ == 1) { + // Alloc scales for weight only quant for attention and FFN weights + deviceMalloc(&weight_only_scale_ptr[0], qkv_size / tensor_para_size_); + deviceMalloc(&weight_only_scale_ptr[1], hidden_units_); + deviceMalloc(&weight_only_scale_ptr[2], inter_size_ / tensor_para_size_); + deviceMalloc(&weight_only_scale_ptr[3], inter_size_ / tensor_para_size_); + deviceMalloc(&weight_only_scale_ptr[4], hidden_units_); + } + } + +} + +template struct LlamaDecoderLayerWeight; +template struct LlamaDecoderLayerWeight; +#ifdef ENABLE_BF16 +template class LlamaDecoderLayerWeight<__nv_bfloat16>; +#endif + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h new file mode 100644 index 000000000..59a100f08 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "src/fastertransformer/kernels/layernorm_kernels.h" +#include "src/fastertransformer/layers/FfnWeight.h" +#include "src/fastertransformer/layers/attention_layers/AttentionWeight.h" +#include "src/fastertransformer/utils/cuda_utils.h" + +namespace fastertransformer { + +template +struct LlamaDecoderLayerWeight { +public: + LlamaDecoderLayerWeight() = default; + LlamaDecoderLayerWeight(const int int8_mode); + LlamaDecoderLayerWeight(const int head_num, + const int kv_head_num, + const int size_per_head, + const int inter_size, + const int tensor_para_size = 1, + const int tensor_para_rank = 0, + const bool use_gptj_residual = true, + const int int8_mode = 0); + ~LlamaDecoderLayerWeight(); + LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other); + LlamaDecoderLayerWeight& operator=(const LlamaDecoderLayerWeight& other); + + void loadModel(std::string dir_path, FtCudaDataType model_file_type); + + LayerNormWeight pre_layernorm_weights; + AttentionWeight self_attention_weights; + LayerNormWeight post_attention_layernorm_weights; + FfnWeight ffn_weights; + +private: + int head_num_; + int kv_head_num_; + int size_per_head_; + int hidden_units_; + int inter_size_; + int tensor_para_size_; + int tensor_para_rank_; + bool use_gptj_residual_; + const int attention_dense_bias_weight_id = 5; + bool is_maintain_buffer = false; + T* weights_ptr[14]; + int int8_mode_ = 0; + + std::vector int8_weights_ptr = std::vector(5, nullptr); + std::vector weight_only_scale_ptr = std::vector(5, nullptr); + + void setWeightPtr(); + void mallocWeights(); + void copyFrom(const LlamaDecoderLayerWeight& other); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaWeight.cc b/src/fastertransformer/models/llama/LlamaWeight.cc new file mode 100644 index 000000000..a1bda4053 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaWeight.cc @@ -0,0 +1,321 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/models/llama/LlamaWeight.h" + +namespace fastertransformer { + +template +LlamaWeight::LlamaWeight(const int head_num, + const int kv_head_num, + const int size_per_head, + const int inter_size, + const int vocab_size, + const int num_layer, + const int max_seq_len, + const int tensor_para_size, + const int tensor_para_rank, + const int layer_para_size, + const int layer_para_rank, + const bool use_gptj_residual, + const int int8_mode, + PromptLearningType prompt_learning_type, + std::map> prompt_learning_pair): + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + hidden_units_(head_num * size_per_head), + inter_size_(inter_size), + vocab_size_(vocab_size), + num_layer_(num_layer), + max_seq_len_(max_seq_len), + tensor_para_size_(tensor_para_size), + tensor_para_rank_(tensor_para_rank), + layer_para_size_(layer_para_size), + layer_para_rank_(layer_para_rank), + use_gptj_residual_(use_gptj_residual), + int8_mode_(int8_mode), + prompt_learning_type_(prompt_learning_type), + prompt_learning_pair_(prompt_learning_pair) +{ + FT_CHECK(num_layer_ % layer_para_size_ == 0); + // set prompt weight size + if (prompt_learning_type_ == PromptLearningType::prefix_prompt) { + prompt_token_weight_size_ = 2 * num_layer_ * hidden_units_ / tensor_para_size_; + } + else if (prompt_learning_type_ == PromptLearningType::p_prompt_tuning) { + prompt_token_weight_size_ = hidden_units_; + } + + // set if load and malloc prompt weights + malloc_load_prompt_weights_ = !prompt_learning_pair_.empty() + && (prompt_learning_type_ == PromptLearningType::p_prompt_tuning + || prompt_learning_type_ == PromptLearningType::prefix_prompt); + + decoder_layer_weights.reserve(num_layer_); + for (int l = 0; l < num_layer_; l++) { + if (isValidLayerParallelId(l)) { + decoder_layer_weights.push_back(new LlamaDecoderLayerWeight( + head_num_, kv_head_num_, size_per_head_, inter_size_, tensor_para_size_, tensor_para_rank_, use_gptj_residual_, int8_mode_)); + } + else { + // Layer-parallelism: allocate empty layer because + // this rank does not compute it: + decoder_layer_weights.push_back(new LlamaDecoderLayerWeight(0, 0, 0, 0)); + } + } + + mallocWeights(); + setWeightPtr(); +} + +template +LlamaWeight::~LlamaWeight() +{ + if (is_maintain_buffer == true) { + for (int i = 0; i < weights_ptr.size(); i++) { + deviceFree(weights_ptr[i]); + } + + pre_decoder_embedding_table = nullptr; + post_decoder_layernorm.beta = nullptr; + post_decoder_layernorm.gamma = nullptr; + post_decoder_embedding.kernel = nullptr; + post_decoder_embedding.bias = nullptr; + is_maintain_buffer = false; + } +} + +template +LlamaWeight::LlamaWeight(const LlamaWeight& other): + head_num_(other.head_num_), + kv_head_num_(other.kv_head_num_), + size_per_head_(other.size_per_head_), + hidden_units_(other.hidden_units_), + inter_size_(other.inter_size_), + vocab_size_(other.vocab_size_), + num_layer_(other.num_layer_), + max_seq_len_(other.max_seq_len_), + tensor_para_size_(other.tensor_para_size_), + tensor_para_rank_(other.tensor_para_rank_), + layer_para_size_(other.layer_para_size_), + layer_para_rank_(other.layer_para_rank_), + use_gptj_residual_(other.use_gptj_residual_), + int8_mode_(other.int8_mode_), + prompt_token_weight_size_(other.prompt_token_weight_size_), + malloc_load_prompt_weights_(other.malloc_load_prompt_weights_), + prompt_learning_type_(other.prompt_learning_type_), + prompt_learning_pair_(other.prompt_learning_pair_) +{ + mallocWeights(); + cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], vocab_size_ * hidden_units_); + cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], hidden_units_); + cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_); + cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_ * vocab_size_); + + // prompt learning table: malloc weights and set weight ptr + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + std::string task_name = prompt.first; + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t prompt_id = num_base_weights + (size_t)task_name_id; + + // cuda device to device memcpy prompt table weights buffer memory + cudaD2Dcpy(weights_ptr[prompt_id], other.weights_ptr[prompt_id], prompt_length * prompt_token_weight_size_); + } + } + + setWeightPtr(); + + decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); + for (int l = 0; l < num_layer_; l++) { + decoder_layer_weights.push_back(other.decoder_layer_weights[l]); + } +} + +template +LlamaWeight& LlamaWeight::operator=(const LlamaWeight& other) +{ + head_num_ = other.head_num_; + kv_head_num_ = other.kv_head_num_; + size_per_head_ = other.size_per_head_; + hidden_units_ = other.hidden_units_; + inter_size_ = other.inter_size_; + vocab_size_ = other.vocab_size_; + num_layer_ = other.num_layer_; + max_seq_len_ = other.max_seq_len_; + tensor_para_size_ = other.tensor_para_size_; + tensor_para_rank_ = other.tensor_para_rank_; + layer_para_size_ = other.layer_para_size_; + layer_para_rank_ = other.layer_para_rank_; + use_gptj_residual_ = other.use_gptj_residual_; + int8_mode_ = other.int8_mode_; + prompt_token_weight_size_ = other.prompt_token_weight_size_; + malloc_load_prompt_weights_ = other.malloc_load_prompt_weights_; + prompt_learning_type_ = other.prompt_learning_type_; + prompt_learning_pair_ = other.prompt_learning_pair_; + + mallocWeights(); + cudaD2Dcpy(weights_ptr[0], other.weights_ptr[0], vocab_size_ * hidden_units_); + cudaD2Dcpy(weights_ptr[1], other.weights_ptr[1], hidden_units_); + cudaD2Dcpy(weights_ptr[2], other.weights_ptr[2], hidden_units_); + cudaD2Dcpy(weights_ptr[3], other.weights_ptr[3], hidden_units_ * vocab_size_); + + // prompt learning table: malloc weights and set weight ptr + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + std::string task_name = prompt.first; + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t prompt_id = num_base_weights + (size_t)task_name_id; + + // cuda device to device memcpy prompt table weights buffer memory + cudaD2Dcpy(weights_ptr[prompt_id], other.weights_ptr[prompt_id], prompt_length * prompt_token_weight_size_); + } + } + + setWeightPtr(); + + decoder_layer_weights.clear(); + decoder_layer_weights.reserve(num_layer_); + for (int l = 0; l < num_layer_; l++) { + decoder_layer_weights.push_back(other.decoder_layer_weights[l]); + } + return *this; +} + +template +void LlamaWeight::setWeightPtr() +{ + prompt_learning_table.resize(prompt_learning_pair_.size()); + + pre_decoder_embedding_table = weights_ptr[0]; + post_decoder_layernorm.beta = weights_ptr[1]; + post_decoder_layernorm.gamma = weights_ptr[2]; + post_decoder_embedding.kernel = weights_ptr[3]; + post_decoder_embedding.bias = nullptr; + + // prompt learning tables: set weight ptr + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t task_weight_id = num_base_weights + (size_t)task_name_id; + + // set weight ptr + prompt_learning_table[task_name_id] = {weights_ptr[task_weight_id], prompt_length}; + } + } +} + +template +void LlamaWeight::mallocWeights() +{ + weights_ptr.resize(num_base_weights + prompt_learning_pair_.size()); + + deviceMalloc(&weights_ptr[0], vocab_size_ * hidden_units_); + deviceMalloc(&weights_ptr[1], hidden_units_); + deviceMalloc(&weights_ptr[2], hidden_units_); + deviceMalloc(&weights_ptr[3], hidden_units_ * vocab_size_); + + // prompt learning tables: malloc weights + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t task_weight_id = num_base_weights + (size_t)task_name_id; + + // malloc weights + T* prompt_weights_ptr = nullptr; + deviceMalloc(&prompt_weights_ptr, prompt_length * prompt_token_weight_size_); + weights_ptr[task_weight_id] = prompt_weights_ptr; + } + } + is_maintain_buffer = true; +} + +template +void LlamaWeight::loadModel(std::string dir_path) +{ + FtCudaDataType model_file_type = getModelFileType(dir_path + "/config.ini", "llama"); + FT_CHECK(is_maintain_buffer == true); + + loadWeightFromBin( + weights_ptr[0], {(size_t)(vocab_size_ * hidden_units_)}, dir_path + "/model.wte.weight.bin", model_file_type); + deviceFill(weights_ptr[1], (size_t)hidden_units_, (T)0.0); + loadWeightFromBin( + weights_ptr[2], {(size_t)hidden_units_}, dir_path + "/model.final_layernorm.weight.bin", model_file_type); + loadWeightFromBin(weights_ptr[3], + {(size_t)(vocab_size_ * hidden_units_)}, + dir_path + "/model.lm_head.weight.bin", + model_file_type); + + // prompt table: load weights from bin + if (malloc_load_prompt_weights_) { + for (auto const& prompt : prompt_learning_pair_) { + std::string task_name = prompt.first; + int task_name_id = prompt.second.first; + int prompt_length = prompt.second.second; + size_t task_weight_id = num_base_weights + (size_t)task_name_id; + + std::string prompt_weight_path_name = (prompt_learning_type_ == PromptLearningType::p_prompt_tuning) ? + (dir_path + "/model.prompt_table." + task_name + ".weight.bin") : + (dir_path + "/model.prefix_prompt." + task_name + ".weight." + + std::to_string(tensor_para_rank_) + ".bin"); + + if (prompt_length > 0) { + loadWeightFromBin(weights_ptr[task_weight_id], + {(size_t)(prompt_length * (int)prompt_token_weight_size_)}, + prompt_weight_path_name, + model_file_type); + } + } + } + + for (int l = 0; l < num_layer_; l++) { + if (isValidLayerParallelId(l)) { + decoder_layer_weights[l]->loadModel(dir_path + "/model.layers." + std::to_string(l), model_file_type); + } + } +} + +template +void LlamaWeight::resizeLayer(const int num_layer) +{ + num_layer_ = num_layer; + decoder_layer_weights.reserve(num_layer_); + for (int l = 0; l < num_layer_; l++) { + decoder_layer_weights.push_back(new LlamaDecoderLayerWeight()); + } +} + +template +bool LlamaWeight::isValidLayerParallelId(int l) +{ + int local_num_layer = (int)(ceil(num_layer_ * 1.0f / layer_para_size_)); + return l < num_layer_ && (l >= local_num_layer * layer_para_rank_) + && (l < local_num_layer * (layer_para_rank_ + 1)); +} + +template struct LlamaWeight; +template struct LlamaWeight; +#ifdef ENABLE_BF16 +template class LlamaWeight<__nv_bfloat16>; +#endif + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaWeight.h b/src/fastertransformer/models/llama/LlamaWeight.h new file mode 100644 index 000000000..4bbc98a6f --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaWeight.h @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/kernels/layernorm_kernels.h" +#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h" +#include "src/fastertransformer/utils/memory_utils.h" +#include "src/fastertransformer/utils/prompt_learning.h" + +namespace fastertransformer { + +template +struct LlamaWeight { + + LlamaWeight() = default; + LlamaWeight( + const int head_num, + const int kv_head_num, + const int size_per_head, + const int inter_size, + const int vocab_size, + const int num_layer, + const int max_seq_len, + const int tensor_para_size = 1, + const int tensor_para_rank = 0, + const int layer_para_size = 1, + const int layer_para_rank = 0, + const bool use_gptj_residual_ = false, + const int int8_mode = 0, + PromptLearningType prompt_learning_type = PromptLearningType::no_prompt, + std::map> prompt_learning_pair = std::map>{}); + + ~LlamaWeight(); + LlamaWeight(const LlamaWeight& other); + LlamaWeight& operator=(const LlamaWeight& other); + + void loadModel(std::string dir_path); + + void resizeLayer(const int num_layer); + + std::vector*> decoder_layer_weights; + const T* pre_decoder_embedding_table = nullptr; + // GPT-J does not use embedding table, but we leave the ptr such that + // GptNeoX::forward and Gpt::forward become identical + const T* position_encoding_table = nullptr; + + /* + prompt_learning_pair = vectors of [weight ptr, prompt length] pair + prompt_length is stored here for compatible prompt learning table + prefix_prompt weights store as shape [num_layers, 2, num_heads, perfix_seq_len, size_per_head] + p/prompt tuning weights store as shape [prompt_len, hidden_units] + idx is the task_name_id of the prompt tables + */ + std::vector> prompt_learning_table = {}; + + LayerNormWeight post_decoder_layernorm; + DenseWeight post_decoder_embedding; + + inline void setMaxSeqLen(size_t max_seq_len) + { + max_seq_len_ = max_seq_len; + } + +private: + void setWeightPtr(); + void mallocWeights(); + bool isValidLayerParallelId(int l); + + int head_num_; + int kv_head_num_; + int size_per_head_; + int hidden_units_; + int inter_size_; + int vocab_size_; + int num_layer_; + int max_seq_len_; + + int tensor_para_size_; + int tensor_para_rank_; + int layer_para_size_; + int layer_para_rank_; + + size_t int8_mode_ = 0; + + // residual type + bool use_gptj_residual_; + + // prompt learning pair (task_name, (task_name_id, prompt_len)) + PromptLearningType prompt_learning_type_; + std::map> prompt_learning_pair_; + bool malloc_load_prompt_weights_ = false; + // each prompt token's weight size + size_t prompt_token_weight_size_ = 0; + + bool is_maintain_buffer = false; + const size_t num_base_weights = 4; + std::vector weights_ptr = std::vector(num_base_weights); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/triton_backend/CMakeLists.txt b/src/fastertransformer/triton_backend/CMakeLists.txt index 56cda1bde..037c36c36 100644 --- a/src/fastertransformer/triton_backend/CMakeLists.txt +++ b/src/fastertransformer/triton_backend/CMakeLists.txt @@ -27,3 +27,4 @@ if (ENABLE_FP8) endif() add_subdirectory(bert) add_subdirectory(deberta) +add_subdirectory(llama) diff --git a/src/fastertransformer/triton_backend/llama/CMakeLists.txt b/src/fastertransformer/triton_backend/llama/CMakeLists.txt new file mode 100644 index 000000000..d5ba5547e --- /dev/null +++ b/src/fastertransformer/triton_backend/llama/CMakeLists.txt @@ -0,0 +1,25 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.8) + +set(parallel_gpt_triton_backend_files + LlamaTritonModel.cc + LlamaTritonModelInstance.cc +) + +add_library(LlamaTritonBackend STATIC ${parallel_gpt_triton_backend_files}) +set_property(TARGET LlamaTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON) +target_link_libraries(LlamaTritonBackend PRIVATE TransformerTritonBackend Llama tensor memory_utils -lcublasLt) +target_compile_features(LlamaTritonBackend PRIVATE cxx_std_14) diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc new file mode 100644 index 000000000..a2d66f524 --- /dev/null +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc @@ -0,0 +1,273 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h" +#include "3rdparty/INIReader.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/utils/allocator.h" + +namespace ft = fastertransformer; + +std::shared_ptr AbstractTransformerModel::createLlamaModel(std::string inifile) +{ + INIReader reader = INIReader(inifile); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << inifile << "'\n"; + return nullptr; + } + + const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); + int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + std::string model_dir = reader.Get("ft_instance_hyperparameter", "model_dir"); + + if (data_type == "half" || data_type == "fp16") { + return std::make_shared>( + reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), + model_dir, + reader.GetInteger("ft_instance_hyperparameter", "int8_mode", 0)); + } +#ifdef ENABLE_BF16 + else if (data_type == "bf16") { + return std::make_shared>( + reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), + model_dir, + reader.GetInteger("ft_instance_hyperparameter", "int8_mode", 0)); + } +#endif + else { + return std::make_shared>( + reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), + model_dir, + reader.GetInteger("ft_instance_hyperparameter", "int8_mode", 0)); + } +} + +template +LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, + std::string model_dir, + int int8_mode): + tensor_para_size_(tensor_para_size), + pipeline_para_size_(pipeline_para_size), + shared_weights_(std::vector>>(ft::getDeviceCount())), + enable_custom_all_reduce_(enable_custom_all_reduce), + int8_mode_(int8_mode) +{ + model_dir_ = model_dir; + const std::string inifile{model_dir + "/config.ini"}; + INIReader reader = INIReader(inifile); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << inifile << "'\n"; + ft::FT_CHECK(false); + } + + model_name_ = reader.Get("llama", "model_name"); + head_num_ = reader.GetInteger("llama", "head_num"); + kv_head_num_ = reader.GetInteger("llama", "kv_head_num", head_num_); + size_per_head_ = reader.GetInteger("llama", "size_per_head"); + inter_size_ = reader.GetInteger("llama", "inter_size"); + num_layer_ = reader.GetInteger("llama", "num_layer"); + vocab_size_ = reader.GetInteger("llama", "vocab_size"); + rotary_embedding_dim_ = reader.GetInteger("llama", "rotary_embedding"); + layernorm_eps_ = reader.GetFloat("llama", "layernorm_eps"); + start_id_ = reader.GetInteger("llama", "start_id"); + end_id_ = reader.GetInteger("llama", "end_id"); + use_gptj_residual_ = false; + + num_tasks_ = reader.GetInteger("llama", "num_tasks", 0); + + prompt_learning_start_id_ = reader.GetInteger("llama", "prompt_learning_start_id", end_id_ + 1); + prompt_learning_type_ = + static_cast(reader.GetInteger("llama", "prompt_learning_type", 0)); + + for (int task_name_id = 0; task_name_id < num_tasks_; task_name_id++) { + std::string config_task_name = "task_" + std::to_string(task_name_id); + std::string task_name = reader.Get(config_task_name, "task_name"); + const int prompt_length = reader.GetInteger(config_task_name, "prompt_length", 0); + prompt_learning_table_pair_.insert({task_name, {task_name_id, prompt_length}}); + } +} + +template +std::unique_ptr LlamaTritonModel::createModelInstance( + int device_id, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_); + + std::unique_ptr> allocator( + new ft::Allocator(device_id)); + + allocator->setStream(stream); + + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + + cublasCreate(&cublas_handle); + cublasLtCreate(&cublaslt_handle); + cublasSetStream(cublas_handle, stream); + + std::unique_ptr cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); + std::unique_ptr cublas_wrapper_mutex(new std::mutex()); + std::unique_ptr cublas_wrapper(new ft::cublasMMWrapper( + cublas_handle, cublaslt_handle, stream, cublas_algo_map.get(), cublas_wrapper_mutex.get(), allocator.get())); + + std::unique_ptr cuda_device_prop_ptr(new cudaDeviceProp); + ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id)); + + if (std::is_same::value) { + cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); + } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif + else if (std::is_same::value) { + cublas_wrapper->setFP32GemmConfig(); + } + + ft::NcclParam tensor_para = nccl_params.first[comms_rank]; + ft::NcclParam pipeline_para = nccl_params.second[comms_rank]; + + ft::AttentionType attention_type = ft::getAttentionType(size_per_head_, + ft::getSMVersion(), + true, // remove_padding + 0, // gpt supports any-seq-length fmha + true, // is_fuse + false, // with_relative_position_bias + true); // causal_mask + auto gpt = std::make_unique>( + ft::Llama(head_num_, + kv_head_num_, + size_per_head_, + inter_size_, + num_layer_, + vocab_size_, + rotary_embedding_dim_, + layernorm_eps_, + start_id_, + end_id_, + prompt_learning_start_id_, // p/prompt tuning virtual token start id + prompt_learning_type_, + use_gptj_residual_, + 0.0f, // beam_search_diversity_rate_, + 0, // top_k_, + 0.0f, // top_p_, + 0, // random seed, note that all gpus should use same seed + 0.0f, // temperature_, + 0.0f, // len_penalty_, + 0.0f, // repetition_penalty_, + tensor_para, + pipeline_para, + stream, + cublas_wrapper.get(), + allocator.get(), + false, + cuda_device_prop_ptr.get(), + attention_type, + int8_mode_, + custom_all_reduce_comm, + enable_custom_all_reduce_)); + + return std::unique_ptr>( + new LlamaTritonModelInstance(std::move(gpt), + shared_weights_[device_id], + std::move(allocator), + std::move(cublas_algo_map), + std::move(cublas_wrapper_mutex), + std::move(cublas_wrapper), + std::move(cuda_device_prop_ptr))); +} + +template +void LlamaTritonModel::createSharedWeights(int device_id, int rank) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int tensor_para_rank = rank % tensor_para_size_; + const int pipeline_para_rank = rank / tensor_para_size_; + shared_weights_[device_id] = std::make_shared>(head_num_, + kv_head_num_, + size_per_head_, + inter_size_, + vocab_size_, + num_layer_, + 0, // max_seq_len, deprecated + tensor_para_size_, + tensor_para_rank, + pipeline_para_size_, + pipeline_para_rank, + use_gptj_residual_, + int8_mode_, + prompt_learning_type_, + prompt_learning_table_pair_); + shared_weights_[device_id]->loadModel(model_dir_); + return; +} + +template +std::string LlamaTritonModel::toString() +{ + std::stringstream ss; + ss << "Model: " + << "\nhead_num: " << head_num_ << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ + << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ << "\nlayernorm_eps: " << layernorm_eps_ + << "\nstart_id: " << start_id_ << "\nend_id: " << end_id_ << "\nuse_gptj_residual: " << use_gptj_residual_ + << "\nprompt_learning_type_: " << static_cast(prompt_learning_type_) << "\nkv_head_num: " << kv_head_num_ + << "\nprompt_learning_start_id_: " << prompt_learning_start_id_ << "\ntensor_para_size: " << tensor_para_size_ + << "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ + << "\nint8_mode: " << int8_mode_ + << "\nmodel_name: " << model_name_ << "\nmodel_dir: " << model_dir_ << std::endl; + return ss.str(); +} + +template +void LlamaTritonModel::createCustomComms( + std::vector>* custom_all_reduce_comms, int world_size) +{ + using commDataType = typename ft::CustomARCommTypeConverter::Type; + ft::initCustomAllReduceComm(custom_all_reduce_comms, enable_custom_all_reduce_, world_size); +} + +template +int LlamaTritonModel::getTensorParaSize() +{ + return tensor_para_size_; +} + +template +int LlamaTritonModel::getPipelineParaSize() +{ + return pipeline_para_size_; +} + +template struct LlamaTritonModel; +template struct LlamaTritonModel; +#ifdef ENABLE_BF16 +template class LlamaTritonModel<__nv_bfloat16>; +#endif diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.h b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.h new file mode 100644 index 000000000..c1b48f118 --- /dev/null +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/models/llama/Llama.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include + +namespace ft = fastertransformer; + +template +struct LlamaTritonModel: public AbstractTransformerModel { + LlamaTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, + std::string model_dir, + int int8_mode); + + ~LlamaTritonModel() = default; + + virtual std::unique_ptr + createModelInstance(int deviceId, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) override; + + virtual void createSharedWeights(int deviceId, int rank) override; + + virtual void createCustomComms(std::vector>* custom_all_reduce_comms, + int world_size) override; + + virtual std::string toString() override; + virtual int getTensorParaSize() override; + virtual int getPipelineParaSize() override; + +private: + size_t head_num_; + size_t kv_head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t vocab_size_; + size_t rotary_embedding_dim_; + float layernorm_eps_; + int start_id_; + int end_id_; + size_t tensor_para_size_; + size_t pipeline_para_size_; + + // shared weights for each device + std::vector>> shared_weights_; + + // residual type + bool use_gptj_residual_ = false; + + int int8_mode_ = 0; + + // number of tasks (for prefix-prompt, p/prompt-tuning) + size_t num_tasks_ = 0; + int prompt_learning_start_id_ = 0; + ft::PromptLearningType prompt_learning_type_ = ft::PromptLearningType::no_prompt; + std::map> prompt_learning_table_pair_ = {}; + + bool is_fp16_; + int enable_custom_all_reduce_ = 0; + + std::string model_name_; + std::string model_dir_; +}; diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.cc b/src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.cc new file mode 100644 index 000000000..e46adf87d --- /dev/null +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.cc @@ -0,0 +1,264 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/triton_backend/triton_utils.hpp" +#include "src/fastertransformer/utils/Tensor.h" +#include +#include +#include +#include + +namespace ft = fastertransformer; + +template +void triton_stream_callback(std::unordered_map* output_tensors, void* ctx) +{ + LlamaTritonModelInstance* model = reinterpret_cast*>(ctx); + auto result = LlamaTritonModelInstance::convert_outputs(*output_tensors); + + model->stream_cb_(result, model->stream_ctx_); +} + +template +LlamaTritonModelInstance::LlamaTritonModelInstance( + std::unique_ptr> gpt, + std::shared_ptr> gpt_weight, + std::unique_ptr> allocator, + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr): + gpt_(std::move(gpt)), + gpt_weight_(gpt_weight), + allocator_(std::move(allocator)), + cublas_algo_map_(std::move(cublas_algo_map)), + cublas_wrapper_mutex_(std::move(cublas_wrapper_mutex)), + cublas_wrapper_(std::move(cublas_wrapper)), + cuda_device_prop_ptr_(std::move(cuda_device_prop_ptr)) +{ +} + +template +std::unordered_map LlamaTritonModelInstance::convert_inputs( + std::shared_ptr> input_tensors) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + move_tensor_H2D(input_tensors->at("input_ids"), d_input_ids_, &allocator_); + move_tensor_H2D(input_tensors->at("input_lengths"), d_input_lengths_, &allocator_); + + const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; + const size_t input_data_len = input_tensors->at("input_ids").shape[1]; + h_total_output_lengths_ = reinterpret_cast(malloc(request_batch_size * sizeof(uint32_t))); + for (int i = 0; i < request_batch_size; ++i) { + h_total_output_lengths_[i] = + reinterpret_cast(input_tensors->at("request_output_len").data)[i] + input_data_len; + } + + std::unordered_map ft_input_tensors = std::unordered_map{ + {"input_ids", as_GPU_tensor(input_tensors->at("input_ids"), d_input_ids_)}, + {"input_lengths", as_GPU_tensor(input_tensors->at("input_lengths"), d_input_lengths_)}, + {"output_seq_len", + ft::Tensor{ft::MEMORY_CPU, + ft::TYPE_UINT32, + {input_tensors->at("request_output_len").shape[0]}, + h_total_output_lengths_}}}; + + if (input_tensors->find("bad_words_list") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("bad_words_list"), d_input_bad_words_, &allocator_); + ft_input_tensors.insert( + {"bad_words_list", as_GPU_tensor(input_tensors->at("bad_words_list"), d_input_bad_words_)}); + } + + if (input_tensors->find("stop_words_list") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("stop_words_list"), d_input_stop_words_, &allocator_); + ft_input_tensors.insert( + {"stop_words_list", as_GPU_tensor(input_tensors->at("stop_words_list"), d_input_stop_words_)}); + } + + if (input_tensors->count("request_prompt_embedding") && input_tensors->count("request_prompt_lengths") + && input_tensors->count("request_prompt_type")) { + + move_tensor_H2D(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_, &allocator_); + ft_input_tensors.insert( + {"request_prompt_lengths", + as_GPU_tensor(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_)}); + + move_tensor_H2D(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_, &allocator_); + ft_input_tensors.insert( + {"request_prompt_embedding", + as_GPU_tensor(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_)}); + } + + if (input_tensors->find("top_p_decay") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("top_p_decay"), d_top_p_decay_, &allocator_); + ft_input_tensors.insert({"top_p_decay", as_GPU_tensor(input_tensors->at("top_p_decay"), d_top_p_decay_)}); + } + if (input_tensors->find("top_p_min") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("top_p_min"), d_top_p_min_, &allocator_); + ft_input_tensors.insert({"top_p_min", as_GPU_tensor(input_tensors->at("top_p_min"), d_top_p_min_)}); + } + if (input_tensors->find("top_p_reset_ids") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("top_p_reset_ids"), d_top_p_reset_ids_, &allocator_); + ft_input_tensors.insert( + {"top_p_reset_ids", as_GPU_tensor(input_tensors->at("top_p_reset_ids"), d_top_p_reset_ids_)}); + } + + for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { + if (t->first.find("input_ids") == std::string::npos && t->first.find("input_lengths") == std::string::npos + && t->first.find("output_seq_len") == std::string::npos + && t->first.find("prefix_soft_prompt_embedding") == std::string::npos + && t->first.find("prefix_soft_prompt_lengths") == std::string::npos) { + if (ft_input_tensors.count(t->first) == 0) { + ft_input_tensors.insert({t->first, t->second.convertTritonTensorToFt()}); + } + } + } + + return ft_input_tensors; +} + +template +std::shared_ptr> +LlamaTritonModelInstance::convert_outputs(const std::unordered_map& output_tensors) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + std::unordered_map* outputs_mapping = + new std::unordered_map(); + + for (auto it = output_tensors.begin(); it != output_tensors.end(); it++) { + outputs_mapping->insert({it->first, triton::Tensor::convertFtTensorToTriton(it->second)}); + } + + return std::shared_ptr>(outputs_mapping); +} + +template +std::shared_ptr> +LlamaTritonModelInstance::forward(std::shared_ptr> input_tensors) +{ + ft::FT_CHECK(false); + return nullptr; +} + +template +std::shared_ptr> +LlamaTritonModelInstance::forward(std::shared_ptr> input_tensors) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + FT_CHECK_WITH_INFO(input_tensors->at("input_ids").shape.size() == 2, + "input_tensors->at(\"input_ids\").shape.size() == 2"); + FT_CHECK_WITH_INFO(input_tensors->at("input_lengths").shape.size() == 1, + "input_tensors->at(\"input_lengths\").shape.size() == 1"); + + const uint32_t request_batch_size = input_tensors->at("input_ids").shape[0]; + const uint32_t max_request_output_len = (size_t)*std::max_element( + (int*)input_tensors->at("request_output_len").data, + (int*)input_tensors->at("request_output_len").data + input_tensors->at("request_output_len").shape[0]); + const uint32_t total_output_len = max_request_output_len + input_tensors->at("input_ids").shape[1]; + const uint32_t beam_width = + input_tensors->count("beam_width") ? (size_t)(*(uint*)input_tensors->at("beam_width").data) : 1; + + allocateBuffer(request_batch_size, beam_width, total_output_len, max_request_output_len); + + std::unordered_map ft_input_tensors = convert_inputs(input_tensors); + + std::unordered_map output_tensors = std::unordered_map{ + {"output_ids", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_UINT32, + std::vector{request_batch_size, beam_width, total_output_len}, + d_output_ids_}}, + {"sequence_length", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_UINT32, + std::vector{request_batch_size, beam_width}, + d_sequence_lengths_}}}; + + if (input_tensors->count("is_return_log_probs") && *((bool*)input_tensors->at("is_return_log_probs").data)) { + output_tensors.insert({"output_log_probs", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_FP32, + std::vector{request_batch_size, beam_width, max_request_output_len}, + d_output_log_probs_}}); + output_tensors.insert({"cum_log_probs", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_FP32, + std::vector{request_batch_size, beam_width}, + d_cum_log_probs_}}); + } + try { + if (stream_cb_ != nullptr) { + gpt_->registerCallback(triton_stream_callback, this); + } + + gpt_->forward(&output_tensors, &ft_input_tensors, gpt_weight_.get()); + + if (stream_cb_ != nullptr) { + gpt_->unRegisterCallback(); + } + } + catch (...) { + h_exception_ = std::current_exception(); + output_tensors.insert({"error_message", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_BYTES, {1}, &h_exception_}}); + } + + if (h_total_output_lengths_ != nullptr) { + free(h_total_output_lengths_); + h_total_output_lengths_ = nullptr; + } + + return convert_outputs(output_tensors); +} + +template +LlamaTritonModelInstance::~LlamaTritonModelInstance() +{ + freeBuffer(); +} + +template +void LlamaTritonModelInstance::allocateBuffer(const size_t request_batch_size, + const size_t beam_width, + const size_t total_output_len, + const size_t max_request_output_len) +{ + d_output_ids_ = (int*)(allocator_->reMalloc( + d_output_ids_, sizeof(int) * request_batch_size * beam_width * total_output_len, false)); + d_sequence_lengths_ = + (int*)(allocator_->reMalloc(d_sequence_lengths_, sizeof(int) * request_batch_size * beam_width, false)); + d_output_log_probs_ = (float*)(allocator_->reMalloc( + d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * max_request_output_len, false)); + d_cum_log_probs_ = + (float*)(allocator_->reMalloc(d_cum_log_probs_, sizeof(float) * request_batch_size * beam_width, false)); +} + +template +void LlamaTritonModelInstance::freeBuffer() +{ + allocator_->free((void**)(&d_output_ids_)); + allocator_->free((void**)(&d_sequence_lengths_)); + allocator_->free((void**)(&d_output_log_probs_)); + allocator_->free((void**)(&d_cum_log_probs_)); +} + +template struct LlamaTritonModelInstance; +template struct LlamaTritonModelInstance; +#ifdef ENABLE_BF16 +template class LlamaTritonModelInstance<__nv_bfloat16>; +#endif diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h b/src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h new file mode 100644 index 000000000..a75e0692d --- /dev/null +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/models/llama/Llama.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include + +namespace ft = fastertransformer; + +template +struct LlamaTritonModelInstance: AbstractTransformerModelInstance { + + LlamaTritonModelInstance(std::unique_ptr> gpt, + std::shared_ptr> gpt_weight, + std::unique_ptr> allocator, + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr); + ~LlamaTritonModelInstance(); + + std::shared_ptr> + forward(std::shared_ptr> input_tensors) override; + + std::shared_ptr> + forward(std::shared_ptr> input_tensors) override; + + static std::shared_ptr> + convert_outputs(const std::unordered_map& output_tensors); + +private: + const std::unique_ptr> allocator_; + const std::unique_ptr> gpt_; + const std::shared_ptr> gpt_weight_; + const std::unique_ptr cublas_algo_map_; + const std::unique_ptr cublas_wrapper_mutex_; + const std::unique_ptr cublas_wrapper_; + const std::unique_ptr cuda_device_prop_ptr_; + + std::unordered_map + convert_inputs(std::shared_ptr> input_tensors); + + void allocateBuffer(const size_t request_batch_size, + const size_t beam_width, + const size_t total_output_len, + const size_t max_request_output_len); + void freeBuffer(); + + int* d_input_ids_ = nullptr; + int* d_input_lengths_ = nullptr; + int* d_input_bad_words_ = nullptr; + int* d_input_stop_words_ = nullptr; + int* d_request_prompt_lengths_ = nullptr; + T* d_request_prompt_embedding_ = nullptr; + float* d_top_p_decay_ = nullptr; + float* d_top_p_min_ = nullptr; + int* d_top_p_reset_ids_ = nullptr; + + int* d_output_ids_ = nullptr; + int* d_sequence_lengths_ = nullptr; + float* d_output_log_probs_ = nullptr; + float* d_cum_log_probs_ = nullptr; + + uint32_t* h_total_output_lengths_ = nullptr; + std::exception_ptr h_exception_ = nullptr; +}; diff --git a/src/fastertransformer/triton_backend/transformer_triton_backend.hpp b/src/fastertransformer/triton_backend/transformer_triton_backend.hpp index 47cf6750c..edffabfd7 100644 --- a/src/fastertransformer/triton_backend/transformer_triton_backend.hpp +++ b/src/fastertransformer/triton_backend/transformer_triton_backend.hpp @@ -293,6 +293,7 @@ struct AbstractTransformerModel { static std::shared_ptr createGptNeoXModel(std::string inifile); static std::shared_ptr createT5Model(std::string model_dir); static std::shared_ptr createT5EncoderModel(std::string model_dir); + static std::shared_ptr createLlamaModel(std::string model_dir); std::pair, std::vector> createNcclParams(const int node_id, const int device_id_start = 0, const bool multi_node = false); From 1edcc80b829284cc76316577111ddd1f481f64ad Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Thu, 24 Aug 2023 13:47:04 -0700 Subject: [PATCH 02/10] commit --- src/fastertransformer/models/llama/Llama.cc | 1 - .../models/llama/LlamaContextDecoder.cc | 41 ------------------- .../models/llama/LlamaDecoder.cc | 7 ---- .../models/llama/LlamaDecoderLayerWeight.cc | 13 ------ 4 files changed, 62 deletions(-) diff --git a/src/fastertransformer/models/llama/Llama.cc b/src/fastertransformer/models/llama/Llama.cc index 9fcdd9169..c139aa9f8 100644 --- a/src/fastertransformer/models/llama/Llama.cc +++ b/src/fastertransformer/models/llama/Llama.cc @@ -755,7 +755,6 @@ void Llama::forward(std::unordered_map* output_ten gpt_context_decoder_->forward( &decoder_output_tensors, &decoder_input_tensors, &gpt_weights->decoder_layer_weights); sync_check_cuda_error(); - printf("gpt_context_decoder_->forward done\n"); invokeDecodingInitialize(finished_buf_, sequence_lengths_, nullptr, diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.cc b/src/fastertransformer/models/llama/LlamaContextDecoder.cc index 8082a2f13..25a506ace 100644 --- a/src/fastertransformer/models/llama/LlamaContextDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.cc @@ -332,7 +332,6 @@ void LlamaContextDecoder::forward(std::unordered_map* AttentionType attention_type = (d_prefix_prompt_lengths != nullptr) ? getUnfusedAttentionType(attention_type_) : attention_type_; - printf("attention_type: %d\n", attention_type); const bool is_unpadded_mha = isUnPaddedMHA(attention_type); for (int ite = 0; ite < iteration_num; ite++) { @@ -458,27 +457,6 @@ void LlamaContextDecoder::forward(std::unordered_map* &self_attention_input_tensors, &gpt_decoder_layer_weight->at(l)->self_attention_weights); - #ifdef ENABLE_FLEX_DEBUG - if (l == 0) { - printf("%d %d: %d %d\n", l, ite, h_token_num, hidden_units_); - T *self_attn_output = new T[h_token_num * hidden_units_]; - cudaMemcpy(self_attn_output, self_attn_output_, sizeof(T)*h_token_num * hidden_units_, cudaMemcpyDeviceToHost); - sync_check_cuda_error(); - int k = 0; - for (int i=0; i::forward(std::unordered_map* } sync_check_cuda_error(); - #ifdef ENABLE_FLEX_DEBUG - if (l == 1) { - printf("%d %d: %d %d\n", l, ite, h_token_num, hidden_units_); - T *self_attn_output = new T[h_token_num * hidden_units_]; - cudaMemcpy(self_attn_output, layer_output, sizeof(T)*h_token_num * hidden_units_, cudaMemcpyDeviceToHost); - sync_check_cuda_error(); - int k = 0; - for (int i=0; i 1) { int data_size = h_token_num * hidden_units_ / tensor_para_.world_size_; diff --git a/src/fastertransformer/models/llama/LlamaDecoder.cc b/src/fastertransformer/models/llama/LlamaDecoder.cc index c82de8568..497b954d4 100644 --- a/src/fastertransformer/models/llama/LlamaDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaDecoder.cc @@ -241,13 +241,6 @@ void LlamaDecoder::forward(std::unordered_map* for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { self_k_cache_size.push_back(*t); } - #ifdef ENABLE_FLEX_DEBUG - printf("self_k_cache_size: "); - for (int i=0; i self_v_cache_size; self_v_cache_size.push_back(local_batch_size); for (auto t = v_cache.shape.begin() + 2; t != v_cache.shape.end(); ++t) { diff --git a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc index daf890c87..34ad480cf 100644 --- a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc @@ -234,19 +234,6 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType {(size_t)hidden_units_, (size_t)(qkv_size / tensor_para_size_)}, dir_path + ".attention.query_key_value.weight." + rank_spec + ".bin", model_file_type); - // { - // printf("qkv_size: %d\n", qkv_size); - // printf("w2\n"); - // int sz = 100; - // T *qkv_buf = new T[sz]; - // cudaMemcpy(qkv_buf, weights_ptr[2], sizeof(T)*sz, cudaMemcpyDeviceToHost); - // sync_check_cuda_error(); - // for (int i=0; i(weights_ptr[4], {(size_t)(hidden_units_ / tensor_para_size_), (size_t)hidden_units_}, From fdda8511f45781476c2257c54b7129d9d2d06f94 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Thu, 24 Aug 2023 13:50:28 -0700 Subject: [PATCH 03/10] commit --- CMakeLists.txt | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 22c0f9c1f..8880526e4 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -333,6 +333,8 @@ add_library(transformer-shared SHARED $ $ $ + $ + $ $ $ $ @@ -353,6 +355,12 @@ add_library(transformer-shared SHARED $ $ $ + $ + $ + $ + $ + $ + $ $ $ $ @@ -361,6 +369,8 @@ add_library(transformer-shared SHARED $ $ $ + $ + $ $ $ $ @@ -394,6 +404,7 @@ add_library(transformer-shared SHARED $ $ $ + $ $ $ $ From c08edd33bfac23756625ded657e1b50e64012f79 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Thu, 24 Aug 2023 13:51:11 -0700 Subject: [PATCH 04/10] commit --- examples/cpp/llama/CMakeLists.txt | 22 + examples/cpp/llama/bad_words.csv | 2 + examples/cpp/llama/check_with_huggingface.py | 16 + .../cpp/llama/huggingface_llama_convert.py | 233 ++++++++ .../cpp/llama/huggingface_llama_convert2.py | 259 +++++++++ examples/cpp/llama/llama_config.ini | 34 ++ examples/cpp/llama/llama_example.cc | 533 ++++++++++++++++++ examples/cpp/llama/llama_triton_example.cc | 448 +++++++++++++++ examples/cpp/llama/model_config.json | 1 + examples/cpp/llama/start_ids.csv | 8 + examples/cpp/llama/stop_words.csv | 2 + 11 files changed, 1558 insertions(+) create mode 100644 examples/cpp/llama/CMakeLists.txt create mode 100644 examples/cpp/llama/bad_words.csv create mode 100644 examples/cpp/llama/check_with_huggingface.py create mode 100644 examples/cpp/llama/huggingface_llama_convert.py create mode 100644 examples/cpp/llama/huggingface_llama_convert2.py create mode 100644 examples/cpp/llama/llama_config.ini create mode 100644 examples/cpp/llama/llama_example.cc create mode 100644 examples/cpp/llama/llama_triton_example.cc create mode 100644 examples/cpp/llama/model_config.json create mode 100644 examples/cpp/llama/start_ids.csv create mode 100644 examples/cpp/llama/stop_words.csv diff --git a/examples/cpp/llama/CMakeLists.txt b/examples/cpp/llama/CMakeLists.txt new file mode 100644 index 000000000..cdf9033dd --- /dev/null +++ b/examples/cpp/llama/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable(llama_example llama_example.cc) +target_link_libraries(llama_example PUBLIC -lcublas -lcublasLt -lcudart + Llama nvtx_utils gpt_example_utils word_list mpi_utils nccl_utils) + +add_executable(llama_triton_example llama_triton_example.cc) +target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart -lpthread + LlamaTritonBackend TransformerTritonBackend custom_ar_comm + gpt_example_utils word_list mpi_utils nccl_utils nvtx_utils) diff --git a/examples/cpp/llama/bad_words.csv b/examples/cpp/llama/bad_words.csv new file mode 100644 index 000000000..6a1126ebd --- /dev/null +++ b/examples/cpp/llama/bad_words.csv @@ -0,0 +1,2 @@ +7768,3908 +1,2 diff --git a/examples/cpp/llama/check_with_huggingface.py b/examples/cpp/llama/check_with_huggingface.py new file mode 100644 index 000000000..d1f356cc1 --- /dev/null +++ b/examples/cpp/llama/check_with_huggingface.py @@ -0,0 +1,16 @@ +import transformers + +from transformers import LlamaForCausalLM, LlamaTokenizer + +tokenizer = LlamaTokenizer.from_pretrained('/data/llama-7b-hf') + +prompt = "Hey, are you consciours? Can you talk to me?" +inputs = tokenizer(prompt, return_tensors='pt') +model = LlamaForCausalLM.from_pretrained("/data/llama-7b-hf") +hf_config = vars(model.config) +print(hf_config) +generated_ids = model.forward(inputs.input_ids, output_hidden_states=True) +print(generated_ids) + +tokens = [0,18637,29892,526,366,1136,455,2470,29973,1815,366,5193,304,592,29973,18637,29892,526,366,1136,455,2470,29973,1815,366,5193,304,592,29973,18637,29892,526,366,1136,455,2470,29973,1815,366,5193,304,592,29973,18637,29892,526,366] +print(tokenizer.decode(tokens)) diff --git a/examples/cpp/llama/huggingface_llama_convert.py b/examples/cpp/llama/huggingface_llama_convert.py new file mode 100644 index 000000000..d771c0b2c --- /dev/null +++ b/examples/cpp/llama/huggingface_llama_convert.py @@ -0,0 +1,233 @@ +# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import configparser +import numpy as np +from pathlib import Path + +import torch +import os +from transformers import LlamaForCausalLM, AutoConfig + +def get_weight_data_type(data_type): + if data_type == "fp32": + return np.float32 + elif data_type == "fp16": + return np.float16 + else: + assert False, f"Invalid weight data type {data_type}" + + +def split_and_convert_process(saved_dir, factor, key, val): + if key.find("input_layernorm.weight") != -1 or key.find("post_attention_layernorm.weight") != -1: + # shared weights, only need to convert the weights of rank 0 + saved_path = saved_dir + "/" + key + ".bin" + val.tofile(saved_path) + elif key.find("attention.dense.weight") != -1 or key.find("mlp.down_proj.weight") != -1: + split_vals = np.split(val, factor, axis=0) + for j in range(factor): + saved_path = saved_dir + "/" + key + ".%d.bin" % j + split_vals[j].tofile(saved_path) + elif key.find("mlp.gate_proj.weight") != -1 or key.find("mlp.up_proj.weight") != -1: + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir + "/" + key + ".%d.bin" % j + split_vals[j].tofile(saved_path) + elif key.find("attention.query_key_value.weight") != -1: + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir + "/" + key + ".%d.bin" % j + split_vals[j].tofile(saved_path) + else: + print("[ERROR] cannot find key '{}'".format(key)) + +def split_and_convert(args): + saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num + + if(os.path.exists(saved_dir) == False): + os.makedirs(saved_dir) + + t_gpu_num = args.trained_gpu_num + i_gpu_num = args.infer_gpu_num + assert(i_gpu_num % t_gpu_num == 0) + + factor = (int)(i_gpu_num / t_gpu_num) + # load position_embedding from rank 0 + # model = torch.load(ckpt_name) + print(f'load model from {args.in_file}') + # model = LlamaForCausalLM.from_pretrained(args.in_file, device_map='auto') + config = AutoConfig.from_pretrained(args.in_file) + # num_layers = 3 + # config.num_hidden_layers = num_layers + print(config) + state_dict = {} + for f in os.listdir(args.in_file): + if not f.endswith('.bin'): + continue + w = torch.load(os.path.join(args.in_file, f), map_location='cpu') + keys = list(w.keys()) + for k in keys: + if 'model.layers.' not in k: + continue + l = int(k.split('.')[2]) + if l < config.num_hidden_layers: + continue + del w[k] + state_dict.update(w) + + model = LlamaForCausalLM.from_pretrained(None, config=config, state_dict=state_dict) + hf_config = vars(model.config) + print(f"hf_config: {hf_config}") + + print("named parameters:") + for name, param in model.named_parameters(): + print(f"- {name}") + + hidden_size = hf_config["hidden_size"] + head_num = hf_config["num_attention_heads"] + kv_head_num = hf_config["num_key_value_heads"] + head_size = hidden_size // head_num + # num_layers = hf_config["num_hidden_layers"] + + + np_weight_data_type = get_weight_data_type(args.weight_data_type) + + try: + model_name = args.model_name + config = configparser.ConfigParser() + config['llama'] = {} + config['llama']['model_name'] = model_name + config['llama']["head_num"] = str(head_num) + config['llama']["kv_head_num"] = str(kv_head_num) + config['llama']["size_per_head"] = str(head_size) + config['llama']["inter_size"] = str(hf_config["intermediate_size"]) + config['llama']["num_layer"] = str(num_layers) + config['llama']["rotary_embedding"] = str(head_size) + config['llama']['layernorm_eps'] = str(hf_config["rms_norm_eps"]) + config['llama']["vocab_size"] = str(hf_config["vocab_size"]) + config['llama']["start_id"] = str(hf_config["bos_token_id"]) + config['llama']["end_id"] = str(hf_config["eos_token_id"]) + config['llama']["weight_data_type"] = args.weight_data_type + + with open((Path(saved_dir) / f"config.ini").as_posix(), 'w') as configfile: + config.write(configfile) + except Exception as e: + print(f"Fail to save the config in config.ini.") + print(e) + + param_to_weights = lambda param: param.detach().cpu().numpy().astype(np_weight_data_type) + + # layer-wise weights, example: + # - model.layers.0.self_attn.q_proj.weight + # - model.layers.0.self_attn.k_proj.weight + # - model.layers.0.self_attn.v_proj.weight + # - model.layers.0.self_attn.o_proj.weight + # - model.layers.0.mlp.gate_proj.weight + # - model.layers.0.mlp.down_proj.weight + # - model.layers.0.mlp.up_proj.weight + # - model.layers.0.input_layernorm.weight + # - model.layers.0.post_attention_layernorm.weight + for l in range(num_layers): + print(f"converting layer {l}") + # first merge QKV into a single weight + # concat direct to FT shape: [hidden_size, 3, head_num, head_size] + # copied from huggingface_gptj_ckpt_convert.py + # qkv_weights = np.stack([ + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']), + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']), + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']), + # ]) + # qkv_weights = np.transpose(qkv_weights, (2, 0, 1)) + q_proj = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']) + k_proj = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']) + v_proj = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']) + q_proj = np.split(q_proj, factor, axis=0) + k_proj = np.split(k_proj, factor, axis=0) + v_proj = np.split(v_proj, factor, axis=0) + for j in range(factor): + qkv_weights = np.concatenate((q_proj[j], k_proj[j], v_proj[j]), axis=0) + print(qkv_weights.shape) + # qkv_weights = np.transpose(qkv_weights, (2, 0, 1)) + qkv_weights = np.transpose(qkv_weights) + qkv_weights_base_name = f'model.layers.{l}.attention.query_key_value.weight' + saved_path = saved_dir + "/" + qkv_weights_base_name + ".%d.bin" % j + qkv_weights.tofile(saved_path) + # qkv_weights = np.concatenate(( + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']), + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']), + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']), + # ), axis=0) + # print(qkv_weights.shape) + # # qkv_weights = np.transpose(qkv_weights, (2, 0, 1)) + # qkv_weights = np.transpose(qkv_weights) + # qkv_weights_base_name = f'model.layers.{l}.attention.query_key_value.weight' + # split_and_convert_process(saved_dir, factor, qkv_weights_base_name, qkv_weights) + + # attention dense + o_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.o_proj.weight']).T + o_weight_base_name = f'model.layers.{l}.attention.dense.weight' + split_and_convert_process(saved_dir, factor, o_weight_base_name, o_weight) + + # MLP + mlp_down_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.mlp.down_proj.weight']).T + mlp_down_base_name = f'model.layers.{l}.mlp.down_proj.weight' + split_and_convert_process(saved_dir, factor, mlp_down_base_name, mlp_down_weight) + + mlp_gate_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.mlp.gate_proj.weight']).T + mlp_gate_base_name = f'model.layers.{l}.mlp.gate_proj.weight' + split_and_convert_process(saved_dir, factor, mlp_gate_base_name, mlp_gate_weight) + + mlp_up_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.mlp.up_proj.weight']).T + mlp_up_base_name = f'model.layers.{l}.mlp.up_proj.weight' + split_and_convert_process(saved_dir, factor, mlp_up_base_name, mlp_up_weight) + + # LayerNorm + input_ln_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.input_layernorm.weight']) + input_ln_base_name = f'model.layers.{l}.input_layernorm.weight' + split_and_convert_process(saved_dir, factor, input_ln_base_name, input_ln_weight) + + post_attn_ln_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.post_attention_layernorm.weight']) + post_attn_ln_base_name = f'model.layers.{l}.post_attention_layernorm.weight' + split_and_convert_process(saved_dir, factor, post_attn_ln_base_name, post_attn_ln_weight) + + print(f"done layer {l}") + + + # final common weights + for name, param in model.named_parameters(): + if name == 'model.embed_tokens.weight': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.wte.weight.bin") + elif name == 'model.norm.weight': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.final_layernorm.weight.bin") + elif name == 'lm_head.weight': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.lm_head.weight.bin") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('-saved_dir', '-o', type=str, help='file name of output file', required=True) + parser.add_argument('-in_file', '-i', type=str, help='file name of input checkpoint file', required=True) + parser.add_argument('-trained_gpu_num', '-t_g', type=int, help='How many gpus for inference', default=1) + parser.add_argument('-infer_gpu_num', '-i_g', type=int, help='How many gpus for inference', required=True) + parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16", "bf16"]) + parser.add_argument('-model_name', '-m_n', type=str, help='model name', required=True) + + args = parser.parse_args() + print("\n=============== Argument ===============") + for key in vars(args): + print("{}: {}".format(key, vars(args)[key])) + print("========================================") + + split_and_convert(args) diff --git a/examples/cpp/llama/huggingface_llama_convert2.py b/examples/cpp/llama/huggingface_llama_convert2.py new file mode 100644 index 000000000..95c8a1af8 --- /dev/null +++ b/examples/cpp/llama/huggingface_llama_convert2.py @@ -0,0 +1,259 @@ +# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import configparser +import numpy as np +from pathlib import Path + +import torch +import os +from transformers import LlamaForCausalLM, AutoConfig +# using numpy extension: https://github.com/GreenWaves-Technologies/bfloat16 +# install the library with `pip install bfloat16` +from bfloat16 import bfloat16 + +def get_weight_data_type(data_type): + if data_type == "fp32": + return np.float32 + elif data_type == "fp16": + return np.float16 + elif data_type == "bf16": + return bfloat16 + else: + assert False, f"Invalid weight data type {data_type}" + + +def split_and_convert_process(saved_dir, factor, key, val): + if key.find("input_layernorm.weight") != -1 or key.find("post_attention_layernorm.weight") != -1: + # shared weights, only need to convert the weights of rank 0 + saved_path = saved_dir + "/" + key + ".bin" + val.tofile(saved_path) + elif key.find("attention.dense.weight") != -1 or key.find("mlp.down_proj.weight") != -1: + split_vals = np.split(val, factor, axis=0) + for j in range(factor): + saved_path = saved_dir + "/" + key + ".%d.bin" % j + split_vals[j].tofile(saved_path) + elif key.find("mlp.gate_proj.weight") != -1 or key.find("mlp.up_proj.weight") != -1: + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir + "/" + key + ".%d.bin" % j + split_vals[j].tofile(saved_path) + elif key.find("attention.query_key_value.weight") != -1: + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir + "/" + key + ".%d.bin" % j + split_vals[j].tofile(saved_path) + else: + print("[ERROR] cannot find key '{}'".format(key)) + +def split_and_convert(args): + saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num + + if(os.path.exists(saved_dir) == False): + os.makedirs(saved_dir) + + t_gpu_num = args.trained_gpu_num + i_gpu_num = args.infer_gpu_num + assert(i_gpu_num % t_gpu_num == 0) + + factor = (int)(i_gpu_num / t_gpu_num) + print(f'load model from {args.in_file}') + # model = LlamaForCausalLM.from_pretrained(args.in_file, device_map='auto') + config = AutoConfig.from_pretrained(args.in_file) + # num_layers = 3 + # config.num_hidden_layers = num_layers + + hf_config = vars(config) + print(f"hf_config: {hf_config}") + + hidden_size = hf_config["hidden_size"] + head_num = hf_config["num_attention_heads"] + kv_head_num = hf_config["num_key_value_heads"] + head_size = hidden_size // head_num + # num_layers = hf_config["num_hidden_layers"] + + + np_weight_data_type = get_weight_data_type(args.weight_data_type) + + try: + model_name = args.model_name + config = configparser.ConfigParser() + config['llama'] = {} + config['llama']['model_name'] = model_name + config['llama']["head_num"] = str(head_num) + config['llama']["kv_head_num"] = str(kv_head_num) + config['llama']["size_per_head"] = str(head_size) + config['llama']["inter_size"] = str(hf_config["intermediate_size"]) + config['llama']["num_layer"] = str(hf_config["num_hidden_layers"]) + config['llama']["rotary_embedding"] = str(head_size) + config['llama']['layernorm_eps'] = str(hf_config["rms_norm_eps"]) + config['llama']["vocab_size"] = str(hf_config["vocab_size"]) + config['llama']["start_id"] = str(hf_config["bos_token_id"]) + config['llama']["end_id"] = str(hf_config["eos_token_id"]) + config['llama']["weight_data_type"] = args.weight_data_type + + with open((Path(saved_dir) / f"config.ini").as_posix(), 'w') as configfile: + config.write(configfile) + except Exception as e: + print(f"Fail to save the config in config.ini.") + print(e) + + param_to_weights = lambda param: param.detach().cpu().float().numpy().astype(np_weight_data_type) + + def get_param(key, cache, loaded): + if key in cache: + return param_to_weights(cache[key]) + if key in loaded: + return param_to_weights(loaded[key]) + return None + + def clear_param(key, cache, loaded): + if key in cache: + del cache[key] + if key in loaded: + del loaded[key] + + def try_dump(key, cache, loaded, save_name, saved_dir, factor, transpose=True): + weight = get_param(key, cache, loaded) + if weight is None: + return + if transpose: + weight = weight.T + split_and_convert_process(saved_dir, factor, save_name, weight) + clear_param(key, state_dict, w) + # layer-wise weights, example: + # - model.layers.0.self_attn.q_proj.weight + # - model.layers.0.self_attn.k_proj.weight + # - model.layers.0.self_attn.v_proj.weight + # - model.layers.0.self_attn.o_proj.weight + # - model.layers.0.mlp.gate_proj.weight + # - model.layers.0.mlp.down_proj.weight + # - model.layers.0.mlp.up_proj.weight + # - model.layers.0.input_layernorm.weight + # - model.layers.0.post_attention_layernorm.weight + state_dict = {} + for f in os.listdir(args.in_file): + if not f.endswith('.bin'): + continue + f = os.path.join(args.in_file, f) + print(f'processing {f}') + w = torch.load(f, map_location='cpu') + for l in range(hf_config["num_hidden_layers"]): + # first merge QKV into a single weight + # concat direct to FT shape: [hidden_size, 3, head_num, head_size] + # copied from huggingface_gptj_ckpt_convert.py + # qkv_weights = np.stack([ + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']), + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']), + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']), + # ]) + # qkv_weights = np.transpose(qkv_weights, (2, 0, 1)) + q_key = f'model.layers.{l}.self_attn.q_proj.weight' + k_key = f'model.layers.{l}.self_attn.k_proj.weight' + v_key = f'model.layers.{l}.self_attn.v_proj.weight' + q_proj = get_param(q_key, state_dict, w) + k_proj = get_param(k_key, state_dict, w) + v_proj = get_param(v_key, state_dict, w) + + if q_proj is not None and k_proj is not None and v_proj is not None: + q_proj = np.split(q_proj, factor, axis=0) + k_proj = np.split(k_proj, factor, axis=0) + v_proj = np.split(v_proj, factor, axis=0) + for j in range(factor): + qkv_weights = np.concatenate((q_proj[j], k_proj[j], v_proj[j]), axis=0) + qkv_weights = np.transpose(qkv_weights) + qkv_weights_base_name = f'model.layers.{l}.attention.query_key_value.weight' + saved_path = saved_dir + "/" + qkv_weights_base_name + ".%d.bin" % j + qkv_weights.tofile(saved_path) + clear_param(q_key, state_dict, w) + clear_param(k_key, state_dict, w) + clear_param(v_key, state_dict, w) + + # attention dense + try_dump(key=f'model.layers.{l}.self_attn.o_proj.weight', + cache=state_dict, + loaded=w, + save_name=f'model.layers.{l}.attention.dense.weight', + saved_dir=saved_dir, + factor=factor) + + # MLP + try_dump(key=f'model.layers.{l}.mlp.down_proj.weight', + cache=state_dict, + loaded=w, + save_name=f'model.layers.{l}.mlp.down_proj.weight', + saved_dir=saved_dir, + factor=factor) + try_dump(key=f'model.layers.{l}.mlp.gate_proj.weight', + cache=state_dict, + loaded=w, + save_name=f'model.layers.{l}.mlp.gate_proj.weight', + saved_dir=saved_dir, + factor=factor) + try_dump(key=f'model.layers.{l}.mlp.up_proj.weight', + cache=state_dict, + loaded=w, + save_name=f'model.layers.{l}.mlp.up_proj.weight', + saved_dir=saved_dir, + factor=factor) + + # LayerNorm + try_dump(key=f'model.layers.{l}.input_layernorm.weight', + cache=state_dict, + loaded=w, + save_name=f'model.layers.{l}.input_layernorm.weight', + saved_dir=saved_dir, + factor=factor, + transpose=False) + try_dump(key=f'model.layers.{l}.post_attention_layernorm.weight', + cache=state_dict, + loaded=w, + save_name=f'model.layers.{l}.post_attention_layernorm.weight', + saved_dir=saved_dir, + factor=factor, + transpose=False) + to_del = [] + for name, param in w.items(): + if name == 'model.embed_tokens.weight': + param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.wte.weight.bin") + elif name == 'model.norm.weight': + param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.final_layernorm.weight.bin") + elif name == 'lm_head.weight': + param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.lm_head.weight.bin") + else: + continue + to_del.append(param) + # for k in to_del: + # del w[k] + print(w.keys()) + state_dict.update(w) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('-saved_dir', '-o', type=str, help='file name of output file', required=True) + parser.add_argument('-in_file', '-i', type=str, help='file name of input checkpoint file', required=True) + parser.add_argument('-trained_gpu_num', '-t_g', type=int, help='How many gpus for inference', default=1) + parser.add_argument('-infer_gpu_num', '-i_g', type=int, help='How many gpus for inference', required=True) + parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16", "bf16"]) + parser.add_argument('-model_name', '-m_n', type=str, help='model name', required=True) + + args = parser.parse_args() + print("\n=============== Argument ===============") + for key in vars(args): + print("{}: {}".format(key, vars(args)[key])) + print("========================================") + + split_and_convert(args) diff --git a/examples/cpp/llama/llama_config.ini b/examples/cpp/llama/llama_config.ini new file mode 100644 index 000000000..38501e362 --- /dev/null +++ b/examples/cpp/llama/llama_config.ini @@ -0,0 +1,34 @@ +[ft_instance_hyperparameter] +data_type=fp16 +enable_custom_all_reduce=0 + +tensor_para_size=4 +pipeline_para_size=1 + +model_name=llama_7b +model_dir=/model/llama2-70B-hf-ft/4-gpu + +[request] +beam_width=1 # beam width for beam search +top_k=1 ; k value for top k sampling +top_p=0.0 ; p value for top p sampling +temperature=1.0 ; Use for sampling +repetition_penalty=1.0 ; Use for sampling +presence_penalty=0.0 ; Only one of repetition_penalty and presence_penalty are allowed. +len_penalty=0.0 +beam_search_diversity_rate=0.0 +request_batch_size=8 # determine by the request +request_output_len=32 # determine by the request + +[llama_7b] +head_num = 64 +kv_head_num = 8 +size_per_head = 128 +inter_size = 28672 +num_layer = 3 +rotary_embedding = 128 +layernorm_eps = 1e-05 +vocab_size = 32000 +start_id = 1 +end_id = 2 +weight_data_type = fp16 diff --git a/examples/cpp/llama/llama_example.cc b/examples/cpp/llama/llama_example.cc new file mode 100644 index 000000000..dca47d169 --- /dev/null +++ b/examples/cpp/llama/llama_example.cc @@ -0,0 +1,533 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "3rdparty/INIReader.h" +#include "examples/cpp/multi_gpu_gpt/gpt_example_utils.h" +#include "src/fastertransformer/models/llama/Llama.h" +#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include "src/fastertransformer/utils/nvtx_utils.h" +#include "src/fastertransformer/utils/word_list.h" + +#include +#include +#include +#include +#include +#include + +using namespace fastertransformer; + +template +void llama_example(const INIReader reader); + +int main(int argc, char* argv[]) +{ + fastertransformer::mpi::initialize(&argc, &argv); + srand(0); + + std::string ini_name; + if (argc == 2) { + ini_name = std::string(argv[1]); + } + else { + ini_name = "/data/FasterTransformer/examples/cpp/llama/llama_config.ini"; + } + + INIReader reader = INIReader(ini_name); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; + return -1; + } + const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); + + if (data_type == "fp32") { + llama_example(reader); + } + else if (data_type == "fp16") { + llama_example(reader); + } +#ifdef ENABLE_BF16 + else if (data_type == "bf16") { + llama_example<__nv_bfloat16>(reader); + } +#endif + else { + FT_LOG_ERROR("is_fp16 should be 0 (use float) or 1 (use half)."); + return -1; + } + mpi::finalize(); + return 0; +} + +template +void llama_example(const INIReader reader) +{ + const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); + std::string model_dir = std::string(reader.Get("ft_instance_hyperparameter", "model_dir")); + + int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + int pipeline_para_size = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"); + int int8_mode = reader.GetInteger("ft_instance_hyperparameter", "int8_mode", 0); + + const size_t head_num = reader.GetInteger(model_name, "head_num"); + const size_t kv_head_num = reader.GetInteger(model_name, "kv_head_num"); + const size_t size_per_head = reader.GetInteger(model_name, "size_per_head"); + const size_t vocab_size = reader.GetInteger(model_name, "vocab_size"); + const size_t decoder_layers = reader.GetInteger(model_name, "num_layer"); + const size_t rotary_embedding_dim = reader.GetInteger(model_name, "rotary_embedding"); + const float layernorm_eps = reader.GetFloat(model_name, "layernorm_eps"); + const int start_id = reader.GetInteger(model_name, "start_id"); + const int end_id = reader.GetInteger(model_name, "end_id"); + + const size_t hidden_units = head_num * size_per_head; + const size_t inter_size = reader.GetInteger(model_name, "inter_size"); + + const size_t beam_width = reader.GetInteger("request", "beam_width"); + const uint top_k = (uint)reader.GetInteger("request", "top_k"); + const float top_p = reader.GetFloat("request", "top_p"); + const float temperature = reader.GetFloat("request", "temperature"); + const float repetition_penalty = reader.GetFloat("request", "repetition_penalty", 1.0f); + const float presence_penalty = reader.GetFloat("request", "presence_penalty", 0.0f); + const float len_penalty = reader.GetFloat("request", "len_penalty"); + const float beam_search_diversity_rate = reader.GetFloat("request", "beam_search_diversity_rate"); + const int min_length = reader.GetInteger("request", "min_length", 0); + const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); + // The length of tokens we hope this model to generate + const int request_output_len = reader.GetInteger("request", "request_output_len"); + + FT_CHECK(head_num % tensor_para_size == 0); + FT_CHECK(decoder_layers % pipeline_para_size == 0); + FT_CHECK_WITH_INFO( + repetition_penalty == 1.0f || presence_penalty == 0.0f, + fmtstr("Found ambiguous parameters repetition_penalty (%f) and presence_penalty (%f) " + "which are mutually exclusive. Please remove one of repetition_penalty or presence_penalty " + "or set to a default value.", + repetition_penalty, + presence_penalty)); + + // Prepare the parallelism parameters + int rank = mpi::getCommWorldRank(); + int world_size = mpi::getCommWorldSize(); + // world_size = 4; + if (rank == 0) { + printf("Total ranks: %d.\n", world_size); + } + int device, device_count; + check_cuda_error(cudaGetDeviceCount(&device_count)); + check_cuda_error(cudaSetDevice(rank % device_count)); + check_cuda_error(cudaGetDevice(&device)); + + struct cudaDeviceProp prop; + check_cuda_error(cudaGetDeviceProperties(&prop, device)); + printf("Device %s\n", prop.name); + + printf("P%d is running with GPU #%d.\n", rank, device); + if (tensor_para_size * pipeline_para_size != world_size) { + if (world_size % pipeline_para_size) { + printf("[ERROR] tensor_para_size * pipeline_para_size should equal to world_size \n"); + exit(-1); + } + tensor_para_size = world_size / pipeline_para_size; + printf("[INFO] Setting tensor_para_size to %d \n", tensor_para_size); + } + + const int layers_per_group = decoder_layers / pipeline_para_size; + if (layers_per_group * pipeline_para_size != (int)decoder_layers) { + printf("[ERROR] layers_per_group (%d) * pipeline_para_size (%d) should equal to decoder_layers (%ld) \n", + layers_per_group, + pipeline_para_size, + decoder_layers); + exit(-1); + } + + // assume gpu_num = k * n, + // tensor parallelism group size is n + // pipeline parallelism group size is k + NcclParam tensor_para; + NcclParam pipeline_para; + ftNcclInitialize(tensor_para, pipeline_para, tensor_para_size, pipeline_para_size); + + // Handle bad_words dictionary + std::vector bad_words; + read_word_list("/data/FasterTransformer/examples/cpp/llama/bad_words.csv", bad_words); + + int* d_bad_words = nullptr; + deviceMalloc(&d_bad_words, bad_words.size(), false); + cudaH2Dcpy(d_bad_words, bad_words.data(), bad_words.size()); + + // Handle stop_words dictionary + std::vector stop_words; + read_word_list("/data/FasterTransformer/examples/cpp/llama/stop_words.csv", stop_words); + + const size_t stop_words_len = stop_words.size() / 2; + // Tile with same dict for each element + std::vector tiled_stop_words; + for (int i = 0; i < request_batch_size; i++) { + tiled_stop_words.insert(tiled_stop_words.end(), stop_words.begin(), stop_words.end()); + } + + + int* d_stop_words = nullptr; + deviceMalloc(&d_stop_words, tiled_stop_words.size(), false); + cudaH2Dcpy(d_stop_words, tiled_stop_words.data(), tiled_stop_words.size()); + + // Read ids of request from file. + size_t max_input_len = -1; + std::vector v_start_lengths; + std::vector v_start_ids; + read_start_ids(request_batch_size, + &v_start_lengths, + &v_start_ids, + max_input_len, + end_id, + 1, + "/data/FasterTransformer/examples/cpp/llama/start_ids.csv"); + + + int* d_input_ids; + int* d_input_lengths; + if (max_input_len == 0) { + // unconditional case, no input ids, so do nothing. + d_input_ids = nullptr; + d_input_lengths = nullptr; + } + else { + // conditional case. + deviceMalloc(&d_input_ids, request_batch_size * max_input_len, false); + deviceMalloc(&d_input_lengths, request_batch_size, false); + cudaH2Dcpy(d_input_ids, v_start_ids.data(), request_batch_size * max_input_len); + cudaH2Dcpy(d_input_lengths, v_start_lengths.data(), request_batch_size); + } + std::vector start_ids(request_batch_size, start_id); + std::vector end_ids(request_batch_size, end_id); + + // Prompt Learning Configurations + // NOTE: if you don't need prefix prompts, remember to set max_prefix_len to 0 and others to nullptr + int prompt_learning_start_id = reader.GetInteger(model_name, "prompt_learning_start_id", end_id + 1); + fastertransformer::PromptLearningType prompt_learning_type = + static_cast(reader.GetInteger(model_name, "prompt_learning_type", 0)); + + // NOTE: specify task names, take name id, prompt length in order to load those prompt learning tables. + // NOTE: Please make sure task ids are continuous and start from 0 + // for example: + // std::map> prefix_prompt_table_pair{{"no_prompt", {0, 0}}, + // {"prompt_1", {1, 1}}, + // {"prompt_2", {2, 2}}, + // {"prompt_3", {3, 3}}, + // {"prompt_4", {4, 4}}, + // {"prompt_5", {5, 5}}}; + + std::map> prefix_prompt_table_pair; + + // NOTE: get prompt table pairs from configuration files + const int num_tasks = reader.GetInteger(model_name, "num_tasks", 0); + for (int task_name_id = 0; task_name_id < num_tasks; task_name_id++) { + std::string config_task_name = model_name + "_task_" + std::to_string(task_name_id); + std::string task_name = reader.Get(config_task_name, "task_name"); + const int prompt_length = reader.GetInteger(config_task_name, "prompt_length", 0); + prefix_prompt_table_pair.insert({task_name, {task_name_id, prompt_length}}); + } + + // NOTE: task_name_ids for each sequence in one batch + // Each sequence can have different prompt learning task ids + std::vector prefix_prompt_task_ids(request_batch_size, 0); + + // Set different task ids + for (int i = 0; i < request_batch_size; i++) { + prefix_prompt_task_ids[i] = (num_tasks > 0) ? i % num_tasks : 0; + } + + const int total_output_len = max_input_len + request_output_len; + + cudaStream_t stream; + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + cudaStreamCreate(&stream); + cublasCreate(&cublas_handle); + cublasLtCreate(&cublaslt_handle); + cublasSetStream(cublas_handle, stream); + cublasAlgoMap* cublas_algo_map = new cublasAlgoMap("gemm_config.in"); + + Allocator allocator(getDevice()); + + std::mutex* cublas_wrapper_mutex = new std::mutex(); + cublasMMWrapper cublas_wrapper = + cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, &allocator); + if (std::is_same::value) { + cublas_wrapper.setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); + } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif + else if (std::is_same::value) { + cublas_wrapper.setFP32GemmConfig(); + } + + const bool use_gptj_residual = false; + printf("kv_head_num: %d\n", kv_head_num); + fastertransformer::LlamaWeight gpt_weights(head_num, + kv_head_num, + size_per_head, + inter_size, + vocab_size, + decoder_layers, + 0, // max_seq_len, deprecated + tensor_para.world_size_, + tensor_para.rank_, + pipeline_para.world_size_, + pipeline_para.rank_, + use_gptj_residual, + int8_mode, + prompt_learning_type, + prefix_prompt_table_pair); + + gpt_weights.loadModel(model_dir); + unsigned long long random_seed; + if (rank == 0) { + random_seed = (unsigned long long)(0); + } + if (world_size > 1) { + mpi::bcast(&random_seed, 1, mpi::MPI_TYPE_UNSIGNED_LONG_LONG, 0, mpi::COMM_WORLD); + } + + AttentionType attention_type = getAttentionType(size_per_head, + getSMVersion(), + true, // remove_padding + 0, // gpt supports any-seq-length fmha + true, // is_fuse + false, // with_relative_position_bias + true); // causal_mask + + Llama gpt = Llama(head_num, + kv_head_num, + size_per_head, + inter_size, + decoder_layers, + vocab_size, + rotary_embedding_dim, + layernorm_eps, + start_id, + end_id, + prompt_learning_start_id, + prompt_learning_type, + use_gptj_residual, + 0.0f, + top_k, + top_p, + random_seed, + temperature, + len_penalty, + repetition_penalty, + tensor_para, + pipeline_para, + stream, + &cublas_wrapper, + &allocator, + false, + &prop, + attention_type, + int8_mode, + nullptr, + 0, + 1.0f); + + int* d_output_ids; + int* d_sequence_lengths; + + + deviceMalloc(&d_output_ids, request_batch_size * beam_width * total_output_len, false); + deviceMalloc(&d_sequence_lengths, request_batch_size * beam_width, false); + + std::vector output_seq_len(request_batch_size, total_output_len); + std::unordered_map input_tensors = std::unordered_map{ + {"input_ids", + Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, (size_t)max_input_len}, d_input_ids}}, + {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size}, d_input_lengths}}, + // NOTE: if you need prefix prompts, remember to add prefix_prompt_task_ids here + // {"prompt_learning_task_name_ids", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, + // prefix_prompt_task_ids.data()}}, + {"output_seq_len", + Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{request_batch_size}, output_seq_len.data()}}, + {"bad_words_list", Tensor{MEMORY_GPU, TYPE_INT32, {2, bad_words.size() / 2}, d_bad_words}}, + {"stop_words_list", Tensor{MEMORY_GPU, TYPE_INT32, {request_batch_size, 2, stop_words_len}, d_stop_words}}, + {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &temperature}}, + {"len_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &len_penalty}}, + {"min_length", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{1}, &min_length}}, + {"start_id", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, start_ids.data()}}, + {"end_id", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, end_ids.data()}}}; + + if (repetition_penalty != 1.0f) { + input_tensors.insert( + {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &repetition_penalty}}); + } + if (presence_penalty != 0.0f) { + input_tensors.insert( + {"presence_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &presence_penalty}}); + } + + if (num_tasks > 0) { + // Prefix Prompt Task Name Ids here + input_tensors.insert( + {"prompt_learning_task_name_ids", + Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, prefix_prompt_task_ids.data()}}); + } + + if (top_k == 0 && top_p == 0.0f) { + FT_CHECK(beam_width > 1); + input_tensors.insert({"beam_search_diversity_rate", + Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &beam_search_diversity_rate}}); + } + else { + input_tensors.insert({"random_seed", Tensor{MEMORY_CPU, TYPE_UINT64, std::vector{1}, &random_seed}}); + if (top_p != 0.0f) { + input_tensors.insert({"runtime_top_p", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &top_p}}); + } + if (top_k != 0) { + input_tensors.insert({"runtime_top_k", Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{1}, &top_k}}); + } + } + + std::unordered_map output_tensors = std::unordered_map{ + {"output_ids", + Tensor{MEMORY_GPU, + TYPE_INT32, + std::vector{request_batch_size, beam_width, (size_t)total_output_len}, + d_output_ids}}, + {"sequence_length", + Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, beam_width}, d_sequence_lengths}}, + {"output_log_probs", + Tensor{MEMORY_GPU, + TYPE_FP32, + std::vector{(size_t)request_output_len, request_batch_size, beam_width}, + nullptr}}}; + + print_mem_usage(); + + int ite = 1; + cudaDeviceSynchronize(); + mpi::barrier(); + + cudaProfilerStart(); + // warm up + ite = 1; + ft_nvtx::setScope("warmup_time"); + PUSH_RANGE("warmup time") + + for (int i = 0; i < ite; ++i) { + gpt.forward(&output_tensors, &input_tensors, &gpt_weights); + } + + cudaDeviceSynchronize(); + mpi::barrier(); + + POP_RANGE; + ft_nvtx::resetScope(); + + + if (rank == 0) { + + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); + if (!outFile.is_open()) { + printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); + } + else { + size_t outCount = total_output_len * request_batch_size * beam_width; + int* hBuf = new int[outCount]; + + cudaD2Hcpy(hBuf, d_output_ids, outCount); + + { + std::cout << "Writing " << outCount << " elements\n"; + int zeroCount = 0; + for (size_t i = 0; i < outCount; i++) { + if (hBuf[i] == int(0)) { + zeroCount++; + } + outFile << hBuf[i] << " "; + if ((i + 1) % (total_output_len) == 0) { + outFile << std::endl; + } + printf("%5d ", hBuf[i]); + // if (i < 10) { + // printf("%5d ", hBuf[i]); + // } + if ((i + 1) % (total_output_len) == 0) { + std::cout << std::endl; + } + } + std::cout << std::endl << "zeroCount = " << zeroCount << std::endl; + } + delete[] hBuf; + } + } + return; + // test time + struct timeval start, end; + mpi::barrier(); + cudaDeviceSynchronize(); + gettimeofday(&start, NULL); + + ft_nvtx::setScope("total_time"); + PUSH_RANGE("total time") + for (int i = 0; i < ite; ++i) { + gpt.forward(&output_tensors, &input_tensors, &gpt_weights); + } + cudaDeviceSynchronize(); + mpi::barrier(); + + POP_RANGE; + ft_nvtx::resetScope(); + gettimeofday(&end, NULL); + + cudaProfilerStop(); + + printf("[INFO] request_batch_size %ld beam_width %ld head_num %ld size_per_head %ld total_output_len %d" + " decoder_layers %ld vocab_size %ld FT-CPP-decoding-beamsearch-time %.2f ms\n", + request_batch_size, + beam_width, + head_num, + size_per_head, + total_output_len, + decoder_layers, + vocab_size, + ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); + + ftNcclParamDestroy(tensor_para); + ftNcclParamDestroy(pipeline_para); + + delete cublas_algo_map; + delete cublas_wrapper_mutex; + + cudaFree(d_bad_words); + cudaFree(d_stop_words); + if (d_input_ids != nullptr) { + cudaFree(d_input_ids); + } + if (d_input_lengths != nullptr) { + cudaFree(d_input_lengths); + } + if (d_output_ids != nullptr) { + deviceFree(d_output_ids); + } + if (d_sequence_lengths != nullptr) { + deviceFree(d_sequence_lengths); + } + return; +} diff --git a/examples/cpp/llama/llama_triton_example.cc b/examples/cpp/llama/llama_triton_example.cc new file mode 100644 index 000000000..25d769e1b --- /dev/null +++ b/examples/cpp/llama/llama_triton_example.cc @@ -0,0 +1,448 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "3rdparty/INIReader.h" +#include "examples/cpp/multi_gpu_gpt/gpt_example_utils.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include "src/fastertransformer/utils/nvtx_utils.h" +#include "src/fastertransformer/utils/word_list.h" + +#include +#include + +namespace ft = fastertransformer; + +struct RequestParam { + int beam_width; + int request_output_len; + float beam_search_diversity_rate; + uint runtime_top_k; + float runtime_top_p; + float temperature; + float len_penalty; + float repetition_penalty; + float presence_penalty; + int min_length; + unsigned long long int random_seed; + int start_id; + int end_id; +}; + +std::vector>> +broadCastRequest(const std::vector& v_start_ids, + const std::vector& v_start_lengths, + const std::vector& v_bad_words, + const int node_id, + const int gpu_count, + const RequestParam param, + std::vector* pointer_record) +{ + // broadcast the request to all nodes, and copy "gpu_count" copies on different gpu + int size_1 = v_start_ids.size(); + int size_2 = v_start_lengths.size(); + int size_bad_words = v_bad_words.size(); + ft::mpi::bcast(&size_1, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_2, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_bad_words, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + + std::vector v_input_ids(size_1); + std::vector v_input_lengths(size_2); + std::vector v_input_bad_words(size_bad_words); + + if (node_id == 0) { + memcpy(v_input_ids.data(), v_start_ids.data(), size_1 * sizeof(int)); + memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int)); + memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int)); + } + ft::mpi::barrier(); + + int request_batch_size = size_2; + int max_input_len = size_1 / size_2; + + ft::mpi::bcast(v_input_ids.data(), size_1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_lengths.data(), size_2, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_bad_words.data(), size_bad_words, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + + std::vector>> request_list; + for (int device_id = 0; device_id < gpu_count; device_id++) { + ft::check_cuda_error(cudaSetDevice(device_id)); + + int* d_input_ids; + int* d_input_lengths; + int* d_input_bad_words; + + if (max_input_len == 0) { + // unconditional case, no input ids, so do nothing. + d_input_ids = nullptr; + d_input_lengths = nullptr; + max_input_len = 0; + } + else { + // conditional case. + ft::deviceMalloc(&d_input_ids, size_1, false); + ft::deviceMalloc(&d_input_lengths, size_2, false); + ft::cudaH2Dcpy(d_input_ids, v_input_ids.data(), size_1); + ft::cudaH2Dcpy(d_input_lengths, v_input_lengths.data(), size_2); + } + ft::deviceMalloc(&d_input_bad_words, size_bad_words, false); + ft::cudaH2Dcpy(d_input_bad_words, v_input_bad_words.data(), size_bad_words); + + uint32_t* request_output_len_ptr = (uint32_t*)malloc(request_batch_size * sizeof(uint32_t)); + for (int i = 0; i < request_batch_size; i++) { + request_output_len_ptr[i] = param.request_output_len; + } + + int* start_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); + int* end_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); + for (int i = 0; i < request_batch_size; i++) { + start_ids_ptr[i] = param.start_id; + end_ids_ptr[i] = param.end_id; + } + pointer_record->push_back(start_ids_ptr); + pointer_record->push_back(end_ids_ptr); + + request_list.push_back(std::shared_ptr>( + new std::unordered_map{ + {"input_ids", + triton::Tensor{triton::MEMORY_GPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size, (size_t)max_input_len}, + d_input_ids}}, + {"input_lengths", + triton::Tensor{triton::MEMORY_GPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size}, + d_input_lengths}}, + {"request_output_len", + triton::Tensor{triton::MEMORY_CPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size}, + request_output_len_ptr}}, + {"bad_words_list", + triton::Tensor{ + triton::MEMORY_GPU, triton::TYPE_INT32, {2, v_input_bad_words.size() / 2}, d_input_bad_words}}, + {"start_id", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, {(size_t)request_batch_size}, start_ids_ptr}}, + {"end_id", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, {(size_t)request_batch_size}, end_ids_ptr}}})); + + int* beam_width_ptr = new int(param.beam_width); + pointer_record->push_back(beam_width_ptr); + request_list[device_id]->insert( + {"beam_width", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, std::vector{1}, beam_width_ptr}}); + if (param.beam_width > 1) { + float* beam_search_diversity_rate_ptr = new float(param.beam_search_diversity_rate); + pointer_record->push_back(beam_search_diversity_rate_ptr); + request_list[device_id]->insert( + {"beam_search_diversity_rate", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, beam_search_diversity_rate_ptr}}); + } + else { + if (param.runtime_top_p != 0.0f) { + float* runtime_top_p_ptr = new float(param.runtime_top_p); + pointer_record->push_back(runtime_top_p_ptr); + request_list[device_id]->insert( + {"runtime_top_p", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, runtime_top_p_ptr}}); + } + if (param.runtime_top_k != 0) { + uint* runtime_top_k_ptr = new uint(param.runtime_top_k); + pointer_record->push_back(runtime_top_k_ptr); + request_list[device_id]->insert( + {"runtime_top_k", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_UINT32, std::vector{1}, runtime_top_k_ptr}}); + } + } + float* temperature_ptr = new float(param.temperature); + pointer_record->push_back(temperature_ptr); + request_list[device_id]->insert( + {"temperature", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, temperature_ptr}}); + float* len_penalty_ptr = new float(param.len_penalty); + pointer_record->push_back(len_penalty_ptr); + request_list[device_id]->insert( + {"len_penalty", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, len_penalty_ptr}}); + if (param.repetition_penalty != 1.0f) { + float* repetition_penalty_ptr = new float(param.repetition_penalty); + pointer_record->push_back(repetition_penalty_ptr); + request_list[device_id]->insert( + {"repetition_penalty", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, repetition_penalty_ptr}}); + } + if (param.presence_penalty != 0.0f) { + float* presence_penalty_ptr = new float(param.presence_penalty); + pointer_record->push_back(presence_penalty_ptr); + request_list[device_id]->insert( + {"presence_penalty", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, presence_penalty_ptr}}); + } + int* min_length_ptr = new int(param.min_length); + pointer_record->push_back(min_length_ptr); + request_list[device_id]->insert( + {"min_length", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, std::vector{1}, min_length_ptr}}); + unsigned long long int* random_seed_ptr = new unsigned long long int(param.random_seed); + pointer_record->push_back(random_seed_ptr); + request_list[device_id]->insert( + {"random_seed", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_UINT64, std::vector{1}, random_seed_ptr}}); + + pointer_record->push_back(d_input_ids); + pointer_record->push_back(d_input_lengths); + pointer_record->push_back(d_input_bad_words); + pointer_record->push_back(request_output_len_ptr); + } + + return request_list; +} + +std::vector>> +prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std::vector* pointer_record) +{ + INIReader reader = INIReader(ini_name); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; + ft::FT_CHECK(false); + } + + const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); + + const int start_id = reader.GetInteger("llama_7b", "start_id"); + const int end_id = reader.GetInteger("llama_7b", "end_id"); + + std::vector v_start_ids; + std::vector v_start_lengths; + + size_t max_input_len = 0; + ft::read_start_ids(request_batch_size, + &v_start_lengths, + &v_start_ids, + max_input_len, + end_id, + 1, + "../examples/cpp/llama/start_ids.csv"); + + std::vector v_bad_words; + ft::read_word_list("../examples/cpp/llama/bad_words.csv", v_bad_words); + + RequestParam param; + param.beam_width = reader.GetInteger("request", "beam_width"); + param.request_output_len = reader.GetInteger("request", "request_output_len"); + param.beam_search_diversity_rate = reader.GetFloat("request", "beam_search_diversity_rate"); + param.runtime_top_k = reader.GetInteger("request", "top_k"); + param.runtime_top_p = reader.GetFloat("request", "top_p"); + param.temperature = reader.GetFloat("request", "temperature"); + param.len_penalty = reader.GetFloat("request", "len_penalty"); + param.repetition_penalty = reader.GetFloat("request", "repetition_penalty", 1.0f); + param.presence_penalty = reader.GetFloat("request", "presence_penalty", 0.0f); + param.min_length = reader.GetInteger("request", "min_length", 0); + param.random_seed = (unsigned long long int)0; + param.start_id = start_id; + param.end_id = end_id; + + auto request_list = + broadCastRequest(v_start_ids, v_start_lengths, v_bad_words, node_id, gpu_count, param, pointer_record); + return request_list; +} + +int threadCreateModelInstances(std::shared_ptr model, + std::vector>* model_instances, + const int device_id, + const int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) +{ + printf("[INFO] rank = %d \n", rank); + ft::check_cuda_error(cudaSetDevice(device_id)); + cudaStream_t stream; + ft::check_cuda_error(cudaStreamCreate(&stream)); + model->createSharedWeights(device_id, rank); + auto model_instance = model->createModelInstance(device_id, rank, stream, nccl_params, custom_all_reduce_comm); + model_instances->at(device_id) = std::move(model_instance); + printf("model instance %d is created \n", device_id); + ft::print_mem_usage(); + return 0; +} + +int threadForward(std::unique_ptr* model_instance, + std::shared_ptr> request, + std::shared_ptr>* output_tensors, + const int device_id) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + *output_tensors = (*model_instance)->forward(request); + return 0; +} + +int main(int argc, char* argv[]) +{ + /* + Prepare the nccl ids, node id, device id and world size + by MPI or triton + */ + + // MPICHECK(MPI_Init(&argc, &argv)); + ft::mpi::initialize(&argc, &argv); + int node_id = ft::mpi::getCommWorldRank(); + int node_num = ft::mpi::getCommWorldSize(); + std::cout << "node_id: " << node_id << ", node_num: " << node_num << std::endl; + + // Note: Only supports that all nodes have same gpu count + const int gpu_count = ft::getDeviceCount(); + const int world_size = node_num * gpu_count; + std::string ini_name = argc >= 2 ? std::string(argv[1]) : "/data/FasterTransformer/examples/cpp/llama/llama_config.ini"; + + // step 1: Create model + std::shared_ptr model = AbstractTransformerModel::createLlamaModel(ini_name); + int tensor_para_size = model->getTensorParaSize(); + int pipeline_para_size = model->getPipelineParaSize(); + FT_CHECK_WITH_INFO(world_size == (tensor_para_size * pipeline_para_size), + "World Size != Tensor Parallel Size * Pipeline Parallel Size !"); + + std::cout << model->toString(); + + // step 2: Initialize the NCCL + std::pair, std::vector> nccl_comms = model->createNcclParams(node_id); + cudaDeviceSynchronize(); + + // Optional Step: create custom all reduce comm + std::vector> custom_all_reduce_comms; + model->createCustomComms(&custom_all_reduce_comms, world_size); + + // step 3: Create model instances + std::vector> model_instances((size_t)gpu_count); + std::vector threads; + for (int device_id = 0; device_id < gpu_count; device_id++) { + const int rank = node_id * gpu_count + device_id; + threads.push_back(std::thread(threadCreateModelInstances, + model, + &model_instances, + device_id, + rank, + nccl_comms, + custom_all_reduce_comms[rank])); + } + for (auto& t : threads) { + t.join(); + } + + // step 4: prepare request + std::vector pointer_record; // Used to prevent the pointers are release after leaving functions + std::vector>> request_list = + prepareRequest(ini_name, node_id, gpu_count, &pointer_record); + printf("[INFO] request is created \n"); + + // step 5: Forward + std::vector>> output_tensors_lists( + (size_t)gpu_count); + for (int i = 0; i < 2; i++) { + threads.clear(); + for (int device_id = 0; device_id < gpu_count; device_id++) { + threads.push_back(std::thread(threadForward, + &model_instances[device_id], + request_list[device_id], + &output_tensors_lists[device_id], + device_id)); + } + for (auto& t : threads) { + t.join(); + } + } + printf("[INFO] forward is completed. \n"); + + const int* d_output_ids = (const int*)output_tensors_lists[0].get()->at("output_ids").data; + const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0]; + const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1]; + const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2]; + // step 6: check results + if (node_id == 0) { + + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); + if (!outFile.is_open()) { + printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); + } + else { + size_t outCount = batch_size * beam_width * seq_len; + int* hBuf = new int[outCount]; + ft::cudaD2Hcpy(hBuf, d_output_ids, outCount); + + { + std::cout << "Writing " << outCount << " elements\n"; + int zeroCount = 0; + for (size_t i = 0; i < outCount; i++) { + if (hBuf[i] == int(0)) + zeroCount++; + outFile << hBuf[i] << " "; + if ((i + 1) % (seq_len) == 0) + outFile << std::endl; + + // if (i < 10) + printf("%5d ", hBuf[i]); + // if ((i + 1) % (seq_len) == 0 && i < 10) + // std::cout << std::endl; + } + std::cout << std::endl << "zeroCount = " << zeroCount << std::endl; + } + delete[] hBuf; + } + } + + // test time + struct timeval start, end; + ft::mpi::barrier(); + cudaDeviceSynchronize(); + gettimeofday(&start, NULL); + + const int ite = 1; + for (int i = 0; i < ite; i++) { + threads.clear(); + for (int device_id = 0; device_id < gpu_count; device_id++) { + threads.push_back(std::thread(threadForward, + &model_instances[device_id], + request_list[device_id], + &output_tensors_lists[device_id], + device_id)); + } + for (auto& t : threads) { + t.join(); + } + } + + cudaDeviceSynchronize(); + ft::mpi::barrier(); + + gettimeofday(&end, NULL); + + printf("[INFO] batch_size %d beam_width %d seq_len %d" + " FT-CPP-GPT-Triton-time %.2f ms\n", + batch_size, + beam_width, + seq_len, + ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); + + ft::mpi::finalize(); + return 0; +} diff --git a/examples/cpp/llama/model_config.json b/examples/cpp/llama/model_config.json new file mode 100644 index 000000000..70266f26b --- /dev/null +++ b/examples/cpp/llama/model_config.json @@ -0,0 +1 @@ +{"vocab_size": 32000, "max_position_embeddings": 2048, "hidden_size": 4096, "intermediate_size": 11008, "num_hidden_layers": 32, "num_attention_heads": 32, "hidden_act": "silu", "initializer_range": 0.02, "rms_norm_eps": 1e-06, "use_cache": True, "return_dict": True, "output_hidden_states": False, "output_attentions": False, "torchscript": False, "torch_dtype": torch.float16, "use_bfloat16": False, "tf_legacy_loss": False, "pruned_heads": {}, "tie_word_embeddings": False, "is_encoder_decoder": False, "is_decoder": False, "cross_attention_hidden_size": None, "add_cross_attention": False, "tie_encoder_decoder": False, "max_length": 20, "min_length": 0, "do_sample": False, "early_stopping": False, "num_beams": 1, "num_beam_groups": 1, "diversity_penalty": 0.0, "temperature": 1.0, "top_k": 50, "top_p": 1.0, "typical_p": 1.0, "repetition_penalty": 1.0, "length_penalty": 1.0, "no_repeat_ngram_size": 0, "encoder_no_repeat_ngram_size": 0, "bad_words_ids": None, "num_return_sequences": 1, "chunk_size_feed_forward": 0, "output_scores": False, "return_dict_in_generate": False, "forced_bos_token_id": None, "forced_eos_token_id": None, "remove_invalid_values": False, "exponential_decay_length_penalty": None, "suppress_tokens": None, "begin_suppress_tokens": None, "architectures": ["LLaMAForCausalLM"], "finetuning_task": None, "id2label": {0: "LABEL_0", 1: "LABEL_1"}, "label2id": {"LABEL_0": 0, "LABEL_1": 1}, "tokenizer_class": None, "prefix": None, "bos_token_id": 0, "pad_token_id": -1, "eos_token_id": 1, "sep_token_id": None, "decoder_start_token_id": None, "task_specific_params": None, "problem_type": None, "_name_or_path": "/data/llama-7b-hf/", "_commit_hash": None, "transformers_version": "4.27.0.dev0", "max_sequence_length": 2048, "model_type": "llama"} diff --git a/examples/cpp/llama/start_ids.csv b/examples/cpp/llama/start_ids.csv new file mode 100644 index 000000000..d1ed9fb33 --- /dev/null +++ b/examples/cpp/llama/start_ids.csv @@ -0,0 +1,8 @@ +1, 18637, 29892, 526, 366, 1136, 455, 2470, 29973, 1815, 366, 5193, 304, 592, 29973 +1, 18637, 29892, 526, 366, 1136, 455, 2470, 29973, 1815, 366, 5193, 304, 592, 29973 +1, 18637, 29892, 526, 366, 1136, 455, 2470, 29973, 1815, 366, 5193, 304, 592, 29973 +1, 18637, 29892, 526, 366, 1136, 455, 2470, 29973, 1815, 366, 5193, 304, 592, 29973 +1, 18637, 29892, 526, 366, 1136, 455, 2470, 29973, 1815, 366, 5193, 304, 592, 29973 +1, 18637, 29892, 526, 366, 1136, 455, 2470, 29973, 1815, 366, 5193, 304, 592, 29973 +1, 18637, 29892, 526, 366, 1136, 455, 2470, 29973, 1815, 366, 5193, 304, 592, 29973 +1, 18637, 29892, 526, 366, 1136, 455, 2470, 29973, 1815, 366, 5193, 304, 592, 29973 diff --git a/examples/cpp/llama/stop_words.csv b/examples/cpp/llama/stop_words.csv new file mode 100644 index 000000000..9b9b09eba --- /dev/null +++ b/examples/cpp/llama/stop_words.csv @@ -0,0 +1,2 @@ +287, 4346, 12 +3, -1, -1 From da73711746cb9228858ce4623f8c35f137a05694 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Thu, 24 Aug 2023 13:51:55 -0700 Subject: [PATCH 05/10] commit --- examples/cpp/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index da24d72c6..980331b68 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -25,6 +25,7 @@ add_subdirectory(vit_int8) add_subdirectory(wenet) add_subdirectory(gptj) +add_subdirectory(llama) add_subdirectory(gptneox) add_subdirectory(multi_gpu_gpt) From 5f3302b9b56d02087384ee7562f6aa38faf1a717 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Thu, 24 Aug 2023 14:01:00 -0700 Subject: [PATCH 06/10] commit --- src/fastertransformer/kernels/CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/fastertransformer/kernels/CMakeLists.txt b/src/fastertransformer/kernels/CMakeLists.txt index fd2a1b494..c62814d21 100644 --- a/src/fastertransformer/kernels/CMakeLists.txt +++ b/src/fastertransformer/kernels/CMakeLists.txt @@ -114,6 +114,10 @@ add_library(transpose_int8_kernels STATIC transpose_int8_kernels.cu) set_property(TARGET transpose_int8_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET transpose_int8_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +add_library(repeat_kv_kernels STATIC repeat_kv_kernels.cu) +set_property(TARGET repeat_kv_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET repeat_kv_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + add_library(matrix_transpose_kernels STATIC matrix_transpose_kernels.cu) set_property(TARGET matrix_transpose_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET matrix_transpose_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) From f30a53ebc29d6b80909d31c0a91af762a973e59c Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Thu, 24 Aug 2023 14:01:39 -0700 Subject: [PATCH 07/10] commit --- .../kernels/repeat_kv_kernels.cu | 85 +++++++++++++++++++ .../kernels/repeat_kv_kernels.h | 29 +++++++ 2 files changed, 114 insertions(+) create mode 100644 src/fastertransformer/kernels/repeat_kv_kernels.cu create mode 100644 src/fastertransformer/kernels/repeat_kv_kernels.h diff --git a/src/fastertransformer/kernels/repeat_kv_kernels.cu b/src/fastertransformer/kernels/repeat_kv_kernels.cu new file mode 100644 index 000000000..fc0e2a347 --- /dev/null +++ b/src/fastertransformer/kernels/repeat_kv_kernels.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022. Authored by Yuqing Ding. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#elif (CUDART_VERSION >= 11050) +#include +#else +#include "3rdparty/cub/cub.cuh" +#endif + +#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" +#include "src/fastertransformer/models/wenet/WenetKernels.h" +#include "src/fastertransformer/utils/cuda_utils.h" +namespace fastertransformer { + + +template +__global__ void repeat_kv(T* dst, const T* src, const int kv_head_num, const int repeat_num, const int size_per_head, const int token_num) +{ + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < kv_head_num * token_num * size_per_head; id += blockDim.x * gridDim.x) { + int token_id = id / (size_per_head * kv_head_num); + int head_id = id / size_per_head % kv_head_num; + int inner_id = id % size_per_head; + for (int r = 0; r < repeat_num; r++) { + int q_dst = 3 * repeat_num * kv_head_num * size_per_head * token_id + repeat_num * size_per_head * head_id + size_per_head * r + inner_id; + int q_src = (repeat_num + 2) * kv_head_num * size_per_head * token_id + repeat_num * size_per_head * head_id + size_per_head * r + inner_id; + dst[q_dst] = src[q_src]; + + int k_dst = 3 * repeat_num * kv_head_num * size_per_head * token_id + repeat_num * kv_head_num * size_per_head + repeat_num * size_per_head * head_id + size_per_head * r + inner_id; + int k_src = (repeat_num + 2) * kv_head_num * size_per_head * token_id + repeat_num * kv_head_num * size_per_head + size_per_head * head_id + inner_id; + dst[k_dst] = src[k_src]; + + int v_dst = 3 * repeat_num * kv_head_num * size_per_head * token_id + 2 * repeat_num * kv_head_num * size_per_head + repeat_num * size_per_head * head_id + size_per_head * r + inner_id; + int v_src = (repeat_num + 2) * kv_head_num * size_per_head * token_id + (repeat_num + 1) * kv_head_num * size_per_head + size_per_head * head_id + inner_id; + dst[v_dst] = src[v_src]; + } + } +} + +template +void invokeRepeatKv(T* dst, const T* src, const int head_num, const int kv_head_num, const int size_per_head, const int token_num, cudaStream_t stream) +{ + dim3 block, grid; + const int n = kv_head_num * token_num; + if (n <= 1024) { + block.x = n; + grid.x = size_per_head; + } + else { + block.x = 1024; + grid.x = ceil(size_per_head * n / 1024.); + } + repeat_kv<<>>(dst, src, kv_head_num, head_num/kv_head_num, size_per_head, token_num); +} + +template void +invokeRepeatKv(float* dst, const float* src, const int head_num, const int kv_head_num, const int size_per_head, const int token_num, cudaStream_t stream); +template void +invokeRepeatKv(half* dst, const half* src, const int head_num, const int kv_head_num, const int size_per_head, const int token_num, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeRepeatKv(__nv_bfloat16* dst, + const __nv_bfloat16* src, + const int head_num, + const int kv_head_num, + const int size_per_head, + const int token_num, + cudaStream_t stream); +#endif + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/repeat_kv_kernels.h b/src/fastertransformer/kernels/repeat_kv_kernels.h new file mode 100644 index 000000000..a41601115 --- /dev/null +++ b/src/fastertransformer/kernels/repeat_kv_kernels.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022. Authored by Yuqing Ding. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include + +namespace fastertransformer { + +template +void invokeRepeatKv(T* dst, const T* src, const int head_num, const int kv_head_num, const int size_per_head, const int token_num, cudaStream_t stream); + +} // namespace fastertransformer From a46f5764e59a87691b9b7958fa4b6e1340cadeca Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Thu, 24 Aug 2023 14:32:19 -0700 Subject: [PATCH 08/10] commit --- src/fastertransformer/layers/FfnLayer.cc | 5 +++-- src/fastertransformer/layers/FfnLayer.h | 1 + src/fastertransformer/layers/TensorParallelSiluFfnLayer.cc | 6 ++++-- src/fastertransformer/layers/TensorParallelSiluFfnLayer.h | 3 ++- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/fastertransformer/layers/FfnLayer.cc b/src/fastertransformer/layers/FfnLayer.cc index 7ac441198..4b18b54ad 100644 --- a/src/fastertransformer/layers/FfnLayer.cc +++ b/src/fastertransformer/layers/FfnLayer.cc @@ -81,7 +81,7 @@ void FfnLayer::forward(TensorMap* output_tensors, TensorMap* input_tensors, c } // TODO: INT8 and Sparsity are currently not implemented (geglu or reglu) - const bool use_gated_activation = use_gated_activation_ && ffn_weights->intermediate_weight2.kernel != nullptr; + const bool use_gated_activation = use_gated_activation_ && (ffn_weights->intermediate_weight2.kernel != nullptr || ffn_weights->intermediate_weight2.int8_kernel != nullptr); // moe can't be used with use_gated_activation currently FT_CHECK(!(use_gated_activation && use_moe)); @@ -684,6 +684,7 @@ SiluFfnLayer::SiluFfnLayer(size_t max_batch_size, IAllocator* allocator, bool is_free_buffer_after_forward, bool sparse, + int int8_mode, bool use_gated_activation): FfnLayer(max_batch_size, max_seq_len, @@ -696,7 +697,7 @@ SiluFfnLayer::SiluFfnLayer(size_t max_batch_size, allocator, is_free_buffer_after_forward, sparse, - 0, + int8_mode, use_gated_activation) { } diff --git a/src/fastertransformer/layers/FfnLayer.h b/src/fastertransformer/layers/FfnLayer.h index af7ae7606..f84915d2f 100644 --- a/src/fastertransformer/layers/FfnLayer.h +++ b/src/fastertransformer/layers/FfnLayer.h @@ -210,6 +210,7 @@ class SiluFfnLayer: public FfnLayer { IAllocator* allocator, bool is_free_buffer_after_forward, bool sparse = false, + int int8_mode = 0, bool use_gated_activation = false); SiluFfnLayer(SiluFfnLayer const& ffn_layer); diff --git a/src/fastertransformer/layers/TensorParallelSiluFfnLayer.cc b/src/fastertransformer/layers/TensorParallelSiluFfnLayer.cc index bfc781cc4..9af46dbd0 100644 --- a/src/fastertransformer/layers/TensorParallelSiluFfnLayer.cc +++ b/src/fastertransformer/layers/TensorParallelSiluFfnLayer.cc @@ -78,7 +78,8 @@ TensorParallelSiluFfnLayer::TensorParallelSiluFfnLayer(size_t max_b bool is_sparse, bool use_gated_activation, std::shared_ptr custom_all_reduce_comm, - int enable_custom_all_reduce): + int enable_custom_all_reduce, + int int8_mode): SiluFfnLayer(max_batch_size, max_seq_len, head_num, @@ -90,6 +91,7 @@ TensorParallelSiluFfnLayer::TensorParallelSiluFfnLayer(size_t max_b allocator, is_free_buffer_after_forward, is_sparse, + int8_mode, use_gated_activation), tensor_para_(tensor_para), custom_all_reduce_comm_(custom_all_reduce_comm), @@ -111,4 +113,4 @@ template class TensorParallelSiluFfnLayer; template class TensorParallelSiluFfnLayer<__nv_bfloat16>; #endif -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/TensorParallelSiluFfnLayer.h b/src/fastertransformer/layers/TensorParallelSiluFfnLayer.h index ae481373a..5f0e6d625 100644 --- a/src/fastertransformer/layers/TensorParallelSiluFfnLayer.h +++ b/src/fastertransformer/layers/TensorParallelSiluFfnLayer.h @@ -47,7 +47,8 @@ class TensorParallelSiluFfnLayer: public SiluFfnLayer { bool is_sparse, bool use_gated_activation = false, std::shared_ptr custom_all_reduce_comm = nullptr, - int enable_custom_all_reduce = 0); + int enable_custom_all_reduce = 0, + int int8_mode = 0); TensorParallelSiluFfnLayer(TensorParallelSiluFfnLayer const& ffn_layer); From 9c7b9934db47ba6d8034e3c54294288a165f520a Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Thu, 24 Aug 2023 14:38:10 -0700 Subject: [PATCH 09/10] commit --- src/fastertransformer/kernels/layernorm_kernels.cu | 11 +++++++++++ .../layers/adapter_layers/LinearAdapterLayer.cc | 5 +++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/fastertransformer/kernels/layernorm_kernels.cu b/src/fastertransformer/kernels/layernorm_kernels.cu index 369030b37..60e6f001a 100644 --- a/src/fastertransformer/kernels/layernorm_kernels.cu +++ b/src/fastertransformer/kernels/layernorm_kernels.cu @@ -1490,6 +1490,17 @@ template void invokeGeneralAddResidualT5PreLayerNorm(half* output, int n, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeGeneralAddResidualT5PreLayerNorm(__nv_bfloat16* output, + __nv_bfloat16* norm_output, + const __nv_bfloat16* input, + const __nv_bfloat16* gamma, + const float layernorm_eps, + int m, + int n, + cudaStream_t stream); +#endif + template void invokeGeneralAddBiasResidualT5PreLayerNorm(T* output, T* norm_output, diff --git a/src/fastertransformer/layers/adapter_layers/LinearAdapterLayer.cc b/src/fastertransformer/layers/adapter_layers/LinearAdapterLayer.cc index c5ea150f4..ee5c2f028 100644 --- a/src/fastertransformer/layers/adapter_layers/LinearAdapterLayer.cc +++ b/src/fastertransformer/layers/adapter_layers/LinearAdapterLayer.cc @@ -88,7 +88,8 @@ LinearAdapterLayer::LinearAdapterLayer(LinearAdapterConfig const& co is_sparse, false, custom_all_reduce_comm, - enable_custom_all_reduce)}, + enable_custom_all_reduce, + 0)}, layer_norm_type_{config.layerNormType()}, layer_norm_eps_{layer_norm_eps}, max_token_size_{max_batch_size * max_seq_len}, @@ -173,4 +174,4 @@ template class LinearAdapterLayer; template class LinearAdapterLayer<__nv_bfloat16>; #endif -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer From f1db3c9196c3a8a17343a71beca0b15d924163e1 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Thu, 24 Aug 2023 14:49:31 -0700 Subject: [PATCH 10/10] commit --- .../layers/attention_layers/LlamaContextAttentionLayer.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc index 91de7d46d..1f3734bb6 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -45,7 +45,6 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten // hidden_features [token_num, hidden_dimension] // key_cache [batch, local_head_num, size_per_head // x, max_seq_len, x] // value_cache [batch, local_head_num, max_seq_len, size_per_head] - printf("LlamaContextAttentionLayer::forward\n"); FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); FT_CHECK(output_tensors->at("key_cache").shape.size() == 5); FT_CHECK(output_tensors->at("value_cache").shape.size() == 4 @@ -71,12 +70,10 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten FT_CHECK_WITH_INFO(attention_type != AttentionType::FUSED_PADDED_MHA, "Llama Context FUSED_PADDED_MHA is not supported !"); - printf("attention buffer alloc %d %d\n", request_batch_size, request_seq_len + max_prompt_length); PUSH_RANGE("attention buffer alloc"); allocateBuffer(request_batch_size, request_seq_len + max_prompt_length, attention_type != AttentionType::FUSED_MHA); POP_RANGE; sync_check_cuda_error(); - printf("attention buffer alloc done\n"); const bool is_final = input_tensors->at("is_final_layer").getVal(); const int m = input_tensors->at("input_query").shape[0]; @@ -312,7 +309,6 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten } sync_check_cuda_error(); - printf("cublas_wrapper_->Gemm: done\n"); // IDEA: append prefix prompt key value here PrefixPromptBatchWeightsParam param{d_prefix_prompt_batch, d_prefix_prompt_lengths,