diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml
index e03cbb353bd3..ccb6c25e14f7 100644
--- a/.github/workflows/nv-pre-compile-ops.yml
+++ b/.github/workflows/nv-pre-compile-ops.yml
@@ -33,7 +33,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
- TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 pip3 install .
+ TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
- name: DS Report
run: |
ds_report
diff --git a/README.md b/README.md
index 4999a485f4ce..459bf6d82433 100755
--- a/README.md
+++ b/README.md
@@ -15,11 +15,11 @@
## Latest News
DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat).
+* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)]
* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md)
* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md)
* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses)
* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)[[English](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md)]
-* [2023/04] 🚀 [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]🚀
---
@@ -35,9 +35,9 @@
---
-# DeepSpeed's three innovation pillars
+# DeepSpeed's four innovation pillars
-
+
## DeepSpeed-Training
@@ -53,6 +53,10 @@ DeepSpeed brings together innovations in parallelism technology such as tensor,
To further increase the inference efficiency, DeepSpeed offers easy-to-use and flexible-to-compose compression techniques for researchers and practitioners to compress their models while delivering faster speed, smaller model size, and significantly reduced compression cost. Moreover, SoTA innovations on compression like ZeroQuant and XTC are included under the compression pillar. Learn more: [DeepSpeed-Compression](https://www.deepspeed.ai/compression)
+## DeepSpeed4Science
+
+In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. Learn more: [DeepSpeed4Science website](https://deepspeed4science.ai/) and [tutorials](https://www.deepspeed.ai/deepspeed4science/)
+
---
# DeepSpeed Software Suite
diff --git a/csrc/deepspeed4science/evoformer_attn/attention.cpp b/csrc/deepspeed4science/evoformer_attn/attention.cpp
new file mode 100644
index 000000000000..ac3364539ff1
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/attention.cpp
@@ -0,0 +1,62 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+
+void attention_impl(torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& o,
+ torch::Tensor& lse);
+void attention(torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& o,
+ torch::Tensor& lse)
+{
+ attention_impl(q, k, v, bias1, bias2, o, lse);
+}
+
+void attention_back_impl(torch::Tensor& go,
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& o,
+ torch::Tensor& lse,
+ torch::Tensor& delta,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& gq,
+ torch::Tensor& gk,
+ torch::Tensor& gv,
+ torch::Tensor& gb1,
+ torch::Tensor& gb2);
+void attention_bwd(torch::Tensor& go,
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& o,
+ torch::Tensor& lse,
+ torch::Tensor& delta,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& gq,
+ torch::Tensor& gk,
+ torch::Tensor& gv,
+ torch::Tensor& gb1,
+ torch::Tensor& gb2)
+{
+ attention_back_impl(go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("attention", &attention, "");
+ m.def("attention_bwd", &attention_bwd, "");
+}
diff --git a/csrc/deepspeed4science/evoformer_attn/attention.cu b/csrc/deepspeed4science/evoformer_attn/attention.cu
new file mode 100644
index 000000000000..b409dd2e401d
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/attention.cu
@@ -0,0 +1,160 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+#include
+#include "gemm_kernel_utils.h"
+#include "kernel_forward.h"
+#include "transform/bias_broadcast.h"
+
+template
+ class Broadcast1_,
+ template
+ class Broadcast2_>
+typename std::enable_if::value>::type attention_impl_template(
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& o,
+ float* lse_ptr)
+{
+ EVOFORMER_CHECK(false, "Unsupported GPU and data type combination")
+}
+
+template
+ class Broadcast1_,
+ template
+ class Broadcast2_>
+typename std::enable_if::value>::type attention_impl_template(
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& o,
+ float* lse_ptr)
+{
+ // Attention definition goes here, replaced with BroadcastType1 and
+ // BroadcastType2
+ using Attention = AttentionKernel;
+
+ static_assert(!Attention::kNeedsOutputAccumulatorBuffer,
+ "This test does not support output accumulator buffer");
+ int head_size = q.size(-1);
+ int head_number = q.size(-2);
+ int seq_length = q.size(-3);
+ auto q_view = q.view({-1, seq_length, head_number, head_size});
+ auto k_view = k.view({-1, seq_length, head_number, head_size});
+ auto v_view = v.view({-1, seq_length, head_number, head_size});
+ auto o_view = o.view({-1, seq_length, head_number, head_size});
+ int batch_size = q_view.size(0);
+ auto q_ptr = reinterpret_cast(q.data_ptr());
+ auto k_ptr = reinterpret_cast(k.data_ptr());
+ auto v_ptr = reinterpret_cast(v.data_ptr());
+ auto o_ptr = reinterpret_cast(o.data_ptr());
+
+ auto bias1_ptr = reinterpret_cast(bias1.data_ptr());
+ auto bias2_ptr = reinterpret_cast(bias2.data_ptr());
+
+ typename Attention::Params p;
+ { // set parameters
+ p.query_ptr = q_ptr;
+ p.key_ptr = k_ptr;
+ p.value_ptr = v_ptr;
+ p.logsumexp_ptr = lse_ptr; // Only needed for bw
+ p.output_accum_ptr = nullptr;
+ p.output_ptr = o_ptr;
+ p.scale = 1.0f / sqrt(float(head_size));
+
+ p.bias1_ptr = bias1_ptr;
+ p.bias2_ptr = bias2_ptr;
+ p.B = q.size(0);
+ p.N = q.size(1);
+
+ p.num_heads = head_number;
+ p.num_batches = batch_size;
+ p.head_dim = head_size;
+ p.head_dim_value = head_size;
+ p.num_queries = seq_length;
+ p.num_keys = seq_length;
+
+ // All tensors are in BMHK shapes
+ p.q_strideH = q_view.stride(-2);
+ p.k_strideH = k_view.stride(-2);
+ p.v_strideH = v_view.stride(-2);
+ p.q_strideM = q_view.stride(-3);
+ p.k_strideM = k_view.stride(-3);
+ p.v_strideM = v_view.stride(-3);
+ p.o_strideM = o_view.stride(-3);
+ p.q_strideB = q_view.stride(-4);
+ p.k_strideB = k_view.stride(-4);
+ p.v_strideB = v_view.stride(-4);
+ }
+
+ constexpr auto kernel_fn = attention_kernel_batched_impl;
+ int smem_bytes = sizeof(typename Attention::SharedStorage);
+ if (smem_bytes > 0xc000) {
+ cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
+ }
+ if (!Attention::check_supported(p)) { throw std::runtime_error("Parameters not supported"); }
+ kernel_fn<<>>(p);
+}
+
+#define CODE(scalar_t, torch_scalar_t) \
+ do { \
+ if (bias1.size(0) == 0 && bias2.size(0) == 0) { \
+ attention_impl_template(q, k, v, bias1, bias2, o, lse_ptr); \
+ } else if (bias1.size(0) == 0) { \
+ attention_impl_template(q, k, v, bias1, bias2, o, lse_ptr); \
+ } else if (bias2.size(0) == 0) { \
+ attention_impl_template(q, k, v, bias1, bias2, o, lse_ptr); \
+ } else { \
+ attention_impl_template( \
+ q, k, v, bias1, bias2, o, lse_ptr); \
+ } \
+ } while (0)
+
+// Function to select and call the correct template based on biases sizes
+void attention_impl(torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& o,
+ torch::Tensor& lse)
+{
+ auto lse_ptr = lse.size(0) == 0 ? nullptr : reinterpret_cast(lse.data_ptr());
+ cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+ DISPATCH_ARCHTAG(prop->major * 10 + prop->minor,
+ DISPATCH_TYPES(q, ([&]() { CODE(scalar_t, torch_scalar_t); })));
+}
diff --git a/csrc/deepspeed4science/evoformer_attn/attention_back.cu b/csrc/deepspeed4science/evoformer_attn/attention_back.cu
new file mode 100644
index 000000000000..1be647b5e552
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/attention_back.cu
@@ -0,0 +1,218 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+#include
+#include
+#include "gemm_kernel_utils.h"
+#include "kernel_backward.h"
+#include "transform/bias_broadcast.h"
+
+constexpr auto kBlockSizeI = 64;
+constexpr auto kBlockSizeJ = 64;
+
+template
+ class Broadcast1_,
+ template
+ class Broadcast2_>
+typename std::enable_if::value>::type attention_back_impl_template(
+ torch::Tensor& go,
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& o,
+ torch::Tensor& lse,
+ torch::Tensor& delta,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& gq,
+ torch::Tensor& gk,
+ torch::Tensor& gv,
+ torch::Tensor& gb1,
+ torch::Tensor& gb2)
+{
+ EVOFORMER_CHECK(false, "Unsupported GPU and data type combination")
+}
+
+template
+ class Broadcast1_,
+ template
+ class Broadcast2_>
+typename std::enable_if::value>::type attention_back_impl_template(
+ torch::Tensor& go,
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& o,
+ torch::Tensor& lse,
+ torch::Tensor& delta,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& gq,
+ torch::Tensor& gk,
+ torch::Tensor& gv,
+ torch::Tensor& gb1,
+ torch::Tensor& gb2)
+{
+ constexpr bool kPreload_ = arch::kMinComputeCapability >= 80;
+ using Kernel = AttentionBackwardKernel;
+ int head_size = q.size(-1);
+ int head_number = q.size(-2);
+ int seq_length = q.size(-3);
+ auto q_view = q.view({-1, seq_length, head_number, head_size});
+ auto k_view = k.view({-1, seq_length, head_number, head_size});
+ auto v_view = v.view({-1, seq_length, head_number, head_size});
+ auto o_view = o.view({-1, seq_length, head_number, head_size});
+ auto do_view = go.view({-1, seq_length, head_number, head_size});
+ auto dk_view = gk.view({-1, seq_length, head_number, head_size});
+ auto dv_view = gv.view({-1, seq_length, head_number, head_size});
+ auto dq_view = gq.view({-1, seq_length, head_number, head_size});
+ auto q_ptr = reinterpret_cast(q.data_ptr());
+ auto k_ptr = reinterpret_cast(k.data_ptr());
+ auto v_ptr = reinterpret_cast(v.data_ptr());
+ auto o_ptr = reinterpret_cast(o.data_ptr());
+ auto do_ptr = reinterpret_cast(go.data_ptr());
+ auto dk_ptr = reinterpret_cast(gk.data_ptr());
+ auto dv_ptr = reinterpret_cast(gv.data_ptr());
+ auto dq_ptr = reinterpret_cast(gq.data_ptr());
+ auto db1_ptr = gb1.size(0) > 0 ? reinterpret_cast(gb1.data_ptr()) : nullptr;
+ auto db2_ptr = gb2.size(0) > 0 ? reinterpret_cast(gb2.data_ptr()) : nullptr;
+ auto lse_ptr = reinterpret_cast(lse.data_ptr());
+ auto delta_ptr = reinterpret_cast(delta.data_ptr());
+ auto bias1_ptr = reinterpret_cast(bias1.data_ptr());
+ auto bias2_ptr = reinterpret_cast(bias2.data_ptr());
+ static_assert(Kernel::kKernelComputesDelta, "Kernel must compute delta");
+
+ typename Kernel::Params p;
+ p.query_ptr = q_ptr;
+ p.key_ptr = k_ptr;
+ p.value_ptr = v_ptr;
+ p.logsumexp_ptr = lse_ptr;
+ p.output_ptr = o_ptr;
+ p.grad_output_ptr = do_ptr;
+ p.delta_ptr = delta_ptr;
+ p.grad_query_ptr = dq_ptr;
+ p.grad_key_ptr = dk_ptr;
+ p.grad_value_ptr = dv_ptr;
+
+ p.grad_bias1_ptr = db1_ptr;
+ p.grad_bias2_ptr = db2_ptr;
+ p.B = q.size(0);
+ p.N = q.size(1);
+ p.bias1_ptr = bias1.size(0) ? bias1_ptr : nullptr;
+ p.bias2_ptr = bias2.size(0) ? bias2_ptr : nullptr;
+
+ p.scale = 1.0f / sqrtf(head_size);
+
+ p.head_dim = head_size;
+ p.head_dim_value = head_size;
+ p.num_queries = seq_length;
+ p.num_keys = seq_length;
+ p.num_heads = head_number;
+
+ p.q_strideM = q_view.stride(-3);
+ p.k_strideM = k_view.stride(-3);
+ p.v_strideM = v_view.stride(-3);
+ p.gO_strideM = do_view.stride(-3);
+ p.o_strideH = o_view.stride(-2);
+ p.q_strideH = q_view.stride(-2);
+ p.k_strideH = k_view.stride(-2);
+ p.v_strideH = v_view.stride(-2);
+ p.o_strideB = o_view.stride(-4);
+ p.q_strideB = q_view.stride(-4);
+ p.k_strideB = k_view.stride(-4);
+ p.v_strideB = v_view.stride(-4);
+ p.lse_strideB = lse.stride(-3);
+ p.lse_strideH = lse.stride(-2);
+ p.delta_strideB = delta.stride(-3);
+ p.delta_strideH = delta.stride(-2);
+ p.num_batches = q_view.size(-4);
+
+ p.gO_strideB = do_view.stride(-4);
+ p.gQ_strideB = dq_view.stride(-4);
+ p.gK_strideB = dk_view.stride(-4);
+ p.gV_strideB = dv_view.stride(-4);
+ p.gO_strideH = do_view.stride(-2);
+ p.gQ_strideH = dq_view.stride(-2);
+ p.gK_strideH = dk_view.stride(-2);
+ p.gV_strideH = dv_view.stride(-2);
+
+ torch::Tensor workspace = torch::empty(p.workspace_size() / 4, lse.options());
+ p.workspace = workspace.data_ptr();
+
+ auto kernel_fn = attention_kernel_backward_batched_impl;
+ size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
+ cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes));
+ if (!Kernel::check_supported(p)) { throw std::runtime_error("Unsupported parameters"); }
+ kernel_fn<<>>(p);
+}
+
+#define CODE(scalar_t, torch_scalar_t) \
+ do { \
+ if (bias1.size(0) == 0 && bias2.size(0) == 0) { \
+ attention_back_impl_template( \
+ go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
+ } else if (bias1.size(0) > 0 && bias2.size(0) > 0) { \
+ attention_back_impl_template( \
+ go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
+ } else if (bias1.size(0) > 0) { \
+ attention_back_impl_template( \
+ go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
+ } else { \
+ attention_back_impl_template( \
+ go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
+ } \
+ } while (0)
+
+void attention_back_impl(torch::Tensor& go,
+ torch::Tensor& q,
+ torch::Tensor& k,
+ torch::Tensor& v,
+ torch::Tensor& o,
+ torch::Tensor& lse,
+ torch::Tensor& delta,
+ torch::Tensor& bias1,
+ torch::Tensor& bias2,
+ torch::Tensor& gq,
+ torch::Tensor& gk,
+ torch::Tensor& gv,
+ torch::Tensor& gb1,
+ torch::Tensor& gb2)
+{
+ cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+ DISPATCH_ARCHTAG(prop->major * 10 + prop->minor,
+ DISPATCH_TYPES(q, ([&]() { CODE(scalar_t, torch_scalar_t); })));
+}
diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_grad_bias.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_grad_bias.h
new file mode 100644
index 000000000000..17b6479ed8c5
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_grad_bias.h
@@ -0,0 +1,250 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#pragma once
+#include
+#include
+#include "../iterators/predicated_tile_iterator_atomic.h"
+#include "cutlass/epilogue/threadblock/epilogue.h"
+
+namespace cutlass {
+namespace epilogue {
+namespace threadblock {
+template
+struct EpilogueTensorOpAffineRankN : public DefaultEpilogueTensorOpAffineRankN {
+ using Base = DefaultEpilogueTensorOpAffineRankN;
+ using OutputTileIterator =
+ cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankNAtomic<
+ typename Base::OutputTileThreadMap,
+ typename Base::ElementOutput,
+ Rank>;
+
+ using Epilogue =
+ cutlass::epilogue::threadblock::Epilogue;
+};
+
+template
+struct EpilogueVoltaTensorOpAffineRankN
+ : public DefaultEpilogueVoltaTensorOpAffineRankN {
+ using Base = DefaultEpilogueVoltaTensorOpAffineRankN;
+ using OutputTileIterator =
+ cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankNAtomic<
+ typename Base::OutputTileThreadMap,
+ typename Base::ElementOutput,
+ Rank>;
+
+ using Epilogue =
+ cutlass::epilogue::threadblock::Epilogue;
+};
+
+template
+struct EpilogueTensorOp : public DefaultEpilogueTensorOp {
+ using Base = DefaultEpilogueTensorOp;
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAtomic<
+ typename Base::OutputTileThreadMap,
+ typename Base::ElementOutput,
+ ScatterD,
+ PermuteDLayout>;
+ using Epilogue =
+ cutlass::epilogue::threadblock::Epilogue;
+};
+
+template
+struct EpilogueVoltaTensorOp : public DefaultEpilogueVoltaTensorOp {
+ using Base = DefaultEpilogueVoltaTensorOp;
+ using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAtomic<
+ typename Base::OutputTileThreadMap,
+ typename Base::ElementOutput,
+ ScatterD,
+ PermuteDLayout>;
+ using Epilogue =
+ cutlass::epilogue::threadblock::Epilogue;
+};
+} // namespace threadblock
+} // namespace epilogue
+} // namespace cutlass
+
+template
+struct BiasGradEpilogue {
+ using Epilogue =
+ typename cutlass::epilogue::threadblock::EpilogueTensorOp::Epilogue;
+};
+
+template
+struct BiasGradEpilogue {
+ using Epilogue =
+ typename cutlass::epilogue::threadblock::EpilogueVoltaTensorOp::Epilogue;
+};
+
+template
+struct BiasGradEpilogueAffineRankN {
+ using Epilogue = typename cutlass::epilogue::threadblock::EpilogueTensorOpAffineRankN<
+ Rank,
+ Shape_,
+ WarpMmaTensorOp_,
+ PartitionsK,
+ OutputOp_,
+ ElementsPerAccess>::Epilogue;
+};
+
+template
+struct BiasGradEpilogueAffineRankN {
+ using Epilogue = typename cutlass::epilogue::threadblock::EpilogueVoltaTensorOpAffineRankN<
+ Rank,
+ Shape_,
+ WarpMmaTensorOp_,
+ PartitionsK,
+ OutputOp_,
+ ElementsPerAccess>::Epilogue;
+};
diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_pipelined.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_pipelined.h
new file mode 100644
index 000000000000..3b7b32d61452
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_pipelined.h
@@ -0,0 +1,592 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holdvr nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
+
+ File copied from "cutlass/epilogue/threadblock/epilogue.h"
+ then modified to:
+ (1) load 2 source fragments at the same time (pipelining)
+ (2) support reading from a different dtype
+ (3) pass the row id to the OutputOp if it takes it
+ (see MemoryEfficientAttentionNormalize)
+ Note that in general the fragment passed to the OutputOp could
+ span multiple rows but it does not happen with the configurations we have
+*/
+
+#pragma once
+
+#if defined(__CUDACC_RTC__)
+#include
+#else
+#include
+#endif
+
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/functional.h"
+#include "cutlass/layout/tensor.h"
+#include "cutlass/layout/vector.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/tensor_coord.h"
+
+#include "cutlass/gemm/gemm.h"
+
+#include "cutlass/transform/pitch_linear_thread_map.h"
+#include "cutlass/transform/threadblock/regular_tile_iterator.h"
+
+#include "cutlass/epilogue/threadblock/epilogue_base.h"
+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
+#include "cutlass/numeric_types.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace epilogue {
+namespace threadblock {
+
+template
+struct ApplyEpilogueOp {
+ static CUTLASS_DEVICE typename Op::FragmentOutput apply(
+ Op const& output_op,
+ int row_id,
+ typename Op::FragmentAccumulator const& accum,
+ typename Op::FragmentOutput const& source)
+ {
+ return output_op(accum, source);
+ }
+ static CUTLASS_DEVICE typename Op::FragmentOutput
+ apply(Op const& output_op, int row_id, typename Op::FragmentAccumulator const& accum)
+ {
+ return output_op(accum);
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Epilogue operator
+template ::value),
+ typename OutputTileSourceIterator_ =
+ OutputTileIterator_ ///< Tile iterator reading tensors
+ >
+class EpiloguePipelined : public EpilogueBase {
+public:
+ using Base = EpilogueBase;
+
+ using Shape = Shape_;
+ using WarpMmaOperator = WarpMmaOperator_;
+ static int const kPartitionsK = PartitionsK;
+ using OutputTileIterator = OutputTileIterator_;
+ using OutputTileSourceIterator = OutputTileSourceIterator_;
+ using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
+ using WarpTileIterator = WarpTileIterator_;
+ using SharedLoadIterator = SharedLoadIterator_;
+ using OutputOp = OutputOp_;
+ using Padding = Padding_;
+
+ using Layout = layout::RowMajor;
+ using LongIndex = typename Layout::LongIndex;
+
+ /// The complete warp-level accumulator tile
+ using AccumulatorTile = typename Base::AccumulatorTile;
+
+ /// Accumulator element
+ using ElementAccumulator = typename WarpTileIterator::Element;
+
+ /// Output element
+ using ElementOutput = typename OutputTileIterator::Element;
+ using ElementSource = typename OutputTileSourceIterator::Element;
+
+ /// Output access size
+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
+
+ /// Tensor reference to destination tensor
+ using TensorRef = typename OutputTileIterator::TensorRef;
+
+ /// Tensor reference to sync tensor
+ using SyncTensorRef = typename cutlass::TensorRef;
+
+ /// Const tensor reference to source tensor
+ using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
+
+ /// Array type used to output
+ using OutputAccessType =
+ Array;
+ using SourceAccessType = Array;
+
+ /// Array type used by output functor
+ using AccumulatorAccessType =
+ Array;
+
+ /// Number of warps
+ using WarpCount = typename Base::WarpCount;
+
+ static int constexpr kSmemTiles =
+ Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
+ static int constexpr kSmemPointerOffset =
+ Base::SharedStorage::StorageShape::kCount / kSmemTiles;
+
+public:
+ static_assert(OutputTileSourceIterator::Fragment::kElements ==
+ OutputTileIterator::Fragment::kElements,
+ "Mismatch between input tile and output tile iterator (kElements)");
+ static_assert(OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations,
+ "Mismatch between input tile and output tile iterator (kIterations)");
+ static_assert(SharedLoadIterator::Fragment::kElements ==
+ OutputTileIterator::Fragment::kElements,
+ "Mismatch between shared load iterator and output tile iterator.");
+
+ static_assert(OutputTileIterator::kElementsPerAccess,
+ "OutputTileIterator::kElementsPerAccess must not be zero.");
+
+ static_assert(!(OutputTileIterator::Fragment::kElements %
+ OutputTileIterator::kElementsPerAccess),
+ "Divisibility");
+
+private:
+ /// Loads fragment from shared memory aligned with output tensor
+ SharedLoadIterator shared_load_iterator_;
+
+public:
+ /// Constructor
+ CUTLASS_DEVICE
+ EpiloguePipelined(typename Base::SharedStorage& shared_storage, ///< Shared storage object
+ int thread_idx, ///< ID of a thread within the threadblock
+ int warp_idx, ///< ID of warp within threadblock
+ int lane_idx ///< Id of thread within warp
+ )
+ : Base(shared_storage, thread_idx, warp_idx, lane_idx),
+ shared_load_iterator_(shared_storage.reference(), thread_idx)
+ {
+ }
+
+ /// Streams the result to global memory
+ CUTLASS_DEVICE
+ void operator()(OutputOp const& output_op, ///< Output operator
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
+ AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile
+ OutputTileSourceIterator source_iterator)
+ { ///< Threadblock tile coordinate in GEMM (in units
+ ///< of threadblock tiles)
+
+ if (!output_op.is_source_needed()) {
+ compute_source_not_needed_(output_op, destination_iterator, accumulators);
+ } else {
+ compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
+ }
+ }
+ CUTLASS_DEVICE
+ void operator()(OutputOp const& output_op, ///< Output operator
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
+ AccumulatorTile const& accumulators)
+ { ///< Complete warp-level accumulator tile
+ compute_source_not_needed_(output_op, destination_iterator, accumulators);
+ }
+
+private:
+ template
+ struct acc2smem_source_not_needed;
+
+ template
+ struct acc2smem_source_not_needed> {
+ template
+ CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
+ WarpTileIterator& warp_tile_iterator)
+ {
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < Advance; i++) { ++accum_fragment_iterator; }
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
+ typename AccumulatorFragmentIterator::Fragment accum_fragment;
+
+ accum_fragment_iterator.load(accum_fragment);
+ ++accum_fragment_iterator;
+
+ warp_tile_iterator.store(accum_fragment);
+ if (p < Base::kFragmentsPerIteration - 1) {
+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
+ }
+ }
+
+ if (Base::kFragmentsPerIteration > 1) {
+ warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
+ (1 - Base::kFragmentsPerIteration));
+ }
+ }
+
+ CUTLASS_DEVICE
+ static void push(size_t pos,
+ AccumulatorFragmentIterator const& iterator_begin,
+ WarpTileIterator& warp_tile_iterator)
+ {
+ int dummy[] = {
+ (pos == (Seq * Base::kFragmentsPerIteration)) &&
+ (helper(iterator_begin, warp_tile_iterator),
+ 0)...};
+
+ CUTLASS_UNUSED(dummy[0]);
+ }
+ };
+
+ static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1,
+ "One of these must be exactly 1.");
+
+ /// Streams the result to global memory
+ CUTLASS_DEVICE
+ void compute_source_not_needed_(
+ OutputOp const& output_op, ///< Output operator
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
+ AccumulatorTile const& accumulators ///< Complete warp-level accumulator tile
+ )
+ {
+ //
+ // Iterator over warp-level accumulator fragment
+ //
+
+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
+
+ //
+ // Iterate over accumulator tile
+ //
+
+#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \
+ : 1)
+ for (int iter = 0; iter < OutputTileIterator::kIterations;
+ iter += Base::kFragmentsPerIteration) {
+ //
+ // Convert and store fragment
+ //
+
+ __syncthreads();
+
+ acc2smem_source_not_needed>::
+ push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
+
+ __syncthreads();
+
+ //
+ // Load fragments from shared memory
+ //
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
+
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
+
+ if (p < Base::kFragmentsPerIteration - 1) {
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
+ } else if (kPartitionsK > 1) {
+ plus add_fragments;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 1; i < kPartitionsK; ++i) {
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
+ aligned_accum_fragment[0] =
+ add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
+ }
+
+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) *
+ kSmemPointerOffset);
+ }
+
+ //
+ // Compute the output result
+ //
+
+ typename OutputTileIterator::Fragment output_fragment;
+
+ apply_output_operator_source_not_needed_(destination_iterator.thread_start_row(),
+ output_fragment,
+ output_op,
+ aligned_accum_fragment[0]);
+
+ //
+ // Store the final result
+ //
+
+ destination_iterator.store(output_fragment);
+ ++destination_iterator;
+ }
+
+ if (Base::kFragmentsPerIteration > 1) {
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset *
+ (1 - Base::kFragmentsPerIteration));
+ }
+ }
+ }
+
+ template
+ struct acc2smem_source_needed;
+
+ template
+ struct acc2smem_source_needed> {
+ template
+ CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
+ WarpTileIterator& warp_tile_iterator)
+ {
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < Advance; i++) { ++accum_fragment_iterator; }
+
+ typename AccumulatorFragmentIterator::Fragment accum_fragment;
+ accum_fragment_iterator.load(accum_fragment);
+ warp_tile_iterator.store(accum_fragment);
+ }
+
+ CUTLASS_DEVICE
+ static void push(size_t pos,
+ AccumulatorFragmentIterator const& iterator_begin,
+ WarpTileIterator& warp_tile_iterator)
+ {
+ int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...};
+ }
+ };
+
+ /// Streams the result to global memory
+ CUTLASS_DEVICE
+ void compute_source_needed_(
+ OutputOp const& output_op, ///< Output operator
+ OutputTileIterator destination_iterator, ///< Tile iterator for destination
+ AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile
+ OutputTileSourceIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units
+ ///< of threadblock tiles)
+ )
+ {
+ typename OutputTileSourceIterator::Fragment source_fragment[2];
+
+ source_fragment[0].clear();
+ source_iterator.load(source_fragment[0]);
+ ++source_iterator;
+ source_fragment[1].clear();
+
+ //
+ // Iterator over warp-level accumulator fragment
+ //
+
+ AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
+
+ //
+ // Iterate over accumulator tile
+ //
+
+#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
+ for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
+ if (iter > 0) { __syncthreads(); }
+ //
+ // Load the source for next iteration (pipelining)
+ //
+
+ if (iter + 1 < OutputTileIterator::kIterations) {
+ source_iterator.load(source_fragment[(iter + 1) % 2]);
+ }
+ ++source_iterator;
+ acc2smem_source_needed>::
+ push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
+
+ __syncthreads();
+
+ //
+ // Load fragments from shared memory
+ //
+
+ typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
+
+ shared_load_iterator_.load(aligned_accum_fragment[0]);
+
+ // If the number of k-slices is > 1 - perform a reduction amongst the
+ // k-slices
+ if (kPartitionsK > 1) {
+ plus add_fragments;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 1; i < kPartitionsK; ++i) {
+ shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
+ shared_load_iterator_.load(aligned_accum_fragment[i]);
+ aligned_accum_fragment[0] =
+ add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
+ }
+
+ shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
+ }
+
+ //
+ // Compute the output result
+ //
+
+ typename OutputTileIterator::Fragment output_fragment;
+
+ apply_output_operator_(destination_iterator.thread_start_row(),
+ output_fragment,
+ output_op,
+ aligned_accum_fragment[0],
+ source_fragment[iter % 2]);
+
+ //
+ // Store the final result
+ //
+
+ destination_iterator.store(output_fragment);
+ ++destination_iterator;
+ }
+ }
+
+ /// Helper to invoke the output functor over each vector of output
+ CUTLASS_DEVICE
+ void apply_output_operator_(int begin_row,
+ typename OutputTileIterator::Fragment& output_fragment,
+ OutputOp const& output_op, ///< Output operator
+ typename SharedLoadIterator::Fragment const& aligned_accum_fragment,
+ typename OutputTileSourceIterator::Fragment const& source_fragment)
+ {
+ OutputAccessType* output_frag_ptr = reinterpret_cast(&output_fragment);
+
+ AccumulatorAccessType const* compute_frag_ptr =
+ reinterpret_cast(&aligned_accum_fragment);
+
+ SourceAccessType const* source_frag_ptr =
+ reinterpret_cast(&source_fragment);
+
+ int const kOutputOpIterations =
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < kOutputOpIterations; ++i) {
+ // Call the output operator
+ output_frag_ptr[i] = ApplyEpilogueOp::apply(
+ output_op,
+ begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
+ compute_frag_ptr[i],
+ source_frag_ptr[i]);
+ }
+ }
+
+ /// Helper to invoke the output functor over each vector of output
+ CUTLASS_DEVICE
+ void apply_output_operator_source_not_needed_(
+ int begin_row,
+ typename OutputTileIterator::Fragment& output_fragment,
+ OutputOp const& output_op, ///< Output operator
+ typename SharedLoadIterator::Fragment const& aligned_accum_fragment)
+ {
+ OutputAccessType* output_frag_ptr = reinterpret_cast(&output_fragment);
+
+ AccumulatorAccessType const* compute_frag_ptr =
+ reinterpret_cast(&aligned_accum_fragment);
+
+ int const kOutputOpIterations =
+ OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < kOutputOpIterations; ++i) {
+ // Call the output operator
+ output_frag_ptr[i] = ApplyEpilogueOp::apply(
+ output_op,
+ begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
+ compute_frag_ptr[i]);
+ }
+ }
+
+ // This should be constexpr, but it's only supported on c++14
+ static int CUTLASS_HOST_DEVICE getRowOffset(int i)
+ {
+ using ThreadMap = typename OutputTileIterator::ThreadMap;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
+ int row_offset = row * ThreadMap::Delta::kRow +
+ group * ThreadMap::Delta::kGroup +
+ cluster * ThreadMap::Delta::kCluster;
+ int frag_row_idx =
+ (row + ThreadMap::Iterations::kRow *
+ (group + ThreadMap::Iterations::kGroup * cluster));
+ CUTLASS_PRAGMA_UNROLL
+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
+ int frag_idx = ThreadMap::kElementsPerAccess *
+ (frag_row_idx * ThreadMap::Iterations::kColumn + column);
+ if (i < frag_idx + ThreadMap::kElementsPerAccess) { return row_offset; }
+ }
+ }
+ }
+ }
+ return -1;
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace epilogue
+} // namespace cutlass
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h
new file mode 100644
index 000000000000..f81a09f74f1e
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_rescale_output.h
@@ -0,0 +1,251 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holdvr nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
+
+ The epilogue rearranges the result of a matrix product through shared memory
+ to match canonical tensor layouts in global memory. Epilogues support
+ conversion and reduction operations.
+
+ This is a copy of cutlass/epilogue/threadblock/epilogue.h that can
+ handle "row_id" as a first argument, as uses it to get the corresponding
+ `m_prime` / `s_prime` to rescale the output.
+*/
+
+#pragma once
+
+#if defined(__CUDACC_RTC__)
+#include
+#else
+#include
+#endif
+
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/functional.h"
+#include "cutlass/layout/tensor.h"
+#include "cutlass/layout/vector.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/tensor_coord.h"
+
+#include "cutlass/gemm/gemm.h"
+
+#include "cutlass/transform/pitch_linear_thread_map.h"
+#include "cutlass/transform/threadblock/regular_tile_iterator.h"
+
+#include "cutlass/epilogue/threadblock/epilogue_base.h"
+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
+#include "cutlass/numeric_types.h"
+
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/epilogue/thread/scale_type.h"
+#include "cutlass/functional.h"
+#include "cutlass/numeric_conversion.h"
+#include "cutlass/numeric_types.h"
+#include "epilogue_pipelined.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace epilogue {
+namespace thread {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Applies a linear combination operator to an array of elements.
+// output <- alpha * accumulator + beta * source
+// with:
+// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise)
+// beta = alpha / m_prime (renormalize the output when the max changes)
+// source is the current output
+template ,
+ ///< but we use 64 or 32 sometimes when there are not enough data
+ ///< to store
+ typename ElementAccumulator_, ///< Accumulator data type
+ typename ElementCompute_, ///< Data type used to compute linear combination
+ bool isFirst,
+ bool isLast,
+ typename FragmentAlphaBeta_,
+ FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
+class MemoryEfficientAttentionNormalize {
+public:
+ using ElementOutput = ElementOutput_;
+ using ElementSource = ElementSource_;
+ using ElementAccumulator = ElementAccumulator_;
+ using ElementCompute = ElementCompute_;
+
+ static int const kCount = Count;
+
+ using FragmentOutput = Array;
+ using FragmentSource = Array;
+ using FragmentAccumulator = Array;
+ using ComputeFragment = Array;
+ using FragmentAlphaBeta = FragmentAlphaBeta_;
+
+ static FloatRoundStyle const kRound = Round;
+
+private:
+ //
+ // Data members
+ //
+
+ FragmentAlphaBeta const& s_prime_;
+ FragmentAlphaBeta const& m_prime_;
+
+public:
+ /// Constructs the function object, possibly loading from pointers in host
+ /// memory
+ CUTLASS_HOST_DEVICE
+ MemoryEfficientAttentionNormalize(FragmentAlphaBeta const& s_prime,
+ FragmentAlphaBeta const& m_prime)
+ : s_prime_(s_prime), m_prime_(m_prime)
+ {
+ }
+
+ /// Returns true if source is needed
+ CUTLASS_HOST_DEVICE
+ bool is_source_needed() const { return !isFirst; }
+
+ /// Functionally required for serial reduction in the epilogue
+ CUTLASS_HOST_DEVICE
+ void set_k_partition(int k_partition, int k_partition_count) {}
+
+ /// Computes linear scaling: D = alpha * accumulator + beta * source
+ CUTLASS_HOST_DEVICE
+ FragmentOutput operator()(int row,
+ FragmentAccumulator const& accumulator,
+ FragmentSource const& source) const
+ {
+ assert(!isFirst);
+
+ // Convert source to internal compute numeric type
+ NumericArrayConverter source_converter;
+ NumericArrayConverter
+ accumulator_converter;
+
+ // Convert to destination numeric type
+ NumericArrayConverter destination_converter;
+
+ ComputeFragment converted_source = source_converter(source);
+ ComputeFragment converted_accumulator = accumulator_converter(accumulator);
+
+ // Perform binary operations
+ ComputeFragment intermediate;
+
+ multiplies mul_add_source;
+ multiply_add mul_add_accumulator;
+
+ ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
+ ElementCompute beta = alpha * m_prime_[row];
+
+ intermediate = mul_add_source(beta, converted_source); // X = beta * C
+
+ intermediate = mul_add_accumulator(
+ alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
+
+ return destination_converter(intermediate);
+ }
+
+ /// Computes linear scaling: D = alpha * accumulator
+ CUTLASS_HOST_DEVICE
+ FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) const
+ {
+ assert(isFirst);
+
+ // Convert source to internal compute numeric type
+ NumericArrayConverter
+ accumulator_converter;
+
+ // Convert to destination numeric type
+ NumericArrayConverter destination_converter;
+
+ ComputeFragment converted_accumulator = accumulator_converter(accumulator);
+
+ ComputeFragment intermediate;
+ multiplies mul_accumulator;
+
+ ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
+
+ intermediate = mul_accumulator(alpha, converted_accumulator); // X = alpha * C + uniform
+
+ return destination_converter(intermediate);
+ }
+};
+
+} // namespace thread
+
+namespace threadblock {
+template
+struct ApplyEpilogueOp<
+ thread::MemoryEfficientAttentionNormalize> {
+ using Op = thread::MemoryEfficientAttentionNormalize;
+ static CUTLASS_DEVICE typename Op::FragmentOutput apply(
+ Op const& output_op,
+ int row_id,
+ typename Op::FragmentAccumulator const& accum,
+ typename Op::FragmentSource const& source)
+ {
+ return output_op(row_id, accum, source);
+ }
+ static CUTLASS_DEVICE typename Op::FragmentOutput
+ apply(Op const& output_op, int row_id, typename Op::FragmentAccumulator const& accum)
+ {
+ return output_op(row_id, accum);
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace epilogue
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h
new file mode 100644
index 000000000000..46fb2bf17c1c
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h
@@ -0,0 +1,168 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ *POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Functor performing linear combination operations used by epilogues.
+*/
+
+#pragma once
+
+#include
+
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/epilogue/thread/activation.h"
+#include "cutlass/functional.h"
+#include "cutlass/numeric_conversion.h"
+#include "cutlass/numeric_types.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace epilogue {
+namespace thread {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+
+template
+struct ArrayExponential {
+ CUTLASS_HOST_DEVICE
+ Array operator()(
+ Array const& input) const
+ {
+ Array result;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < ElementsPerAccess; ++i) { result[i] = expf(input[i]); }
+
+ return result;
+ }
+};
+
+template
+struct ArrayExponential {
+ CUTLASS_DEVICE
+ Array operator()(Array const& input) const
+ {
+ Array result;
+
+ int const kVectorCount = ElementsPerAccess / 2;
+
+ __half2 const* input_ptr = reinterpret_cast<__half2 const*>(input.raw_data());
+ __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data());
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < kVectorCount; ++i) { res_ptr[i] = h2exp(input_ptr[i]); }
+
+ return result;
+ }
+};
+} // namespace detail
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Applies:
+/// output <- (input - lse).exp()
+template
+class ApplyLogSumExp {
+public:
+ using ElementOutput = ElementOutput_;
+ using ElementAccumulator = ElementAccumulator_;
+ using ElementCompute = ElementCompute_;
+ using ElementLSE = ElementLSE_;
+
+ static int const kElementsPerAccess = ElementsPerAccess;
+ static int const kCount = kElementsPerAccess;
+ static const ScaleType::Kind kScale = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
+
+ using FragmentOutput = Array;
+ using FragmentAccumulator = Array;
+ using FragmentCompute = Array;
+ using FragmentLSE = Array;
+ using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h
+
+public:
+ //
+ // Methods
+ //
+
+ CUTLASS_HOST_DEVICE
+ ApplyLogSumExp() {}
+
+ /// Returns true if source is needed
+ CUTLASS_HOST_DEVICE
+ bool is_source_needed() const { return true; }
+
+ /// Functionally required for serial reduction in the epilogue
+ CUTLASS_HOST_DEVICE
+ void set_k_partition(int k_partition, int k_partition_count) {}
+
+ CUTLASS_HOST_DEVICE
+ FragmentOutput operator()(FragmentAccumulator const& AB,
+ FragmentLSE const& scale_unused,
+ // bias used as LSE
+ FragmentLSE const& bias) const
+ {
+ FragmentCompute frag_AB =
+ NumericArrayConverter()(AB);
+ FragmentCompute frag_lse_compute =
+ NumericArrayConverter()(bias);
+ FragmentCompute frag_compute;
+
+ minus minus_lse;
+ detail::ArrayExponential apply_exp;
+ frag_compute = minus_lse(frag_AB, frag_lse_compute);
+ frag_compute = apply_exp(frag_compute);
+
+ return NumericArrayConverter()(
+ frag_compute);
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace thread
+} // namespace epilogue
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma.h
new file mode 100644
index 000000000000..75833bbfe7d2
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma.h
@@ -0,0 +1,119 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holdvr nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#pragma once
+
+#include "custom_mma_multistage.h"
+#include "custom_mma_pipelined.h"
+#include "cutlass/gemm/threadblock/mma_multistage.h"
+#include "cutlass/gemm/threadblock/mma_pipelined.h"
+
+template
+struct MakeCustomMma;
+
+template
+struct MakeCustomMma,
+ kMaxK> {
+ // Reduce the number of stages if we don't need that many
+ static int constexpr kStages =
+ kMaxK == cutlass::platform::numeric_limits::max()
+ ? Stages
+ : cutlass::const_min(Stages, (kMaxK + int(Shape::kK) - 1) / int(Shape::kK));
+ using Mma = cutlass::gemm::threadblock::CustomMmaMultistage;
+};
+
+template
+struct MakeCustomMma,
+ kMaxK> {
+ using Mma = cutlass::gemm::threadblock::CustomMmaPipelined;
+};
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_base.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_base.h
new file mode 100644
index 000000000000..bbf91240b900
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_base.h
@@ -0,0 +1,181 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ *POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Template for a double-buffered threadblock-scoped GEMM kernel.
+*/
+
+#pragma once
+
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/arch/memory.h"
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/threadblock/mma_base.h"
+#include "cutlass/matrix_shape.h"
+#include "cutlass/numeric_types.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Structure to compute the matrix product targeting CUDA cores and SIMT math
+/// instructions.
+template <
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
+ typename Shape_,
+ /// Policy describing tuning details (concept: MmaPolicy)
+ typename Policy_,
+ /// Number of stages,
+ int Stages,
+ /// Used for partial specialization
+ typename Enable = bool>
+class CustomMmaBase {
+public:
+ ///< Size of the Gemm problem - concept: gemm::GemmShape<>
+ using Shape = Shape_;
+
+ ///< Policy describing tuning details
+ using Policy = Policy_;
+
+ //
+ // Dependent types
+ //
+
+ /// Warp-level Mma
+ using Operator = typename Policy::Operator;
+
+ /// Shape describing the overall GEMM computed from shared memory
+ /// by each warp.
+ using WarpGemm = typename Policy::Operator::Shape;
+
+ /// Shape describing the number of warps filling the CTA
+ using WarpCount =
+ GemmShape;
+
+ /// Number of warp-level GEMM oeprations
+ static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
+
+ /// Number of stages
+ static int const kStages = Stages;
+
+ //
+ // Nested structs
+ //
+
+ /// Shared storage object needed by threadblock-scoped GEMM
+ template
+ struct OperandSharedStorage {
+ AlignedBuffer buffer;
+ using TensorRef = TensorRef;
+
+ CUTLASS_DEVICE
+ static OperandLayout Layout()
+ {
+ return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn});
+ }
+
+ /// Returns a TensorRef to the operand
+ CUTLASS_HOST_DEVICE
+ TensorRef ref() { return TensorRef{buffer.data(), Layout()}; }
+ };
+
+ /// Shape of the A matrix operand in shared memory
+ using ShapeA = MatrixShape;
+
+ /// Shape of the B matrix operand in shared memory
+ using ShapeB = MatrixShape;
+
+ using SharedStorageA =
+ OperandSharedStorage;
+ using SharedStorageB =
+ OperandSharedStorage;
+ using TensorRefA = typename SharedStorageA::TensorRef;
+ using TensorRefB = typename SharedStorageB::TensorRef;
+
+ struct SharedStorage {
+ /// Buffer for A operand
+ SharedStorageA operand_A;
+
+ /// Buffer for B operand
+ SharedStorageB operand_B;
+ };
+
+protected:
+ //
+ // Data members
+ //
+
+ /// Iterator to load a warp-scoped tile of A operand from shared memory
+ typename Operator::IteratorA warp_tile_iterator_A_;
+
+ /// Iterator to load a warp-scoped tile of B operand from shared memory
+ typename Operator::IteratorB warp_tile_iterator_B_;
+
+public:
+ /// Construct from tensor references
+ CUTLASS_DEVICE
+ CustomMmaBase(
+ ///< Shared storage needed for internal use by threadblock-scoped GEMM
+ SharedStorageA& shared_storageA,
+ SharedStorageB& shared_storageB,
+ ///< ID within the threadblock
+ int thread_idx,
+ ///< ID of warp
+ int warp_idx,
+ ///< ID of each thread within a warp
+ int lane_idx)
+ : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx),
+ warp_tile_iterator_B_(shared_storageB.ref(), lane_idx)
+ {
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace gemm
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h
new file mode 100644
index 000000000000..50ba58b1d1dd
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h
@@ -0,0 +1,706 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ *POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Template for a double-buffered threadblock-scoped GEMM kernel.
+*/
+
+#pragma once
+
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/arch/cache_operation.h"
+#include "cutlass/arch/memory.h"
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/matrix_shape.h"
+#include "cutlass/numeric_types.h"
+
+#include "custom_mma_base.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Structure to compute the matrix product targeting CUDA cores and SIMT math
+/// instructions.
+template <
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
+ typename Shape_,
+ /// Iterates over tiles of A operand in global memory
+ // (concept: ReadableTileIterator | ForwardTileIterator |
+ // MaskedTileIterator)
+ typename IteratorA_,
+ /// Iterates over tiles of A operand in shared memory
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+ typename SmemIteratorA_,
+ /// Cache operation for operand A
+ cutlass::arch::CacheOperation::Kind CacheOpA,
+ /// Iterates over tiles of B operand in global memory
+ // (concept: ReadableTileIterator | ForwardTileIterator |
+ // MaskedTileIterator)
+ typename IteratorB_,
+ /// Iterates over tiles of B operand in shared memory
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+ typename SmemIteratorB_,
+ /// Cache operation for operand B
+ cutlass::arch::CacheOperation::Kind CacheOpB,
+ /// Data type of accumulator matrix
+ typename ElementC_,
+ /// Data type of accumulator matrix
+ typename LayoutC_,
+ /// Policy describing tuning details (concept: MmaPolicy)
+ typename Policy_,
+ /// Number of stages,
+ int Stages,
+ /// Use zfill or predicate for out-of-bound cp.async
+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
+ /// Upper boundon the K dimension
+ int kMaxK = cutlass::platform::numeric_limits::max(),
+ /// Used for partial specialization
+ typename Enable = bool>
+class CustomMmaMultistage : public CustomMmaBase {
+public:
+ ///< Base class
+ using Base = CustomMmaBase;
+ ///< Size of the Gemm problem - concept: gemm::GemmShape<>
+ using Shape = Shape_;
+ ///< Iterates over tiles of A operand in global memory
+ using IteratorA = IteratorA_;
+ ///< Iterates over tiles of B operand in global memory
+ using IteratorB = IteratorB_;
+ ///< Data type of accumulator matrix
+ using ElementC = ElementC_;
+ ///< Layout of accumulator matrix
+ using LayoutC = LayoutC_;
+ ///< Policy describing tuning details
+ using Policy = Policy_;
+
+ using SmemIteratorA = SmemIteratorA_;
+ using SmemIteratorB = SmemIteratorB_;
+
+ static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
+ static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
+
+ //
+ // Dependent types
+ //
+
+ /// Fragment of accumulator tile
+ using FragmentC = typename Policy::Operator::FragmentC;
+
+ /// Warp-level Mma
+ using Operator = typename Policy::Operator;
+
+ /// Minimum architecture is Sm80 to support cp.async
+ using ArchTag = arch::Sm80;
+
+ /// Complex transform on A operand
+ static ComplexTransform const kTransformA = Operator::kTransformA;
+
+ /// Complex transform on B operand
+ static ComplexTransform const kTransformB = Operator::kTransformB;
+
+ /// Internal structure exposed for introspection.
+ struct Detail {
+ static_assert(Base::kWarpGemmIterations > 1,
+ "The pipelined structure requires at least two warp-level "
+ "GEMM operations.");
+
+ /// Number of cp.async instructions to load one stage of operand A
+ static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
+
+ /// Number of cp.async instructions to load one stage of operand B
+ static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
+
+ /// Number of stages
+ static int const kStages = Stages;
+
+ /// Number of cp.async instructions to load on group of operand A
+ static int const kAccessesPerGroupA =
+ (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) /
+ Base::kWarpGemmIterations;
+
+ /// Number of cp.async instructions to load on group of operand B
+ static int const kAccessesPerGroupB =
+ (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) /
+ Base::kWarpGemmIterations;
+ };
+
+ static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages;
+ static constexpr int kNumStagesConcurrentLoad = kSmemContainsEntireMat ? Stages : Stages - 1;
+
+private:
+ using WarpLoadedFragmentA = typename Operator::FragmentA;
+ using WarpLoadedFragmentB = typename Operator::FragmentB;
+ using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
+ using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
+
+private:
+ //
+ // Data members
+ //
+
+ /// Iterator to write threadblock-scoped tile of A operand to shared memory
+ SmemIteratorA smem_iterator_A_;
+
+ /// Iterator to write threadblock-scoped tile of B operand to shared memory
+ SmemIteratorB smem_iterator_B_;
+
+ bool prologue_done_;
+
+ // Set to `True` to ensure the accumulator will be zero outside the GEMM
+ // footprint
+ bool zero_outside_bounds_;
+
+public:
+ /// Construct from tensor references
+ CUTLASS_DEVICE
+ CustomMmaMultistage(
+ ///< Shared storage needed for internal use by threadblock-scoped GEMM
+ typename Base::SharedStorageA& shared_storageA,
+ typename Base::SharedStorageB& shared_storageB,
+ ///< ID within the threadblock
+ int thread_idx,
+ ///< ID of warp
+ int warp_idx,
+ ///< ID of each thread within a warp
+ int lane_idx)
+ : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
+ smem_iterator_A_(shared_storageA.ref(), thread_idx),
+ smem_iterator_B_(shared_storageB.ref(), thread_idx),
+ prologue_done_(false),
+ zero_outside_bounds_(false)
+ {
+ // Compute warp location within threadblock tile by mapping the warp_id to
+ // three coordinates:
+ // _m: the warp's position within the threadblock along the M dimension
+ // _n: the warp's position within the threadblock along the N dimension
+ // _k: the warp's position within the threadblock along the K dimension
+
+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
+
+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
+
+ // Add per-warp offsets in units of warp-level tiles
+ this->warp_tile_iterator_A_.add_tile_offset(
+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
+ this->warp_tile_iterator_B_.add_tile_offset(
+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
+ }
+ CUTLASS_DEVICE
+ CustomMmaMultistage(
+ ///< Shared storage needed for internal use by threadblock-scoped GEMM
+ typename Base::SharedStorage& st,
+ ///< ID within the threadblock
+ int thread_idx,
+ ///< ID of warp
+ int warp_idx,
+ ///< ID of each thread within a warp
+ int lane_idx)
+ : CustomMmaMultistage(st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx)
+ {
+ }
+
+ CUTLASS_DEVICE
+ bool set_prologue_done(bool value) { prologue_done_ = value; }
+
+ CUTLASS_DEVICE
+ bool set_zero_outside_bounds(bool value) { zero_outside_bounds_ = value; }
+
+ template
+ CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage,
+ ///< iterator over A operand in global memory
+ IteratorA iterator_A,
+ ///< iterator over B operand in global memory
+ IteratorB iterator_B,
+ int thread_idx,
+ int problem_size_k)
+ {
+ prologue(shared_storage.operand_A,
+ shared_storage.operand_B,
+ iterator_A,
+ iterator_B,
+ thread_idx,
+ problem_size_k);
+ }
+
+ template
+ CUTLASS_DEVICE static void prologue(typename Base::SharedStorageA& shared_storageA,
+ typename Base::SharedStorageB& shared_storageB,
+ ///< iterator over A operand in global memory
+ IteratorA iterator_A,
+ ///< iterator over B operand in global memory
+ IteratorB iterator_B,
+ int thread_idx,
+ int problem_size_k)
+ {
+ SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx);
+ SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx);
+ int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK;
+ _prologue(iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B);
+ }
+
+ CUTLASS_DEVICE
+ void copy_tiles_and_advance(IteratorA& iterator_A,
+ IteratorB& iterator_B,
+ int group_start_A = 0,
+ int group_start_B = 0)
+ {
+ iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
+ this->smem_iterator_A_.set_iteration_index(group_start_A);
+
+ // Async Copy for operand A
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
+ if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
+ typename IteratorA::AccessType* dst_ptr =
+ reinterpret_cast(this->smem_iterator_A_.get());
+
+ int const kSrcBytes = sizeof_bits::value *
+ IteratorA::ThreadMap::kElementsPerAccess /
+ IteratorA::kAccessesPerVector / 8;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
+ auto gmem_ptr = iterator_A.get();
+
+ if (zero_outside_bounds_ ||
+ SharedMemoryClear == SharedMemoryClearOption::kZfill) {
+ cutlass::arch::cp_async_zfill(
+ dst_ptr + v, gmem_ptr, iterator_A.valid());
+ } else {
+ cutlass::arch::cp_async(
+ dst_ptr + v, gmem_ptr, iterator_A.valid());
+ }
+
+ ++iterator_A;
+ }
+
+ ++this->smem_iterator_A_;
+ }
+ }
+
+ iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
+ this->smem_iterator_B_.set_iteration_index(group_start_B);
+
+ // Async Copy for operand B
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
+ if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
+ typename IteratorB::AccessType* dst_ptr =
+ reinterpret_cast(this->smem_iterator_B_.get());
+
+ int const kSrcBytes = sizeof_bits::value *
+ IteratorB::ThreadMap::kElementsPerAccess /
+ IteratorB::kAccessesPerVector / 8;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
+ auto gmem_ptr = iterator_B.get();
+
+ if (zero_outside_bounds_ ||
+ SharedMemoryClear == SharedMemoryClearOption::kZfill) {
+ cutlass::arch::cp_async_zfill(
+ dst_ptr + v, gmem_ptr, iterator_B.valid());
+ } else {
+ cutlass::arch::cp_async(
+ dst_ptr + v, gmem_ptr, iterator_B.valid());
+ }
+
+ ++iterator_B;
+ }
+ ++this->smem_iterator_B_;
+ }
+ }
+ }
+
+ template
+ CUTLASS_DEVICE static void _prologue(IteratorA& iterator_A,
+ IteratorB& iterator_B,
+ int32_t& gemm_k_iterations,
+ SmemIteratorA& smem_iterator_A_,
+ SmemIteratorB& smem_iterator_B_)
+ {
+ // Issue several complete stages
+ CUTLASS_PRAGMA_UNROLL
+ for (int stage = 0; stage < kNumStagesConcurrentLoad; ++stage, --gemm_k_iterations) {
+ iterator_A.clear_mask(gemm_k_iterations == 0);
+ iterator_B.clear_mask(gemm_k_iterations == 0);
+
+ iterator_A.set_iteration_index(0);
+ smem_iterator_A_.set_iteration_index(0);
+
+ // Async Copy for operand A
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
+ typename IteratorA::AccessType* dst_ptr =
+ reinterpret_cast(smem_iterator_A_.get());
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
+ int const kSrcBytes = sizeof_bits::value *
+ IteratorA::ThreadMap::kElementsPerAccess /
+ IteratorA::kAccessesPerVector / 8;
+
+ int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
+
+ if (kLoadA) {
+ cutlass::arch::cp_async_zfill(
+ dst_ptr + v, iterator_A.get(), iterator_A.valid());
+ }
+
+ ++iterator_A;
+ }
+
+ ++smem_iterator_A_;
+ }
+
+ iterator_B.set_iteration_index(0);
+ smem_iterator_B_.set_iteration_index(0);
+
+ // Async Copy for operand B
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
+ typename IteratorB::AccessType* dst_ptr =
+ reinterpret_cast(smem_iterator_B_.get());
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
+ int const kSrcBytes = sizeof_bits::value *
+ IteratorB::ThreadMap::kElementsPerAccess /
+ IteratorB::kAccessesPerVector / 8;
+
+ if (kLoadB) {
+ cutlass::arch::cp_async_zfill(
+ dst_ptr + v, iterator_B.get(), iterator_B.valid());
+ }
+
+ ++iterator_B;
+ }
+
+ ++smem_iterator_B_;
+ }
+
+ // Move to the next stage
+ iterator_A.add_tile_offset({0, 1});
+ iterator_B.add_tile_offset({1, 0});
+
+ smem_iterator_A_.add_tile_offset({0, 1});
+ smem_iterator_B_.add_tile_offset({1, 0});
+
+ // Defines the boundary of a stage of cp.async.
+ cutlass::arch::cp_async_fence();
+ }
+ }
+
+ /// Perform a threadblock-scoped matrix multiply-accumulate
+ CUTLASS_DEVICE
+ void operator()(
+ ///< problem size of GEMM
+ int gemm_k_iterations,
+ ///< destination accumulator tile
+ FragmentC& accum,
+ ///< iterator over A operand in global memory
+ IteratorA iterator_A,
+ ///< iterator over B operand in global memory
+ IteratorB iterator_B,
+ ///< initial value of accumulator
+ FragmentC const& src_accum)
+ {
+ //
+ // Prologue
+ //
+
+ if (!prologue_done_) {
+ _prologue(
+ iterator_A, iterator_B, gemm_k_iterations, smem_iterator_A_, smem_iterator_B_);
+ } else if (!kSmemContainsEntireMat) {
+ _prologue(
+ iterator_A, iterator_B, gemm_k_iterations, smem_iterator_A_, smem_iterator_B_);
+ } else {
+ gemm_k_iterations -= kNumStagesConcurrentLoad;
+ }
+
+ // Perform accumulation in the 'd' output operand
+ accum = src_accum;
+
+ //
+ // Clear the remaining tiles of SMEM. This is a functional requirement for
+ // some kernels so that all accumulator elements outside the GEMM footprint
+ // are zero.
+ //
+
+ if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
+ /// Iterator to write threadblock-scoped tile of A operand to shared
+ /// memory
+ SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
+
+ typename IteratorA::AccessType zero_A;
+ zero_A.clear();
+
+ last_smem_iterator_A.set_iteration_index(0);
+
+ // Async Copy for operand A
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
+ typename IteratorA::AccessType* dst_ptr =
+ reinterpret_cast(last_smem_iterator_A.get());
+
+ *dst_ptr = zero_A;
+
+ ++last_smem_iterator_A;
+ }
+
+ /// Iterator to write threadblock-scoped tile of B operand to shared
+ /// memory
+ SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
+ typename IteratorB::AccessType zero_B;
+
+ zero_B.clear();
+ last_smem_iterator_B.set_iteration_index(0);
+
+ // Async Copy for operand B
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
+ typename IteratorB::AccessType* dst_ptr =
+ reinterpret_cast(last_smem_iterator_B.get());
+
+ *dst_ptr = zero_B;
+
+ ++last_smem_iterator_B;
+ }
+ }
+
+ // Waits until kStages-2 stages have committed.
+ cutlass::arch::cp_async_wait();
+ __syncthreads();
+
+ // Pair of fragments used to overlap shared memory loads and math
+ // instructions
+ WarpLoadedFragmentA warp_loaded_frag_A[2];
+ WarpLoadedFragmentB warp_loaded_frag_B[2];
+ WarpTransformedFragmentA warp_transformed_frag_A[2];
+ WarpTransformedFragmentB warp_transformed_frag_B[2];
+
+ Operator warp_mma;
+
+ this->warp_tile_iterator_A_.set_kgroup_index(0);
+ this->warp_tile_iterator_B_.set_kgroup_index(0);
+
+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
+
+ ++this->warp_tile_iterator_A_;
+ ++this->warp_tile_iterator_B_;
+
+ iterator_A.clear_mask(gemm_k_iterations == 0);
+ iterator_B.clear_mask(gemm_k_iterations == 0);
+
+ int smem_write_stage_idx = Base::kStages - 1;
+ int smem_read_stage_idx = 0;
+
+ warp_mma.transform(warp_transformed_frag_A[0],
+ warp_transformed_frag_B[0],
+ warp_loaded_frag_A[0],
+ warp_loaded_frag_B[0]);
+
+ // tf32x3 kernels use staging accumulation. warp_mma uses a temporary
+ // accumulator and this temporary accumulator is added to the final
+ // accumulator once in every mainloop iteration.
+ plus plus_accum;
+
+ FragmentC tmp_accum;
+
+ if (platform::is_same::value ||
+ platform::is_same::value) {
+ tmp_accum.clear();
+ }
+
+ //
+ // Mainloop
+ //
+
+ CUTLASS_GEMM_LOOP
+ for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) {
+ //
+ // Loop over GEMM K dimension
+ //
+
+ // Computes a warp-level GEMM on data held in shared memory
+ // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
+ CUTLASS_PRAGMA_UNROLL
+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
+ // Load warp-level tiles from shared memory, wrapping to k offset if
+ // this is the last group as the case may be.
+
+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
+ Base::kWarpGemmIterations);
+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) %
+ Base::kWarpGemmIterations);
+
+ // In case of a non-circular buffer ("kSmemContainsEntireMat")
+ // make sure we don't load out of bounds data.
+ if (!kSmemContainsEntireMat || gemm_k_iterations > (-kNumStagesConcurrentLoad) ||
+ warp_mma_k < Base::kWarpGemmIterations - 1) {
+ this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
+ this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
+ }
+
+ ++this->warp_tile_iterator_A_;
+ ++this->warp_tile_iterator_B_;
+
+ if (warp_mma_k > 0)
+ warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
+ warp_transformed_frag_B[warp_mma_k % 2],
+ warp_loaded_frag_A[warp_mma_k % 2],
+ warp_loaded_frag_B[warp_mma_k % 2]);
+
+ if (platform::is_same::value ||
+ platform::is_same::value) {
+ warp_mma(tmp_accum,
+ warp_transformed_frag_A[warp_mma_k % 2],
+ warp_transformed_frag_B[warp_mma_k % 2],
+ tmp_accum);
+
+ if (warp_mma_k == 0) {
+ accum = plus_accum(accum, tmp_accum);
+ tmp_accum.clear();
+ }
+ } else {
+ warp_mma(accum,
+ warp_transformed_frag_A[warp_mma_k % 2],
+ warp_transformed_frag_B[warp_mma_k % 2],
+ accum);
+ }
+
+ // Issue global->shared copies for the this stage
+ if (!kSmemContainsEntireMat && warp_mma_k < Base::kWarpGemmIterations - 1) {
+ int group_start_iteration_A, group_start_iteration_B;
+
+ group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
+ group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
+
+ copy_tiles_and_advance(
+ iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
+ }
+
+ if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
+ if (!kSmemContainsEntireMat) {
+ int group_start_iteration_A, group_start_iteration_B;
+ group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
+ group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
+
+ copy_tiles_and_advance(iterator_A,
+ iterator_B,
+ group_start_iteration_A,
+ group_start_iteration_B);
+ }
+
+ // Inserts a memory fence between stages of cp.async instructions.
+ cutlass::arch::cp_async_fence();
+
+ // Waits until kStages-2 stages have committed.
+ cutlass::arch::cp_async_wait();
+ __syncthreads();
+
+ // Move to the next stage
+ iterator_A.add_tile_offset({0, 1});
+ iterator_B.add_tile_offset({1, 0});
+
+ this->smem_iterator_A_.add_tile_offset({0, 1});
+ this->smem_iterator_B_.add_tile_offset({1, 0});
+
+ // Add negative offsets to return iterators to the 'start' of the
+ // circular buffer in shared memory
+ if (smem_write_stage_idx == (Base::kStages - 1)) {
+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
+ smem_write_stage_idx = 0;
+ } else {
+ ++smem_write_stage_idx;
+ }
+
+ if (!kSmemContainsEntireMat && smem_read_stage_idx == (Base::kStages - 1)) {
+ this->warp_tile_iterator_A_.add_tile_offset(
+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
+ this->warp_tile_iterator_B_.add_tile_offset(
+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
+ smem_read_stage_idx = 0;
+ } else {
+ ++smem_read_stage_idx;
+ }
+
+ --gemm_k_iterations;
+ iterator_A.clear_mask(gemm_k_iterations == 0);
+ iterator_B.clear_mask(gemm_k_iterations == 0);
+ }
+
+ // Do any conversions feeding the first stage at the end of the loop so
+ // we can start right away on mma instructions
+ if (warp_mma_k + 1 == Base::kWarpGemmIterations)
+ warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
+ warp_transformed_frag_B[(warp_mma_k + 1) % 2],
+ warp_loaded_frag_A[(warp_mma_k + 1) % 2],
+ warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
+ }
+ }
+
+ if (platform::is_same::value ||
+ platform::is_same::value) {
+ accum = plus_accum(accum, tmp_accum);
+ }
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace gemm
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_pipelined.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_pipelined.h
new file mode 100644
index 000000000000..07b26ca31299
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_pipelined.h
@@ -0,0 +1,388 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ *POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Template for a double-buffered threadblock-scoped GEMM kernel.
+*/
+
+#pragma once
+
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/numeric_conversion.h"
+
+#include "cutlass/matrix_shape.h"
+#include "cutlass/numeric_types.h"
+
+#include "custom_mma_base.h"
+#include "cutlass/gemm/gemm.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Structure to compute the matrix product targeting CUDA cores and SIMT math
+/// instructions.
+template <
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
+ typename Shape_,
+ /// Iterates over tiles of A operand in global memory
+ // (concept: ReadableTileIterator | ForwardTileIterator |
+ // MaskedTileIterator)
+ typename IteratorA_,
+ /// Iterates over tiles of A operand in shared memory
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+ typename SmemIteratorA_,
+ /// Iterates over tiles of B operand in global memory
+ // (concept: ReadableTileIterator | ForwardTileIterator |
+ // MaskedTileIterator)
+ typename IteratorB_,
+ /// Iterates over tiles of B operand in shared memory
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+ typename SmemIteratorB_,
+ /// Data type of accumulator matrix
+ typename ElementC_,
+ /// Data type of accumulator matrix
+ typename LayoutC_,
+ /// Policy describing tuning details (concept: MmaPolicy)
+ typename Policy_,
+ /// Transformation applied to A operand
+ typename TransformA_ = NumericArrayConverter,
+ ///
+ /// Transformation applied to B operand
+ typename TransformB_ = NumericArrayConverter,
+ /// Used for partial specialization
+ typename Enable = bool>
+class CustomMmaPipelined : public CustomMmaBase {
+public:
+ ///< Base class
+ using Base = CustomMmaBase;
+
+ using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
+ using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
+ using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
+ using ElementC = ElementC_; ///< Data type of accumulator matrix
+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix
+ using Policy = Policy_; ///< Policy describing tuning details
+
+ using SmemIteratorA = SmemIteratorA_;
+ using SmemIteratorB = SmemIteratorB_;
+
+ using TransformA = TransformA_;
+ using TransformB = TransformB_;
+
+ //
+ // Dependent types
+ //
+
+ /// Fragment of operand A loaded from global memory
+ using FragmentA = typename IteratorA::Fragment;
+
+ /// Fragment of operand B loaded from global memory
+ using FragmentB = typename IteratorB::Fragment;
+
+ /// Fragment of accumulator tile
+ using FragmentC = typename Policy::Operator::FragmentC;
+
+ /// Warp-level Mma
+ using Operator = typename Policy::Operator;
+
+ /// Obtain the arch tag from the warp-level operator
+ using ArchTag = typename Policy::Operator::ArchTag;
+
+ /// Complex transform on A operand
+ static ComplexTransform const kTransformA = Operator::kTransformA;
+
+ /// Complex transform on B operand
+ static ComplexTransform const kTransformB = Operator::kTransformB;
+
+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
+ static_assert((Base::kStages == 2), "MmaPipelined requires kStages set to value 2");
+
+ static bool const kSmemContainsEntireMat = false;
+
+private:
+ using WarpFragmentA = typename Operator::FragmentA;
+ using WarpFragmentB = typename Operator::FragmentB;
+
+protected:
+ /// Iterator to write threadblock-scoped tile of A operand to shared memory
+ SmemIteratorA smem_iterator_A_;
+
+ /// Iterator to write threadblock-scoped tile of B operand to shared memory
+ SmemIteratorB smem_iterator_B_;
+
+public:
+ /// Construct from tensor references
+ CUTLASS_DEVICE
+ CustomMmaPipelined(typename Base::SharedStorageA& shared_storageA,
+ typename Base::SharedStorageB& shared_storageB,
+ int thread_idx, ///< ID within the threadblock
+ int warp_idx, ///< ID of warp
+ int lane_idx ///< ID of each thread within a warp
+ )
+ : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
+ smem_iterator_A_(shared_storageA.ref(), thread_idx),
+ smem_iterator_B_(shared_storageB.ref(), thread_idx)
+ {
+ // Compute warp location within threadblock tile by mapping the warp_id to
+ // three coordinates:
+ // _m: the warp's position within the threadblock along the M dimension
+ // _n: the warp's position within the threadblock along the N dimension
+ // _k: the warp's position within the threadblock along the K dimension
+
+ int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
+ int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
+
+ int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
+ int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
+
+ // Add per-warp offsets in units of warp-level tiles
+ this->warp_tile_iterator_A_.add_tile_offset(
+ {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
+ this->warp_tile_iterator_B_.add_tile_offset(
+ {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
+ }
+ CUTLASS_DEVICE
+ CustomMmaPipelined(
+ ///< Shared storage needed for internal use by threadblock-scoped GEMM
+ typename Base::SharedStorage& st,
+ ///< ID within the threadblock
+ int thread_idx,
+ ///< ID of warp
+ int warp_idx,
+ ///< ID of each thread within a warp
+ int lane_idx)
+ : CustomMmaPipelined(st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx)
+ {
+ }
+
+ CUTLASS_DEVICE
+ bool set_prologue_done(bool value)
+ {
+ // NOT IMPLEMENTED FOR PIPELINED
+ }
+
+ CUTLASS_DEVICE
+ bool set_zero_outside_bounds(bool value)
+ {
+ // NOT NEEDED FOR PIPELINED
+ // shared memory will always be zero-filled
+ }
+
+ template
+ CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage,
+ ///< iterator over A operand in global memory
+ IteratorA iterator_A,
+ ///< iterator over B operand in global memory
+ IteratorB iterator_B,
+ int thread_idx,
+ int problem_size_k)
+ {
+ prologue(shared_storage.operand_A,
+ shared_storage.operand_B,
+ iterator_A,
+ iterator_B,
+ thread_idx,
+ problem_size_k);
+ }
+
+ template
+ CUTLASS_DEVICE static void prologue(typename Base::SharedStorageA& shared_storageA,
+ typename Base::SharedStorageB& shared_storageB,
+ ///< iterator over A operand in global memory
+ IteratorA iterator_A,
+ ///< iterator over B operand in global memory
+ IteratorB iterator_B,
+ int thread_idx,
+ int problem_size_k)
+ {
+ // NOT IMPLEMENTED FOR PIPELINED
+ }
+
+ /// Perform a threadblock-scoped matrix multiply-accumulate
+ CUTLASS_DEVICE
+ void operator()(
+ int gemm_k_iterations, ///< number of iterations of the mainloop
+ FragmentC& accum, ///< destination accumulator tile
+ IteratorA iterator_A, ///< iterator over A operand in global memory
+ IteratorB iterator_B, ///< iterator over B operand in global memory
+ FragmentC const& src_accum, ///< source accumulator tile
+ TransformA transform_A = TransformA(), ///< transformation applied to A fragment
+ TransformB transform_B = TransformB())
+ { ///< transformation applied to B fragment
+
+ //
+ // Prologue
+ //
+
+ // Perform accumulation in the 'd' output operand
+ accum = src_accum;
+
+ FragmentA tb_frag_A;
+ FragmentB tb_frag_B;
+
+ tb_frag_A.clear();
+ tb_frag_B.clear();
+
+ // The last kblock is loaded in the prolog
+ iterator_A.load(tb_frag_A);
+ iterator_B.load(tb_frag_B);
+
+ ++iterator_A;
+ ++iterator_B;
+
+ this->smem_iterator_A_.store(transform_A(tb_frag_A));
+ this->smem_iterator_B_.store(transform_B(tb_frag_B));
+
+ ++this->smem_iterator_A_;
+ ++this->smem_iterator_B_;
+
+ __syncthreads();
+
+ // Pair of fragments used to overlap shared memory loads and math
+ // instructions
+ WarpFragmentA warp_frag_A[2];
+ WarpFragmentB warp_frag_B[2];
+
+ this->warp_tile_iterator_A_.set_kgroup_index(0);
+ this->warp_tile_iterator_B_.set_kgroup_index(0);
+
+ this->warp_tile_iterator_A_.load(warp_frag_A[0]);
+ this->warp_tile_iterator_B_.load(warp_frag_B[0]);
+
+ ++this->warp_tile_iterator_A_;
+ ++this->warp_tile_iterator_B_;
+
+ Operator warp_mma;
+
+ int smem_write_stage_idx = 1;
+
+ // Avoid reading out of bounds
+ iterator_A.clear_mask(gemm_k_iterations <= 1);
+ iterator_B.clear_mask(gemm_k_iterations <= 1);
+
+ // Issue loads during the first warp-level matrix multiply-add *AFTER*
+ // issuing shared memory loads (which have the tightest latency requirement).
+
+ //
+ // Mainloop
+ //
+
+ // Note: The main loop does not support Base::kWarpGemmIterations == 2.
+ CUTLASS_GEMM_LOOP
+ for (; gemm_k_iterations > 0; --gemm_k_iterations) {
+ //
+ // Loop over GEMM K dimension
+ //
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
+ // Load warp-level tiles from shared memory, wrapping to k offset if
+ // this is the last group as the case may be.
+
+ if (warp_mma_k == Base::kWarpGemmIterations - 1) {
+ // Write fragments to shared memory
+ this->smem_iterator_A_.store(transform_A(tb_frag_A));
+
+ this->smem_iterator_B_.store(transform_B(tb_frag_B));
+
+ __syncthreads();
+
+ ++this->smem_iterator_A_;
+ ++this->smem_iterator_B_;
+
+ // Add negative offsets to return iterators to the 'start' of the
+ // circular buffer in shared memory
+ if (smem_write_stage_idx == 1) {
+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
+ this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
+ } else {
+ this->warp_tile_iterator_A_.add_tile_offset(
+ {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
+ this->warp_tile_iterator_B_.add_tile_offset(
+ {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
+ }
+
+ smem_write_stage_idx ^= 1;
+ }
+
+ this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
+ Base::kWarpGemmIterations);
+ this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) %
+ Base::kWarpGemmIterations);
+
+ this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
+ this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
+
+ ++this->warp_tile_iterator_A_;
+ ++this->warp_tile_iterator_B_;
+
+ if (warp_mma_k == 0) {
+ iterator_A.load(tb_frag_A);
+ iterator_B.load(tb_frag_B);
+
+ ++iterator_A;
+ ++iterator_B;
+
+ // Avoid reading out of bounds if this was the last loop iteration
+ iterator_A.clear_mask(gemm_k_iterations <= 2);
+ iterator_B.clear_mask(gemm_k_iterations <= 2);
+ }
+
+ warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
+ }
+ }
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace gemm
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/find_default_mma.h b/csrc/deepspeed4science/evoformer_attn/gemm/find_default_mma.h
new file mode 100644
index 000000000000..163dcbf85259
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/gemm/find_default_mma.h
@@ -0,0 +1,191 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holdvr nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*! \file
+ \brief Cutlass provides helper template functions to figure out the right
+ datastructures to instantiate to run a GEMM with various parameters (see
+ `cutlass/gemm/threadblock/default_mma.h`). However, due to template
+ instantiation priority rules, it will only create an MmaMultiStage with
+ kStages=3 (otherwise creates an MmePipelined - which is not compatible with
+ FastF32). kStages=3 uses too much shared memory and we want to use kStages=2,
+ so we just copy-pasted some code from `default_mma.h` and
+ `default_mma_core.h` files and wrapped this template to allow our usecase.
+
+ This is really only for the FastF32 case - aka using TensorCores with fp32.
+*/
+
+#pragma once
+
+#include "cutlass/gemm/threadblock/default_mma.h"
+#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+template <
+ /// Element type for A matrix operand
+ typename ElementA,
+ /// Layout type for A matrix operand
+ typename LayoutA,
+ /// Access granularity of A matrix in units of elements
+ int kAlignmentA,
+ /// Element type for B matrix operand
+ typename ElementB,
+ /// Layout type for B matrix operand
+ typename LayoutB,
+ /// Access granularity of B matrix in units of elements
+ int kAlignmentB,
+ /// Element type for internal accumulation
+ typename ElementAccumulator,
+ /// Layout type for C and D matrix operand
+ typename LayoutC,
+ /// Operator class tag
+ typename OperatorClass,
+ /// Tag indicating architecture to tune for
+ typename ArchTag,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape,
+ /// Instruction-level tile size (concept: GemmShape)
+ typename InstructionShape,
+ /// Number of stages used in the pipelined mainloop
+ int Stages,
+ /// Operation performed by GEMM
+ typename Operator,
+ typename Enable_ = void>
+struct FindDefaultMma {
+ static constexpr bool AccumulatorsInRowMajor = false;
+ static constexpr SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone;
+ using DefaultMma = cutlass::gemm::threadblock::DefaultMma;
+};
+
+/// Specialization for sm80 / FastF32 / multistage with kStages=2
+template
+struct FindDefaultMma 1)>::type> {
+ using LayoutC = layout::RowMajor;
+ using OperatorClass = arch::OpClassTensorOp;
+ using ArchTag = arch::Sm80;
+
+ using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma;
+ struct DefaultMma : DefaultMma_ {
+ using MmaCore_ = typename DefaultMma_::MmaCore;
+ // Define the threadblock-scoped multistage matrix multiply
+ using ThreadblockMma =
+ cutlass::gemm::threadblock::MmaMultistage;
+ };
+};
+
+} // namespace threadblock
+} // namespace gemm
+} // namespace cutlass
diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/mma_accum_lambda_iterator.h b/csrc/deepspeed4science/evoformer_attn/gemm/mma_accum_lambda_iterator.h
new file mode 100644
index 000000000000..5e2f0cf681bf
--- /dev/null
+++ b/csrc/deepspeed4science/evoformer_attn/gemm/mma_accum_lambda_iterator.h
@@ -0,0 +1,347 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holdvr nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#pragma once
+
+#include "cutlass/functional.h"
+#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
+#include "cutlass/matrix_shape.h"
+
+/*
+TensorCores have different accumulator layouts.
+This file provides a class to easily map the accumulator
+i-th element with the corresponding matrix row/col.
+*/
+
+template
+struct AccumLambdaIteratorSm80 {
+ static_assert(cutlass::platform::is_same::value,
+ "only RowMajor is supported");
+
+ using Policy = typename T::Policy;
+ using InstructionShape = typename T::InstructionShape;
+ using OpDelta = typename T::OpDelta;
+ using Shape = typename T::Shape;
+ static int const kElementsPerAccess = InstructionShape::kN / 4;
+ static int const kRowsPerTile = 8;
+ static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile;
+
+ static cutlass::MatrixCoord CUTLASS_DEVICE
+ get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset)
+ {
+ int quad = (lane_id >> 2);
+ int lane_in_quad = (lane_id & 3);
+ return cutlass::MatrixCoord(
+ quad + tile_offset.row() * Shape::kRow,
+ lane_in_quad * kElementsPerAccess + tile_offset.column() * Shape::kColumn);
+ }
+
+ template
+ CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset,
+ FA beginRow,
+ FB op,
+ FC endRow)
+ {
+ // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
+ CUTLASS_PRAGMA_UNROLL
+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int row = 0; row < kAccumulatorRows; ++row) {
+ int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + row * kRowsPerTile +
+ lane_offset.row();
+ beginRow(accum_m);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
+ int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
+ (mma_n * Policy::MmaIterations::kRow + mma_m);
+ CUTLASS_PRAGMA_UNROLL
+ for (int col = 0; col < kElementsPerAccess; ++col) {
+ int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col +
+ lane_offset.column();
+ int idx = mma_accum_start + row * kElementsPerAccess + col;
+ op(accum_m, accum_n, idx);
+ }
+ }
+
+ endRow(accum_m);
+ }
+ }
+ }
+
+ template
+ CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn)
+ {
+ // In each warp, 4 threads will work on the same row
+ // - the ones with the same `quad`
+ auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1);
+ myValue = fn(myValue, otherV);
+ otherV = __shfl_xor_sync(0xffffffff, myValue, 2);
+ myValue = fn(myValue, otherV);
+ int lane_in_quad = (lane_id & 3);
+ return lane_in_quad == 0;
+ }
+};
+
+template
+struct AccumLambdaIteratorSm70 {
+ static_assert(cutlass::platform::is_same::value,
+ "only RowMajor is supported");
+
+ using Policy = typename T::Policy;
+ using InstructionShape = typename T::InstructionShape;
+ using OpDelta = typename T::OpDelta;
+ using Shape = typename T::Shape;
+ using Element = accum_t;
+
+ static int const kElementsPerPartial = 4;
+ using EleShapePerPatial =
+ typename cutlass::platform::conditional::value,
+ cutlass::MatrixShape<2, 2>,
+ cutlass::MatrixShape<1, 4>>::type;
+ static int const kElementsPerMma = 8;
+ static int const kAccumulatorPatials = 2;
+ using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>;
+
+ static cutlass::MatrixCoord CUTLASS_DEVICE
+ get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset)
+ {
+ int quad = (lane_id >> 2);
+ int lane_in_quad = (lane_id & 3);
+ int accum_m, accum_n;
+
+ if (cutlass::platform::is_same::value) {
+ // (quad[2],quad[0])+lane_in_quad[0]
+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1);
+ // (quad[1])+lane_in_quad[1]
+ accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials +
+ (lane_in_quad & 2);
+ } else {
+ accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + lane_in_quad; // (quad[2],quad[0])
+ accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials;
+ }
+ return cutlass::MatrixCoord(accum_m + tile_offset.row() * Shape::kRow,
+ accum_n + tile_offset.column() * Shape::kColumn);
+ }
+
+ template
+ CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn)
+ {
+ static_assert(cutlass::platform::is_same::value,
+ "update to support non-float accum");
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
+ // T0 & T2 share same line within a quad
+ auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1);
+ myValue = fn(myValue, otherV);
+ // quad 0 and quad 2 are on the same lines
+ otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3);
+ myValue = fn(myValue, otherV);
+ return (lane_id & ((1 << 1) | (1 << 3))) == 0;
+ }
+
+ template
+ CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset,
+ FA beginRow,
+ FB op,
+ FC endRow)
+ {
+ CUTLASS_PRAGMA_UNROLL
+ for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int m = 0; m < EleShapePerPatial::kRow; ++m) {
+ int accum_m = tile_m * Policy::InterleavedTile::kRow +
+ mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row();
+ beginRow(accum_m);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int p = 0; p < kAccumulatorPatials; ++p) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int n = 0; n < EleShapePerPatial::kColumn; ++n) {
+ int mma_accum_start =
+ (((tile_n * Policy::TileIterations::kRow + tile_m) *
+ Policy::MmaIterations::kColumn +
+ mma_n) *
+ Policy::MmaIterations::kRow +
+ mma_m) *
+ kElementsPerMma;
+ int accum_n = tile_n * Policy::InterleavedTile::kColumn +
+ mma_n * QuadShapePerPatialMma::kColumn +
+ p * Policy::InterleavedTile::kColumn / 2 + n +
+ lane_offset.column();
+ int idx = mma_accum_start + p * kElementsPerPartial +
+ m * EleShapePerPatial::kColumn + n;
+ op(accum_m, accum_n, idx);
+ }
+ }
+ }
+ }
+ endRow(accum_m);
+ }
+ }
+ }
+ }
+};
+
+template
+struct AccumLambdaIteratorSimt {
+ using Policy = typename T::Policy;
+ using Iterations = typename T::Iterations;
+ using Element = typename T::Element;
+ using Delta = typename T::Delta;
+ using Shape = typename T::Shape;
+ static_assert(cutlass::platform::is_same::value,
+ "only RowMajor is supported");
+
+ template
+ CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn)
+ {
+ CUTLASS_PRAGMA_UNROLL
+ for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) {
+ auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit);
+ myValue = fn(myValue, otherV);
+ }
+ return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0;
+ }
+
+ template
+ CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset,
+ FA beginRow,
+ FB op,
+ FC endRow)
+ {
+ CUTLASS_PRAGMA_UNROLL
+ for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
+ int accum_m = mma_m * Delta::kRow + m + lane_offset.row();
+ beginRow(accum_m);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
+ int accum_n = mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN +
+ lane_offset.column();
+ CUTLASS_PRAGMA_UNROLL
+ for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
+ int idx = n + Policy::LaneMmaShape::kN *
+ (mma_n + Iterations::kColumn *
+ (m + mma_m * Policy::LaneMmaShape::kM));
+ op(accum_m, accum_n + n, idx);
+ }
+ }
+ endRow(accum_m);
+ }
+ }
+ }
+
+ static cutlass::MatrixCoord CUTLASS_DEVICE
+ get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset)
+ {
+ static_assert(cutlass::platform::is_same>::value,
+ "");
+ typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
+
+ cutlass::MatrixCoord lane_offset =
+ lane_layout.inverse(lane_id) *
+ cutlass::MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
+ return lane_offset + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn);
+ }
+};
+
+template
+struct DefaultMmaAccumLambdaIterator;
+
+// Simt
+template
+struct DefaultMmaAccumLambdaIterator<
+ cutlass::gemm::warp::MmaSimtTileIterator,
+ accum_t,
+ kWarpSize> {
+ using WarpIterator =
+ typename cutlass::gemm::warp::MmaSimtTileIterator;
+ using Iterator = AccumLambdaIteratorSimt;
+};
+
+// TensorOp - Volta
+template
+struct DefaultMmaAccumLambdaIterator<
+ cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator>,
+ accum_t,
+ kWarpSize> {
+ using WarpIterator = typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
+ S1,
+ accum_t,
+ cutlass::layout::RowMajor,
+ S2,
+ cutlass::MatrixShape<1, 1>>;
+ using Iterator = AccumLambdaIteratorSm70