Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support (Bias)SkipLayerNormalization fusion in GPT2 #13988

Merged
merged 23 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3840,7 +3840,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>1D bias tensor with shape (hidden_size</dd>
</dl>

#### Outputs (1 - 3)
#### Outputs (1 - 4)

<dl>
<dt><tt>output</tt> : T</dt>
Expand All @@ -3849,6 +3849,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Saved mean used during training to speed up gradient computation</dd>
<dt><tt>inv_std_var</tt> (optional) : U</dt>
<dd>Saved inverse standard variance used during training to speed up gradient computation.</dd>
<dt><tt>input_skip_sum</tt> (optional) : T</dt>
<dd>Sum of the input and skip inputs with shape (batch_size, sequence_length, hidden_size).</dd>
</dl>

#### Type Constraints
Expand Down
6 changes: 3 additions & 3 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ Do not modify directly.*
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
Expand Down Expand Up @@ -797,7 +797,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
| |
Expand Down Expand Up @@ -1130,7 +1130,7 @@ Do not modify directly.*
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
| |
| |
|**Operator Domain:** *com.microsoft.dml*||||
Expand Down
19 changes: 16 additions & 3 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ Status SkipLayerNorm<T>::Compute(OpKernelContext* p_ctx) const {
const Tensor* beta = p_ctx->Input<Tensor>(3);
const Tensor* bias = p_ctx->Input<Tensor>(4);
Tensor* output = p_ctx->Output(0, input->Shape());
// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
Tensor* skip_input_add_output = p_ctx->Output(3, input->Shape());

const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 3) {
Expand Down Expand Up @@ -98,18 +101,28 @@ Status SkipLayerNorm<T>::Compute(OpKernelContext* p_ctx) const {

T* output_data = output->MutableData<T>();

// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
T* skip_input_add_output_data = skip_input_add_output != nullptr ? skip_input_add_output->MutableData<T>() : nullptr;

concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
const T* p_input = input_data + task_idx * hidden_size;
const T* p_skip = skip_data + task_idx * hidden_size;
T* p_output = output_data + task_idx * hidden_size;
auto offset = task_idx * hidden_size;

const T* p_input = input_data + offset;
const T* p_skip = skip_data + offset;
T* p_output = output_data + offset;
T* p_skip_input_add_output_data = skip_input_add_output_data != nullptr ? skip_input_add_output_data + offset : nullptr;

T mean = 0;
T mean_square = 0;

for (int64_t h = 0; h < hidden_size; h++) {
T value = p_input[h] + p_skip[h];
if (nullptr != p_skip_input_add_output_data) {
p_skip_input_add_output_data[h] = value;
}
if (nullptr != bias_data) {
value += bias_data[h];
}
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {

Tensor* output = ctx->Output(0, input->Shape());

// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
Tensor* skip_input_add_output = ctx->Output(3, input->Shape());

if (input->Shape() != skip->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"skip is expected to have same shape as input");
Expand Down Expand Up @@ -99,6 +103,7 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
return LaunchSkipLayerNormKernel<CudaT>(
Stream(ctx),
reinterpret_cast<CudaT*>(output->MutableData<T>()),
skip_input_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_add_output->MutableData<T>()) : nullptr,
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const CudaT*>(skip->Data<T>()),
reinterpret_cast<const CudaT*>(gamma->Data<T>()),
Expand Down
82 changes: 56 additions & 26 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ half maybe2half(float x) {

template <typename T, unsigned TPB>
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias,
const T epsilon, T* output) {
const int ld, const T* input, const T* skip,
const T* beta, const T* gamma, const T* bias,
const T epsilon, T* output, T* skip_input_add_output) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;

Expand All @@ -61,6 +62,11 @@ __global__ void SkipLayerNormKernel(

for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;

if (skip_input_add_output != nullptr) {
skip_input_add_output[idx] = input[idx] + skip[idx];
}

const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, cub::KeyValuePair<T, T>(rldval, rldval * val));
Expand All @@ -74,13 +80,14 @@ __global__ void SkipLayerNormKernel(
template <typename T, unsigned TPB, int ILP>
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, bool hasBias) {
const T* bias, const T epsilon, T* output, T* skip_input_add_output,
bool hasBias, bool hasSkipInputAdditionOutput) {
const T rld = T(1.f / ld);
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld

using VecT = aligned_vector<T, ILP>;

T input_v[ILP], skip_v[ILP], bias_v[ILP];
T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_add_output_v[ILP];

VecT* input_val = reinterpret_cast<VecT*>(&input_v);
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);
Expand All @@ -100,103 +107,126 @@ __global__ void SkipLayerNormKernelSmall(
T rldvalsq_sum = T(0.f);
#pragma unroll
for (int i = 0; i < ILP; i++) {
if (hasSkipInputAdditionOutput) {
skip_input_add_output_v[i] = input_v[i] + skip_v[i];
}

input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i];
const T rldval = rld * input_v[i];
rldval_sum += rldval;
rldvalsq_sum += rldval * input_v[i];
}

if (hasSkipInputAdditionOutput) {
*(reinterpret_cast<VecT*>(&skip_input_add_output[idx])) = *reinterpret_cast<VecT*>(&skip_input_add_output_v);
}

thread_data = cub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
}
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
}

template <typename T>
Status LaunchSkipLayerNormKernel(
cudaStream_t stream, T* output, const T* input, const T* skip, const T* gamma,
cudaStream_t stream, T* output, T* skip_input_add_output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, float epsilon, const int ld, const int element_count,
size_t element_size) {
// this must be true because n is the total size of the tensor
assert(element_count % ld == 0);
bool hasBias = (bias == nullptr) ? false : true;
bool hasSkipInputAdditionOutput = (skip_input_add_output == nullptr) ? false : true;

if (0 == (ld % 4)) {
const int grid_size = element_count / ld;
if (ld <= 32) {
constexpr int block_size = 32;
SkipLayerNormKernelSmall<T, block_size, 1>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias,
maybe2half<T>(epsilon), output, hasBias);
maybe2half<T>(epsilon), output,
skip_input_add_output, hasBias, hasSkipInputAdditionOutput);
} else if (ld <= 64) {
constexpr int block_size = 64 / 2;
SkipLayerNormKernelSmall<T, block_size, 2>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias,
maybe2half<T>(epsilon), output, hasBias);
maybe2half<T>(epsilon), output,
skip_input_add_output, hasBias, hasSkipInputAdditionOutput);
} else if (ld <= 128) {
constexpr int block_size = 128 / 4;
SkipLayerNormKernelSmall<T, block_size, 4>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias,
maybe2half<T>(epsilon), output, hasBias);
maybe2half<T>(epsilon), output,
skip_input_add_output, hasBias, hasSkipInputAdditionOutput);
} else if (ld <= 384) {
constexpr int block_size = 384 / 4;
SkipLayerNormKernelSmall<T, block_size, 4>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias,
maybe2half<T>(epsilon), output, hasBias);
maybe2half<T>(epsilon), output,
skip_input_add_output, hasBias, hasSkipInputAdditionOutput);
} else if (ld <= 768) {
constexpr int block_size = 768 / 4;
SkipLayerNormKernelSmall<T, block_size, 4>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias,
maybe2half<T>(epsilon), output, hasBias);
maybe2half<T>(epsilon), output,
skip_input_add_output, hasBias, hasSkipInputAdditionOutput);
} else if (ld <= 1024) {
constexpr int block_size = 1024 / 4;
SkipLayerNormKernelSmall<T, block_size, 4>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias,
maybe2half<T>(epsilon), output, hasBias);
maybe2half<T>(epsilon), output,
skip_input_add_output, hasBias, hasSkipInputAdditionOutput);
} else {
constexpr int block_size = 256;
SkipLayerNormKernel<T, block_size>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias,
maybe2half<T>(epsilon), output);
maybe2half<T>(epsilon), output, skip_input_add_output);
}
} else {
const int grid_size = element_count / ld;
if (ld <= 32) {
constexpr int block_size = 32;
SkipLayerNormKernelSmall<T, block_size, 1>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias,
maybe2half<T>(epsilon), output, hasBias);
maybe2half<T>(epsilon), output,
skip_input_add_output, hasBias, hasSkipInputAdditionOutput);
} else if (ld <= 64) {
constexpr int block_size = 64;
SkipLayerNormKernelSmall<T, block_size, 1>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias,
maybe2half<T>(epsilon), output, hasBias);
maybe2half<T>(epsilon), output,
skip_input_add_output, hasBias, hasSkipInputAdditionOutput);
} else if (ld <= 128) {
constexpr int block_size = 128;
SkipLayerNormKernelSmall<T, block_size, 1>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias,
maybe2half<T>(epsilon), output, hasBias);
maybe2half<T>(epsilon), output,
skip_input_add_output, hasBias, hasSkipInputAdditionOutput);
} else if (ld == 384) {
constexpr int block_size = 384;
SkipLayerNormKernelSmall<T, block_size, 1>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias,
maybe2half<T>(epsilon), output, hasBias);
maybe2half<T>(epsilon), output,
skip_input_add_output, hasBias, hasSkipInputAdditionOutput);
} else {
constexpr int block_size = 256;
SkipLayerNormKernel<T, block_size>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias,
maybe2half<T>(epsilon), output);
maybe2half<T>(epsilon), output, skip_input_add_output);
}
}
return CUDA_CALL(cudaGetLastError());
}

template Status LaunchSkipLayerNormKernel<float>(cudaStream_t stream, float* output, const float* input,
const float* skip, const float* gamma, const float* beta,
const float* bias, float epsilon, const int ld,
const int element_count, size_t element_size);

template Status LaunchSkipLayerNormKernel<half>(cudaStream_t stream, half* output, const half* input,
const half* skip, const half* gamma, const half* beta,
const half* bias, float epsilon, const int ld,
const int element_count, size_t element_size);
template Status LaunchSkipLayerNormKernel<float>(cudaStream_t stream, float* output, float* skip_input_add_output,
const float* input,
const float* skip, const float* gamma, const float* beta,
const float* bias, float epsilon, const int ld,
const int element_count, size_t element_size);

template Status LaunchSkipLayerNormKernel<half>(cudaStream_t stream, half* output, half* skip_input_add_output,
const half* input,
const half* skip, const half* gamma, const half* beta,
const half* bias, float epsilon, const int ld,
const int element_count, size_t element_size);

} // namespace cuda
} // namespace contrib
Expand Down
19 changes: 10 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ namespace cuda {
template <typename T>
Status LaunchSkipLayerNormKernel(
cudaStream_t stream,
T* output, // output tensor
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
const T* bias, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int element_count, // number of elements in input tensor
T* output, // normalized output tensor
T* skip_input_add_output, // sum of the input and skip tensors output
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
const T* bias, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int element_count, // number of elements in input tensor
size_t element_size);

} // namespace cuda
Expand Down
Loading