-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Integrating evoformer attention * add cutlass version check * Updaate error message * add benchmark * Update * Update evoformer_attn.py * Update run_evoformer_test.py * Update evoformer_attn.py * Update run_evoformer_test.py * support more GPU archs * add copyright * add tests * Fix bugs * Update benchmark * update * Fix nvcc macro * clean code * fix formatting * fix yaml import * skip unit test when not compatible * fix yaml requirement * revert changes * update tutorial * update * fix formatting * fix format * skip evoformer attn in pre-compile-ops * revert changes * update tutorial * fix cutlass check * update tutorial * refactor tutorial * revise * Updated the Megatron-DS section (#565) * Updated the Megatron-DS section * minor fix * minor fix * minor fix * separate evoformer tutorial * Revised the ds4science landing page (#566) * Updated the Megatron-DS section * minor fix * minor fix * minor fix * Revised the landing page * Revised the landing page * Removing unused file * fix links image position * modify main page * fix doc --------- Co-authored-by: Shiyang Chen <[email protected]> Co-authored-by: Minjia Zhang <[email protected]>
- Loading branch information
1 parent
00dfab9
commit a5552a6
Showing
42 changed files
with
15,421 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
// DeepSpeed Team | ||
|
||
#include <torch/extension.h> | ||
|
||
void attention_impl(torch::Tensor& q, | ||
torch::Tensor& k, | ||
torch::Tensor& v, | ||
torch::Tensor& bias1, | ||
torch::Tensor& bias2, | ||
torch::Tensor& o, | ||
torch::Tensor& lse); | ||
void attention(torch::Tensor& q, | ||
torch::Tensor& k, | ||
torch::Tensor& v, | ||
torch::Tensor& bias1, | ||
torch::Tensor& bias2, | ||
torch::Tensor& o, | ||
torch::Tensor& lse) | ||
{ | ||
attention_impl(q, k, v, bias1, bias2, o, lse); | ||
} | ||
|
||
void attention_back_impl(torch::Tensor& go, | ||
torch::Tensor& q, | ||
torch::Tensor& k, | ||
torch::Tensor& v, | ||
torch::Tensor& o, | ||
torch::Tensor& lse, | ||
torch::Tensor& delta, | ||
torch::Tensor& bias1, | ||
torch::Tensor& bias2, | ||
torch::Tensor& gq, | ||
torch::Tensor& gk, | ||
torch::Tensor& gv, | ||
torch::Tensor& gb1, | ||
torch::Tensor& gb2); | ||
void attention_bwd(torch::Tensor& go, | ||
torch::Tensor& q, | ||
torch::Tensor& k, | ||
torch::Tensor& v, | ||
torch::Tensor& o, | ||
torch::Tensor& lse, | ||
torch::Tensor& delta, | ||
torch::Tensor& bias1, | ||
torch::Tensor& bias2, | ||
torch::Tensor& gq, | ||
torch::Tensor& gk, | ||
torch::Tensor& gv, | ||
torch::Tensor& gb1, | ||
torch::Tensor& gb2) | ||
{ | ||
attention_back_impl(go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||
{ | ||
m.def("attention", &attention, ""); | ||
m.def("attention_bwd", &attention_bwd, ""); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
// DeepSpeed Team | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
#include <torch/extension.h> | ||
#include "gemm_kernel_utils.h" | ||
#include "kernel_forward.h" | ||
#include "transform/bias_broadcast.h" | ||
|
||
template <typename arch, | ||
typename scalar_t, | ||
typename torch_scalar_t, | ||
template <typename, typename, typename> | ||
class Broadcast1_, | ||
template <typename, typename, typename> | ||
class Broadcast2_> | ||
typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_impl_template( | ||
torch::Tensor& q, | ||
torch::Tensor& k, | ||
torch::Tensor& v, | ||
torch::Tensor& bias1, | ||
torch::Tensor& bias2, | ||
torch::Tensor& o, | ||
float* lse_ptr) | ||
{ | ||
EVOFORMER_CHECK(false, "Unsupported GPU and data type combination") | ||
} | ||
|
||
template <typename arch, | ||
typename scalar_t, | ||
typename torch_scalar_t, | ||
template <typename, typename, typename> | ||
class Broadcast1_, | ||
template <typename, typename, typename> | ||
class Broadcast2_> | ||
typename std::enable_if<CheckArch<arch, scalar_t>::value>::type attention_impl_template( | ||
torch::Tensor& q, | ||
torch::Tensor& k, | ||
torch::Tensor& v, | ||
torch::Tensor& bias1, | ||
torch::Tensor& bias2, | ||
torch::Tensor& o, | ||
float* lse_ptr) | ||
{ | ||
// Attention definition goes here, replaced with BroadcastType1 and | ||
// BroadcastType2 | ||
using Attention = AttentionKernel<scalar_t, /* scalar_t */ | ||
arch, /* ArchTag */ | ||
true, /* Memory is aligned */ | ||
64, | ||
64, | ||
true, | ||
true, /* Supports bias */ | ||
Broadcast1_, | ||
Broadcast2_>; | ||
|
||
static_assert(!Attention::kNeedsOutputAccumulatorBuffer, | ||
"This test does not support output accumulator buffer"); | ||
int head_size = q.size(-1); | ||
int head_number = q.size(-2); | ||
int seq_length = q.size(-3); | ||
auto q_view = q.view({-1, seq_length, head_number, head_size}); | ||
auto k_view = k.view({-1, seq_length, head_number, head_size}); | ||
auto v_view = v.view({-1, seq_length, head_number, head_size}); | ||
auto o_view = o.view({-1, seq_length, head_number, head_size}); | ||
int batch_size = q_view.size(0); | ||
auto q_ptr = reinterpret_cast<scalar_t*>(q.data_ptr<torch_scalar_t>()); | ||
auto k_ptr = reinterpret_cast<scalar_t*>(k.data_ptr<torch_scalar_t>()); | ||
auto v_ptr = reinterpret_cast<scalar_t*>(v.data_ptr<torch_scalar_t>()); | ||
auto o_ptr = reinterpret_cast<scalar_t*>(o.data_ptr<torch_scalar_t>()); | ||
|
||
auto bias1_ptr = reinterpret_cast<scalar_t*>(bias1.data_ptr<torch_scalar_t>()); | ||
auto bias2_ptr = reinterpret_cast<scalar_t*>(bias2.data_ptr<torch_scalar_t>()); | ||
|
||
typename Attention::Params p; | ||
{ // set parameters | ||
p.query_ptr = q_ptr; | ||
p.key_ptr = k_ptr; | ||
p.value_ptr = v_ptr; | ||
p.logsumexp_ptr = lse_ptr; // Only needed for bw | ||
p.output_accum_ptr = nullptr; | ||
p.output_ptr = o_ptr; | ||
p.scale = 1.0f / sqrt(float(head_size)); | ||
|
||
p.bias1_ptr = bias1_ptr; | ||
p.bias2_ptr = bias2_ptr; | ||
p.B = q.size(0); | ||
p.N = q.size(1); | ||
|
||
p.num_heads = head_number; | ||
p.num_batches = batch_size; | ||
p.head_dim = head_size; | ||
p.head_dim_value = head_size; | ||
p.num_queries = seq_length; | ||
p.num_keys = seq_length; | ||
|
||
// All tensors are in BMHK shapes | ||
p.q_strideH = q_view.stride(-2); | ||
p.k_strideH = k_view.stride(-2); | ||
p.v_strideH = v_view.stride(-2); | ||
p.q_strideM = q_view.stride(-3); | ||
p.k_strideM = k_view.stride(-3); | ||
p.v_strideM = v_view.stride(-3); | ||
p.o_strideM = o_view.stride(-3); | ||
p.q_strideB = q_view.stride(-4); | ||
p.k_strideB = k_view.stride(-4); | ||
p.v_strideB = v_view.stride(-4); | ||
} | ||
|
||
constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>; | ||
int smem_bytes = sizeof(typename Attention::SharedStorage); | ||
if (smem_bytes > 0xc000) { | ||
cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); | ||
} | ||
if (!Attention::check_supported(p)) { throw std::runtime_error("Parameters not supported"); } | ||
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p); | ||
} | ||
|
||
#define CODE(scalar_t, torch_scalar_t) \ | ||
do { \ | ||
if (bias1.size(0) == 0 && bias2.size(0) == 0) { \ | ||
attention_impl_template<ArchTag, \ | ||
scalar_t, \ | ||
torch_scalar_t, \ | ||
BroadcastNoLoad, \ | ||
BroadcastNoLoad>(q, k, v, bias1, bias2, o, lse_ptr); \ | ||
} else if (bias1.size(0) == 0) { \ | ||
attention_impl_template<ArchTag, \ | ||
scalar_t, \ | ||
torch_scalar_t, \ | ||
BroadcastNoLoad, \ | ||
BroadcastB>(q, k, v, bias1, bias2, o, lse_ptr); \ | ||
} else if (bias2.size(0) == 0) { \ | ||
attention_impl_template<ArchTag, \ | ||
scalar_t, \ | ||
torch_scalar_t, \ | ||
BroadcastA, \ | ||
BroadcastNoLoad>(q, k, v, bias1, bias2, o, lse_ptr); \ | ||
} else { \ | ||
attention_impl_template<ArchTag, scalar_t, torch_scalar_t, BroadcastA, BroadcastB>( \ | ||
q, k, v, bias1, bias2, o, lse_ptr); \ | ||
} \ | ||
} while (0) | ||
|
||
// Function to select and call the correct template based on biases sizes | ||
void attention_impl(torch::Tensor& q, | ||
torch::Tensor& k, | ||
torch::Tensor& v, | ||
torch::Tensor& bias1, | ||
torch::Tensor& bias2, | ||
torch::Tensor& o, | ||
torch::Tensor& lse) | ||
{ | ||
auto lse_ptr = lse.size(0) == 0 ? nullptr : reinterpret_cast<float*>(lse.data_ptr<float>()); | ||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); | ||
DISPATCH_ARCHTAG(prop->major * 10 + prop->minor, | ||
DISPATCH_TYPES(q, ([&]() { CODE(scalar_t, torch_scalar_t); }))); | ||
} |
Oops, something went wrong.