Skip to content

Commit

Permalink
[OPTIMIZER] Improved flash attention forward pass performance (#1075)
Browse files Browse the repository at this point in the history
- Fixed typo in instruction reordering pass
- Minor additional optimizations for shared memory allocator
- Optimized flash attention tutorial forward pass kernel
  • Loading branch information
ptillet authored Jan 19, 2023
1 parent b2c522a commit 408d1d7
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 57 deletions.
3 changes: 3 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
return result;
}

bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
triton::gpu::DotOperandEncodingAttr &dotOperandLayout);

} // namespace mlir

#endif // TRITON_ANALYSIS_UTILITY_H
7 changes: 7 additions & 0 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
auto dstTy = op.result().getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

// MmaToDotShortcut doesn't use shared mem
if (auto mmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>())
if (auto dotOperandLayout = dstLayout.dyn_cast<DotOperandEncodingAttr>())
if (isMmaToDotShortcut(mmaLayout, dotOperandLayout))
return {};

assert(srcLayout && dstLayout &&
"Unexpect layout in getScratchConfigForCvtLayout()");
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
Expand Down
16 changes: 16 additions & 0 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto axis = op.axis();
SmallVector<SmallVector<unsigned>> smemShapes(3);

auto argLayout = srcTy.getEncoding();
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
return {{1, 1}, {1, 1}};

/// shared memory block0
smemShapes[0] = convertType<unsigned>(getSrcShape());
smemShapes[0][axis] = getInterWarpSize();
Expand Down Expand Up @@ -148,4 +154,14 @@ std::string getValueOperandName(Value value, AsmState &state) {
return opName;
}

bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
triton::gpu::DotOperandEncodingAttr &dotOperandLayout) {
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
return mmaLayout.getVersionMajor() == 2 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout;
}

} // namespace mlir
10 changes: 1 addition & 9 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "ConvertLayoutOpToLLVM.h"
#include "DotOpHelpers.h"
#include "Utility.h"

using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
Expand All @@ -17,15 +18,6 @@ using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::isaDistributedLayout;
using ::mlir::triton::gpu::SharedEncodingAttr;

bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
DotOperandEncodingAttr &dotOperandLayout) {
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
return mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout;
}

struct ConvertLayoutOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
public:
Expand Down
3 changes: 0 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ using namespace mlir::triton;

using ::mlir::triton::gpu::DotOperandEncodingAttr;

bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
DotOperandEncodingAttr &dotOperandLayout);

void populateConvertLayoutOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
Expand Down
10 changes: 0 additions & 10 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,16 +330,6 @@ class ConvertTritonGPUOpToLLVMPatternBase {
return ret;
}

bool isMmaToDotShortcut(
MmaEncodingAttr &mmaLayout,
triton::gpu::DotOperandEncodingAttr &dotOperandLayout) const {
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
return mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout;
}

void storeDistributedToShared(Value src, Value llSrc,
ArrayRef<Value> dstStrides,
ArrayRef<SmallVector<Value>> srcIndices,
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class TritonGPUReorderInstructionsPass
if (!dstEncoding)
return;
int opIdx = dstEncoding.getOpIdx();
if (opIdx != 1)
if (opIdx != 0)
return;
if (op->getUsers().empty())
return;
Expand Down
66 changes: 32 additions & 34 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to work around a compiler bug
L, M,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
Expand All @@ -32,58 +32,55 @@ def _fwd_kernel(
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
# start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)
k = tl.load(k_ptrs)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk += tl.dot(q, k)
qk *= sm_scale
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
l_prev *= tl.exp(m_prev - m_curr)
# attention weights
p = tl.exp(qk - m_curr[:, None])
l_curr = tl.sum(p, 1) + l_prev
# rescale operands of matmuls
l_rcp = 1. / l_curr
p *= l_rcp
acc *= (l_prev * l_rcp)[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
p = p.to(tl.float16)
v = tl.load(v_ptrs)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
l_prev = l_curr
m_prev = m_curr
# update pointers
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
tl.store(l_ptrs, l_prev)
tl.store(m_ptrs, m_prev)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
Expand Down Expand Up @@ -209,14 +206,13 @@ def forward(ctx, q, k, v, sm_scale):
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8

_fwd_kernel[grid](
q, k, v, sm_scale,
tmp, L, m,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
Expand Down Expand Up @@ -316,15 +312,15 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 15)],
x_vals=[2**i for i in range(10, 14)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
) for mode in ['fwd']]
) for mode in ['fwd', 'bwd']]


@triton.testing.perf_report(configs)
Expand Down Expand Up @@ -357,4 +353,6 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms

# bench_flash_attention.run(save_path='.', print_data=True)

# only works on post-Ampere GPUs right now
bench_flash_attention.run(save_path='.', print_data=True)
1 change: 1 addition & 0 deletions test/lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ add_mlir_library(TritonTestAnalysis

LINK_LIBS PUBLIC
TritonAnalysis
${dialect_libs}
)

0 comments on commit 408d1d7

Please sign in to comment.