diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml index e03cbb353bd3..ccb6c25e14f7 100644 --- a/.github/workflows/nv-pre-compile-ops.yml +++ b/.github/workflows/nv-pre-compile-ops.yml @@ -33,7 +33,7 @@ jobs: #python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - name: Compile DeepSpeed Ops run: | - TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 pip3 install . + TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install . - name: DS Report run: | ds_report diff --git a/README.md b/README.md index 4999a485f4ce..459bf6d82433 100755 --- a/README.md +++ b/README.md @@ -15,11 +15,11 @@ ## Latest News DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat). +* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] * [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) * [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md) * [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses) * [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)[[English](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md)] -* [2023/04] 🚀 [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]🚀 --- @@ -35,9 +35,9 @@ --- -# DeepSpeed's three innovation pillars +# DeepSpeed's four innovation pillars - + ## DeepSpeed-Training @@ -53,6 +53,10 @@ DeepSpeed brings together innovations in parallelism technology such as tensor, To further increase the inference efficiency, DeepSpeed offers easy-to-use and flexible-to-compose compression techniques for researchers and practitioners to compress their models while delivering faster speed, smaller model size, and significantly reduced compression cost. Moreover, SoTA innovations on compression like ZeroQuant and XTC are included under the compression pillar. Learn more: [DeepSpeed-Compression](https://www.deepspeed.ai/compression) +## DeepSpeed4Science + +In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. Learn more: [DeepSpeed4Science website](https://deepspeed4science.ai/) and [tutorials](https://www.deepspeed.ai/deepspeed4science/) + --- # DeepSpeed Software Suite diff --git a/csrc/deepspeed4science/evoformer_attn/attention.cpp b/csrc/deepspeed4science/evoformer_attn/attention.cpp new file mode 100644 index 000000000000..ac3364539ff1 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/attention.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +void attention_impl(torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& o, + torch::Tensor& lse); +void attention(torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& o, + torch::Tensor& lse) +{ + attention_impl(q, k, v, bias1, bias2, o, lse); +} + +void attention_back_impl(torch::Tensor& go, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& o, + torch::Tensor& lse, + torch::Tensor& delta, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& gq, + torch::Tensor& gk, + torch::Tensor& gv, + torch::Tensor& gb1, + torch::Tensor& gb2); +void attention_bwd(torch::Tensor& go, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& o, + torch::Tensor& lse, + torch::Tensor& delta, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& gq, + torch::Tensor& gk, + torch::Tensor& gv, + torch::Tensor& gb1, + torch::Tensor& gb2) +{ + attention_back_impl(go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("attention", &attention, ""); + m.def("attention_bwd", &attention_bwd, ""); +} diff --git a/csrc/deepspeed4science/evoformer_attn/attention.cu b/csrc/deepspeed4science/evoformer_attn/attention.cu new file mode 100644 index 000000000000..b409dd2e401d --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/attention.cu @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include "gemm_kernel_utils.h" +#include "kernel_forward.h" +#include "transform/bias_broadcast.h" + +template + class Broadcast1_, + template + class Broadcast2_> +typename std::enable_if::value>::type attention_impl_template( + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& o, + float* lse_ptr) +{ + EVOFORMER_CHECK(false, "Unsupported GPU and data type combination") +} + +template + class Broadcast1_, + template + class Broadcast2_> +typename std::enable_if::value>::type attention_impl_template( + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& o, + float* lse_ptr) +{ + // Attention definition goes here, replaced with BroadcastType1 and + // BroadcastType2 + using Attention = AttentionKernel; + + static_assert(!Attention::kNeedsOutputAccumulatorBuffer, + "This test does not support output accumulator buffer"); + int head_size = q.size(-1); + int head_number = q.size(-2); + int seq_length = q.size(-3); + auto q_view = q.view({-1, seq_length, head_number, head_size}); + auto k_view = k.view({-1, seq_length, head_number, head_size}); + auto v_view = v.view({-1, seq_length, head_number, head_size}); + auto o_view = o.view({-1, seq_length, head_number, head_size}); + int batch_size = q_view.size(0); + auto q_ptr = reinterpret_cast(q.data_ptr()); + auto k_ptr = reinterpret_cast(k.data_ptr()); + auto v_ptr = reinterpret_cast(v.data_ptr()); + auto o_ptr = reinterpret_cast(o.data_ptr()); + + auto bias1_ptr = reinterpret_cast(bias1.data_ptr()); + auto bias2_ptr = reinterpret_cast(bias2.data_ptr()); + + typename Attention::Params p; + { // set parameters + p.query_ptr = q_ptr; + p.key_ptr = k_ptr; + p.value_ptr = v_ptr; + p.logsumexp_ptr = lse_ptr; // Only needed for bw + p.output_accum_ptr = nullptr; + p.output_ptr = o_ptr; + p.scale = 1.0f / sqrt(float(head_size)); + + p.bias1_ptr = bias1_ptr; + p.bias2_ptr = bias2_ptr; + p.B = q.size(0); + p.N = q.size(1); + + p.num_heads = head_number; + p.num_batches = batch_size; + p.head_dim = head_size; + p.head_dim_value = head_size; + p.num_queries = seq_length; + p.num_keys = seq_length; + + // All tensors are in BMHK shapes + p.q_strideH = q_view.stride(-2); + p.k_strideH = k_view.stride(-2); + p.v_strideH = v_view.stride(-2); + p.q_strideM = q_view.stride(-3); + p.k_strideM = k_view.stride(-3); + p.v_strideM = v_view.stride(-3); + p.o_strideM = o_view.stride(-3); + p.q_strideB = q_view.stride(-4); + p.k_strideB = k_view.stride(-4); + p.v_strideB = v_view.stride(-4); + } + + constexpr auto kernel_fn = attention_kernel_batched_impl; + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + if (!Attention::check_supported(p)) { throw std::runtime_error("Parameters not supported"); } + kernel_fn<<>>(p); +} + +#define CODE(scalar_t, torch_scalar_t) \ + do { \ + if (bias1.size(0) == 0 && bias2.size(0) == 0) { \ + attention_impl_template(q, k, v, bias1, bias2, o, lse_ptr); \ + } else if (bias1.size(0) == 0) { \ + attention_impl_template(q, k, v, bias1, bias2, o, lse_ptr); \ + } else if (bias2.size(0) == 0) { \ + attention_impl_template(q, k, v, bias1, bias2, o, lse_ptr); \ + } else { \ + attention_impl_template( \ + q, k, v, bias1, bias2, o, lse_ptr); \ + } \ + } while (0) + +// Function to select and call the correct template based on biases sizes +void attention_impl(torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& o, + torch::Tensor& lse) +{ + auto lse_ptr = lse.size(0) == 0 ? nullptr : reinterpret_cast(lse.data_ptr()); + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + DISPATCH_ARCHTAG(prop->major * 10 + prop->minor, + DISPATCH_TYPES(q, ([&]() { CODE(scalar_t, torch_scalar_t); }))); +} diff --git a/csrc/deepspeed4science/evoformer_attn/attention_back.cu b/csrc/deepspeed4science/evoformer_attn/attention_back.cu new file mode 100644 index 000000000000..1be647b5e552 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/attention_back.cu @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include +#include "gemm_kernel_utils.h" +#include "kernel_backward.h" +#include "transform/bias_broadcast.h" + +constexpr auto kBlockSizeI = 64; +constexpr auto kBlockSizeJ = 64; + +template + class Broadcast1_, + template + class Broadcast2_> +typename std::enable_if::value>::type attention_back_impl_template( + torch::Tensor& go, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& o, + torch::Tensor& lse, + torch::Tensor& delta, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& gq, + torch::Tensor& gk, + torch::Tensor& gv, + torch::Tensor& gb1, + torch::Tensor& gb2) +{ + EVOFORMER_CHECK(false, "Unsupported GPU and data type combination") +} + +template + class Broadcast1_, + template + class Broadcast2_> +typename std::enable_if::value>::type attention_back_impl_template( + torch::Tensor& go, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& o, + torch::Tensor& lse, + torch::Tensor& delta, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& gq, + torch::Tensor& gk, + torch::Tensor& gv, + torch::Tensor& gb1, + torch::Tensor& gb2) +{ + constexpr bool kPreload_ = arch::kMinComputeCapability >= 80; + using Kernel = AttentionBackwardKernel; + int head_size = q.size(-1); + int head_number = q.size(-2); + int seq_length = q.size(-3); + auto q_view = q.view({-1, seq_length, head_number, head_size}); + auto k_view = k.view({-1, seq_length, head_number, head_size}); + auto v_view = v.view({-1, seq_length, head_number, head_size}); + auto o_view = o.view({-1, seq_length, head_number, head_size}); + auto do_view = go.view({-1, seq_length, head_number, head_size}); + auto dk_view = gk.view({-1, seq_length, head_number, head_size}); + auto dv_view = gv.view({-1, seq_length, head_number, head_size}); + auto dq_view = gq.view({-1, seq_length, head_number, head_size}); + auto q_ptr = reinterpret_cast(q.data_ptr()); + auto k_ptr = reinterpret_cast(k.data_ptr()); + auto v_ptr = reinterpret_cast(v.data_ptr()); + auto o_ptr = reinterpret_cast(o.data_ptr()); + auto do_ptr = reinterpret_cast(go.data_ptr()); + auto dk_ptr = reinterpret_cast(gk.data_ptr()); + auto dv_ptr = reinterpret_cast(gv.data_ptr()); + auto dq_ptr = reinterpret_cast(gq.data_ptr()); + auto db1_ptr = gb1.size(0) > 0 ? reinterpret_cast(gb1.data_ptr()) : nullptr; + auto db2_ptr = gb2.size(0) > 0 ? reinterpret_cast(gb2.data_ptr()) : nullptr; + auto lse_ptr = reinterpret_cast(lse.data_ptr()); + auto delta_ptr = reinterpret_cast(delta.data_ptr()); + auto bias1_ptr = reinterpret_cast(bias1.data_ptr()); + auto bias2_ptr = reinterpret_cast(bias2.data_ptr()); + static_assert(Kernel::kKernelComputesDelta, "Kernel must compute delta"); + + typename Kernel::Params p; + p.query_ptr = q_ptr; + p.key_ptr = k_ptr; + p.value_ptr = v_ptr; + p.logsumexp_ptr = lse_ptr; + p.output_ptr = o_ptr; + p.grad_output_ptr = do_ptr; + p.delta_ptr = delta_ptr; + p.grad_query_ptr = dq_ptr; + p.grad_key_ptr = dk_ptr; + p.grad_value_ptr = dv_ptr; + + p.grad_bias1_ptr = db1_ptr; + p.grad_bias2_ptr = db2_ptr; + p.B = q.size(0); + p.N = q.size(1); + p.bias1_ptr = bias1.size(0) ? bias1_ptr : nullptr; + p.bias2_ptr = bias2.size(0) ? bias2_ptr : nullptr; + + p.scale = 1.0f / sqrtf(head_size); + + p.head_dim = head_size; + p.head_dim_value = head_size; + p.num_queries = seq_length; + p.num_keys = seq_length; + p.num_heads = head_number; + + p.q_strideM = q_view.stride(-3); + p.k_strideM = k_view.stride(-3); + p.v_strideM = v_view.stride(-3); + p.gO_strideM = do_view.stride(-3); + p.o_strideH = o_view.stride(-2); + p.q_strideH = q_view.stride(-2); + p.k_strideH = k_view.stride(-2); + p.v_strideH = v_view.stride(-2); + p.o_strideB = o_view.stride(-4); + p.q_strideB = q_view.stride(-4); + p.k_strideB = k_view.stride(-4); + p.v_strideB = v_view.stride(-4); + p.lse_strideB = lse.stride(-3); + p.lse_strideH = lse.stride(-2); + p.delta_strideB = delta.stride(-3); + p.delta_strideH = delta.stride(-2); + p.num_batches = q_view.size(-4); + + p.gO_strideB = do_view.stride(-4); + p.gQ_strideB = dq_view.stride(-4); + p.gK_strideB = dk_view.stride(-4); + p.gV_strideB = dv_view.stride(-4); + p.gO_strideH = do_view.stride(-2); + p.gQ_strideH = dq_view.stride(-2); + p.gK_strideH = dk_view.stride(-2); + p.gV_strideH = dv_view.stride(-2); + + torch::Tensor workspace = torch::empty(p.workspace_size() / 4, lse.options()); + p.workspace = workspace.data_ptr(); + + auto kernel_fn = attention_kernel_backward_batched_impl; + size_t smem_bytes = sizeof(typename Kernel::SharedStorage); + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes)); + if (!Kernel::check_supported(p)) { throw std::runtime_error("Unsupported parameters"); } + kernel_fn<<>>(p); +} + +#define CODE(scalar_t, torch_scalar_t) \ + do { \ + if (bias1.size(0) == 0 && bias2.size(0) == 0) { \ + attention_back_impl_template( \ + go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \ + } else if (bias1.size(0) > 0 && bias2.size(0) > 0) { \ + attention_back_impl_template( \ + go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \ + } else if (bias1.size(0) > 0) { \ + attention_back_impl_template( \ + go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \ + } else { \ + attention_back_impl_template( \ + go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \ + } \ + } while (0) + +void attention_back_impl(torch::Tensor& go, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& o, + torch::Tensor& lse, + torch::Tensor& delta, + torch::Tensor& bias1, + torch::Tensor& bias2, + torch::Tensor& gq, + torch::Tensor& gk, + torch::Tensor& gv, + torch::Tensor& gb1, + torch::Tensor& gb2) +{ + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + DISPATCH_ARCHTAG(prop->major * 10 + prop->minor, + DISPATCH_TYPES(q, ([&]() { CODE(scalar_t, torch_scalar_t); }))); +} diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_grad_bias.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_grad_bias.h new file mode 100644 index 000000000000..17b6479ed8c5 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_grad_bias.h @@ -0,0 +1,250 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once +#include +#include +#include "../iterators/predicated_tile_iterator_atomic.h" +#include "cutlass/epilogue/threadblock/epilogue.h" + +namespace cutlass { +namespace epilogue { +namespace threadblock { +template +struct EpilogueTensorOpAffineRankN : public DefaultEpilogueTensorOpAffineRankN { + using Base = DefaultEpilogueTensorOpAffineRankN; + using OutputTileIterator = + cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankNAtomic< + typename Base::OutputTileThreadMap, + typename Base::ElementOutput, + Rank>; + + using Epilogue = + cutlass::epilogue::threadblock::Epilogue; +}; + +template +struct EpilogueVoltaTensorOpAffineRankN + : public DefaultEpilogueVoltaTensorOpAffineRankN { + using Base = DefaultEpilogueVoltaTensorOpAffineRankN; + using OutputTileIterator = + cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankNAtomic< + typename Base::OutputTileThreadMap, + typename Base::ElementOutput, + Rank>; + + using Epilogue = + cutlass::epilogue::threadblock::Epilogue; +}; + +template +struct EpilogueTensorOp : public DefaultEpilogueTensorOp { + using Base = DefaultEpilogueTensorOp; + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAtomic< + typename Base::OutputTileThreadMap, + typename Base::ElementOutput, + ScatterD, + PermuteDLayout>; + using Epilogue = + cutlass::epilogue::threadblock::Epilogue; +}; + +template +struct EpilogueVoltaTensorOp : public DefaultEpilogueVoltaTensorOp { + using Base = DefaultEpilogueVoltaTensorOp; + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAtomic< + typename Base::OutputTileThreadMap, + typename Base::ElementOutput, + ScatterD, + PermuteDLayout>; + using Epilogue = + cutlass::epilogue::threadblock::Epilogue; +}; +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +template +struct BiasGradEpilogue { + using Epilogue = + typename cutlass::epilogue::threadblock::EpilogueTensorOp::Epilogue; +}; + +template +struct BiasGradEpilogue { + using Epilogue = + typename cutlass::epilogue::threadblock::EpilogueVoltaTensorOp::Epilogue; +}; + +template +struct BiasGradEpilogueAffineRankN { + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueTensorOpAffineRankN< + Rank, + Shape_, + WarpMmaTensorOp_, + PartitionsK, + OutputOp_, + ElementsPerAccess>::Epilogue; +}; + +template +struct BiasGradEpilogueAffineRankN { + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueVoltaTensorOpAffineRankN< + Rank, + Shape_, + WarpMmaTensorOp_, + PartitionsK, + OutputOp_, + ElementsPerAccess>::Epilogue; +}; diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_pipelined.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_pipelined.h new file mode 100644 index 000000000000..3b7b32d61452 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_pipelined.h @@ -0,0 +1,592 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + File copied from "cutlass/epilogue/threadblock/epilogue.h" + then modified to: + (1) load 2 source fragments at the same time (pipelining) + (2) support reading from a different dtype + (3) pass the row id to the OutputOp if it takes it + (see MemoryEfficientAttentionNormalize) + Note that in general the fragment passed to the OutputOp could + span multiple rows but it does not happen with the configurations we have +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +struct ApplyEpilogueOp { + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentOutput const& source) + { + return output_op(accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput + apply(Op const& output_op, int row_id, typename Op::FragmentAccumulator const& accum) + { + return output_op(accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template ::value), + typename OutputTileSourceIterator_ = + OutputTileIterator_ ///< Tile iterator reading tensors + > +class EpiloguePipelined : public EpilogueBase { +public: + using Base = EpilogueBase; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using OutputTileSourceIterator = OutputTileSourceIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + using ElementSource = typename OutputTileSourceIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = + Array; + using SourceAccessType = Array; + + /// Array type used by output functor + using AccumulatorAccessType = + Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + static int constexpr kSmemTiles = + Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = + Base::SharedStorage::StorageShape::kCount / kSmemTiles; + +public: + static_assert(OutputTileSourceIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between input tile and output tile iterator (kElements)"); + static_assert(OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, + "Mismatch between input tile and output tile iterator (kIterations)"); + static_assert(SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + +public: + /// Constructor + CUTLASS_DEVICE + EpiloguePipelined(typename Base::SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx) + { + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()(OutputOp const& output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator source_iterator) + { ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } else { + compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator); + } + } + CUTLASS_DEVICE + void operator()(OutputOp const& output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& accumulators) + { ///< Complete warp-level accumulator tile + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } + +private: + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { ++accum_fragment_iterator; } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * + (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) + { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper(iterator_begin, warp_tile_iterator), + 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& accumulators ///< Complete warp-level accumulator tile + ) + { + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ + : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; + iter += Base::kFragmentsPerIteration) { + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_not_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } else if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = + add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * + kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_(destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * + (1 - Base::kFragmentsPerIteration)); + } + } + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { ++accum_fragment_iterator; } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) + { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + ) + { + typename OutputTileSourceIterator::Fragment source_fragment[2]; + + source_fragment[0].clear(); + source_iterator.load(source_fragment[0]); + ++source_iterator; + source_fragment[1].clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + if (iter > 0) { __syncthreads(); } + // + // Load the source for next iteration (pipelining) + // + + if (iter + 1 < OutputTileIterator::kIterations) { + source_iterator.load(source_fragment[(iter + 1) % 2]); + } + ++source_iterator; + acc2smem_source_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = + add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_(destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0], + source_fragment[iter % 2]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_(int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment, + typename OutputTileSourceIterator::Fragment const& source_fragment) + { + OutputAccessType* output_frag_ptr = reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + SourceAccessType const* source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i], + source_frag_ptr[i]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment) + { + OutputAccessType* output_frag_ptr = reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i]); + } + } + + // This should be constexpr, but it's only supported on c++14 + static int CUTLASS_HOST_DEVICE getRowOffset(int i) + { + using ThreadMap = typename OutputTileIterator::ThreadMap; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + int frag_idx = ThreadMap::kElementsPerAccess * + (frag_row_idx * ThreadMap::Iterations::kColumn + column); + if (i < frag_idx + ThreadMap::kElementsPerAccess) { return row_offset; } + } + } + } + } + return -1; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h new file mode 100644 index 000000000000..f81a09f74f1e --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h @@ -0,0 +1,251 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. + + This is a copy of cutlass/epilogue/threadblock/epilogue.h that can + handle "row_id" as a first argument, as uses it to get the corresponding + `m_prime` / `s_prime` to rescale the output. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "epilogue_pipelined.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +// output <- alpha * accumulator + beta * source +// with: +// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) +// beta = alpha / m_prime (renormalize the output when the max changes) +// source is the current output +template , + ///< but we use 64 or 32 sometimes when there are not enough data + ///< to store + typename ElementAccumulator_, ///< Accumulator data type + typename ElementCompute_, ///< Data type used to compute linear combination + bool isFirst, + bool isLast, + typename FragmentAlphaBeta_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class MemoryEfficientAttentionNormalize { +public: + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentSource = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + using FragmentAlphaBeta = FragmentAlphaBeta_; + + static FloatRoundStyle const kRound = Round; + +private: + // + // Data members + // + + FragmentAlphaBeta const& s_prime_; + FragmentAlphaBeta const& m_prime_; + +public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + MemoryEfficientAttentionNormalize(FragmentAlphaBeta const& s_prime, + FragmentAlphaBeta const& m_prime) + : s_prime_(s_prime), m_prime_(m_prime) + { + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return !isFirst; } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, + FragmentAccumulator const& accumulator, + FragmentSource const& source) const + { + assert(!isFirst); + + // Convert source to internal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + ElementCompute beta = alpha * m_prime_[row]; + + intermediate = mul_add_source(beta, converted_source); // X = beta * C + + intermediate = mul_add_accumulator( + alpha, converted_accumulator, intermediate); // D = alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) const + { + assert(isFirst); + + // Convert source to internal compute numeric type + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + ComputeFragment intermediate; + multiplies mul_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + + intermediate = mul_accumulator(alpha, converted_accumulator); // X = alpha * C + uniform + + return destination_converter(intermediate); + } +}; + +} // namespace thread + +namespace threadblock { +template +struct ApplyEpilogueOp< + thread::MemoryEfficientAttentionNormalize> { + using Op = thread::MemoryEfficientAttentionNormalize; + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentSource const& source) + { + return output_op(row_id, accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput + apply(Op const& output_op, int row_id, typename Op::FragmentAccumulator const& accum) + { + return output_op(row_id, accum); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h new file mode 100644 index 000000000000..46fb2bf17c1c --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct ArrayExponential { + CUTLASS_HOST_DEVICE + Array operator()( + Array const& input) const + { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { result[i] = expf(input[i]); } + + return result; + } +}; + +template +struct ArrayExponential { + CUTLASS_DEVICE + Array operator()(Array const& input) const + { + Array result; + + int const kVectorCount = ElementsPerAccess / 2; + + __half2 const* input_ptr = reinterpret_cast<__half2 const*>(input.raw_data()); + __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { res_ptr[i] = h2exp(input_ptr[i]); } + + return result; + } +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies: +/// output <- (input - lse).exp() +template +class ApplyLogSumExp { +public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementLSE = ElementLSE_; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + static const ScaleType::Kind kScale = cutlass::epilogue::thread::ScaleType::NoBetaScaling; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentLSE = Array; + using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h + +public: + // + // Methods + // + + CUTLASS_HOST_DEVICE + ApplyLogSumExp() {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return true; } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const& AB, + FragmentLSE const& scale_unused, + // bias used as LSE + FragmentLSE const& bias) const + { + FragmentCompute frag_AB = + NumericArrayConverter()(AB); + FragmentCompute frag_lse_compute = + NumericArrayConverter()(bias); + FragmentCompute frag_compute; + + minus minus_lse; + detail::ArrayExponential apply_exp; + frag_compute = minus_lse(frag_AB, frag_lse_compute); + frag_compute = apply_exp(frag_compute); + + return NumericArrayConverter()( + frag_compute); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma.h new file mode 100644 index 000000000000..75833bbfe7d2 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma.h @@ -0,0 +1,119 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "custom_mma_multistage.h" +#include "custom_mma_pipelined.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" + +template +struct MakeCustomMma; + +template +struct MakeCustomMma, + kMaxK> { + // Reduce the number of stages if we don't need that many + static int constexpr kStages = + kMaxK == cutlass::platform::numeric_limits::max() + ? Stages + : cutlass::const_min(Stages, (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); + using Mma = cutlass::gemm::threadblock::CustomMmaMultistage; +}; + +template +struct MakeCustomMma, + kMaxK> { + using Mma = cutlass::gemm::threadblock::CustomMmaPipelined; +}; diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_base.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_base.h new file mode 100644 index 000000000000..bbf91240b900 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_base.h @@ -0,0 +1,181 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaBase { +public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = + GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + template + struct OperandSharedStorage { + AlignedBuffer buffer; + using TensorRef = TensorRef; + + CUTLASS_DEVICE + static OperandLayout Layout() + { + return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn}); + } + + /// Returns a TensorRef to the operand + CUTLASS_HOST_DEVICE + TensorRef ref() { return TensorRef{buffer.data(), Layout()}; } + }; + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; + + using SharedStorageA = + OperandSharedStorage; + using SharedStorageB = + OperandSharedStorage; + using TensorRefA = typename SharedStorageA::TensorRef; + using TensorRefB = typename SharedStorageB::TensorRef; + + struct SharedStorage { + /// Buffer for A operand + SharedStorageA operand_A; + + /// Buffer for B operand + SharedStorageB operand_B; + }; + +protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorageA& shared_storageA, + SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx), + warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h new file mode 100644 index 000000000000..50ba58b1d1dd --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h @@ -0,0 +1,706 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Upper boundon the K dimension + int kMaxK = cutlass::platform::numeric_limits::max(), + /// Used for partial specialization + typename Enable = bool> +class CustomMmaMultistage : public CustomMmaBase { +public: + ///< Base class + using Base = CustomMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages; + static constexpr int kNumStagesConcurrentLoad = kSmemContainsEntireMat ? Stages : Stages - 1; + +private: + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + +private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + bool prologue_done_; + + // Set to `True` to ensure the accumulator will be zero outside the GEMM + // footprint + bool zero_outside_bounds_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx), + prologue_done_(false), + zero_outside_bounds_(false) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaMultistage(st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx) + { + } + + CUTLASS_DEVICE + bool set_prologue_done(bool value) { prologue_done_ = value; } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) { zero_outside_bounds_ = value; } + + template + CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) + { + prologue(shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue(typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) + { + SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx); + SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx); + int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK; + _prologue(iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) + { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + template + CUTLASS_DEVICE static void _prologue(IteratorA& iterator_A, + IteratorB& iterator_B, + int32_t& gemm_k_iterations, + SmemIteratorA& smem_iterator_A_, + SmemIteratorB& smem_iterator_B_) + { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + if (kLoadA) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + } + + ++iterator_A; + } + + ++smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + if (kLoadB) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + } + + ++iterator_B; + } + + ++smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const& src_accum) + { + // + // Prologue + // + + if (!prologue_done_) { + _prologue( + iterator_A, iterator_B, gemm_k_iterations, smem_iterator_A_, smem_iterator_B_); + } else if (!kSmemContainsEntireMat) { + _prologue( + iterator_A, iterator_B, gemm_k_iterations, smem_iterator_A_, smem_iterator_B_); + } else { + gemm_k_iterations -= kNumStagesConcurrentLoad; + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], + warp_transformed_frag_B[0], + warp_loaded_frag_A[0], + warp_loaded_frag_B[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum; + + if (platform::is_same::value || + platform::is_same::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + + // In case of a non-circular buffer ("kSmemContainsEntireMat") + // make sure we don't load out of bounds data. + if (!kSmemContainsEntireMat || gemm_k_iterations > (-kNumStagesConcurrentLoad) || + warp_mma_k < Base::kWarpGemmIterations - 1) { + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + if (platform::is_same::value || + platform::is_same::value) { + warp_mma(tmp_accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma(accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (!kSmemContainsEntireMat && warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + if (!kSmemContainsEntireMat) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (!kSmemContainsEntireMat && smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same::value || + platform::is_same::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_pipelined.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_pipelined.h new file mode 100644 index 000000000000..07b26ca31299 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_pipelined.h @@ -0,0 +1,388 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter, + /// + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaPipelined : public CustomMmaBase { +public: + ///< Base class + using Base = CustomMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "MmaPipelined requires kStages set to value 2"); + + static bool const kSmemContainsEntireMat = false; + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + +protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaPipelined(typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaPipelined( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaPipelined(st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx) + { + } + + CUTLASS_DEVICE + bool set_prologue_done(bool value) + { + // NOT IMPLEMENTED FOR PIPELINED + } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) + { + // NOT NEEDED FOR PIPELINED + // shared memory will always be zero-filled + } + + template + CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) + { + prologue(shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue(typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) + { + // NOT IMPLEMENTED FOR PIPELINED + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + TransformA transform_A = TransformA(), ///< transformation applied to A fragment + TransformB transform_B = TransformB()) + { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/find_default_mma.h b/csrc/deepspeed4science/evoformer_attn/gemm/find_default_mma.h new file mode 100644 index 000000000000..163dcbf85259 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/find_default_mma.h @@ -0,0 +1,191 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Cutlass provides helper template functions to figure out the right + datastructures to instantiate to run a GEMM with various parameters (see + `cutlass/gemm/threadblock/default_mma.h`). However, due to template + instantiation priority rules, it will only create an MmaMultiStage with + kStages=3 (otherwise creates an MmePipelined - which is not compatible with + FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, + so we just copy-pasted some code from `default_mma.h` and + `default_mma_core.h` files and wrapped this template to allow our usecase. + + This is really only for the FastF32 case - aka using TensorCores with fp32. +*/ + +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + typename Enable_ = void> +struct FindDefaultMma { + static constexpr bool AccumulatorsInRowMajor = false; + static constexpr SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone; + using DefaultMma = cutlass::gemm::threadblock::DefaultMma; +}; + +/// Specialization for sm80 / FastF32 / multistage with kStages=2 +template +struct FindDefaultMma 1)>::type> { + using LayoutC = layout::RowMajor; + using OperatorClass = arch::OpClassTensorOp; + using ArchTag = arch::Sm80; + + using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma; + struct DefaultMma : DefaultMma_ { + using MmaCore_ = typename DefaultMma_::MmaCore; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaMultistage; + }; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/mma_accum_lambda_iterator.h b/csrc/deepspeed4science/evoformer_attn/gemm/mma_accum_lambda_iterator.h new file mode 100644 index 000000000000..5e2f0cf681bf --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/mma_accum_lambda_iterator.h @@ -0,0 +1,347 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "cutlass/functional.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/matrix_shape.h" + +/* +TensorCores have different accumulator layouts. +This file provides a class to easily map the accumulator +i-th element with the corresponding matrix row/col. +*/ + +template +struct AccumLambdaIteratorSm80 { + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset) + { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord( + quad + tile_offset.row() * Shape::kRow, + lane_in_quad * kElementsPerAccess + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) + { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + row * kRowsPerTile + + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col + + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) + { + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + int lane_in_quad = (lane_id & 3); + return lane_in_quad == 0; + } +}; + +template +struct AccumLambdaIteratorSm70 { + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = + typename cutlass::platform::conditional::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset) + { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord(accum_m + tile_offset.row() * Shape::kRow, + accum_n + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) + { + static_assert(cutlass::platform::is_same::value, + "update to support non-float accum"); + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + // T0 & T2 share same line within a quad + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); + myValue = fn(myValue, otherV); + // quad 0 and quad 2 are on the same lines + otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); + myValue = fn(myValue, otherV); + return (lane_id & ((1 << 1) | (1 << 3))) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) + { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + + lane_offset.column(); + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + } +}; + +template +struct AccumLambdaIteratorSimt { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) + { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = n + Policy::LaneMmaShape::kN * + (mma_n + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + } + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset) + { + static_assert(cutlass::platform::is_same>::value, + ""); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = + lane_layout.inverse(lane_id) * + cutlass::MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + return lane_offset + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + } +}; + +template +struct DefaultMmaAccumLambdaIterator; + +// Simt +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaSimtTileIterator, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaSimtTileIterator; + using Iterator = AccumLambdaIteratorSimt; +}; + +// TensorOp - Volta +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator>, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>; + using Iterator = AccumLambdaIteratorSm70; +}; + +// TensorOp - Sm75+ +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp:: + MmaTensorOpAccumulatorTileIterator, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp:: + MmaTensorOpAccumulatorTileIterator; + using Iterator = AccumLambdaIteratorSm80; +}; diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/mma_from_smem.h b/csrc/deepspeed4science/evoformer_attn/gemm/mma_from_smem.h new file mode 100644 index 000000000000..40d3265c7a63 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm/mma_from_smem.h @@ -0,0 +1,1939 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/vector_iterator.h" + +#include "../epilogue/epilogue_thread_apply_logsumexp.h" +#include "../gemm/mma_accum_lambda_iterator.h" +#include "../gemm_kernel_utils.h" +#include "../iterators/make_residual_last.h" +#include "../iterators/transpose_warp_iterator.h" +#include "../iterators/warp_iterator_from_smem.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +/// Shared storage object needed by accumulator +/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +template +class AccumulatorSharedStorage { +public: + // + // Type definitions + // + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using Padding = Padding_; + + /// Tensor reference to the accumulator + using TensorRefAccum = cutlass::TensorRef; + + /// Shape of the accumulator matrix in shared memory + using ShapeAccum = + cutlass::MatrixShape; + +public: + // + // Data members + // + + /// Buffer for accumulator + cutlass::AlignedBuffer accum; + +public: + // + // Methods + // + + /// Returns a layout object for the Accum matrix + CUTLASS_DEVICE + static Layout LayoutAccum() { return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); } + + /// Returns a TensorRef to the Accumulator + CUTLASS_HOST_DEVICE + TensorRefAccum accum_ref() { return TensorRefAccum{accum.data(), LayoutAccum()}; } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // Maximum value for K + int kMaxK, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaBaseFromSharedMemory { +public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = + GemmShape; + using WarpCount1 = WarpCount; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = kWarpGemmIterations; + + /// Number of stages + static int const kStages = Stages; + + /// If this is true, we fill the entire shmem buffer at start + /// and don't need to iterate through it in a circular fashion + static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + // + // Methods + // + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() + { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } + }; + +protected: + // + // Data members + // + + // /// Iterator to load a warp-scoped tile of A operand from shared memory + // typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaBaseFromSharedMemory( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) + { + } +}; + +namespace { + +// has necessary trait compliance with WarpIteratorFromSmem but doesn't do +// anything, can be default initialized, and uses fragment that takes up +// (almost) no space. this warp iterator is selected at compile time when +// elementwise on-the-fly scaling for operand A is disabled, in which case +// operations related to loading scale factors for operand A get wiped out by +// the compiler. +template +class NoOpWarpIteratorScale { +public: + // in pipelined+multistage MMA implementations we keep an array of fragments. + // if we aren't using scaling we don't want to waste registers on fragments + // of scale elements, so ideally this would be sized 0. + // Since arrays of zero-sized objects are not allowed, using size as 1. + // The compiler will most likely wipe it out anyways. + using Fragment = cutlass::Array; + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale() {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale(TensorRef const&, int) {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& add_tile_offset(typename TensorRef::TensorCoord const&) { return *this; } + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& operator++() { return *this; } + + CUTLASS_DEVICE + void load(Fragment&) const {} +}; + +// if scaling is enabled, performs fragment elementwise multiplication between +// fragment and its scaling factor. +template +class FragmentElementwiseScaler; + +// specialization for scaling being enabled. +template +class FragmentElementwiseScaler { +public: + // cast scale_frag to correct type then apply elementwise to fragment + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const& scale_frag) + { + Fragment converted_scale_frag = + cutlass::NumericArrayConverter()(scale_frag); + return cutlass::multiplies()(frag, converted_scale_frag); + } +}; + +// specialization for scaling being disabled. doesn't do anything and should +// just get wiped out by the compiler. +template +class FragmentElementwiseScaler { +public: + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const&) { return frag; } +}; +} // namespace + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // BEGIN smem + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + // Accumulator type + typename AccumulatorSharedStorage, + // END smem + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter, + /// Used for partial specialization + typename Enable = bool> +class MmaPipelinedFromSharedMemory + : public MmaBaseFromSharedMemory { +public: + ///< Base class + using Base = MmaBaseFromSharedMemory; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< loads fragments of A_scale from shared memory if operand A scaling is + ///< enabled. otherwise no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA, + NoOpWarpIteratorScale>::type; + + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorB = SmemIteratorB_; + + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "MmaPipelined requires kStages set to value 2"); + +private: + using WarpFragmentA = typename Operator::FragmentA; + + /// fragment type of OperandA elementwise scaling matrix. (almost) empty + /// if operand A scaling is disabled. + using WarpFragmentAScale = typename WarpIteratorAScale::Fragment; + + using WarpFragmentB = typename Operator::FragmentB; + + /// applies scaling factor to operand A fragment if operand A scaling is + /// enabled. otherwise no-op. + using FragmentAScaler = + FragmentElementwiseScaler; + +protected: + // /// Iterator to write threadblock-scoped tile of A operand to shared memory + // SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to load a warp-scoped tile of A operand from intermediate + /// accumulator tile + WarpIteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of A_scale from intermediate + /// accumulator tile (only used if ScaleOperandA_ is true) + WarpIteratorAScale warp_tile_iterator_A_scale_; + +public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + // shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + // warp iterator over A tile held in shared memory + WarpIteratorA warp_iter_a, + // warp iterator over A_scale tile held in shared memory + WarpIteratorAScale warp_iter_a_scale, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(warp_iter_a), + warp_tile_iterator_A_scale_(warp_iter_a_scale), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::SharedStorage& shared_storage, ///< Shared storage needed for internal use + ///< by threadblock-scoped GEMM + AccumulatorSharedStorage& accumulator_shared_storage, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + // For API compatibility with MmaMultistageFromSharedMemory + // but not supported as it worsens perf: older gpus < sm80 don't + // support async transfers and have to waste registers + CUTLASS_DEVICE + void set_prologue_done(bool value) {} + CUTLASS_DEVICE + static void prologue(typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) + { + } + + CUTLASS_DEVICE + static void drain_cp_asyncs() {} + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + // IteratorA iterator_A, ///< iterator over A + // operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + // TransformA transform_A = TransformA(), ///< transformation + // applied to A fragment + TransformB transform_B = TransformB()) + { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentB tb_frag_B; + + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_B.set_residual_tile(gemm_k_iterations == 1); + iterator_B.load(tb_frag_B); + + ++iterator_B; + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_B_; + + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentAScale warp_frag_A_scale[2]; + WarpFragmentB warp_frag_B[2]; + warp_frag_A[0].clear(); + warp_frag_A_scale[0].clear(); + warp_frag_B[0].clear(); + + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_B.set_residual_tile(gemm_k_iterations == 2); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + bool hasNext = true; + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory SMEM: Don't reset iterator A, as + // we are continuing our iteration at this point + if (smem_write_stage_idx == 1) { + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + } + + smem_write_stage_idx ^= 1; + hasNext = gemm_k_iterations > 1; + } + + // Only read the next if we need to + if (hasNext) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_B.load(tb_frag_B); + + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B.set_residual_tile(gemm_k_iterations == 3); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + } + + warp_mma(accum, + FragmentAScaler::apply(warp_frag_A[warp_mma_k % 2], + warp_frag_A_scale[warp_mma_k % 2]), + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + // Accumulator type + typename AccumulatorSharedStorage, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages_, + int kMaxK_, + /// Used for partial specialization + typename Enable = bool> +class MmaMultistageFromSharedMemory + : public MmaBaseFromSharedMemory { +public: + ///< Base class + using Base = MmaBaseFromSharedMemory; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + using IteratorB = IteratorB1; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate + ///< accumulator tile in shared memory + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< warp level iterator over A_scale matrix tile kept in shared memory. + ///< if elementwise A scaling is disabled then everything this does is no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA1, + NoOpWarpIteratorScale>::type; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + using FragmentC = FragmentC1; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB1 = IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + }; + + static constexpr int kNumStagesConcurrentLoad = kSmemContainsEntireB ? Base::kStages + : Base::kStages - 1; + +private: + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + /// fragment of OperandA scale matrix. if operand A scaling is disabled this + /// is (almost) empty. + using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + /// applies elementwise scaling to fragment of A. if operand A scaling is + /// disabled this is a no-op. + using FragmentAScaler = + FragmentElementwiseScaler; + +private: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate + /// accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory + /// if operand A scaling is disabled everything this does is a no-op. + WarpIteratorAScale warp_tile_iterator_A1_scale_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + + bool prologue_done_; + +public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + // shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + // warp level iterator over operand A tile kept in shared memory + WarpIteratorA1 warp_tile_iterator_A1, + // warp level iterator over operand A elementwise scale tile kept in + // shared memory. + WarpIteratorAScale warp_tile_iterator_A1_scale, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(warp_tile_iterator_A1), + warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale), + smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), + prologue_done_(false) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + warp_tile_iterator_A1_scale_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::SharedStorage& shared_storage, ///< Shared storage needed for internal use + ///< by threadblock-scoped GEMM + AccumulatorSharedStorage& accumulator_shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx, + ///< GEMM0 N is used for accumulator extent + int problem_size_0_n) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(accumulator_shared_storage.accum_ref(), lane_idx), + smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), + prologue_done_(false) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + CUTLASS_DEVICE + void set_prologue_done(bool value) { prologue_done_ = value; } + + CUTLASS_DEVICE + static void prologue(typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) + { + SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx); + _prologue(iterator_B1, + (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK, + smem_iterator_B1); + } + + CUTLASS_DEVICE + static void drain_cp_asyncs() + { + // commit and drain all pending and predicated cp.async pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1(IteratorB1& iterator_B1, int group_start_B1 = 0) + { + iterator_B1.set_iteration_index(group_start_B1 * IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + CUTLASS_DEVICE + static void _prologue(IteratorB& iterator_B1, + int32_t gemm_k_iterations_1, + SmemIteratorB1& smem_iterator_B1_) + { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; ++stage, --gemm_k_iterations_1) { + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast(smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_1_, + ///< destination accumulator tile + FragmentC1& accum, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC1 const& src_accum) + { + // 2nd Gemm + + // + // Prologue + // + // Perform accumulation in the 'd' output operand + accum = src_accum; + + if (!prologue_done_) { + _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_); + } else if (!kSmemContainsEntireB) { + // Restore the iterators increments + + int gemm_k_iterations_1 = gemm_k_iterations_1_; + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; ++stage, --gemm_k_iterations_1) { + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { ++iterator_B1; } + ++this->smem_iterator_B1_; + } + iterator_B1.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1); + iterator_B1.clear_mask(gemm_k_iterations_1 <= 0); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma1.transform( + warp_transformed_frag_A1[0], + warp_transformed_frag_B1[0], + FragmentAScaler::apply(warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]), + warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC1 tmp_accum; + + if (platform::is_same::value || + platform::is_same::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); + gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + // Load warp-level tile from accumulator fragment (A) + // or shared memory (operand B) + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations1); + // skip warp tile loading for the last kgroup (we are out of the buf) + if (gemm_k_iterations_1 > (-Base::kStages + 2) || + warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_scale_.load( + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + ++warp_tile_iterator_A1_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma1.transform( + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + FragmentAScaler::apply(warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_A1_scale[warp_mma_k % 2]), + warp_loaded_frag_B1[warp_mma_k % 2]); + + if (platform::is_same::value || + platform::is_same::value) { + warp_mma1(tmp_accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma1(accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (!kSmemContainsEntireB) { + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + } + + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2); + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform( + warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + FragmentAScaler::apply(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]), + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same::value || + platform::is_same::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +template +struct DefaultWarpIteratorAFromSharedMemory {}; + +// TensorOp - Ampere half +template +struct DefaultWarpIteratorAFromSharedMemory< + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<(sizeof_bits::value == 16 && + Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + using WarpShape = cutlass::MatrixShape<32, 32>; + + using WarpIterator = + cutlass::gemm::warp::WarpIteratorFromSmem; +}; + +// TensorOp - Ampere f32 +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<(sizeof_bits::value != 16 || + Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + cutlass::MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajor, + cutlass::MatrixShape, + OpDelta::kRow, + kWarpSize>; +}; + +// TensorOp - Volta +template +struct DefaultWarpIteratorAFromSharedMemory, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape<32, 32>, // MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, + cutlass::MatrixShape<16, 4>, + OpDelta::kRow, + kWarpSize>; +}; + +// Simt +template +struct DefaultWarpIteratorAFromSharedMemory, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr auto kWarpSize = 32; + + // We just use the same iterator, as we reproduced the same shared-memory + // schema. Just modify it to handle non-complete tiles. + using WarpIterator = RegularWarpIterator; +}; + +// Converts a "regular" Mma into their counterpart from shared memory +template +struct DefaultMmaFromSharedMemory; + +// Mma pipelined +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_, + /// Transformation applied to B operand + typename TransformB_, + typename AccumulatorSharedStorage_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory, + AccumulatorSharedStorage_, + kScaleOperandA, + kTransposeA> { + static constexpr int kWarpSize = 32; + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + using RegularMma = MmaPipelined; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using ArchMmaOperator = typename Policy_::Operator; + + static constexpr bool kIsTransposedA = false; + using WarpIteratorA = + typename DefaultWarpIteratorAFromSharedMemory::WarpIterator; + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast::Iterator; + + using Mma = + typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory; +}; + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + typename AccumulatorSharedStorage_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory, + AccumulatorSharedStorage_, + kScaleOperandA, + kTransposeA> { + static constexpr int kWarpSize = 32; + + using RegularMma = MmaMultistage; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using WarpIteratorA_ = + typename DefaultWarpIteratorAFromSharedMemory::WarpIterator; + using WarpIteratorTranspose = TransposeWarpIterator; + static constexpr bool kIsTransposedA = WarpIteratorTranspose::kSupportsTranspose && kTransposeA; + using WarpIteratorA = typename platform:: + conditional::type; + + static int constexpr kMaxK = kIsTransposedA ? AccumulatorSharedStorage_::Shape::kM + : AccumulatorSharedStorage_::Shape::kN; + // Reduce the number of stages if we don't need that many + static int constexpr kStagesMax = (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); + static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); + + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast::Iterator; + using Mma = typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + AccumulatorSharedStorage_, + IteratorB, + SmemIteratorB_, + RegularMma::kCacheOpB, + ElementC_, + LayoutC_, + Policy_, + kStages, + kMaxK>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct B2bGemm; + +// Tensor Cores >= Sm75 specialization (Ampere ...) +template < /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp:: + MmaTensorOpAccumulatorTileIterator, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = typename cutlass::gemm::warp:: + MmaTensorOpAccumulatorTileIterator; + using FragmentC = typename IteratorC::Fragment; + using InstructionShape = InstructionShape_; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using accum_t = Element_; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + // Iterator to load accumulators (results of matmul in registers) + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape, + InstructionShape, + accum_t, + typename Operator::Policy::Operator::FragmentC, + cutlass::layout::RowMajor>; + + // Iterator to store to shared-memory + using SmemIteratorD0 = + typename cutlass::epilogue::warp::TileIteratorTensorOp; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage; + // We need to provide an operation for the epilogue. Let's create an + // operation that does nothing (ScaleType::Nothing), just converts + // from accum_t (float) -> scalar_t (can be half) + using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< + typename SmemIteratorD0::Element, // ElementOutput + FragmentIteratorAccumulator::Fragment::kElements, + accum_t, // ElementAccumulator + typename SmemIteratorD0::Element, // ElementCompute + cutlass::epilogue::thread::ScaleType::Nothing>; + using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + SmemIteratorD0, // ScaleBiasIterator + // - not used + OutputOpNoOp>; + + // Epilogue 2: with LSE (for backwards pass) + static int const kElementsPerAccess = 2; // TODO: Why 2? + using IteratorAccumulatorLSE = cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + // Shape + cutlass::MatrixShape, + // WarpShape + cutlass::MatrixShape, + lse_scalar_t, + cutlass::layout::RowMajor, + kElementsPerAccess>>; + using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp< + scalar_t, // ElementOutput_ + lse_scalar_t, // ElementLSE_ + accum_t, // ElementAccumulator_ + accum_t, // ElementCompute_ + 128 / cutlass::sizeof_bits::value + // FragmentIteratorAccumulator::Fragment::kElements + // InstructionShape::kM * InstructionShape::kN / 32 + >; + using EpilogueWithLSE = + cutlass::epilogue::threadblock::EpilogueSmemAccumulator; + + static void CUTLASS_DEVICE accumToSmem(AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) + { + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * cutlass::MatrixCoord{SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + Epilogue epilogue; + epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem(AccumulatorSharedStorage& shared_storage, + FragmentC& accum, + lse_scalar_t const* lse, + int32_t lse_extents, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) + { + constexpr int32_t kAlignLSE = 32; + IteratorAccumulatorLSE iterator_lse( + lse, + {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE}, + thread_id, + warp_id, + cutlass::MatrixCoord{0, 0} // offset + ); + + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * cutlass::MatrixCoord{SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + EpilogueWithLSE epilogue; + EpilogueOpApplyLSE minus_lse_exp({}); + epilogue(minus_lse_exp, + smem_iterator_attn, + accum, + // scale - unused + iterator_lse, + // bias + iterator_lse); + } +}; + +// Volta Specialization +// only supported for f16 +template +struct B2bGemm, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>, + Operator, + cutlass::half_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>; + using scalar_t = cutlass::half_t; + using accum_t = IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = IteratorC::Fragment; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorVoltaTensorOp, + scalar_t, + SmemAccumulatorLayout>; + + // // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< + 16, + 32>, // typename SmemIteratorD0::TensorLayout, + cutlass::MatrixShape<0, 0> // Padding + >; + + using OutputLayout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; + using TensorRef = cutlass::TensorRef; + using Policy = typename IteratorC::Policy; + using Element = accum_t; + // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields + // Let's copy their values + static int const kElementsPerPartial = 4; + using EleShapePerPatial = + typename cutlass::platform::conditional::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static void CUTLASS_DEVICE accumToSmem(AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) + { + // ctor - from MmaVoltaTensorOpAccumulatorTileIterator + TensorRef ref_(shared_storage.accum_ref()); + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + cutlass::MatrixCoord lane_offset(accum_m, accum_n); + + // Tile offset + ref_.add_coord_offset(tile_coords * cutlass::MatrixCoord({IteratorC::Shape::kRow, + IteratorC::Shape::kColumn})); + + using AccessType = cutlass::Array; + + // store - from MmaVoltaTensorOpAccumulatorTileIterator + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + int mma_accum_start = (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2; + int r = (accum_m + lane_offset.row()); + AccessType to_store; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + int c = (accum_n + n + lane_offset.column()); + to_store[n] = scalar_t(accum[idx]); + } + int c = (accum_n + lane_offset.column()); + assert(r < 32); + assert(c < 32); + *reinterpret_cast(ref_.data() + ref_.offset({r, c})) = + to_store; + } + } + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem(AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) + { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + auto lane_offset = AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +// Simt Specialization +// for f32 on Sm70-Sm75 and f16/f32 below + +template +struct B2bGemm, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>; + using accum_t = typename IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = typename IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage // Padding + >; + + static void CUTLASS_DEVICE accumToSmem(AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) + { + using Policy = typename IteratorC::Policy; + using Element = typename IteratorC::Element; + using Iterations = typename IteratorC::Iterations; + using Delta = typename IteratorC::Delta; + + auto ref_ = shared_storage.accum_ref(); + // ctor - MmaSimtTileIterator + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + + ref_.add_coord_offset(lane_offset); + + // Tile offset + ref_.add_coord_offset(tile_coords * cutlass::MatrixCoord({IteratorC::Shape::kRow, + IteratorC::Shape::kColumn})); + + // store - MmaSimtTileIterator + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int r = Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + m; + int c = mma_n * Delta::kColumn + n; + int idx = n + Policy::LaneMmaShape::kN * + (mma_n + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + ref_.at({r, c}) = scalar_t(accum[idx]); + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem(AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) + { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + auto lane_offset = AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h new file mode 100644 index 000000000000..f846658f2579 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include "cutlass/arch/mma.h" + +template +struct CheckArch { + static constexpr bool isPreVolta = arch::kMinComputeCapability < 70; + static constexpr bool isPreAmpere = + arch::kMinComputeCapability < 80 && arch::kMinComputeCapability >= 70; + static constexpr bool isAmpere = arch::kMinComputeCapability >= 80; +#if defined(__CUDA_ARCH__) + static constexpr bool compiler_cc = arch::kMinComputeCapability * 10 <= __CUDA_ARCH__; +#else + static constexpr bool compiler_cc = true; +#endif + static constexpr bool value = (isPreVolta && std::is_same_v) || + (isPreAmpere && !std::is_same_v) || + isAmpere && compiler_cc; +}; + +#define DISPATCH_ARCHTAG(CC, func) \ + { \ + if (CC >= 80) { \ + using ArchTag = cutlass::arch::Sm80; \ + func; \ + } else if (CC >= 75) { \ + using ArchTag = cutlass::arch::Sm75; \ + func; \ + } else if (CC >= 70) { \ + using ArchTag = cutlass::arch::Sm70; \ + func; \ + } else { \ + EVOFORMER_CHECK(false, "Only GPUs with Tensor Core are supported for now"); \ + } \ + } + +#define DISPATCH_TYPES(tensor, func) \ + { \ + if (tensor.scalar_type() == at::ScalarType::Half) { \ + using scalar_t = cutlass::half_t; \ + using torch_scalar_t = at::Half; \ + func(); \ + } else if (tensor.scalar_type() == at::ScalarType::BFloat16) { \ + using scalar_t = cutlass::bfloat16_t; \ + using torch_scalar_t = at::BFloat16; \ + func(); \ + } else { \ + EVOFORMER_CHECK(false, "Only fp16 and bf16 supported at the moment"); \ + } \ + } + +#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ + { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + F(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + F(); \ + } \ + } + +#ifdef TORCH_CHECK +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + EVOFORMER_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") +#define EVOFORMER_CHECK TORCH_CHECK +#elif defined(__CUDACC_RTC__) +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { return false; } +#define EVOFORMER_CHECK(COND, ERR) \ + if (!(COND)) { return false; } +#else +#include +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + std::cerr << #PTR " is not correctly aligned\n"; \ + return false; \ + } +#define EVOFORMER_CHECK(COND, ERR) \ + if (!(COND)) { \ + std::cerr << "[Evoformer Attention]" \ + << "'" #COND "' failed: " << ERR << "\n"; \ + return false; \ + } +#endif + +namespace gemm_kernel_utils { + +template +constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) +{ + return (n + m - 1) / m; +} + +template +constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) +{ + return ((n + m - 1) / m) * m; +} + +//////////////////////////////////////////////////////////////////////////////// +// Determine the type of GEMM we do (TensorCores or not, Shapes ...) +// TODO: Maybe we could rely on Cutlass's DefaultGemm templates +//////////////////////////////////////////////////////////////////////////////// + +// Fallback to Simt (FMA on cuda cores) if not in a special case below +template +struct DefaultGemmType { + static constexpr int ThreadK = 8; + static constexpr int WarpK = 8; + static constexpr int kMinimumAlignment = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using OpClass = cutlass::arch::OpClassSimt; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f32 +template +struct DefaultGemmType< + ArchTag, + float, + typename cutlass::platform::enable_if= 80>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAddFastF32; +}; + +// Specialization for tensorcores with f16/bf16 - Sm75+ +template +struct DefaultGemmType< + ArchTag, + scalar_t, + typename cutlass::platform::enable_if= 75 && + cutlass::sizeof_bits::value == 16>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f16 - Volta +template <> +struct DefaultGemmType { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 2; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Enables to do +// `auto x = kCondition ? fa(arg) : fb(arg)` +// when `fa` and `fb` have different types +template +struct call_conditional; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -> decltype(ta(arg)) + { + return ta(arg); + } +}; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -> decltype(tb(arg)) + { + return tb(arg); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Mark a variable as warp-uniform - enables some compiler optimizations +// The cheapest way to do it is just to broadcast it from lane 0 +//////////////////////////////////////////////////////////////////////////////// + +CUTLASS_DEVICE int32_t warp_uniform(int32_t value) +{ + return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0); +} + +template +CUTLASS_DEVICE T* warp_uniform(T* ptr) +{ + struct { + union { + T* ptr; + uint32_t asInt[2]; + }; + } p; + p.ptr = ptr; + p.asInt[0] = warp_uniform(p.asInt[0]); + p.asInt[1] = warp_uniform(p.asInt[1]); + return p.ptr; +} +} // namespace gemm_kernel_utils diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/epilogue_predicated_tile_iterator.h b/csrc/deepspeed4science/evoformer_attn/iterators/epilogue_predicated_tile_iterator.h new file mode 100644 index 000000000000..667f1982d30d --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/epilogue_predicated_tile_iterator.h @@ -0,0 +1,691 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Epilogue iterator that supports prefetching + + Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in +/// epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | +/// ForwardTileIterator +/// +template +class PredicatedTileIteratorPrefetch { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { predicates[i] = false; } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { predicates[i] = true; } + } + }; + +private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + // + // Methods + // + +public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorPrefetch(PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) + { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { mask_.clear(); } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = + reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void prefetch_all() + { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kIterations; ++iter) { + prefetch(); + ++(*this); + } + } + + CUTLASS_DEVICE + void prefetch() + { + uint8_t* byte_pointer = byte_pointer_; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + AccessType* memory_pointer = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + // on windows using unsigned long here gives the error + // error: asm operand type size(4) does not match + // type/size implied by constraint 'l' + uint64_t addr = + (uint64_t)((void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess]); + asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + + add_Q; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) row_add_P = 0; + if (output_Q > convolution_Q - 2) row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const + { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorPrefetch& operator++() + { + ++state_[0]; + + if (!ScatterD) { byte_pointer_ += params_.advance_row; } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +template +struct MakePrefetchableIterator { + using Iterator = PredicatedTileIteratorPrefetch; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/make_residual_last.h b/csrc/deepspeed4science/evoformer_attn/iterators/make_residual_last.h new file mode 100644 index 000000000000..ff0e324c3a6c --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/make_residual_last.h @@ -0,0 +1,91 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "predicated_tile_access_iterator_residual_last.h" +#include "predicated_tile_iterator_residual_last.h" + +namespace cutlass { +namespace transform { +namespace threadblock { + +template +struct MakeIteratorResidualLast; + +template +struct MakeIteratorResidualLast< + PredicatedTileIterator> { + using Iterator = PredicatedTileIteratorResidualLast; +}; + +template +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileAccessIteratorResidualLast; +}; +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h new file mode 100644 index 000000000000..7f6a2430845a --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_access_iterator_residual_last.h @@ -0,0 +1,1964 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorResidualLast +/// +template +class PredicatedTileAccessIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear +/// data. +/// +template +class PredicatedTileAccessIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : Base(layout.stride(0), + MakePredicatedTileAccessIteratorDesc()()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + +private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + +private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Below is used when Gather is turned on. We need to record strided_offset + /// and contiguous_offset separated to compute the offset by using + /// + /// offset = contiguous_offset + indices[strided_offset] + /// + + /// Gather indices + int const* indices_; + + Index gather_offset_strided; + +private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) + { + the_predicates.compute_predicates_(extent, is_steady_state); + } + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : params_(params), + pointer_(reinterpret_cast(const_cast(pointer))), + the_predicates(extent), + indices_(indices) + { + the_predicates.set_predicates(thread_id, threadblock_offset); + the_predicates.get_mask(residual_tile_mask); + + // Working around a weird compiler bug happening on P100 for the backward. + // I've seen together: the_predicates.predicates_[0] = 14 (instead of 15) + // residual_tile_mask[0] = 15 (correct) + // + // Adding prints when the value is calculated (in `compute_predicates_`) + // sometimes removes the bug. The consequence is that we skip some + // element of a tensor, leading to wrong results + // Setting `compute_predicates_`'s second argument (`is_steady_state`) to + // true also seems to get rid of the bug - at the cost of twice as many + // comparisons. +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700) + constexpr bool kWorkAroundCompilerBug = false; +#else + constexpr bool kWorkAroundCompilerBug = true; +#endif + the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug); + + // update internal pointers + Layout layout(params_.stride_); + + if (!Gather) { + add_pointer_offset(layout(the_predicates.thread_offset_)); + } else { + gather_offset_strided = the_predicates.thread_offset_.strided(); + add_pointer_offset(layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); + } + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { the_predicates.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) + { + if (is_residual_tile) { the_predicates.set_mask(residual_tile_mask); } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + if (!Gather) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); + gather_offset_strided += Shape::kStrided * tile_offset.strided(); + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const + { + if (Gather) { + assert(indices_); + + if (!valid()) { return nullptr; } + + LongIndex contiguous_offset = + the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + + the_predicates.iteration_vector_; + int strided_index = gather_offset_strided + + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + + LongIndex strided_offset = indices_[strided_index] * LongIndex(params_.stride_) * + sizeof_bits::value / 8; + + return reinterpret_cast(pointer_ + contiguous_offset + strided_offset); + } + + return reinterpret_cast( + pointer_ + the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * sizeof_bits::value) / + 8) + + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { return *this; } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + if (!Gather) { pointer_ += params_.inc_strided_; } + + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + if (!Gather) { + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, + // this subtraction as well as the subsequent integer addition are both + // elided by the compiler. + pointer_ -= params_.inc_advance_; + } + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { the_predicates.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { the_predicates.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { the_predicates.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return the_predicates.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType, + Gather>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()), + indices) + { + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { return reinterpret_cast(iterator_.get()); } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()), + indices) + { + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { return reinterpret_cast(iterator_.get()); } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIteratorResidualLast; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + // Default ctor + CUTLASS_HOST_DEVICE + Params() : stride_(0), inc_contiguous_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : stride_({layout.stride(0), layout.stride(1)}) + { + inc_contiguous_ = (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = + inc_strided_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = + Shape::kStrided * LongIndex(stride_[1]) * sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - + LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + +private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + +private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) + { + the_predicates.compute_predicates_(extent, is_steady_state); + } + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : params_(params), + pointer_(reinterpret_cast(const_cast(pointer))), + the_predicates(extent) + { + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { the_predicates.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) + { + if (is_residual_tile) { the_predicates.set_mask(residual_tile_mask); } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const + { + return reinterpret_cast(pointer_) + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { return *this; } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { the_predicates.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { the_predicates.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { the_predicates.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return the_predicates.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + +private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())) + { + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + iterator_.add_tile_offset(make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { return reinterpret_cast(iterator_.get()); } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + +private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())) + { + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + iterator_.add_tile_offset(make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { return reinterpret_cast(iterator_.get()); } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// interleaved data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class PredicatedTileAccessIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) + { + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { return reinterpret_cast(iterator_.get()); } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// interleaved data. +// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) + { + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast(params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) + { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { return reinterpret_cast(iterator_.get()); } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) + { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_atomic.h b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_atomic.h new file mode 100644 index 000000000000..8d4173f1a6a2 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_atomic.h @@ -0,0 +1,886 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once +#include +#include +#include +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +struct atomic_store {}; + +template +struct atomic_store::value>::type> { + using Element = typename AccessType::Element; + static const int kCount = AccessType::kElements; + + CUTLASS_DEVICE + atomic_store(AccessType const& D, void* ptr, bool pred_guard) + { + static_assert(!(kCount % 2), "kCount must be even"); + half2* p = reinterpret_cast(ptr); + uint const* data = reinterpret_cast(&D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + : + : "r"((int)pred_guard)); + for (int i = 0; i < kCount / 2; i++) { + asm volatile(" @p red.relaxed.global.add.noftz.f16x2 [%0], %1;\n" + : + : "l"(p + i), "r"(data[i])); + } + asm volatile("}\n" ::); + } +}; + +template +struct atomic_store::value>::type> { + using Element = typename AccessType::Element; + static const int kCount = AccessType::kElements; + + CUTLASS_DEVICE + atomic_store(AccessType const& D, void* ptr, bool pred_guard) + { + Element* p = reinterpret_cast(ptr); + uint const* data = reinterpret_cast(&D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + : + : "r"((int)pred_guard)); + for (int i = 0; i < kCount; i++) { + asm volatile(" @p red.relaxed.global.add.f32 [%0], %1;\n" + : + : "l"(p + i), "r"(data[i])); + } + asm volatile("}\n" ::); + } +}; + +template +class PredicatedTileIteratorAffineRankNAtomic { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::AffineRankN; + using TensorRef = TensorRef; + using TensorView = TensorView; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = typename Layout::TensorCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + static_assert(!(Layout::kRank % 2), + "Layout rank must be even. This assumes the first half of the " + "modes correspond to the 'row' " + "and the second half of the modes correspond to the 'column'"); + + static bool const kBigEndian = false; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Parameters structure + struct Params { + // + // Data members + // + + Layout layout; + + /// Stride in units of bytes along M modes + Coord stride_m; + + /// Stride in units of bytes along N modes + Coord stride_n; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank / 2 - 1)]; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank / 2 - 1)]; + + int64_t rank2_inc_col; + int64_t rank2_inc_row; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(TensorCoord const& extent, Layout const& layout_) : layout(layout_) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2; ++i) { + stride_m[i] = OffsetBytes(layout_.stride()[i]); + stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2]); + } + + if (kBigEndian) { + // "Big Endian" scheme + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { + divmod_m[i] = FastDivmod(extent[i + 1]); + divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]); + } + } else { + // "Little Endian" scheme + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2 - 1; ++i) { + divmod_m[i] = FastDivmod(extent[i]); + divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]); + } + } + } + + CUTLASS_HOST_DEVICE + Params(Layout const& layout_) : layout(layout_) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank / 2; ++i) { + stride_m[i] = OffsetBytes(layout_.stride()[i]); + stride_n[i] = OffsetBytes(layout_.stride()[i + Layout::kRank / 2]); + } + + rank2_inc_col = ThreadMap::Delta::kColumn * stride_n[0]; + rank2_inc_row = ThreadMap::Delta::kRow * stride_m[0]; + } + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { predicates[i] = false; } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { predicates[i] = true; } + } + }; + +private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in columns + Index extent_col_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column position (assuming steady-state predicates have + /// been computed) + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Offsets in columns, cached for performance + int64_t offset_modes_n_[ThreadMap::Iterations::kColumn]; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + +private: + // + // Methods + // + +public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorAffineRankNAtomic( + Params const& params, + Element* pointer, + MatrixCoord extent, + int thread_idx, + MatrixCoord threadblock_offset = MatrixCoord(), + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : params_(params) + { + MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_col_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + if (Layout::kRank > 2) { + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + // + // Compute coordinate and decompose into N modes + // + + int coord_n = thread_start_column_ + c * ThreadMap::Delta::kColumn; + + mask_.predicates[c] = coord_n < extent.column(); + + Coord modes_n; + + int64_t offset_modes_n = 0; + + if (kBigEndian) { + modes_n = CoordinateDecomposition(coord_n, params_.divmod_n); + + offset_modes_n = dot(modes_n, params_.stride_n); + } else { + modes_n = CoordinateDecompositionLittleEndian( + coord_n, params_.divmod_n); + + offset_modes_n = dot(modes_n, params_.stride_n); + } + + offset_modes_n_[c] = offset_modes_n; + } + + if (!pointer) { mask_.clear(); } + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer); + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) + { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + int row_begin = thread_start_row_ + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int64_t offset_modes_m = row_begin * params_.stride_m[0]; + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + // + // Compute coordinate and decompose into M modes + // + + int coord_m = row * ThreadMap::Delta::kRow + row_begin; + + Coord modes_m; + + if (Layout::kRank > 2) { + if (kBigEndian) { + modes_m = CoordinateDecomposition(coord_m, + params_.divmod_m); + } else { + modes_m = CoordinateDecompositionLittleEndian( + coord_m, params_.divmod_m); + } + + offset_modes_m = dot(modes_m, params_.stride_m); + } + + // + // Compute the offset due to modes M + // + + bool row_guard = (coord_m < extent_row_); + int64_t offset_modes_n = thread_start_column_ * params_.stride_n[0]; + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + // + // Compute coordinate and decompose into N modes + // + + if (Layout::kRank > 2) { offset_modes_n = offset_modes_n_[column]; } + + // + // Compute the pointer and access + // + bool guard; + if (Layout::kRank > 2) { + guard = row_guard && mask_.predicates[column]; + } else { + guard = (coord_m < extent_row_) && + ((thread_start_column_ + ThreadMap::Delta::kColumn * column) < + extent_col_); + } + + atomic_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)(byte_pointer + offset_modes_m + offset_modes_n + byte_offset), + guard); + + if (Layout::kRank == 2) { offset_modes_n += params_.rank2_inc_col; } + } + + if (Layout::kRank == 2) { offset_modes_m += params_.rank2_inc_row; } + } + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_byte_offset(frag, 0); } + + CUTLASS_DEVICE + void load(Fragment& frag) {} + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAffineRankNAtomic& operator++() + { + ++state_[0]; + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +template +class PredicatedTileIteratorAtomic { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static bool constexpr PermuteD = !layout::is_trivial_permute; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { predicates[i] = false; } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { predicates[i] = true; } + } + }; + +private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer. This pointer is usually for both load() and store(), + /// unless PermuteD is performed. When having PermuteD, byte_pointer_ is only + /// for load(). + uint8_t* byte_pointer_; + + /// Byte-level pointer for store(). Due to PermuteD Op, store_byte_pointer_ + /// may be with different address computation compared to byte_pointer_. + uint8_t* store_byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + /// PermuteDLayout + PermuteDLayout permute_layout_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + // + // Methods + // + +public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorAtomic(PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), + indices_(indices), + permute_layout_(PitchLinearCoord(extent.column(), extent.row()), + params_.stride * kElementsPerAccess / sizeof(AccessType)) + { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { mask_.clear(); } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize byte_pointer_ + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = + reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + // store_byte_pointer_ is set to be the same with byte_pointer_ unless + // PermuteD is used. + store_byte_pointer_ = PermuteD ? reinterpret_cast(pointer) : byte_pointer_; + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + store_byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = store_byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (PermuteD) { + int col_offset = column * ThreadMap::Delta::kColumn; + + int col = col_offset + thread_start_column_; + int row = row_offset + thread_start_row_; + + // Locate memory_pointer + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + permute_layout_(PitchLinearCoord(col, row)) * sizeof(AccessType) / + kElementsPerAccess); + } + atomic_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[0], + guard); + + if (!PermuteD) { + memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD && !PermuteD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } + + CUTLASS_DEVICE + void load(Fragment& frag) {} + + CUTLASS_DEVICE + MatrixCoord thread_start() const + { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAtomic& operator++() + { + ++state_[0]; + + if (!ScatterD && !PermuteD) { store_byte_pointer_ += params_.advance_row; } + + if (!ScatterD) { byte_pointer_ += params_.advance_row; } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + store_byte_pointer_ += params_.advance_group; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + store_byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + store_byte_pointer_ += params_.advance_tile; + + thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow * + ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile; + } + } + } + + return *this; + } + + /// Advances a number of positions to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorAtomic& operator+=(int increment) + { + // Row + state_[0] += increment; + int increment_row = state_[0] / ThreadMap::Count::kRow; + state_[0] = state_[0] % ThreadMap::Count::kRow; + + byte_pointer_ += (params_.advance_row * increment); + store_byte_pointer_ += (params_.advance_row * increment); + thread_start_row_ += (ThreadMap::Shape::kRow * increment); + + // Group + state_[1] += increment_row; + int increment_group = state_[1] / ThreadMap::Count::kGroup; + state_[1] = state_[1] % ThreadMap::Count::kGroup; + + byte_pointer_ += (params_.advance_group * increment_row); + store_byte_pointer_ += (params_.advance_group * increment_row); + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * + ThreadMap::Count::kRow * increment_row; + + // Cluster + state_[2] += increment_group; + int increment_cluster = state_[2] / ThreadMap::Count::kCluster; + state_[2] = state_[2] % ThreadMap::Count::kCluster; + + byte_pointer_ += (params_.advance_cluster * increment_group); + store_byte_pointer_ += (params_.advance_cluster * increment_group); + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow * increment_group; + + // Tile + byte_pointer_ += (params_.advance_tile * increment_cluster); + store_byte_pointer_ += (params_.advance_tile * increment_cluster); + thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow * + ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile * + increment_cluster; + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_residual_last.h b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_residual_last.h new file mode 100644 index 000000000000..629047dbb057 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/predicated_tile_iterator_residual_last.h @@ -0,0 +1,1938 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 + tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIteratorResidualLast +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize +/// register liveness and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" +/// object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is +/// constructed. Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator +/// is constructed. Subsequent additions to logical coordinate offset may be +/// performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be +/// partially full in both the advance dimension and the steady-state dimension. +/// This is assumed to be the last tile in the iteration sequence. Advancing an +/// iterator that has just been constructed moves to the first tile that is full +/// in the advance dimension and recomputes predicates. Subsequent accesses may +/// be performed without updating internal predicates and are efficient in terms +/// of live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced +/// at least once outside any looping structure to minimize integer arithmetic. +/// +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to +/// dereferencing the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update +// internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - +// subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to +// steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = +// transform::threadblock::PredicatedTileIteratorResidualLast; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template +class PredicatedTileIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = + AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : params_(base) {} + }; + +private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + +private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : address_iterator_(params.params_, pointer, extent, thread_id, threadblock_offset, indices) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { address_iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) + { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const* access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) + { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { *access_ptr = frag_ptr[idx]; } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = + PredicatedTileIteratorResidualLast, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()), + indices) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) + { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) + { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = + PredicatedTileIteratorResidualLast, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< Gather indices + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()), + indices) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) + { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) + { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = + AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + +private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + +private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : address_iterator_(params.params_, pointer, extent, thread_id, threadblock_offset) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { address_iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) + { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const* access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) + { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { *access_ptr = frag_ptr[idx]; } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = + PredicatedTileIteratorResidualLast, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) + { + } + }; + +private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) + { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) + { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = + PredicatedTileIteratorResidualLast, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) + { + } + }; + +private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) + { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) + { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class PredicatedTileIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 +/// data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorResidualLast, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { +public: + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = + cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) : params_(base) {} + }; + +private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_(params.params_, + pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) + { + } + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast(Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast(params, pointer, extent, thread_id, make_Coord(0, 0)) + { + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() + { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) + { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) + { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) + { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/transpose_warp_iterator.h b/csrc/deepspeed4science/evoformer_attn/iterators/transpose_warp_iterator.h new file mode 100644 index 000000000000..2435c07f8989 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/transpose_warp_iterator.h @@ -0,0 +1,57 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "warp_iterator_from_smem.h" + +template +struct TransposeWarpIterator { + using Iterator = char; + static bool constexpr kSupportsTranspose = false; +}; + +template < + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element, + bool kTranspose> +struct TransposeWarpIterator< + cutlass::gemm::warp::WarpIteratorFromSmem> { + using Iterator = cutlass::gemm::warp::WarpIteratorFromSmem; + static bool constexpr kSupportsTranspose = true; +}; diff --git a/csrc/deepspeed4science/evoformer_attn/iterators/warp_iterator_from_smem.h b/csrc/deepspeed4science/evoformer_attn/iterators/warp_iterator_from_smem.h new file mode 100644 index 000000000000..7dd59832b4b0 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/iterators/warp_iterator_from_smem.h @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/*! \file + \brief Inspired from + "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM + operands from a RowMajor shared-memory layout into registers to use by A100 + TensorCores. + + The difference with "mma_tensor_op_tile_access_iterator.h" is that: + (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly + faster) (2) We support to transpose the operand (eg read `A.transpose()` when + the shared memory holds `A`) + + This is only implemented for the specific shapes. +*/ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace warp { + +template < + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + bool kTranspose = false> +class WarpIteratorFromSmem { +public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = cutlass::MatrixShape<32, 32>; + + /// Operand tag + static Operand const kOperand = Operand_; + + /// Basic check + static_assert( + kOperand == Operand::kA || kOperand == Operand::kB, + "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + static_assert(sizeof_bits::value == 16, "Only supported for half"); + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = cutlass::MatrixShape<16, 8>; + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = 1; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 : 32 / sizeof_bits::value); + + using InstructionCount = MatrixShape; + + static int const kIterations = (kOperand == Operand::kA) ? InstructionCount::kColumn + : InstructionCount::kRow; + +public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = + Array; + + /// Memory access type + // using AccessType = AlignedArray; + using AccessType = Array; + + static int constexpr kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? InstructionShape::kColumn : InstructionShape::kRow); + static int constexpr kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + static int const kTilesPerInstruction = InstructionShape::kRow / 8; + +private: + /// Underlying tensor reference + TensorRef ref_; + + /// Origin + MatrixCoord origin_; + + /// Iterations in a tile + int iterations_; + +public: + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, int lane_id) + : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) + { + } + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) + : ref_(ref), iterations_(0) + { + int ldsm_vec_num = (lane_id >> 3); + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id % 8, 0); + static_assert(InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; + ++access_m_idx) { + int access_idx = + access_m_idx + + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx); + + MatrixCoord offset(access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { offset = MatrixCoord(offset.column(), offset.row()); } + origin_ += offset; + } + } + } + } + } else { + origin_ = MatrixCoord(0, lane_id % 8); + static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset(inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { offset = MatrixCoord(offset.column(), offset.row()); } + origin_ += offset; + } + } + } + } + + ref_.add_coord_offset(origin_); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) + { + TensorCoord coord_offset(tile_offset.row() * Shape::kRow, + tile_offset.column() * Shape::kColumn); + if (kTranspose) { coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; } + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + void advance() + { + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } else { + add_tile_offset({1, 0}); + } + + iterations_ = 0; + } + + /// increase iterations in a tile + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& operator++() + { + iterations_++; + + if (iterations_ >= kIterations) advance(); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_DEVICE + void load(Fragment& frag) const + { + AccessType* access_ptr = reinterpret_cast(&frag); + using LoadLayout = + typename platform::conditional::type; + + MatrixCoord offset; + if (kOperand == Operand::kA) { + offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn); + } else { + offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + } + if (kTranspose) { offset = MatrixCoord(offset.column(), offset.row()); } + cutlass::arch::ldsm(access_ptr[0], ref_.data() + ref_.offset(offset)); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/deepspeed4science/evoformer_attn/kernel_backward.h b/csrc/deepspeed4science/evoformer_attn/kernel_backward.h new file mode 100644 index 000000000000..87e6df18bb04 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/kernel_backward.h @@ -0,0 +1,1965 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/fast_math.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "gemm_kernel_utils.h" + +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "epilogue/epilogue_pipelined.h" +#include "iterators/epilogue_predicated_tile_iterator.h" + +#include "epilogue/epilogue_grad_bias.h" +#include "gemm/custom_mma.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_accum_lambda_iterator.h" +#include "gemm/mma_from_smem.h" +#include "transform/bias_broadcast.h" +#include "transform/tile_smem_loader.h" + +#include + +using namespace gemm_kernel_utils; + +namespace { + +template +struct GmemTile { + /* + Helper functions to efficient store/load RF to gmem + + GEMM accumulators have a particular format on A100, and + it takes some compute/shared-memory to rearrange them to + a RowMajor or ColumnMajor format in global memory through + an Epilogue. The same complexity goes for loading into RF. + + This class loads/stores RF as they are, and can be used for + efficient accumulation across gemms for instance: + + ``` + GmemTile tile; + for (int i = 0; i < N; ++i) { + // ... + + Fragment accum; + if (i == 0) { + accum.clear(); + } else { + tile.load(accum); + } + mma(accum, ...); + if (i < N-1) { + // Store for next GEMM + tile.store(accum); + } else { + // Store in tensor (eg RowMajor) + epilogue(accum); + } + + // ... + } + ``` + */ + + // 128bits per thread + using AccessType = cutlass::Array; + static constexpr int32_t kBytes = sizeof(AccessType); + static constexpr int32_t kStride = kNumThreads * AccessType::kElements; + static constexpr int32_t kNumIters = FragmentType::kElements / AccessType::kElements; + static constexpr int32_t kElementsStored = kNumThreads * FragmentType::kElements; + static_assert(FragmentType::kElements % AccessType::kElements == 0, + "fragment not aligned on 128 bits"); + + float* ptr; + + CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + cutlass::arch::global_load(sub_fragment, gmem_ptr, true); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + fragment[i * AccessType::kElements + j] = sub_fragment[j]; + } + } + } + + CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + sub_fragment[j] = fragment[i * AccessType::kElements + j]; + } + cutlass::arch::global_store(sub_fragment, gmem_ptr, true); + } + } +}; + +template +constexpr int getWarpsPerSm() +{ + constexpr bool is_half = !cutlass::platform::is_same::value; + if (Arch::kMinComputeCapability >= 80) { return is_half ? 12 : 8; } + return 8; +} +} // namespace + +template < + // which arch we target (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // input/output type + typename scalar_t_, + // run optimized kernel because memory accesses will be aligned + bool kIsAligned_, + // use dropout if enabled + bool kApplyDropout_, + // when doing a GEMM, preload the next one (uses more shmem) + bool kPreload_, + // block dimensions + int kBlockSizeI_, + int kBlockSizeJ_, + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + template class Broadcast1_ = BroadcastNoLoad, + template class Broadcast2_ = BroadcastNoLoad> +struct AttentionBackwardKernel { + using scalar_t = scalar_t_; + using output_t = scalar_t; + using output_accum_t = float; + using lse_scalar_t = float; + using accum_t = float; + using ArchTag = ArchTag_; + static constexpr bool kIsAligned = kIsAligned_; + static constexpr bool kApplyDropout = kApplyDropout_; + static constexpr bool kPreload = kPreload_; + static constexpr int kBlockSizeI = kBlockSizeI_; + static constexpr int kBlockSizeJ = kBlockSizeJ_; + static constexpr int kMaxK = kMaxK_; + + struct Params { + // Input tensors + scalar_t* query_ptr; // [Mq, nH, K] + scalar_t* key_ptr; // [Mk, nH, K] + scalar_t* value_ptr; // [Mk, nH, Kv] + lse_scalar_t* logsumexp_ptr; // [nH, Mq] + scalar_t* output_ptr; // [Mq, nH, Kv] + scalar_t* grad_output_ptr; // [Mq, nH, Kv] + accum_t* delta_ptr; // [nH, Mq] + int32_t* cu_seqlens_q_ptr = nullptr; + int32_t* cu_seqlens_k_ptr = nullptr; + + // Output tensors + output_t* grad_query_ptr; // [Mq, nH, K] + output_t* grad_key_ptr; // [Mk, nH, K] + output_t* grad_value_ptr; // [Mk, nH, Kv] + + accum_t* grad_bias1_ptr = nullptr; + accum_t* grad_bias2_ptr = nullptr; + int32_t B = 0; + int32_t N = 0; + scalar_t* bias1_ptr = nullptr; + scalar_t* bias2_ptr = nullptr; + + // Accumulators + union { + output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv] + output_accum_t* workspace_gk; + }; + output_accum_t* workspace_gv; // (will be calculated by the kernel) + output_accum_t* workspace_gq; // (will be calculated by the kernel) + + // Scale + accum_t scale; + + // Dimensions/strides + int32_t head_dim = -1; + int32_t head_dim_value = -1; + int32_t num_queries = -1; + int32_t num_keys = -1; + int32_t num_heads = -1; + + int32_t q_strideM; + int32_t k_strideM; + int32_t v_strideM; + int32_t gO_strideM; + int32_t gB_strideM; + int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise + + // RNG sequence offset based on batch_id and head_id + unsigned long long dropout_batch_head_rng_offset; + float dropout_prob = 0.0f; + + CUTLASS_HOST_DEVICE int32_t o_strideM() const { return head_dim_value * num_heads; } + CUTLASS_HOST_DEVICE int32_t gQ_strideM() const + { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gK_strideM() const + { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gV_strideM() const + { + return gQKV_strideM_multiplier * num_heads * head_dim_value; + } + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int64_t o_strideH; + int32_t q_strideH; + int32_t k_strideH; + int32_t v_strideH; + int64_t o_strideB; + int64_t q_strideB; + int64_t k_strideB; + int64_t v_strideB; + int64_t lse_strideB; + int64_t lse_strideH; + int64_t delta_strideB; + int64_t delta_strideH; + int32_t num_batches; + + int64_t gO_strideB = 0; + int64_t gQ_strideB = 0; + int64_t gK_strideB = 0; + int64_t gV_strideB = 0; + int64_t gB_strideB = 0; + int64_t gO_strideH = 0; + int64_t gQ_strideH = 0; + int64_t gK_strideH = 0; + int64_t gV_strideH = 0; + int64_t gB_strideH = 0; + + CUTLASS_DEVICE bool advance_to_block() + { + int64_t batch_id = blockIdx.z; + int32_t head_id = blockIdx.y; + + if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { + assert(workspace_size() == 0 || workspace != nullptr); + + workspace += (batch_id * num_heads + head_id) * workspace_strideBH(); + workspace = warp_uniform(workspace); + workspace_gv = workspace + workspace_elements_gk(); + workspace_gq = workspace_gv + workspace_elements_gv(); + } else { + workspace = nullptr; + } + + // Advance pointers that depend on the total concatenated + // number of queries, as `num_queries` is modified in the block + // below + dropout_batch_head_rng_offset = batch_id * (num_heads * num_queries * num_keys) + + head_id * (num_queries * num_keys); + logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH; + + query_ptr += batch_id * q_strideB + head_id * q_strideH; + key_ptr += batch_id * k_strideB + head_id * k_strideH; + value_ptr += batch_id * v_strideB + head_id * v_strideH; + output_ptr += batch_id * o_strideB + head_id * o_strideH; + grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; + delta_ptr += batch_id * delta_strideB + head_id * delta_strideH; + + grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; + grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; + grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + using broadcast_1 = Broadcast1_; + using broadcast_2 = Broadcast2_; + + if (broadcast_1::kEnable && grad_bias1_ptr) { + grad_bias1_ptr += batch_id * num_queries; + } + if (broadcast_2::kEnable && grad_bias2_ptr) { + auto strideB = num_heads * num_queries * num_keys; + auto strideH = num_queries * num_keys; + grad_bias2_ptr += (batch_id / N) * strideB + head_id * strideH; + } + if (broadcast_1::kEnable && bias1_ptr) { + bias1_ptr = broadcast_1::advance(bias1_ptr, + batch_id / N, + batch_id % N, + head_id, + num_queries * N, + num_queries, + 0); + } + if (broadcast_2::kEnable && bias2_ptr) { + auto strideB = num_heads * num_queries * num_keys; + auto strideH = num_queries * num_keys; + bias2_ptr = broadcast_2::advance( + bias2_ptr, batch_id / N, batch_id % N, head_id, strideB, 0, strideH); + } + + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); + if (broadcast_1::kEnable) { + grad_bias1_ptr = warp_uniform(grad_bias1_ptr); + bias1_ptr = warp_uniform(bias1_ptr); + } + if (broadcast_2::kEnable) { + grad_bias2_ptr = warp_uniform(grad_bias2_ptr); + bias2_ptr = warp_uniform(bias2_ptr); + } + + return true; + } + + __host__ dim3 getBlocksGrid() const { return dim3(1, num_heads, num_batches); } + __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1); } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const + { + if (!kNeedsAccumGradK) { return 0; } + return align_up(num_keys, (int32_t)kBlockSizeJ) * + align_up(head_dim, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const + { + if (!kNeedsAccumGradV) { return 0; } + return align_up(num_keys, (int32_t)kBlockSizeJ) * + align_up(head_dim_value, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const + { + if (!kNeedsAccumGradQ) { return 0; } + if (num_keys <= kBlockSizeJ) { return 0; } + return align_up(num_queries, (int32_t)kBlockSizeI) * + align_up(head_dim, (int32_t)kBlockSizeJ); + } + CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const + { + // Aligned on 128bits + return align_up( + workspace_elements_gk() + workspace_elements_gv() + workspace_elements_gq(), + int64_t(4)); + } + CUTLASS_HOST_DEVICE int64_t workspace_size() const + { + // Returns size of buffer we need to run this kernel + return num_batches * num_heads * workspace_strideBH() * sizeof(float); + } + }; + + static constexpr int64_t kWarpSize = 32; + + // If this is true, we store and accumulate dK/dV in RF + // rather than going back to gmem every time + static constexpr bool kIsHalf = cutlass::sizeof_bits::value <= 16; + static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; + static_assert(!kPreload || (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF), + "preload MMA not supported"); + static constexpr bool kPrologueQK = kPreload; + static constexpr bool kPrologueGV = kPreload; + static constexpr bool kPrologueDOV = kPreload; + static constexpr bool kPrologueGQ = kPreload; + static constexpr bool kPrologueGK = kPreload; + + static constexpr int64_t kNumWarpsPerBlock = (kBlockSizeI * kBlockSizeJ) / (32 * 32); + + // Compute delta for the f16 kernels + // TODO: Figure out why it's slower on the f32 kernels + // (something due to RF pressure?) + // TODO: Remove condition on `kOutputInRF` - this is needed to work + // around a compiler bug on V100, not exactly sure why but I spent + // too much time on this already. Reproducible with + // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance + static constexpr bool kKernelComputesDelta = + kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70); + + static constexpr bool kNeedsAccumGradQ = + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradK = + !kOutputInRF && !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradV = + !kOutputInRF && !cutlass::platform::is_same::value; + + // Launch bounds + static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int64_t kMinBlocksPerSm = + getWarpsPerSm() / kNumWarpsPerBlock; + + using GemmType = DefaultGemmType; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration; + static constexpr auto kOptimalAlignement = + cutlass::platform::max(DefaultConfig::kAlignmentA, DefaultConfig::kAlignmentB); + static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment; + + struct MatmulQK { + /* + attn_T = k_j @ q_i.transpose(-2, -1) # matmul + attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, + -1)).exp() # epilogue + + with attn_T.shape = (kBlockSizeJ, kBlockSizeI) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + accum_t, // ElementC + cutlass::layout::RowMajor, // LayoutC + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + DefaultConfig::kStages, + typename GemmType::Operator, + false, // AccumulatorsInRowMajor = false, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using MmaCore = typename DefaultMma::MmaCore; + using Mma = typename MakeCustomMma::Mma; + + // used for efficient load of bias tile (Bij) from global memory to shared + // memory + using BiasLoader = + TileSmemLoader, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = + typename cutlass::gemm::threadblock::B2bGemm; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradV { + /* + grad_v[j_start:j_end] += attn_T @ do_i # matmul + + Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) + (we might need to iterate multiple times on K) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = + cutlass::gemm::kernel::DefaultGemm; + + // if dropout: + // for computing dVj += (Pij.T * Zij) @ dOi + // Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of + // Pij.T are loaded in. The reason we do it this way is because Pij.T and + // Zij are reused in later steps, while Pij_dropped.T is only needed in + // this step. computing Pij_dropped.T on the fly allows us to avoid + // keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the + // same time. + // if no dropout: + // for computing dVj += Pij.T @ dOi + using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MatmulQK::AccumulatorSharedStorage, + kApplyDropout>; // kScaleOperandA + + using Mma = typename DefaultMmaFromSmem::Mma; + using WarpIteratorA = typename DefaultMmaFromSmem::WarpIteratorA; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + struct MatmulDOIVJ { + /* + doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul + tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue? + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + + using ElementC = accum_t; // CSY: Change it for better accuracy + using ElementAccum = accum_t; + + // no-op output op - epilogue just stores result to global memory + using BiasGradEpilogueOutputOp = typename cutlass::epilogue::thread::LinearCombination< + ElementC, + DefaultConfig::EpilogueOutputOp::kCount, + typename DefaultConfig::EpilogueOutputOp::ElementAccumulator, + typename DefaultConfig::EpilogueOutputOp::ElementCompute, + cutlass::epilogue::thread::ScaleType::Nothing>; + + using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + ElementC, // ElementC + cutlass::layout::RowMajor, // LayoutC + ElementAccum, // ElementAccumulator + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + BiasGradEpilogueOutputOp, // EpilogueOutputOp + void, // ThreadblockSwizzle (not used) + // multiple preloads, dropout Zij tile, and 3 stages push us over shared + // memory capacity on A100. set a ceiling on number of stages to save + // shared memory if dropout is in use. + kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64) + ? cutlass::const_min(2, DefaultConfig::kStages) + : DefaultConfig::kStages, // Stages + false, // SplitKSerial + typename GemmType::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using Mma = typename MakeCustomMma::Mma; + + // epilogue used to write bias gradient, which is just the output of this + // matmul with some operations applied to the fragment + using BiasGradEpilogue = typename DefaultGemm::Epilogue; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = + typename cutlass::gemm::threadblock::B2bGemm; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradQ { + // grad_q <- tmp @ k_j + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = + cutlass::gemm::kernel::DefaultGemm; + + using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MatmulDOIVJ::AccumulatorSharedStorage, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + struct MatmulGradK { + // grad_k <- tmp.transpose(-2, -1) @ q_i + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = + cutlass::gemm::kernel::DefaultGemm; + + using DefaultMmaFromSmemN = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MatmulQK::AccumulatorSharedStorage, + false>; // kScaleOperandA + using DefaultMmaFromSmemT = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MatmulDOIVJ::AccumulatorSharedStorage, + false, // kScaleOperandA + kPreload>; // kTransposeA + using DefaultMmaFromSmem = + typename cutlass::platform::conditional::type; + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + using broadcast_1 = Broadcast1_; + using broadcast_2 = Broadcast2_; + + // shared storage for keeping Zij matrix. not needed if we aren't using + // dropout, in which case we use an empty array to save shared memory + using ZijSharedStorage = typename cutlass::platform::conditional< + kApplyDropout, + typename MatmulQK::AccumulatorSharedStorage, + // dummy shared storage object that takes up no space. + typename cutlass::gemm::threadblock::AccumulatorSharedStorage< +#ifdef _WIN32 + // windows builds throw the error: + // "type containing an unknown-size array is not allowed" + // if we try to make Zij shared storage zero-sized. + // To get around this just make it sized 1 on windows. + typename cutlass::gemm::GemmShape<1, 1, 0>, +#else + typename cutlass::gemm::GemmShape<0, 0, 0>, +#endif + typename MatmulQK::AccumulatorSharedStorage::Element, + typename MatmulQK::AccumulatorSharedStorage::Layout, + typename cutlass::MatrixShape<0, 0>>>::type; + + struct SharedStoragePrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + } persistent; + union { + struct { + // part1 - after Q.K / dV / dO.V + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + // typename MatmulQK::BiasLoader::SmemTile bias; + cutlass::AlignedBuffer bias; + // 4. store Pij. it is needed: + // - in dVj += (Pij.T * Zij) @ dOi + // - in dSij = Pij * (dPij - Di) + // 6. dVj += (Pij.T * Zij) @ dOi + // 10. write to fragment + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 5. store Zij. it is needed: + // - to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij + // are loaded for the computation of dVj. + // - to compute dPij = (dOi @ Vj.T) * Zij + // 6. used in dVj += (Pij.T * Zij) @ dOi + // 9. used in dPij = dPij_dropped * Zij + ZijSharedStorage zij; + + union { + // 2. prologue for dVj + // 6. workspace for dVj += (Pij.T * Zij) @ dOi + typename MatmulGradV::Mma::SharedStorage mm_gradV; + // 7. dVj epilogue + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + + // 3. prologue for dPij_dropped + // 8. used in dPij_dropped = dOi @ Vj.T + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + } part1; + + struct { + // part2 - dQ + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload) + union { + // store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + }; + + } part2; + + struct { + // part3 - after last iteration on dQ's epilogue / dK + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue_lastIter; + + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + } part3; + + struct { + // part4 - after last iteration on dK's epilogue / preload next K.Q_t + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + + // If we reach end of current key, dump RF->gmem with "final" epilogues + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue_final; + } part4; + }; +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { return INSIDE_STRUCT.FIELDNAME; } + + FIELD(persistent, di) + FIELD(persistent, mm_qk_k) + FIELD(part1, bias) + FIELD(part1, attn_shared_storage) + FIELD(part1, zij) + FIELD(part1, mm_gradV) + FIELD(part1, gradV_epilogue) + FIELD(part1, mm_doivj) + FIELD(part2, mm_gradK) + FIELD(part2, mm_gradQ) + FIELD(part2, gradB_epilogue) + FIELD(part2, gradQ_epilogue) + FIELD(part2, tmp_shared_storage) + FIELD(part3, tmpT_shared_storage) + FIELD(part3, gradQ_epilogue_lastIter) + FIELD(part3, gradK_epilogue) + FIELD(part4, mm_qk_q) + FIELD(part4, gradK_epilogue_final) + FIELD(part4, gradV_epilogue_final) + }; + + struct SharedStorageNoPrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + } persistent; + union { + struct { + // part1 - Q.K matmul + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + } part1; + + struct { + // part2 - compute gradV + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + cutlass::AlignedBuffer bias; + // 2. store Pij to shared memory. it is needed: + // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi + // - in next step where it is used in dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 3. store Zij. it is needed: + // - in this step, where it is used to compute Pij_dropped = Pij * Zij + // on the + // fly as fragments of Pij are loaded for the computation of dVj. + // - later to compute dPij = (dOi @ Vj.T) * Zij + ZijSharedStorage zij; + + union { + typename MatmulGradV::Mma::SharedStorage mm_gradV; + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + } part2; + + struct { + // part3 - DO.V matmul + union { + // first compute dPij = (dOi @ Vj.T) * Zij + // and dSij = Pij * (dPij - Di) + struct { + // (from part2) - Pij for computing dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + // (from part2) - Zij for computing dPij = dPij_dropped * Zij + ZijSharedStorage zij; + // matmul to compute dOiVj + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + }; + // then store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + }; + } part3; + + struct { + // part4 - compute gradQ + typename MatmulQK::AccumulatorSharedStorage tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue_lastIter; + }; + } part4; + + struct { + // part5 - compute gradK + typename MatmulQK::AccumulatorSharedStorage tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradK::Mma::SharedStorage mm_gradK; + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + }; + } part5; + + struct { + // part6 - store RF accumulated into gmem + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue_final; + } part6; + }; +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { return INSIDE_STRUCT.FIELDNAME; } + + FIELD(persistent, di) + FIELD(part1, mm_qk_k) + FIELD(part1, mm_qk_q) + FIELD(part2, bias) + FIELD(part2, attn_shared_storage) + FIELD(part2, zij) + FIELD(part2, mm_gradV) + FIELD(part2, gradV_epilogue) + FIELD(part3, mm_doivj) + FIELD(part3, gradB_epilogue) + FIELD(part4, tmpT_shared_storage) + FIELD(part4, tmp_shared_storage) + FIELD(part4, mm_gradQ) + FIELD(part4, gradQ_epilogue) + FIELD(part4, gradQ_epilogue_lastIter) + FIELD(part5, mm_gradK) + FIELD(part5, gradK_epilogue) + FIELD(part6, gradK_epilogue_final) + FIELD(part6, gradV_epilogue_final) + }; + + using SharedStorage = typename cutlass::platform:: + conditional::type; + + struct OutputFragments { + typename MatmulGradV::Mma::FragmentC gradV; + typename MatmulGradK::Mma::FragmentC gradK; + + CUTLASS_DEVICE void clear() + { + gradV.clear(); + gradK.clear(); + } + }; + + static bool __host__ check_supported(Params const& p) + { + CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); + EVOFORMER_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned"); + EVOFORMER_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned"); + EVOFORMER_CHECK(p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0, + "query is not correctly aligned (strideH)"); + EVOFORMER_CHECK(p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0, + "key is not correctly aligned (strideH)"); + EVOFORMER_CHECK(p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0, + "value is not correctly aligned (strideH)"); + EVOFORMER_CHECK(p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, + "query is not correctly aligned (strideB)"); + EVOFORMER_CHECK(p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, + "key is not correctly aligned (strideB)"); + EVOFORMER_CHECK(p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0, + "value is not correctly aligned (strideB)"); + EVOFORMER_CHECK(p.q_strideM % kMinimumAlignment == 0, + "query is not correctly aligned (strideM)"); + EVOFORMER_CHECK(p.k_strideM % kMinimumAlignment == 0, + "key is not correctly aligned (strideM)"); + EVOFORMER_CHECK(p.v_strideM % kMinimumAlignment == 0, + "value is not correctly aligned (strideM)"); + EVOFORMER_CHECK(p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f, + "Invalid value for `dropout_prob`"); + EVOFORMER_CHECK(kApplyDropout || p.dropout_prob == 0.0f, + "Set `kApplyDropout`=True to support `dropout_prob > 0`"); + EVOFORMER_CHECK(p.head_dim > 0, "Invalid value for `head_dim`"); + EVOFORMER_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`"); + EVOFORMER_CHECK(p.num_queries > 0, "Invalid value for `num_queries`"); + EVOFORMER_CHECK(p.num_keys > 0, "Invalid value for `num_keys`"); + EVOFORMER_CHECK(p.num_heads > 0, "Invalid value for `num_heads`"); + EVOFORMER_CHECK(p.num_batches > 0, "Invalid value for `num_batches`"); + EVOFORMER_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`"); + EVOFORMER_CHECK(p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`"); + return true; + } + + static CUTLASS_DEVICE void attention_kernel(Params p) + { + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + + uint16_t thread_id = threadIdx.x; + uint8_t warp_id = warp_uniform(thread_id / 32); + uint8_t lane_id = thread_id % 32; + + if (kPrologueQK) { + prologueQkNextIteration(shared_storage, p, 0, 0, warp_id, lane_id); + } + + // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr` + if (kKernelComputesDelta) { + constexpr int kOptimalElements = 128 / cutlass::sizeof_bits::value; + if (p.head_dim_value % kOptimalElements == 0) { + for (int query_start = 0; query_start < p.num_queries; query_start += kBlockSizeI) { + computeDelta(p, query_start, warp_id, lane_id); + } + } else { + for (int query_start = 0; query_start < p.num_queries; query_start += kBlockSizeI) { + computeDelta<1>(p, query_start, warp_id, lane_id); + } + } + __syncthreads(); + } + + OutputFragments output_frags; + + int32_t key_start = 0; + int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ; + for (; key_start < key_end; key_start += kBlockSizeJ) { + output_frags.clear(); + int32_t query_start = getQueryStart(p, key_start); + int32_t query_end = + query_start + (p.num_queries - query_start) / kBlockSizeI * kBlockSizeI; + for (; query_start < query_end; query_start += kBlockSizeI) { + processBlockIJ( + shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id); + } + // last (partial) query + if (query_start < p.num_queries) { + processBlockIJ( + shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id); + } + if (kOutputInRF) { + writeFragsToGmem( + shared_storage, output_frags, p, key_start, warp_id, lane_id); + } else if (getQueryStart(p, key_start) >= p.num_queries) { + zfillGradKV(p, key_start, warp_id, lane_id); + } + __syncthreads(); + } + // Last (partial) key + if (key_start != p.num_keys) { + output_frags.clear(); + int32_t query_start = getQueryStart(p, key_start); + for (; query_start < p.num_queries; query_start += kBlockSizeI) { + warp_id = warp_uniform(warp_id); + processBlockIJ( + shared_storage, output_frags, p, query_start, key_start, warp_id, lane_id); + } + if (kOutputInRF) { + writeFragsToGmem( + shared_storage, output_frags, p, key_start, warp_id, lane_id); + } else if (getQueryStart(p, key_start) >= p.num_queries) { + zfillGradKV(p, key_start, warp_id, lane_id); + } + } + } + + static CUTLASS_DEVICE void loadDi(cutlass::Array& di, + Params const& p, + int32_t query_start) + { + int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; + if (thread_id < kBlockSizeI) { + accum_t di_rf = accum_t(0); + if (query_start + thread_id < p.num_queries) { + di_rf = p.delta_ptr[query_start + thread_id]; + } + di[thread_id] = di_rf; + } + } + + template + static CUTLASS_DEVICE void zfillGradKV(Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) + { + constexpr int kThreadsPerKey = 8; + constexpr int kParallelKeys = kNumThreads / kThreadsPerKey; + static_assert(kBlockSizeJ % kParallelKeys == 0, ""); + // This function is not really optimized, but should rarely be used + // It's only used when some keys are "useless" and don't attend to + // any query, due to causal masking + int thread_id = 32 * warp_id + lane_id; + int k_shift = lane_id % kThreadsPerKey; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) { + int key = key_start + j + (thread_id / kThreadsPerKey); + if (!skipBoundsChecks && key >= p.num_keys) { continue; } + auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM(); + auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM(); + + for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) { + gv_ptr[k] = scalar_t(0); + } + for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) { gk_ptr[k] = scalar_t(0); } + } + } + + template + static CUTLASS_DEVICE void processBlockIJ(SharedStorage& shared_storage, + OutputFragments& output_frags, + Params& p, + int32_t query_start, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) + { + cutlass::MatrixCoord no_offset{0, 0}; + accum_t scale = p.scale; + int16_t thread_id = 32 * warp_id + lane_id; + auto rematerializeThreadIds = [&]() { + // Prevents `nvcc` from keeping values deduced from + // `thread_id`, `warp_id`, ... in RF - to reduce register pressure + warp_id = warp_uniform(thread_id / 32); + lane_id = thread_id % 32; + thread_id = 32 * warp_id + lane_id; + }; + + bool isFirstQuery = (query_start == getQueryStart(p, key_start)); + int32_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + bool isLastQuery = next_key != key_start; + __syncthreads(); + loadDi(shared_storage.di(), p, query_start); + + int32_t num_queries_in_block = + skipBoundsChecks ? MatmulQK::Mma::Shape::kN + : warp_uniform(cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kN, + p.num_queries - query_start)); + int32_t num_keys_in_block = + skipBoundsChecks ? MatmulQK::Mma::Shape::kM + : warp_uniform(cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kM, + p.num_keys - key_start)); + + auto prologueGradV = [&](int col) { + typename MatmulGradV::Mma::IteratorB iterator_dO( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, + {num_queries_in_block, p.head_dim_value - col}, + thread_id, + no_offset); + MatmulGradV::Mma::prologue( + shared_storage.mm_gradV(), iterator_dO, thread_id, num_queries_in_block); + }; + auto prologueGradQ = [&](int col) { + typename MatmulGradQ::Mma::IteratorB iterator_K( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, + {num_keys_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradQ::Mma::prologue( + shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block); + }; + auto prologueGradK = [&](int col) { + typename MatmulGradK::Mma::IteratorB iterator_Q( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, + {num_queries_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradK::Mma::prologue( + shared_storage.mm_gradK(), iterator_Q, thread_id, num_queries_in_block); + }; + auto prologueDOV = [&]() { + typename MatmulDOIVJ::Mma::IteratorA iterator_A( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + typename MatmulDOIVJ::Mma::IteratorB iterator_B({int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + MatmulDOIVJ::Mma::prologue( + shared_storage.mm_doivj(), iterator_A, iterator_B, thread_id, p.head_dim_value); + }; + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulQK + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulQK::Mma; + + cutlass::gemm::GemmCoord problem_size(num_keys_in_block, + num_queries_in_block, + p.head_dim // k + ); + + // k_j + typename Mma::IteratorA iterator_A({int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, + {problem_size.m(), problem_size.k()}, + thread_id, + no_offset); + + // q_i.transpose(-2, -1) + typename Mma::IteratorB iterator_B({int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + Mma mma( + shared_storage.mm_qk_k(), shared_storage.mm_qk_q(), thread_id, warp_id, lane_id); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma.set_prologue_done(kPrologueQK); + mma.set_zero_outside_bounds(!skipBoundsChecks); + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Epilogue: add LSE + exp and store that to our shared memory buffer + // shmem <- (matmul_result - + // logsumexp[i_start:i_end].unsqueeze(1)).exp() + int warp_idx_mn_0 = warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + + if (broadcast_1::kEnable || broadcast_2::kEnable) { + cutlass::TensorRef bias_tensor_ref( + shared_storage.bias().data(), + cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM)); + using Shape = cutlass::MatrixShape; + AttentionBiasEpilogue + bias_epilogue; + bias_epilogue(bias_tensor_ref, + p.bias1_ptr + key_start, + p.bias2_ptr + query_start * p.num_keys + key_start, + thread_id, + {num_queries_in_block, num_keys_in_block}, + p.num_keys); + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_n) {}, + [&](int accum_m, int accum_n, int idx) { + // remember we are transposed + accum[idx] = accum[idx] * scale + bias_tensor_ref.at({accum_n, accum_m}); + }, + [&](int accum_n) {}); + } else { + accum = cutlass::multiplies()(scale, accum); + } + + __syncthreads(); + if (kPrologueGV) { prologueGradV(0); } + if (kPrologueDOV) { prologueDOV(); } + + MatmulQK::B2bGemm::accumApplyLSEToSmem(shared_storage.attn_shared_storage(), + accum, + p.logsumexp_ptr + query_start, + problem_size.n(), + thread_id, + warp_id, + lane_id, + output_tile_coords); + + __syncthreads(); + } + rematerializeThreadIds(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradV matmul + // + // grad_v[j_start:j_end] += attn_T @ do_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + constexpr bool kSingleIterationGradV = kMaxK <= MatmulGradV::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value); + col += MatmulGradV::ThreadblockShape::kN) { + using Mma = typename MatmulGradV::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, p.head_dim_value - col, num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradV::OutputTileIterator( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM() + col, + {num_keys_in_block, p.head_dim_value - col}, + thread_id); + }; + typename Mma::IteratorB iterator_B({int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, + {num_queries_in_block, p.head_dim_value - col}, + thread_id, + no_offset); + + // if dropout: dVj += (Pij.T * Zij) @ dOi + // otherwise: dVj += Pij.T @ dOi + Mma mma(shared_storage.mm_gradV(), + // operand A: Pij + typename MatmulGradV::WarpIteratorA( + shared_storage.attn_shared_storage().accum_ref(), lane_id), + // if we're using dropout, operand A is Pij_dropped = Pij * Zij + // which is computed on the fly as fragments of Pij are loaded in + typename Mma::WarpIteratorAScale(shared_storage.zij().accum_ref(), lane_id), + thread_id, + warp_id, + lane_id); + + int storage_id = col / MatmulGradV::ThreadblockShape::kN; + AccumTileGmem gmem_tile{p.workspace_gv + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradV) { + output_frags.gradV.clear(); + } else { + gmem_tile.load(output_frags.gradV, thread_id); + } + } + mma.set_prologue_done(kPrologueGV); + + auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, output_frags.gradV, iterator_B, output_frags.gradV); + __syncthreads(); + if (kPrologueGV && !kSingleIterationGradV && + col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) { + prologueGradV(col + MatmulGradV::ThreadblockShape::kN); + } + + if (!kOutputInRF) { + if (kNeedsAccumGradV && !isLastQuery) { + gmem_tile.store(output_frags.gradV, thread_id); + } else { + accumulateInGmem(shared_storage.gradV_epilogue(), + output_frags.gradV, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradV, + warp_id, + lane_id); + } + } + } + __syncthreads(); + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulDOIVJ + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulDOIVJ::Mma; + // do_i + typename Mma::IteratorA iterator_A({int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + + // v_j.transpose(-2, -1) + typename Mma::IteratorB iterator_B({int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + + Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id); + mma.set_prologue_done(kPrologueDOV); + mma.set_zero_outside_bounds(!skipBoundsChecks); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + if (kPrologueGQ) { prologueGradQ(0); } + if (kPrologueGK) { prologueGradK(0); } + + int warp_idx_mn_0 = warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + // TODO: This must be terribly inefficient. There must be a better way + // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem] + // attn_shared_storage [smem] <- tmp.T + // tmp_shared_storage [smem] <- tmp + { + using LambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, output_tile_coords); + + auto attn_T = shared_storage.attn_shared_storage().accum_ref(); + accum_t current_di; + // dSij = (dPij - Di) * Pij + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { current_di = shared_storage.di()[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + if (skipBoundsChecks || + (accum_m < num_queries_in_block && accum_n < num_keys_in_block)) { + accum_t attn = attn_T.at({accum_n, accum_m}); + accum[idx] = (accum[idx] - current_di) * attn; + } else { + accum[idx] = 0; + } + }, + [&](int accum_m) { + + }); + + using DefaultGemm = typename MatmulDOIVJ::DefaultGemm; + using OutputOp = typename MatmulDOIVJ::BiasGradEpilogueOutputOp; + if (broadcast_1::kEnable && p.grad_bias1_ptr) { + using Epilogue = + typename BiasGradEpilogueAffineRankN::Epilogue; + cutlass::layout::AffineRankN<2> layout({0, 1}); + auto dst_ptr = p.grad_bias1_ptr + key_start; + typename Epilogue::OutputTileIterator output_iter( + {layout}, + dst_ptr, + {num_queries_in_block, num_keys_in_block}, + (int)thread_id); + Epilogue epilogue(shared_storage.gradB_epilogue(), + (int)thread_id, + (int)warp_id, + (int)lane_id); + epilogue(OutputOp(1), output_iter, accum); + } + + if (broadcast_2::kEnable && p.grad_bias2_ptr) { + if (broadcast_1::kEnable) { __syncthreads(); } + using Epilogue = + typename BiasGradEpilogue::Epilogue; + typename Epilogue::OutputTileIterator::Params params{p.num_keys}; + auto dst_ptr = p.grad_bias2_ptr + query_start * p.num_keys + key_start; + typename Epilogue::OutputTileIterator output_iter( + params, dst_ptr, {num_queries_in_block, num_keys_in_block}, (int)thread_id); + Epilogue epilogue(shared_storage.gradB_epilogue(), + (int)thread_id, + (int)warp_id, + (int)lane_id); + epilogue(OutputOp(1), output_iter, accum); + } + + accum = accum * scale; + + __syncthreads(); + if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) { + auto tmpT = shared_storage.tmpT_shared_storage().accum_ref(); + // attn <- attn_T.T + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]); + }, + [&](int accum_m) {}); + } + } + + MatmulDOIVJ::B2bGemm::accumToSmem( + shared_storage.tmp_shared_storage(), accum, lane_id, output_tile_coords); + __syncthreads(); + } + p.head_dim = warp_uniform(p.head_dim); + p.k_strideM = warp_uniform(p.k_strideM); + rematerializeThreadIds(); + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradQ matmul + // + // grad_q[i_start:i_end] += tmp @ k_j + ///////////////////////////////////////////////////////////////////////////////////////////////// + // Skip the loop & associated branches if we know at compile time the number + // of iterations + constexpr bool kSingleIterationGradQ = kMaxK <= MatmulGradQ::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim); + col += MatmulGradQ::ThreadblockShape::kN) { + using Mma = typename MatmulGradQ::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_queries_in_block, + false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col, + num_keys_in_block); + + // k_j + typename Mma::IteratorB iterator_B({int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto a = shared_storage.tmp_shared_storage().accum_ref(); + Mma mma(shared_storage.mm_gradQ(), + shared_storage.tmp_shared_storage(), + thread_id, + warp_id, + lane_id, + problem_size.k()); + + typename Mma::FragmentC accum; + + bool isFirst = key_start == 0; + int col_id = col / MatmulGradQ::ThreadblockShape::kN; + int num_cols = + kSingleIterationGradQ ? 1 : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN); + int storage_id = (col_id + query_start / kBlockSizeI * num_cols); + AccumTileGmem gmem_tile{p.workspace_gq + storage_id * AccumTileGmem::kElementsStored}; + if (isFirst || !kNeedsAccumGradQ) { + accum.clear(); + } else { + gmem_tile.load(accum, thread_id); + } + + auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + mma.set_prologue_done(kPrologueGQ); + mma(gemm_k_iterations, accum, iterator_B, accum); + __syncthreads(); + bool isLastColumn = kSingleIterationGradQ || + (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim); + if (kPrologueGQ && !isLastColumn) { + prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN); + } + + // Output results + int32_t next_query, next_key; + incrIteration(p, p.num_queries, key_start, next_query, next_key); + bool isLast = next_query > query_start || next_key >= p.num_keys; + if (kNeedsAccumGradQ && !isLast) { + gmem_tile.store(accum, thread_id); + } else { + typename MatmulGradQ::OutputTileIterator output_it( + typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, + p.grad_query_ptr + query_start * p.gQ_strideM() + col, + {problem_size.m(), problem_size.n()}, + thread_id); + accumulateInGmem(isLastColumn + ? shared_storage.gradQ_epilogue_lastIter() + : shared_storage.gradQ_epilogue(), + accum, + output_it, + isFirst || kNeedsAccumGradQ, + warp_id, + lane_id); + } + } + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradK matmul + // + // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + rematerializeThreadIds(); + + constexpr bool kSingleIterationGradK = kMaxK <= MatmulGradK::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim); + col += MatmulGradK::ThreadblockShape::kN) { + using Mma = typename MatmulGradK::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col, + num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradK::OutputTileIterator( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM() + col, + {num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, + thread_id); + }; + + // q_i + typename Mma::IteratorB iterator_B({int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); }; + auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); }; + // this is basically: + // opA = kIsTransposedA ? getTmp() : getTmpT(); + bool constexpr kIsTransposedA = MatmulGradK::DefaultMmaFromSmem::kIsTransposedA; + auto& opA = + *call_conditional::apply( + getTmp, getTmpT, 0); + Mma mma(shared_storage.mm_gradK(), opA, thread_id, warp_id, lane_id, problem_size.k()); + + int storage_id = col / MatmulGradK::ThreadblockShape::kN; + AccumTileGmem gmem_tile{p.workspace_gk + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradK) { + output_frags.gradK.clear(); + } else { + gmem_tile.load(output_frags.gradK, thread_id); + } + } + mma.set_prologue_done(kPrologueGK); + + auto gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, output_frags.gradK, iterator_B, output_frags.gradK); + __syncthreads(); + bool isLastColumn = kSingleIterationGradK || + col + MatmulGradK::ThreadblockShape::kN >= p.head_dim; + if (kPrologueGK && !isLastColumn) { + prologueGradK(col + MatmulGradK::ThreadblockShape::kN); + } + + if (kPrologueQK && isLastColumn) { + int32_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + DISPATCH_BOOL(next_key != key_start, kForceReloadK, ([&]() { + prologueQkNextIteration( + shared_storage, p, next_query, next_key, warp_id, lane_id); + })); + } + + // Output results + if (!kOutputInRF) { + if (kNeedsAccumGradK && !isLastQuery) { + gmem_tile.store(output_frags.gradK, thread_id); + } else { + accumulateInGmem(isLastColumn + ? shared_storage.gradK_epilogue_final() + : shared_storage.gradK_epilogue(), + output_frags.gradK, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradK, + warp_id, + lane_id); + __syncthreads(); + } + } + } + } + + static CUTLASS_DEVICE int32_t getQueryStart(Params const& p, int32_t key_start) { return 0; }; + + static CUTLASS_DEVICE void incrIteration(Params const& p, + int32_t query_start, + int32_t key_start, + int32_t& next_query, + int32_t& next_key) + { + next_query = query_start + kBlockSizeI; + next_key = key_start; + if (next_query >= p.num_queries) { + next_key = key_start + kBlockSizeJ; + next_query = getQueryStart(p, next_key); + } + } + + template + static CUTLASS_DEVICE void prologueQkNextIteration(SharedStorage& shared_storage, + Params const& p, + int32_t query_start, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) + { + if (query_start >= p.num_queries || key_start >= p.num_keys) { return; } + + static constexpr bool kReloadK = kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; + int thread_id = 32 * warp_id + lane_id; + typename MatmulQK::Mma::IteratorA iterator_A({int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, + {p.num_keys - key_start, p.head_dim}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + typename MatmulQK::Mma::IteratorB iterator_B({int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, + {p.head_dim, p.num_queries - query_start}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + MatmulQK::Mma::prologue(shared_storage.mm_qk_k(), + shared_storage.mm_qk_q(), + iterator_A, + iterator_B, + thread_id, + p.head_dim); + } + + template + static CUTLASS_DEVICE void writeFragsToGmem(SharedStorage& shared_storage, + OutputFragments& output_frags, + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) + { + uint16_t thread_id = 32 * warp_id + lane_id; + int32_t num_keys_in_block = + skipBoundsChecks + ? MatmulQK::Mma::Shape::kM + : cutlass::fast_min((int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); + typename MatmulGradV::OutputTileIterator outputV_it( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM(), + {num_keys_in_block, p.head_dim_value}, + thread_id); + accumulateInGmem(shared_storage.gradV_epilogue_final(), + output_frags.gradV, + outputV_it, + true, + warp_id, + lane_id); + + typename MatmulGradK::OutputTileIterator outputK_it( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM(), + {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, + thread_id); + accumulateInGmem(shared_storage.gradK_epilogue_final(), + output_frags.gradK, + outputK_it, + true, + warp_id, + lane_id); + } + + template + static CUTLASS_DEVICE void accumulateInGmem( + typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, + typename MatmulT::Mma::FragmentC const& accum, + typename MatmulT::OutputTileIterator output_it, + bool first, + uint8_t warp_id, + uint8_t lane_id) + { + using DefaultEpilogue = typename MatmulT::DefaultEpilogue; + using DefaultOutputOp = typename MatmulT::DefaultOutputOp; + using Mma = typename MatmulT::Mma; + int thread_id = 32 * warp_id + lane_id; + DISPATCH_BOOL( + first, kIsFirst, ([&]() { + static constexpr auto ScaleType = + kIsFirst ? cutlass::epilogue::thread::ScaleType::Nothing + : cutlass::epilogue::thread::ScaleType::NoBetaScaling; + using EpilogueOutputOp = typename cutlass::epilogue::thread::LinearCombination< + typename DefaultOutputOp::ElementOutput, + DefaultOutputOp::kCount, + typename DefaultOutputOp::ElementAccumulator, + typename DefaultOutputOp::ElementCompute, + ScaleType>; + using Epilogue = typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MatmulT::OutputTileIterator, + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true // IterationsUnroll + >; + EpilogueOutputOp rescale({1, 1}); + Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id); + epilogue(rescale, output_it, accum, output_it); + })); + } + + template + static CUTLASS_DEVICE void computeDelta(Params const& p, + int32_t query_start, + uint8_t warp_id, + uint8_t lane_id) + { + // Each thread computes one value for Delta + // Depending on warp configuration, we might have multiple + // threads of the same warp working on the same row + using AccessType = cutlass::Array; + static_assert(kNumThreads >= kBlockSizeI, ""); + static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI; + int16_t thread_id = 32 * warp_id + lane_id; + + int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine); + int16_t laneRow = thread_id / kNumThreadsPerLine; + bool rowPred = (query_start + laneRow) < p.num_queries; + bool pred = rowPred; + + // on windows, previous syntax __restrict__ AccessType* + // resulted in error: "restrict" is not allowed + const AccessType* __restrict__ grad_output_ptr = reinterpret_cast( + p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + laneFirstCol); + const AccessType* __restrict__ output_ptr = reinterpret_cast( + p.output_ptr + (query_start + laneRow) * p.o_strideM() + laneFirstCol); + + static constexpr int64_t kMaxIters = kMaxK / (kElementsPerAccess * kNumThreadsPerLine); + constexpr int kPipelineStages = 2; + accum_t delta_value = accum_t(0); + using GlobalLoad = cutlass::arch::global_load; + AccessType frag_grad_output[kPipelineStages]; + AccessType frag_output[kPipelineStages]; + + auto loadAndIncrement = [&](int ld_pos, bool is_valid) { + frag_grad_output[ld_pos].clear(); + frag_output[ld_pos].clear(); + GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid); + GlobalLoad(frag_output[ld_pos], output_ptr, is_valid); + grad_output_ptr += kNumThreadsPerLine; + output_ptr += kNumThreadsPerLine; + }; + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kPipelineStages - 1; ++iter) { + int ld_pos = iter % kPipelineStages; + pred = pred && (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) < + p.head_dim_value; + loadAndIncrement(ld_pos, pred); + } + auto columnIteration = [&](int iter) { + // Load for next iter + int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages; + pred = pred && (laneFirstCol + (iter + kPipelineStages - 1) * kElementsPerAccess * + kNumThreadsPerLine) < p.head_dim_value; + loadAndIncrement(ld_pos, pred); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < AccessType::kElements; ++i) { + delta_value += accum_t(frag_output[iter % kPipelineStages][i]) * + accum_t(frag_grad_output[iter % kPipelineStages][i]); + } + }; + + // If we have a small lower-bound for K, we can unroll the loop + if (kMaxK <= 256) { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kMaxIters; ++iter) { columnIteration(iter); } + } else { + int num_iters = ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) * + (kElementsPerAccess * kNumThreadsPerLine); + for (int iter = 0; iter < num_iters; ++iter) { columnIteration(iter); } + } + + // Reduce between workers + static_assert(kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 || kNumThreadsPerLine == 4, + ""); + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kNumThreadsPerLine; i *= 2) { + delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i); + } + + // Store in gmem + if (rowPred) { p.delta_ptr[query_start + laneRow] = delta_value; } + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched_impl(typename AK::Params p) +{ + if (!p.advance_to_block()) { return; } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched(typename AK::Params params); diff --git a/csrc/deepspeed4science/evoformer_attn/kernel_forward.h b/csrc/deepspeed4science/evoformer_attn/kernel_forward.h new file mode 100644 index 000000000000..e3b11ebcc661 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/kernel_forward.h @@ -0,0 +1,986 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "epilogue/epilogue_pipelined.h" +#include "epilogue/epilogue_rescale_output.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" +#include "gemm_kernel_utils.h" +#include "transform/bias_broadcast.h" +#include "transform/tile_smem_loader.h" + +#include + +using namespace gemm_kernel_utils; + +namespace { +template +constexpr int getWarpsPerSm() +{ + return (Arch::kMinComputeCapability >= 80 && !cutlass::platform::is_same::value + ? 16 + : 12); +} +static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) +{ + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} +} // namespace + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock, + int kKeysPerBlock_, + bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock` + // This is quite slower on V100 for some reason + // Set to false if you know at compile-time you will never need dropout + bool kSupportsBias_ = false, + template class Broadcast1_ = BroadcastNoLoad, + template class Broadcast2_ = BroadcastNoLoad> +struct AttentionKernel { + using scalar_t = scalar_t_; + using accum_t = float; + using lse_scalar_t = float; + using output_t = scalar_t; + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + static constexpr bool kSupportsBias = kSupportsBias_; + static constexpr int kKeysPerBlock = kKeysPerBlock_; + static constexpr bool kIsAligned = isAligned_; + static constexpr bool kSingleValueIteration = kSingleValueIteration_; + static constexpr int32_t kAlignLSE = 32; // block size of backward + static constexpr bool kPreloadV = + ArchTag::kMinComputeCapability >= 80 && cutlass::sizeof_bits::value == 16; + static constexpr bool kKeepOutputInRF = kSingleValueIteration; + static constexpr bool kNeedsOutputAccumulatorBuffer = + !kKeepOutputInRF && !cutlass::platform::is_same::value; + + static_assert(kQueriesPerBlock % 32 == 0, ""); + static_assert(kKeysPerBlock % 32 == 0, ""); + static constexpr int kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (32 * 32); + static constexpr int kWarpSize = 32; + + // Launch bounds + static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int kMinBlocksPerSm = getWarpsPerSm() / kNumWarpsPerBlock; + + struct Params { + // Input tensors + scalar_t* query_ptr; // [num_queries, num_heads, head_dim] + scalar_t* key_ptr; // [num_keys, num_heads, head_dim] + scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] + + // Output tensors + output_t* output_ptr; // [num_queries, num_heads, head_dim_value] + output_accum_t* output_accum_ptr; // [num_queries, num_heads, head_dim_value] + lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null + + // Scale + accum_t scale; + + // Dimensions/strides + int32_t head_dim; + int32_t head_dim_value; + int32_t num_queries; + int32_t num_keys; + + int32_t q_strideM; + int32_t k_strideM; + int32_t v_strideM; + // int32_t bias_strideM = 0; + + int32_t o_strideM = 0; + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int32_t q_strideH; + int32_t k_strideH; + int32_t v_strideH; + // int32_t bias_strideH = 0; + + int64_t q_strideB; + int64_t k_strideB; + int64_t v_strideB; + // int32_t bias_strideB = 0; + + int32_t num_batches; + int32_t num_heads; + + // Parameters for biases + scalar_t* bias1_ptr = nullptr; + scalar_t* bias2_ptr = nullptr; + int32_t B = 0; + int32_t N = 0; + + // Moves pointers to what we should process + // Returns "false" if there is no work to do + CUTLASS_DEVICE bool advance_to_block() + { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; + + query_ptr += batch_id * q_strideB; + key_ptr += batch_id * k_strideB; + value_ptr += batch_id * v_strideB; + output_ptr += int64_t(batch_id * num_queries) * o_strideM; + if (output_accum_ptr != nullptr) { + output_accum_ptr += int64_t(batch_id * num_queries) * (head_dim_value * num_heads); + } + + int64_t q_start = 0, k_start = 0; + // Advance to the current batch / head / query_start + query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; + key_ptr += k_start * k_strideM + head_id * k_strideH; + + value_ptr += k_start * v_strideM + head_id * v_strideH; + output_ptr += int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value; + + if (output_accum_ptr != nullptr) { + output_accum_ptr += int64_t(q_start + query_start) * (head_dim_value * num_heads) + + head_id * head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + output_accum_ptr = (accum_t*)output_ptr; + } + + if (logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + logsumexp_ptr += batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; + } + + using broadcast_1 = Broadcast1_; + if (kSupportsBias && broadcast_1::kEnable && bias1_ptr) { + bias1_ptr = broadcast_1::advance(bias1_ptr, + batch_id / N, + batch_id % N, + head_id, + num_queries * N, + num_queries, + 0); + } + using broadcast_2 = Broadcast2_; + if (kSupportsBias && broadcast_2::kEnable && bias2_ptr) { + auto strideB = num_heads * num_queries * num_keys; + auto strideH = num_queries * num_keys; + bias2_ptr = broadcast_2::advance( + bias2_ptr, batch_id / N, batch_id % N, head_id, strideB, 0, strideH); + } + + num_queries -= query_start; + num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) return false; + q_strideM = q_strideH; + num_queries = num_heads; + num_heads = 1; // unused but here for intent + o_strideM = head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + output_ptr = warp_uniform(output_ptr); + output_accum_ptr = warp_uniform(output_accum_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); + head_dim = warp_uniform(head_dim); + head_dim_value = warp_uniform(head_dim_value); + o_strideM = warp_uniform(o_strideM); + if (kSupportsBias && broadcast_1::kEnable) { bias1_ptr = warp_uniform(bias1_ptr); } + if (kSupportsBias && broadcast_2::kEnable) { bias2_ptr = warp_uniform(bias2_ptr); } + return true; + } + + __host__ dim3 getBlocksGrid() const + { + return dim3(ceil_div(num_queries, (int32_t)kQueriesPerBlock), num_heads, num_batches); + } + + __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize, kNumWarpsPerBlock, 1); } + }; + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration; + static constexpr int kAlignmentA = kIsAligned ? DefaultConfig::kAlignmentA + : GemmType::kMinimumAlignment; + static constexpr int kAlignmentB = kIsAligned ? DefaultConfig::kAlignmentB + : GemmType::kMinimumAlignment; + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::ColumnMajor, // LayoutB, + kAlignmentB, + accum_t, + cutlass::layout::RowMajor, // LayoutC, + OpClass, + ArchTag, // ArchTag + ThreadblockShape, // ThreadblockShape + WarpShape, // WarpShape + typename GemmType::InstructionShape, // InstructionShape + DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that + // uses too much smem + typename GemmType::Operator // Operator + >::DefaultMma; + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using Mma = typename DefaultMma::ThreadblockMma; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + static_assert(MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * MmaCore::WarpCount::kK == + kNumWarpsPerBlock, + ""); + + // used for efficient load of bias tile Bij from global to shared memory + using BiasLoader = + TileSmemLoader, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = + typename cutlass::gemm::threadblock::B2bGemm; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /** + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration; + static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem + static constexpr int kAlignmentB = kIsAligned ? DefaultConfig::kAlignmentB + : GemmType::kMinimumAlignment; + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using LayoutB = cutlass::layout::RowMajor; + using DefaultGemm = + cutlass::gemm::kernel::DefaultGemm; + + using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MM0::AccumulatorSharedStorage, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert(WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + + struct SharedStorageMM1 { + typename Mma::SharedStorage mm; + }; + }; + + static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; + static constexpr int64_t kAlignmentK = MM0::kAlignmentB; + static constexpr int64_t kAlignmentV = 1; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + // typename MM0::BiasLoader::SmemTile bias; + cutlass::AlignedBuffer bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& epilogue_shared_storage() + { + return epilogue; + } + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + // typename MM0::BiasLoader::SmemTile bias; + cutlass::AlignedBuffer bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& epilogue_shared_storage() + { + return after_mm0.epilogue; + } + }; + + using SharedStorage = + typename cutlass::platform::conditional::type; + + static bool __host__ check_supported(Params const& p) + { + CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); + CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); + CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); + EVOFORMER_CHECK(p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned (strideM)"); + EVOFORMER_CHECK(p.k_strideM % kAlignmentK == 0, "key is not correctly aligned (strideM)"); + EVOFORMER_CHECK(p.v_strideM % kAlignmentV == 0, "value is not correctly aligned (strideM)"); + EVOFORMER_CHECK(p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0, + "query is not correctly aligned (strideH)"); + EVOFORMER_CHECK(p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0, + "key is not correctly aligned (strideH)"); + EVOFORMER_CHECK(p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0, + "value is not correctly aligned (strideH)"); + return true; + } + + static void CUTLASS_DEVICE attention_kernel(Params& p) + { + // In this block, we will only ever: + // - read query[query_start:query_end, :] + // - write to output[query_start:query_end, :] + + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + auto& mi = shared_storage.mi; + const uint32_t query_start = blockIdx.x * kQueriesPerBlock; + + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = accum_t(0); + m_prime[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)p.o_strideM}, + p.output_ptr, + typename OutputTileIterator::TensorCoord{p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{(int32_t)(p.head_dim_value * p.num_heads)}, + p.output_accum_ptr, + typename OutputTileIteratorAccum::TensorCoord{p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + // Iterate through keys + for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); + int32_t problem_size_0_n = + cutlass::fast_min(int32_t(kKeysPerBlock), p.num_keys - iter_key_start); + int32_t const& problem_size_0_k = p.head_dim; + int32_t const& problem_size_1_n = p.head_dim_value; + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + MM1::Mma::prologue( + shared_storage.after_mm0.mm1.mm, iterator_V, thread_id(), problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * MM0::Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * MM0::Mma::Shape::kN}; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params(typename MM0::MmaCore::LayoutA(p.q_strideM)), + p.query_ptr, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + tb_offset_A); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params(typename MM0::MmaCore::LayoutB(p.k_strideM)), + p.key_ptr + iter_key_start * p.k_strideM, + {problem_size_0_k, problem_size_0_n}, + thread_id(), + tb_offset_B); + + auto my_warp_id = warp_id(); + auto my_lane_id = lane_id(); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma(shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } else { + MM1::Mma::drain_cp_asyncs(); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord iteratorC_tile_offset = { + (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + + (my_warp_id % MM0::Mma::WarpCount::kM), + (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + + (my_warp_id / MM0::Mma::WarpCount::kM)}; + + // multiply by scaling factor + // if (kSupportsBias) { + // accum = + // cutlass::multiplies()(p.scale, + // accum); + // } + + if (kSupportsBias) { + cutlass::TensorRef bias_tensor_ref( + shared_storage.after_mm0.bias.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + using Shape = + cutlass::MatrixShape; + AttentionBiasEpilogue + bias_epilogue; + bias_epilogue(bias_tensor_ref, + p.bias1_ptr + iter_key_start, + p.bias2_ptr + query_start * p.num_keys + iter_key_start, + thread_id(), + {problem_size_0_m, problem_size_0_n}, + p.num_keys); + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] = + accum[idx] * p.scale + bias_tensor_ref.at({accum_m, accum_n}); + } + }, + [&](int accum_m) {}); + } + + DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + p.num_keys - iter_key_start >= kKeysPerBlock, kFullColumns, ([&] { + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax(accum_o, + accum, + mi, + m_prime, + s_prime, + lane_id(), + thread_id(), + warp_id(), + p.num_keys - iter_key_start, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : p.scale); + })); + })); + + // Output results to shared-memory + int warp_idx_mn_0 = + my_warp_id % (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = + cutlass::MatrixCoord{warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); + + __syncthreads(); + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = + kSingleValueIteration + ? 1 + : ceil_div((int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + typename MM1::Mma mma_pv(shared_storage.after_mm0.mm1.mm, + shared_storage.after_mm0.si, + (int)thread_id(), + (int)warp_id(), + (int)lane_id(), + (int)problem_size_1_k); + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { accum_o.clear(); } + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + MM1::Mma::drain_cp_asyncs(); + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= p.num_keys, kIsLast, ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue::thread:: + MemoryEfficientAttentionNormalize< + typename cutlass::platform:: + conditional:: + type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + ElementCompute, + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = + call_conditional:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue(shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kSingleValueIteration) { __syncthreads(); } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), thread_id(), warp_id(), lane_id()); + MM1::Mma::drain_cp_asyncs(); + epilogue(rescale, dest_iter, accum_o); + } + + // 7. Calculate logsumexp + // To make the backward easier, we pad logsumexp with `inf` + // this avoids a few bound checks, and is not more expensive during fwd + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { + auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; + if (thread_id() < p.num_queries) { + p.logsumexp_ptr[thread_id()] = + accum_t(mi[thread_id()]) + cutlass::fast_log(accum_t(s_prime[thread_id()])); + } else if (thread_id() < lse_dim) { + p.logsumexp_ptr[thread_id()] = + cutlass::platform::numeric_limits::infinity(); + } + } + } + + template + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int16_t max_col, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) + { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + if (!kIsFirst) { + if (thread_id < kQueriesPerBlock) { m_prime[thread_id] = mi[thread_id]; } + __syncthreads(); + } + + auto lane_offset = LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { max = -cutlass::platform::numeric_limits::infinity(); }, + [&](int accum_m, int accum_n, int idx) { + if (kFullColumns || accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max * scaling); + }); + } + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + // Make sure we all share the update values for `mi` + __syncthreads(); + + if (thread_id < kQueriesPerBlock) { + auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); + m_prime[thread_id] = m_prime_exp; + s_prime[thread_id] *= m_prime_exp; + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !kIsFirst) { + accum_t mp; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mp = m_prime[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, + [&](int accum_m) {}); + __syncthreads(); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = (kFullColumns || accum_n < max_col) ? exp2f(frag[idx] - mi_row) + : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { return a + b; })) { + atomicAdd(&s_prime[accum_m], total_row); + } + }); + } + } + + static CUTLASS_DEVICE int8_t lane_id() { return threadIdx.x; } + static CUTLASS_DEVICE int8_t warp_id() { return threadIdx.y; } + static CUTLASS_DEVICE int16_t thread_id() { return threadIdx.x + threadIdx.y * blockDim.x; } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl(typename AK::Params p) +{ + if (!p.advance_to_block()) { return; } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched(typename AK::Params params); diff --git a/csrc/deepspeed4science/evoformer_attn/transform/bias_broadcast.h b/csrc/deepspeed4science/evoformer_attn/transform/bias_broadcast.h new file mode 100644 index 000000000000..4b8fd8a293d6 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/transform/bias_broadcast.h @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +// This does nothing. +template +struct BroadcastNoLoad { + using Fragment = + cutlass::Array; + static const bool kEnable = false; + CUTLASS_DEVICE static void load(Fragment& frag, + scalar_t* ptr, + int thread_id, + const cutlass::MatrixCoord& extent, + int stride) + { + } + CUTLASS_DEVICE static scalar_t* + advance(scalar_t* ptr, int B_id, int N_id, int H_id, int strideB, int strideN, int strideH) + { + return ptr; + } +}; + +// This is to load the bias matrix from the global memory with on-the-fly +// broadcast. The shape in global memory is [B, N, 1, 1, L]. Each time we load +// the last dimension as a L row vector, and we further broadcast the L vector +// to a tile of size [L, L] by repeating the L vector L times +template +struct BroadcastA : public BroadcastNoLoad { + using Base = BroadcastNoLoad; + static const bool kEnable = true; + using layout = cutlass::layout::AffineRank2RowMajor; + + using GmemTileIterator = cutlass::transform::threadblock:: + PredicatedTileIterator; + using Fragment = typename GmemTileIterator::Fragment; + + CUTLASS_DEVICE static void load(Fragment& frag, + scalar_t* ptr, + int thread_id, + const cutlass::MatrixCoord& extent, + int stride) + { + GmemTileIterator iter({layout(0, 1)}, ptr, extent, thread_id); + iter.load(frag); + } + + CUTLASS_DEVICE static scalar_t* + advance(scalar_t* ptr, int B_id, int N_id, int H_id, int strideB, int strideN, int strideH) + { + return ptr + B_id * strideB + N_id * strideN; + } +}; + +// This is to load the bias matrix from the global memory with on-the-fly +// broadcast. The shape in global memory is [B, 1, H, L, L]. Each time we load +// a [L, L] matrix. Different N use the same bias matrix when B and H are the +// same. +template +struct BroadcastB : public BroadcastNoLoad { + using Base = BroadcastNoLoad; + static const bool kEnable = true; + using layout = cutlass::layout::RowMajor; + + using GmemTileIterator = cutlass::transform::threadblock:: + PredicatedTileIterator; + using Fragment = typename GmemTileIterator::Fragment; + + CUTLASS_DEVICE static void load(Fragment& frag, + scalar_t* ptr, + int thread_id, + const cutlass::MatrixCoord& extent, + int stride) + { + GmemTileIterator iter({layout(stride)}, ptr, extent, thread_id); + iter.load(frag); + } + + CUTLASS_DEVICE static scalar_t* + advance(scalar_t* ptr, int B_id, int N_id, int H_id, int strideB, int strideN, int strideH) + { + return ptr + B_id * strideB + H_id * strideH; + } +}; + +template + class Broadcast1_, + template + class Broadcast2_> +struct AttentionBiasEpilogue { + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape, + kThreads, + 128 / cutlass::sizeof_bits::value>; + + using Broadcast1 = Broadcast1_; + using Broadcast2 = Broadcast2_; + + Broadcast1 broadcast1; + Broadcast2 broadcast2; + + using Ref = cutlass::TensorRef; + using SmemTileIterator = cutlass::transform::threadblock:: + RegularTileIterator; + + CUTLASS_DEVICE void operator()(const Ref& ref, + scalar_t* ptr1, + scalar_t* ptr2, + int thread_id, + const cutlass::MatrixCoord& extent, + int stride) + { + static_assert(Broadcast1::Fragment::kElements == Broadcast2::Fragment::kElements, + "The two broadcast fragments must have the same number of " + "elements"); + typename SmemTileIterator::Fragment frag; + frag.clear(); + float* frag_ptr = reinterpret_cast(&frag); + if (Broadcast1::kEnable) { + typename Broadcast1::Fragment frag1; + frag1.clear(); + broadcast1.load(frag1, ptr1, thread_id, extent, stride); + scalar_t* frag1_ptr = reinterpret_cast(&frag1); + for (int i = 0; i < Broadcast1::Fragment::kElements; ++i) { + frag_ptr[i] += static_cast(frag1_ptr[i]); + } + } + if (Broadcast2::kEnable) { + typename Broadcast2::Fragment frag2; + frag2.clear(); + broadcast2.load(frag2, ptr2, thread_id, extent, stride); + scalar_t* frag2_ptr = reinterpret_cast(&frag2); + for (int i = 0; i < Broadcast2::Fragment::kElements; ++i) { + frag_ptr[i] += static_cast(frag2_ptr[i]); + } + } + SmemTileIterator iter(ref, thread_id); + iter.store(frag); + __syncthreads(); + } +}; diff --git a/csrc/deepspeed4science/evoformer_attn/transform/tile_smem_loader.h b/csrc/deepspeed4science/evoformer_attn/transform/tile_smem_loader.h new file mode 100644 index 000000000000..5f3f0ee507e5 --- /dev/null +++ b/csrc/deepspeed4science/evoformer_attn/transform/tile_smem_loader.h @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once +#include +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +template // thread access width in elements +class TileSmemLoader { +public: + using Shape = ThreadblockTileShape; + using SmemTile = cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using SmemTileIterator = + cutlass::transform::threadblock::RegularTileIterator; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load(GmemTileIterator tile_load_iter, SmemTileIterator tile_store_iter) + { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +}; diff --git a/deepspeed/ops/deepspeed4science/__init__.py b/deepspeed/ops/deepspeed4science/__init__.py new file mode 100644 index 000000000000..1c5fd280fc32 --- /dev/null +++ b/deepspeed/ops/deepspeed4science/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .evoformer_attn import DS4Sci_EvoformerAttention, EvoformerFusedAttention diff --git a/deepspeed/ops/deepspeed4science/evoformer_attn.py b/deepspeed/ops/deepspeed4science/evoformer_attn.py new file mode 100644 index 000000000000..ba7e20e51d50 --- /dev/null +++ b/deepspeed/ops/deepspeed4science/evoformer_attn.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import numpy as np +from deepspeed.ops.op_builder import EvoformerAttnBuilder +from deepspeed.accelerator import get_accelerator + +kernel_ = None + + +def _attention(Q, K, V, bias1, bias2): + assert Q.shape[-3] > 16, "seq_len must be greater than 16" + O = torch.empty_like(Q, dtype=Q.dtype) + assert get_accelerator().on_accelerator(Q), "Q must be on cuda" + assert get_accelerator().on_accelerator(K), "K must be on cuda" + assert get_accelerator().on_accelerator(V), "V must be on cuda" + assert get_accelerator().on_accelerator(bias1), "bias1 must be on cuda" + assert get_accelerator().on_accelerator(bias2), "bias2 must be on cuda" + global kernel_ + if kernel_ is None: + kernel_ = EvoformerAttnBuilder().load() + nheads = Q.shape[-2] + nq = (Q.shape[-3] + 31) // 32 * 32 + nb = np.prod(Q.shape[:-3]) + lse = torch.empty((nb, nheads, nq), dtype=torch.float32, device=Q.device) + kernel_.attention(Q, K, V, bias1, bias2, O, lse) + return O, lse + + +def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2): + assert max(Q.shape[-1], V.shape[-1]) <= 64, "Hidden size is too large. Need to change kMax to a larger value" + dQ = torch.empty_like(Q, dtype=Q.dtype) + dK = torch.empty_like(K, dtype=K.dtype) + dV = torch.empty_like(V, dtype=V.dtype) + assert get_accelerator().on_accelerator(dO), "dO must be on cuda" + assert get_accelerator().on_accelerator(Q), "Q must be on cuda" + assert get_accelerator().on_accelerator(K), "K must be on cuda" + assert get_accelerator().on_accelerator(V), "V must be on cuda" + assert get_accelerator().on_accelerator(O), "O must be on cuda" + global kernel_ + if kernel_ is None: + kernel_ = EvoformerAttnBuilder().load() + delta = torch.empty_like(lse) + dB1 = torch.zeros_like(bias1, dtype=torch.float32) + dB2 = torch.zeros_like(bias2, dtype=torch.float32) + kernel_.attention_bwd(dO, Q, K, V, O, lse, delta, bias1, bias2, dQ, dK, dV, dB1, dB2) + return dQ, dK, dV, dB1.to(dO.dtype), dB2.to(dO.dtype) + + +class EvoformerFusedAttention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, bias1=None, bias2=None): + """ + q, k, v: are in shape [*, L, H, D] + """ + bias1_ = bias1.contiguous() if bias1 is not None else torch.tensor([], dtype=q.dtype, device=q.device) + bias2_ = bias2.contiguous() if bias2 is not None else torch.tensor([], dtype=q.dtype, device=q.device) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + o, lse = _attention(q, k, v, bias1_, bias2_) + ctx.save_for_backward(q, k, v, o, lse, bias1_, bias2_) + return o + + @staticmethod + def backward(ctx, grad_output): + q, k, v, o, lse, bias1, bias2 = ctx.saved_tensors + dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2) + if bias1.numel() == 0: + dB1 = None + if bias2.numel() == 0: + dB2 = None + return dQ, dK, dV, dB1, dB2 + + +def DS4Sci_EvoformerAttention(Q, K, V, biases): + assert len(biases) <= 2 + + if (len(biases) == 0): + biases.append(None) + + if (len(biases) == 1): + biases.append(None) + + bias_1_shape = lambda x: (x.shape[0], x.shape[1], 1, 1, x.shape[2]) + bias_2_shape = lambda x: (x.shape[0], 1, x.shape[3], x.shape[2], x.shape[2]) + + if biases[0] is not None: + assert biases[0].shape == bias_1_shape(Q) + else: + biases[0] = Q.new_zeros(bias_1_shape(Q)) + + if biases[1] is not None: + assert biases[1].shape == bias_2_shape(Q) + else: + biases[1] = Q.new_zeros(bias_2_shape(Q)) + + return EvoformerFusedAttention.apply(Q, K, V, biases[0], biases[1]) diff --git a/docs/_config.yml b/docs/_config.yml index 7127b8459fe2..ac8d9028e58f 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -41,6 +41,7 @@ collections: - cifar-10.md - curriculum-learning.md - data-efficiency.md + - ds4sci_evoformerattention.md - flops-profiler.md - pytorch-profiler.md - autotuning.md diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml index d0587867c260..217d56c14812 100755 --- a/docs/_data/navigation.yml +++ b/docs/_data/navigation.yml @@ -17,6 +17,8 @@ lnav: url: /inference/ - title: 'Compression' url: /compression/ + - title: 'Science' + url: /deepspeed4science/ - title: 'Getting Started' url: /getting-started/ - title: 'ds_config' @@ -67,6 +69,8 @@ lnav: url: /tutorials/curriculum-learning/ - title: 'Data Efficiency' url: /tutorials/data-efficiency/ + - title: 'DS4Sci_EvoformerAttention' + url: /tutorials/ds4sci_evoformerattention/ - title: 'Flops Profiler' url: /tutorials/flops-profiler/ - title: 'PyTorch Profiler' diff --git a/docs/_pages/deepspeed4science.md b/docs/_pages/deepspeed4science.md new file mode 100755 index 000000000000..aa11d6f9eaf1 --- /dev/null +++ b/docs/_pages/deepspeed4science.md @@ -0,0 +1,39 @@ +--- +title: "DeepSpeed4Science Overview and Tutorial" +permalink: /deepspeed4science/ +toc: true +toc_label: "Contents" +toc_sticky: true +--- + +In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. This page serves as an overview page for all technologies released (or to be released in the future) as part of the DeepSpeed4Science initiative, making it easier for scientists to shop for techniques they need. Details of the DeepSpeed4Science initiative can be found at [our website](https://deepspeed4science.ai/). For each technique we will introduce what is it for, when to use it, links to how to use it, and existing scientific applications of the techniques (we welcome users to contribute more showcases if you apply our techniques in your scientific research): + +* [2023/09] We are releasing two techniques: [DeepSpeed4Science large-scale training framework](#new-megatron-deepspeed-for-large-scale-ai4science-model-training), [DS4Sci_EvoformerAttention](#memory-efficient-evoformerattention-kernels) and their scientific applications in structural biology research. + + +## New Megatron-DeepSpeed for Large-Scale AI4Science Model Training + +We are proud to introduce [new Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed), which is an updated framework for large-scale model training. We rebased and enabled DeepSpeed with the newest Megatron-LM for long sequence support and many other capabilities. With the new Megatron-DeepSpeed, users can now train their large AI4Science models like GenSLMS with much longer sequences via a synergetic combination of ZeRO-style data parallelism, tensor parallelism, sequence parallelism, pipeline parallelism, model state offloading, and several newly added memory optimization techniques such as attention mask offloading and position embedding partitoining. + +![new Megatron-DeepSpeed](/assets/images/new-megatron-ds.png){: .align-center} +

+The figure depicts system capability in terms of enabling long sequence lengths for training a 33B parameter GPT-like model using our new Megatron-DeepSpeed framework. The results show that the new Megatron-DeepSpeed enables 9x onger sequence lengths than NVIDIA's Megatron-LM without triggering out-of-memory error. +

+ +To see how the new Megatron-DeepSpeed helps enabling new system capabilities, such as training models with massive sequences length, please read our [tutorial](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples_deepspeed/deepspeed4science/megatron_long_seq_support). + +Meanwhile, our new Megatron-DeepSpeed has been applied to genome-scale foundation model [GenSLMs](https://github.com/ramanathanlab/genslm), which is a 2022 [ACM Gordon Bell award](https://www.acm.org/media-center/2022/november/gordon-bell-special-prize-covid-research-2022) winning genome-scale language model from Argonne National Lab. To achieve their scientific goal, GenSLMs and similar models require very long sequence support for both training and inference that is beyond generic LLM's long-sequence strategies. By leveraging DeepSpeed4Science's new Megatron-DeepSpeed, GenSLMs team is able to train their 25B model with 512K sequence length, much longer than their original 42K sequence length. Detailed information about this application can be found at [our website](https://deepspeed4science.ai/). GenSLMs team also hosts an [example](https://github.com/ramanathanlab/genslm/tree/main/examples/long-sequences) about how to use DeepSpeed4Science in the GenSLMs repo. + + +## Memory-Efficient EvoformerAttention Kernels + +[Evoformer](https://www.nature.com/articles/s41586-021-03819-2) is a key building block for scientific models such as DeepMind's AlphaFold. However, EvoFormer's multiple sequence alignment (MSA) attention frequently runs into memory explosion problems during training/inference, such as in protein structure prediction models. Existing techniques such as FlashAttention cannot effectively support Evoformer because EvoFormerAttention uses row-wise/column-wise/triangle attention, which are different from standard Transformer self-attention and cross-attention that require custom optimizations. To mitigate the memory explosion problem, we introduce `DS4Sci_EvoformerAttention` kernels, a collection of kernels that improve the memory efficiency of variants of EvoFormer. `DS4Sci_EvoformerAttention` is easy-to-use. To see how you can use it, please refer to our [tutorial](/tutorials/ds4sci_evoformerattention/). + +`DS4Sci_EvoformerAttention` has already been applied to [OpenFold](https://github.com/aqlaboratory/openfold), which is a community reproduction of DeepMind's AlphaFold2 that makes it possible to train or finetune AlphaFold2 on new datasets. With DS4Sci_EvoformerAttention kernels, OpenFold team is able to reduce the peak memory requirement by 13x without accuracy loss. Detailed information about this application can be found at [our website](https://deepspeed4science.ai/). + + + +![DS4Sci_EvoformerAttention](/assets/images/evoformer.png){: .align-center} +

+The figure shows that DeepSpeed's EvoFormerAttention kernels help reduce OpenFold’s peak memory requirement for training by 13X. +

diff --git a/docs/_tutorials/ds4sci_evoformerattention.md b/docs/_tutorials/ds4sci_evoformerattention.md new file mode 100644 index 000000000000..a951943dfa5b --- /dev/null +++ b/docs/_tutorials/ds4sci_evoformerattention.md @@ -0,0 +1,74 @@ +--- +title: "DS4Sci_EvoformerAttention eliminates memory explosion problems for scaling Evoformer-centric structural biology models" +tags: training inference +--- + +## 1. What is DS4Sci_EvoformerAttention +`DS4Sci_EvoformerAttention` is a collection of kernels built to scale the [Evoformer](https://www.nature.com/articles/s41586-021-03819-2) computation to larger number of sequences and residuals by reducing the memory footprint and increasing the training speed. + +## 2. When to use DS4Sci_EvoformerAttention +`DS4Sci_EvoformerAttention` is most beneficial when the number of sequences and residuals is large. The forward kernel is optimized to accelerate computation. It is beneficial to use the forward kernel during inference for various attention mechanisms. The associated backward kernel can be used during training to reduce the memory footprint at the cost of some computation. Therefore, it is beneficial to use `DS4Sci_EvoformerAttention` in training for memory-constrained operations such as MSA row-wise attention and MSA column-wise attention. + +## 3. How to use DS4Sci_EvoformerAttention + +### 3.1 Installation + +`DS4Sci_EvoformerAttention` is released as part of DeepSpeed >= 0.10.3. `DS4Sci_EvoformerAttention` is implemented based on [CUTLASS](https://github.com/NVIDIA/cutlass). You need to clone the CUTLASS repository and specify the path to it in the environment variable `CUTLASS_PATH`. + +```shell +git clone https://github.com/NVIDIA/cutlass +export CUTLASS_PATH=/path/to/cutlass +``` +The kernels will be compiled when `DS4Sci_EvoformerAttention` is called for the first time. + +`DS4Sci_EvoformerAttention` requires GPUs with compute capability 7.0 or higher (NVIDIA V100 or later GPUs) and the minimal CUDA version is 11.3. It is recommended to use CUDA 11.7 or later for better performance. Besides, the performance of backward kernel on V100 kernel is not as good as that on A100 for now. + +### 3.2 Unit test and benchmark + +The unit test and benchmark are available in the `tests` folder in DeepSpeed repo. You can use the following command to run the unit test and benchmark. + +```shell +pytest -s tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py +python tests/benchmarks/DS4Sci_EvoformerAttention_bench.py +``` + +### 3.3 Applying DS4Sci_EvoformerAttention to your own model + +To use `DS4Sci_EvoformerAttention` in user's own models, you need to import `DS4Sci_EvoformerAttention` from `deepspeed.ops.deepspeed4science`. + +```python +from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention +``` + +`DS4Sci_EvoformerAttention` supports four attention mechanisms in Evoformer (MSA row-wise, MSA column-wise, and 2 kinds of Triangular) by using different inputs as shown in the following examples. In the examples, we denote the number of sequences as `N_seq` and the number of residuals as `N_res`. The dimension of the hidden states `Dim` and head number `Head` are different among different attention. Note that `DS4Sci_EvoformerAttention` requires the input tensors to be in `torch.float16` or `torch.bfloat16` data type. + +(a) **MSA row-wise attention** builds attention weights for residue pairs and integrates the information from the pair representation as an additional bias term. +```python +# Q, K, V: [Batch, N_seq, N_res, Head, Dim] +# res_mask: [Batch, N_seq, 1, 1, N_res] +# pair_bias: [Batch, 1, Head, N_res, N_res] +out = DS4Sci_EvoformerAttention(Q, K, V, [res_mask, pair_bias]) +``` + +(b) **MSA column-wise attention** lets the elements that belong to the same target residue exchange information. +```python +# Q, K, V: [Batch, N_res, N_seq, Head, Dim] +# res_mask: [Batch, N_seq, 1, 1, N_res] +out = DS4Sci_EvoformerAttention(Q, K, V, [res_mask]) +``` + +(c) **Triangular self-attention** updates the pair representation. There are two kinds of Triangular self-attention: around starting and around ending node. Below is the example of triangular self-attention around starting node. The triangular self-attention around ending node is similar. +```python +# Q, K, V: [Batch, N_res, N_res, Head, Dim] +# res_mask: [Batch, N_res, 1, 1, N_res] +# right_edges: [Batch, 1, Head, N_res, N_res] +out = DS4Sci_EvoformerAttention(Q, K, V, [res_mask, right_edges]) +``` + +## 4. DS4Sci_EvoformerAttention scientific application + +### 4.1 DS4Sci_EvoformerAttention eliminates memory explosion problems for scaling Evoformer-centric structural biology models in OpenFold + +[OpenFold](https://github.com/aqlaboratory/openfold) is a community reproduction of DeepMind's AlphaFold2 that makes it possible to train or finetune AlphaFold2 on new datasets. Training AlphaFold2 incurs a memory explosion problem because it contains several custom Evoformer attention variants that manifest unusually large activations. By leveraging DeepSpeed4Science's DS4Sci_EvoformerAttention kernels, OpenFold team is able to reduce the peak memory requirement by 13x without accuracy loss. Detailed information about this application can be found at [our website](https://deepspeed4science.ai/). + + diff --git a/docs/assets/images/3pillars.png b/docs/assets/images/3pillars.png deleted file mode 100755 index c2943ca912a1..000000000000 Binary files a/docs/assets/images/3pillars.png and /dev/null differ diff --git a/docs/assets/images/DeepSpeed-pillars.png b/docs/assets/images/DeepSpeed-pillars.png new file mode 100644 index 000000000000..e41a02a86058 Binary files /dev/null and b/docs/assets/images/DeepSpeed-pillars.png differ diff --git a/docs/assets/images/evoformer.png b/docs/assets/images/evoformer.png new file mode 100755 index 000000000000..a3da3b18febd Binary files /dev/null and b/docs/assets/images/evoformer.png differ diff --git a/docs/assets/images/new-megatron-ds.png b/docs/assets/images/new-megatron-ds.png new file mode 100755 index 000000000000..a8f408338afe Binary files /dev/null and b/docs/assets/images/new-megatron-ds.png differ diff --git a/docs/index.md b/docs/index.md index b4ae1b84cdea..6454aad0069a 100755 --- a/docs/index.md +++ b/docs/index.md @@ -7,11 +7,11 @@ title: "Latest News" --- DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat). +* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](/deepspeed4science/)] * [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) * [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md) * [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses) * [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)[[English](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md)] -* [2023/04] 🚀 [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]🚀 # Extreme Speed and Scale for DL Training and Inference @@ -24,9 +24,9 @@ title: "Latest News" * Achieve extreme compression for an unparalleled inference latency and model size reduction with low costs -# DeepSpeed has three innovation pillars: +# DeepSpeed has four innovation pillars: -![Three innovation pillars](/assets/images/3pillars.png){: .align-center} +[![Four innovation pillars](/assets/images/DeepSpeed-pillars.png){: .align-center}](https://deepspeed4science.ai/) ## DeepSpeed-Training @@ -41,6 +41,10 @@ DeepSpeed brings together innovations in parallelism technology such as tensor, To further increase the inference efficiency, DeepSpeed offers easy-to-use and flexible-to-compose compression techniques for researchers and practitioners to compress their models while delivering faster speed, smaller model size, and significantly reduced compression cost. Moreover, SoTA innovations on compression like ZeroQuant and XTC are included under the DeepSpeed-Compression pillar. Learn more: [DeepSpeed-Compression](https://www.deepspeed.ai/compression) +## DeepSpeed4Science + +In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. Learn more: [DeepSpeed4Science website](https://deepspeed4science.ai/) and [tutorials](/deepspeed4science/) + # DeepSpeed Software Suite ## DeepSpeed Library diff --git a/op_builder/evoformer_attn.py b/op_builder/evoformer_attn.py new file mode 100644 index 000000000000..e557ba8ed0c1 --- /dev/null +++ b/op_builder/evoformer_attn.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CUDAOpBuilder, installed_cuda_version +import os + + +class EvoformerAttnBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_EVOFORMER_ATTN" + NAME = "evoformer_attn" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + self.cutlass_path = os.environ.get('CUTLASS_PATH') + + def absolute_name(self): + return f'deepspeed.ops.{self.NAME}_op' + + def extra_ldflags(self): + if not self.is_rocm_pytorch(): + return ['-lcurand'] + else: + return [] + + def sources(self): + src_dir = 'csrc/deepspeed4science/evoformer_attn' + return [f'{src_dir}/attention.cpp', f'{src_dir}/attention_back.cu', f'{src_dir}/attention.cu'] + + def is_compatible(self, verbose=True): + try: + import torch + except ImportError: + self.warning("Please install torch if trying to pre-compile kernels") + return False + if self.cutlass_path is None: + self.warning("Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH") + return False + with open(f'{self.cutlass_path}/CHANGELOG.md', 'r') as f: + if '3.1.0' not in f.read(): + self.warning("Please use CUTLASS version >= 3.1.0") + return False + cuda_okay = True + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 7: + self.warning("Please use a GPU with compute capability >= 7.0") + cuda_okay = False + if torch_cuda_major < 11 or sys_cuda_major < 11: + self.warning("Please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def include_paths(self): + includes = [f'{self.cutlass_path}/include', f'{self.cutlass_path}/tools/util/include'] + return includes diff --git a/tests/benchmarks/DS4Sci_EvoformerAttention_bench.py b/tests/benchmarks/DS4Sci_EvoformerAttention_bench.py new file mode 100644 index 000000000000..f11e69cb4320 --- /dev/null +++ b/tests/benchmarks/DS4Sci_EvoformerAttention_bench.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +This script is to test the correctness of the DS4Sci_EvoformerAttention op. +To run the script, +1. Clone the CUTLASS repo. E.g. git clone https://github.com/NVIDIA/cutlass.git +2. Specify the CUTLASS_PATH environment variable. E.g. export CUTLASS_PATH=$(pwd)/cutlass +3. Run the script. E.g. python DS4Sci_EvoformerAttention_bench.py +""" + +import contextlib +import torch +from typing import List +from torch.nn import functional as F +from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention +from deepspeed.accelerator import get_accelerator + + +def attention_reference( + q_input: torch.Tensor, # [*, Dim_Q, H, C_hid] + k_input: torch.Tensor, # [*, Dim_Q, H, C_hid] + v_input: torch.Tensor, # [*, Dim_Q, H, C_hid] + biases: List[torch.Tensor], + sm_scale: float) -> torch.Tensor: + # Original shape: [*, Dim_Q, H, C_hid] -> Transpose to: [*, H, Dim_Q, C_hid] + q = q_input.transpose(-2, -3) + k = k_input.transpose(-2, -3) + v = v_input.transpose(-2, -3) + + # Now, q, k, v are in shape: [*, H, Dim_Q, C_hid] + + # Transpose k to shape [*, H, C_hid, Dim_Q] + k_t = k.transpose(-1, -2) + + # Now, q and k_t are in shapes: [*, H, Dim_Q, C_hid] and [*, H, C_hid, Dim_Q] respectively + + # [*, H, Dim_Q, Dim_Q] + a = torch.matmul(q, k_t) * sm_scale + + for b in biases: + a += b + + a = F.softmax(a, dim=-1) + + # Now, a is in shape [*, H, Dim_Q, Dim_Q], v is in shape [*, H, Dim_Q, C_hid] + + # Matmul operation results in [*, H, Dim_Q, C_hid] + a_v = torch.matmul(a, v) + + # [*, Dim_Q, H, C_hid] + o = a_v.transpose(-2, -3) + + return o + + +dtype = torch.float16 + +batch = 1 +N = 256 +heads = 4 +dim = 32 +seq_len = 256 + + +@contextlib.contextmanager +def cuda_timer(res_list): + start = get_accelerator().Event(enable_timing=True) + end = get_accelerator().Event(enable_timing=True) + start.record() + yield + end.record() + get_accelerator().synchronize() + res_list.append(start.elapsed_time(end)) + + +def benchmark(): + ours_fw = [] + ours_bw = [] + baseline_fw = [] + baseline_bw = [] + for batch_size in range(1, 17): + Q = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True) + K = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True) + V = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True) + bias1 = torch.randn(batch, N, 1, 1, seq_len, dtype=dtype, device="cuda", requires_grad=True) + bias2 = torch.randn(batch, 1, heads, seq_len, seq_len, dtype=dtype, device="cuda", requires_grad=True) + # warm up + DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2]) + with cuda_timer(ours_fw): + out = DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2]) + d_out = torch.rand_like(out) + with cuda_timer(ours_bw): + out.backward(d_out) + # warm up + attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5)) + with cuda_timer(baseline_fw): + ref_out = attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5)) + with cuda_timer(baseline_bw): + ref_out.backward(d_out) + + print(f"batch size\tours (FW)\tbaseline (FW)\tours (BW)\tbaseline (BW)") + for i in range(len(ours_fw)): + print(f"{i+1}\t{ours_fw[i]}\t{baseline_fw[i]}\t{ours_bw[i]}\t{baseline_bw[i]}") + + +benchmark() diff --git a/tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py b/tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py new file mode 100644 index 000000000000..f8cd46e29228 --- /dev/null +++ b/tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import pytest +import torch +from torch.nn import functional as F +import deepspeed +from deepspeed.ops.op_builder import EvoformerAttnBuilder +from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention +from deepspeed.accelerator import get_accelerator +from unit.util import skip_on_arch + +if not deepspeed.ops.__compatible_ops__[EvoformerAttnBuilder.NAME]: + pytest.skip("DS4Sci_EvoformerAttention ops are not available on this system", allow_module_level=True) + + +def attention_reference( + q_input: torch.Tensor, # [*, Dim_Q, H, C_hid] + k_input: torch.Tensor, # [*, Dim_Q, H, C_hid] + v_input: torch.Tensor, # [*, Dim_Q, H, C_hid] + biases: List[torch.Tensor], + sm_scale: float) -> torch.Tensor: + q = q_input.transpose(-2, -3) + k = k_input.transpose(-2, -3) + v = v_input.transpose(-2, -3) + k_t = k.transpose(-1, -2) + a = torch.matmul(q, k_t) * sm_scale + + for b in biases: + a += b + + a = F.softmax(a, dim=-1) + a_v = torch.matmul(a, v) + o = a_v.transpose(-2, -3) + + return o + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("tensor_shape", [(1, 256, 256, 4, 32), (1, 512, 256, 8, 8)]) +def test_DS4Sci_EvoformerAttention(dtype, tensor_shape): + skip_on_arch(8 if dtype == torch.bfloat16 else 7) + batch, n, seq_len, heads, dim = tensor_shape + Q = torch.randn(batch, + n, + seq_len, + heads, + dim, + dtype=dtype, + device=get_accelerator().device_name(), + requires_grad=True) + K = torch.randn(batch, + n, + seq_len, + heads, + dim, + dtype=dtype, + device=get_accelerator().device_name(), + requires_grad=True) + V = torch.randn(batch, + n, + seq_len, + heads, + dim, + dtype=dtype, + device=get_accelerator().device_name(), + requires_grad=True) + bias1 = torch.randn(batch, + n, + 1, + 1, + seq_len, + dtype=dtype, + device=get_accelerator().device_name(), + requires_grad=True) + bias2 = torch.randn(batch, + 1, + heads, + seq_len, + seq_len, + dtype=dtype, + device=get_accelerator().device_name(), + requires_grad=True) + dummy_out = torch.rand_like(Q, dtype=dtype, device=get_accelerator().device_name()) + ref_out = attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5)) + ref_out.backward(dummy_out) + ref_dv, V.grad = V.grad.clone(), None + ref_dk, K.grad = K.grad.clone(), None + ref_dq, Q.grad = Q.grad.clone(), None + ref_db1, bias1.grad = bias1.grad.clone(), None + ref_db2, bias2.grad = bias2.grad.clone(), None + + out = DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2]) + out.backward(dummy_out) + dv, v_grad = V.grad.clone(), None + dk, k_grad = K.grad.clone(), None + dq, q_grad = Q.grad.clone(), None + db1, bias1.grad = bias1.grad.clone(), None + db2, bias2.grad = bias2.grad.clone(), None + + assert torch.allclose(ref_out, out, atol=2e-2, rtol=0), f"\n{ref_out} \n {out}" + assert torch.allclose(ref_dv, dv, atol=2e-2, rtol=0), f"\n{ref_dv} \n {dv}" + assert torch.allclose(ref_dk, dk, atol=2e-2, rtol=0), f"\n{ref_dk} \n {dk}" + assert torch.allclose(ref_dq, dq, atol=2e-2, rtol=0), f"\n{ref_dq} \n {dq}" + assert torch.allclose(ref_db1, db1, atol=2e-2, rtol=1e-2), f"{ref_db1} \n {db1}" + assert torch.allclose(ref_db2, db2, atol=2e-2, rtol=1e-2), f"{ref_db2} \n {db2}"