Skip to content

Commit

Permalink
[ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959)
Browse files Browse the repository at this point in the history
This PR is to support efficient attention and flash attention in
ORTModule, including:
- Use ATen to call efficient attention, which requires PyTorch 2.2.0 dev
or newer. ORTMODULE_USE_EFFICIENT_ATTENTION=1 to enable.
- Integrate Triton Flash attention, which requires
triton==2.0.0.dev20221202. Need A100 or H100.
ORTMODULE_USE_FLASH_ATTENTION=1 to enable.
- A python transformer tool to match sub-graph by config and write
transformer quickly.

Current transformers supports attention mask for both efficient attn and
flash attn, and dropout for efficient attn only. To support more
training scenarios (such as causal mask in GPT2), more transformers need
to be added.

The feature is guarded by system environment variables, it won't effect
any current behavior if not enabled. Since it requires specific
PyTorch/Triton versions, related tests is not added for now.
  • Loading branch information
centwang authored Oct 27, 2023
1 parent 37873be commit b7408f7
Show file tree
Hide file tree
Showing 26 changed files with 2,037 additions and 93 deletions.
7 changes: 7 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ if (onnxruntime_ENABLE_TRAINING)
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/*"
)
file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*"
)
file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py"
)
Expand Down Expand Up @@ -741,6 +744,7 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/kernel
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils
Expand Down Expand Up @@ -794,6 +798,9 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_graph_optimizers_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ort_triton_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ Status ATen::Compute(OpKernelContext* p_ctx) const {
aten_ops::ATenOperatorExecutor::Instance()(op_name_, overload_name_, input_size, dlpack_inputs.get(), output_size,
dlpack_outputs.get());
for (size_t i = 0; i < output_size; ++i) {
ORT_RETURN_IF_ERROR(
p_ctx_internal->SetOutputMLValue(static_cast<int>(i), dlpack::DlpackToOrtValue(dlpack_outputs[i])));
if (dlpack_outputs[i]) {
ORT_RETURN_IF_ERROR(
p_ctx_internal->SetOutputMLValue(static_cast<int>(i), dlpack::DlpackToOrtValue(dlpack_outputs[i])));
}
}

return Status::OK();
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace onnxruntime {
namespace contrib {
namespace aten_ops {

typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index);
typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input);
typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size,
DLManagedTensor** dlpack_inputs, size_t output_size,
DLManagedTensor** dlpack_outputs);
Expand All @@ -22,17 +22,17 @@ class ATenOperatorExecutor {
return instance;
}

void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) {
ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw);
p_is_tensor_argument_func_ = reinterpret_cast<IsTensorArgumentFunc>(p_is_tensor_argument_func_raw);
void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) {
ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw);
p_is_cpu_argument_func_ = reinterpret_cast<IsCpuArgumentFunc>(p_is_cpu_argument_func_raw);
p_execute_aten_op_func_ = reinterpret_cast<ExecuteATenOperatorFunc>(p_execute_aten_op_func_raw);
}

bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; }

bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index) {
ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized.");
return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index);
bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) {
ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized.");
return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input);
}

void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size,
Expand All @@ -43,7 +43,7 @@ class ATenOperatorExecutor {
}

private:
IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr;
IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr;
ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr;
};

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/fallback_cpu_capability.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "onnx/defs/data_type_utils.h"

#include "core/framework/op_kernel.h"
#include "core/framework/utils.h"

using namespace ONNX_NAMESPACE::Utils;

Expand Down Expand Up @@ -77,7 +78,7 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe
ORT_THROW_IF_ERROR(node->ForEachWithIndex(
node->OutputDefs(),
[&](const NodeArg& node_arg, size_t out_index) {
if (kernel_info->kernel_def->IsOutputOnCpu(out_index)) {
if (utils::IsOutputOnCpu(*node, kernel_info, out_index)) {
cpu_output_args.insert(&node_arg);
auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
for (auto& consumer_node : consumer_nodes) {
Expand Down
27 changes: 26 additions & 1 deletion onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,32 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index)
overload_name = attrs.at("overload_name").s();
}

return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index);
return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true);
}
#else
ORT_UNUSED_PARAMETER(node);
#endif

return false;
}

bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) {
if (p_kci && p_kci->kernel_def->IsOutputOnCpu(index)) {
return true;
}

#ifdef ENABLE_ATEN
if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" &&
node.Domain() == kPytorchAtenDomain) {
const auto& attrs = node.GetAttributes();
ORT_ENFORCE(utils::HasString(attrs.at("operator")));
std::string op_name = attrs.at("operator").s();
std::string overload_name = "";
if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) {
overload_name = attrs.at("overload_name").s();
}

return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false);
}
#else
ORT_UNUSED_PARAMETER(node);
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet
bool sync_subgraph_fetches = false);

bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index);
bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index);

template <typename T>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/optimizer/transformer_memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
if (!arg->Exists())
continue;

if (kci && kci->kernel_def->IsOutputOnCpu(i))
if (utils::IsOutputOnCpu(node, kci, i))
non_provider_output_defs_.insert(arg);
else
provider_output_defs_.insert(arg);
Expand Down Expand Up @@ -308,7 +308,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co
if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it);
}
if (arg_output_index != -1) {
if (!kci || !kci->kernel_def->IsOutputOnCpu(arg_output_index)) provider_output_nodes_[arg].insert(&it);
if (!kci || !utils::IsOutputOnCpu(it, kci, arg_output_index)) provider_output_nodes_[arg].insert(&it);
}
}
}
Expand Down Expand Up @@ -404,8 +404,8 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker
// normally initializers are only inputs, but things may change with ops like assign
ORT_THROW_IF_ERROR(Node::ForEachWithIndex(
p_node->OutputDefs(),
[kci, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) {
if (kci->kernel_def->IsOutputOnCpu(index)) {
[kci, &p_node, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) {
if (utils::IsOutputOnCpu(*p_node, kci, index)) {
ORT_ENFORCE(dup_replacements.find(&arg) == dup_replacements.end());
}
return Status::OK();
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1214,14 +1214,14 @@ void addGlobalMethods(py::module& m) {

#ifdef ENABLE_ATEN
m.def("register_aten_op_executor",
[](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
size_t is_tensor_argument_address_int, aten_op_executor_address_int;
[](const std::string& is_cpu_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
size_t is_cpu_argument_address_int, aten_op_executor_address_int;
ORT_THROW_IF_ERROR(
ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int));
ParseStringWithClassicLocale(is_cpu_argument_address_str, is_cpu_argument_address_int));
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int));
void* p_is_tensor_argument = reinterpret_cast<void*>(is_tensor_argument_address_int);
void* p_is_cpu_argument = reinterpret_cast<void*>(is_cpu_argument_address_int);
void* p_aten_op_executor = reinterpret_cast<void*>(aten_op_executor_address_int);
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor);
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_cpu_argument, p_aten_op_executor);
});
#endif
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def load_aten_op_executor_cpp_extension():
from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor

_C.register_aten_op_executor(
str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address())
str(aten_op_executor.is_cpu_argument_address()), str(aten_op_executor.execute_aten_operator_address())
)
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,32 @@ class ATenOperatorCache {
std::unordered_map<std::pair<std::string, std::string>, ATenOperator, PairHash> ops_;
};

// Backend uses this function to check if an argument is CPU input (non-tensor argument) or not.
bool IsTensorArgument(const char* op_name, const char* overload_name, size_t index) {
const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name);
TORCH_INTERNAL_ASSERT(index < aten_op.argument_size);
return aten_op.elem_kinds[index] == c10::TypeKind::TensorType;
const std::unordered_map<std::string, std::unordered_set<size_t>> kCpuTensorInputsMap = {
{"_efficient_attention_forward", {4, 5, 11, 12}}, {"_efficient_attention_backward", {6, 7, 12, 13}}};

const std::unordered_map<std::string, std::unordered_set<size_t>> kCpuTensorOutputsMap = {
{"_efficient_attention_forward", {2, 3}}};

// Backend uses this function to check if an argument is CPU input or not.
bool IsCpuArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) {
if (is_input) {
// If the argument is non-tensor type, it's CPU argument.
const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name);
TORCH_INTERNAL_ASSERT(index < aten_op.argument_size);
if (aten_op.elem_kinds[index] != c10::TypeKind::TensorType) {
return true;
}
}

std::string full_name = std::string(op_name);
std::string overload_name_str = std::string(overload_name);
if (overload_name_str != "") {
full_name += ("." + overload_name_str);
}

const auto& cpu_tensors_map = is_input ? kCpuTensorInputsMap : kCpuTensorOutputsMap;
return cpu_tensors_map.find(full_name) != cpu_tensors_map.end() &&
cpu_tensors_map.at(full_name).find(index) != cpu_tensors_map.at(full_name).end();
}

void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size,
Expand Down Expand Up @@ -196,14 +217,15 @@ void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t
size_t output_index = 0;
for (const auto& ret : torch::jit::pop(stack, output_size)) {
const auto& tensor = ret.toTensor();
dlpack_outputs[output_index++] = at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous());
dlpack_outputs[output_index++] =
tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr;
}
}

size_t is_tensor_argument_address() { return reinterpret_cast<size_t>(&IsTensorArgument); }
size_t is_cpu_argument_address() { return reinterpret_cast<size_t>(&IsCpuArgument); }
size_t execute_aten_operator_address() { return reinterpret_cast<size_t>(&ExecuteATenOperator); }

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("is_tensor_argument_address", &is_tensor_argument_address, "Address of tensor argument check.");
m.def("is_cpu_argument_address", &is_cpu_argument_address, "Address of tensor argument check.");
m.def("execute_aten_operator_address", &execute_aten_operator_address, "Address of Aten operator executor");
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from onnxruntime.capi import _pybind_state as _C

from .aten_op_executor import execute_aten_operator_address, is_tensor_argument_address
from .aten_op_executor import execute_aten_operator_address, is_cpu_argument_address


def run_once_aten_op_executor(f):
Expand All @@ -30,7 +30,7 @@ def aten_op_executor_wrapper(*args, **kwargs):

@run_once_aten_op_executor
def load_aten_op_executor_cpp_extension():
_C.register_aten_op_executor(str(is_tensor_argument_address()), str(execute_aten_operator_address()))
_C.register_aten_op_executor(str(is_cpu_argument_address()), str(execute_aten_operator_address()))


def init_aten_op_executor():
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4180,6 +4180,7 @@ Return true if all elements are true and false otherwise.
.Attr("func_name", "Function name of the Python Triton kernel.", AttributeProto::STRING, std::string(""))
.Attr("onnx_key", "The hash key for the ONNX graph.", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("onnx_string", "The onnx string of the triton kernel.", AttributeProto::STRING, std::string(""))
.AllowUncheckedAttributes()
.Input(0, "inputs",
"Input tensors. If to call an existing Python Triton kernel, "
"the input count and order should match the arguments of the function. If to compute an ONNX graph, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,28 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out
from ._slice_scel import slice_scel, slice_scel_backward, transform_slice_scel
import os

__all__ = [
from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401
from ._slice_scel import optimize_graph_for_slice_scel, slice_scel, slice_scel_backward # noqa: F401

_all_kernels = [
"triton_gemm",
"triton_gemm_out",
"triton_matmul",
"triton_matmul_out",
"slice_scel",
"slice_scel_backward",
"transform_slice_scel",
]

_all_optimizers = [
"optimize_graph_for_slice_scel",
]

if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1:
from ._flash_attn import flash_attn_backward, flash_attn_forward, optimize_graph_for_flash_attention # noqa: F401

_all_kernels.extend(["flash_attn_forward", "flash_attn_backward"])
_all_optimizers.append("optimize_graph_for_flash_attention")

__all__ = _all_kernels + _all_optimizers # noqa: PLE0605
Loading

0 comments on commit b7408f7

Please sign in to comment.