Skip to content

Commit

Permalink
Add INT8 support for fused_multi_transformer_op (#45284) (#46169)
Browse files Browse the repository at this point in the history
Co-authored-by: RichardWooSJTU <[email protected]>
  • Loading branch information
minghaoBD and RichardWooSJTU authored Sep 19, 2022
1 parent e5dc9d6 commit db368d5
Show file tree
Hide file tree
Showing 22 changed files with 4,168 additions and 1,428 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
auto var_data_type = var_node->Var()->GetDataType();
VLOG(5) << "var_name is " << var_name << ", data type is "
<< var_data_type;
if (var_data_type == paddle::framework::proto::VarType::FP16) {
if (var_data_type == paddle::framework::proto::VarType::FP16 &&
t->dtype() != paddle::experimental::DataType::FLOAT16) {
framework::Tensor half_tensor;
half_tensor.set_type(paddle::experimental::DataType::FLOAT16);
half_tensor.Resize(t->dims());
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ register_operators(
fused_transformer_op
fused_feedforward_op
fused_multi_transformer_op
fused_multi_transformer_int8_op
fused_bias_dropout_residual_layer_norm_op
resnet_unit_op
fused_gemm_epilogue_op
Expand Down Expand Up @@ -119,6 +120,7 @@ if(WITH_GPU OR WITH_ROCM)
# fused_attention_op
op_library(fused_attention_op)
op_library(fused_multi_transformer_op)
op_library(fused_multi_transformer_int8_op)
op_library(fused_bias_dropout_residual_layer_norm_op)
endif()
# resnet_unit needs cudnn 8.0 above
Expand Down
30 changes: 24 additions & 6 deletions paddle/fluid/operators/fused/attention_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ limitations under the License. */
namespace paddle {
namespace operators {

template <typename T>
// NOTE: T must be the same as OutType in ComputeBackward
template <typename T, typename InType = T, typename OutType = T>
class AttnLayerNorm {
public:
AttnLayerNorm(const phi::GPUContext& dev_ctx,
Expand All @@ -33,25 +34,42 @@ class AttnLayerNorm {

~AttnLayerNorm() {}

void ComputeForward(const T* x_data,
void ComputeForward(const InType* x_data,
const LayerNormParamType<T>* scale_data,
const LayerNormParamType<T>* bias_data,
T* y_data,
OutType* y_data,
LayerNormParamType<T>* mean_data,
LayerNormParamType<T>* var_data) {
LayerNormParamType<T>* var_data,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
auto stream = dev_ctx_.stream();

switch (GetDesiredBlockDim(feature_size_)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, LayerNormParamType<T>, kBlockDim>
LayerNormForward<T,
LayerNormParamType<T>,
kBlockDim,
false,
InType,
OutType>
<<<batch_size_, kBlockDim, 0, stream>>>(x_data,
scale_data,
bias_data,
y_data,
mean_data,
var_data,
epsilon_,
feature_size_));
feature_size_,
dequant_out_scale_data,
quant_out_scale_offset,
quant_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Feature_size must be larger than 1"));
Expand Down
189 changes: 189 additions & 0 deletions paddle/fluid/operators/fused/attn_gemm_int8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <iostream>
#include <vector>
#include "paddle/fluid/operators/fused/cublaslt.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
class AttnMatmulINT8 {
public:
AttnMatmulINT8(
const phi::GPUContext& dev_ctx, int m, int n, int k, bool compute_bias)
: dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) {
auto helper = std::make_shared<CublasLtHelper>(m, k, n);
helpers_.emplace_back(helper);
}
~AttnMatmulINT8() {}

// This function is used to execute GEMM, with input and output's types are
// both T.
void ComputeForward(const framework::Tensor* weight,
const framework::Tensor* input,
framework::Tensor* input_tmp,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* output_tmp,
framework::Tensor* bias_out,
const float quant_in_scale,
const framework::Tensor* dequant_out_scale,
const int quant_out_scale_offset,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());

helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream());

dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
quant_in_scale,
dequant_out_scale->data<float>(),
quant_out_scale_offset);

if (compute_bias_) {
// bias_out = output + bias
std::vector<const framework::Tensor*> ins = {output, bias};
std::vector<framework::Tensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
"cuda error occured after computing bias. "
"But it does not mean this error is caused by "
"bias computing"));
}
}

// This function is used to execute GEMM, with input and output's types are
// both INT8.
void ComputeForwardINT8ToINT8(const framework::Tensor* weight,
framework::Tensor* input,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* bias_out) {
helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream());
}

// This function is used to execute GEMM, with input and output's types are
// INT8 and T.
void ComputeForwardINT8ToT(const framework::Tensor* weight,
const float quant_in_scale,
framework::Tensor* input,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* output_tmp,
framework::Tensor* bias_out,
const framework::Tensor* dequant_out_scale,
const int quant_out_scale_offset) {
helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream());

dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
quant_in_scale,
dequant_out_scale->data<float>(),
quant_out_scale_offset);

if (compute_bias_) {
// bias_out = output + bias
std::vector<const framework::Tensor*> ins = {output, bias};
std::vector<framework::Tensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
"cuda error occured after computing bias. "
"But it does not mean this error is caused by "
"bias computing"));
}
}

// This function is used to execute GEMM, with input and output's types are T
// and INT8.
void ComputeForwardTToINT8(const framework::Tensor* weight,
const float quant_in_scale,
const framework::Tensor* input,
framework::Tensor* input_tmp,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* bias_out,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());

helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream());
}

private:
const phi::GPUContext& dev_ctx_;

int m_; // m
int n_; // n
int k_; // k

int compute_bias_;
std::vector<std::shared_ptr<CublasLtHelper>> helpers_;
};

} // namespace operators
} // namespace paddle
Loading

0 comments on commit db368d5

Please sign in to comment.