Skip to content

Commit

Permalink
[CP] AMD Performance cherry picks (#682)
Browse files Browse the repository at this point in the history
* [AMD] Emit vectorized 16-bit float LLVM atomic ops (triton-lang#4925)

In the case of 16 bit floats operands for tt::AtomicRMWOp, construct
only one LLVM::AtomicRMWOp but use vector of elements.
Such approach allows to generate packed intrinsics and process 2
elements at once.
Added a lit test for f16 vectorized case.

(cherry picked from commit 78c8054)

* [AMD] Restructure ReorderInstructions pass (triton-lang#4998)

(cherry picked from commit 86a2ac7)

* [AMD] Support warp-level reduction with DPP (triton-lang#5019)

This commit adds support for warp-level reduction
with DPP instructions, which can improve performance.

See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/

(cherry picked from commit 21119e3)

* [AMD] Add missing dependency to TritonAMDGPUIR (triton-lang#5053)

TritonAMDGPUTransforms now depends on it.

(cherry picked from commit 0b443ce)

* [AMD] Support warp-level reduction with DPP (triton-lang#5019)

This commit adds support for warp-level reduction
with DPP instructions, which can improve performance.

See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/

(cherry picked from commit 21119e3)

* [AMD] Use DPP to accelerate 16-bit floats (triton-lang#5072)

In the case of unpaired f16 elements utilize DPP instructions to
accelerate atomics. Here is an algorithm of lowering
`tt::atomicRmwOp(%ptr, %val, %mask)`:

0. Group thread by pairs. Master thread is (tid % 2 == 0);
1. All the threads send `%val` to `(tid - 1)` thread via `dppUpdateOp
shl`, so all the masters recieve value from secondary threads;
2. Take into account parity in the `%mask` value, build CF structures
according to it;
3. Generate `llvm::atomicRmwOp` in the threads enabled by `%mask` value;
4. All the threads send result of generated operation to `(tid + 1)`
thread via `dppUpdateOp shl`, so all secondary thread also recieve their
result.

DPP approach has ~5% perf improvment so use this one in the
case target arch supports DPP.

Signed-off-by: Ilya Veselov <[email protected]>
(cherry picked from commit bab3470)

* [AMD] Reland sinking the 2nd tt.load after local_load's (triton-lang#4935)

This PR adds more restrictions about when should we apply
the sched-load optimizations and un-revert
triton-lang#4823.

We will only apply the optimization when all of the following is
satisfied:
1. pureMatmulProblem, i.e. 1 `tt.dot` in the main loop
2. two `tt.load`s in the main loop
3. 2nd `tt.load` is ahead of the `tt.dot`
4. 1st user of 2nd `tt.load` is after the `tt.dot`
5. tile size is large enough, i.e. nonKDim >= 128 and kDim >= 64

(cherry picked from commit 4f6f768)

---------

Co-authored-by: Ilya V <[email protected]>
Co-authored-by: Lei Zhang <[email protected]>
Co-authored-by: Kyle Wang <[email protected]>
Co-authored-by: Lixun Zhang <[email protected]>
  • Loading branch information
5 people authored Dec 13, 2024
1 parent 2de5803 commit deee2b1
Show file tree
Hide file tree
Showing 14 changed files with 1,172 additions and 664 deletions.
4 changes: 4 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,10 @@ def TT_ReduceOp: TT_Op<"reduce",
llvm::SmallVector<RankedTensorType> getInputTypes();
llvm::SmallVector<Type> getElementTypes();
unsigned getNumOperands();

// Returns the CombineOp iff this ReduceOp's region contains only
// one CombineOp other than the return, or nullptr if not applicable.
::mlir::Operation *getSingleCombiner();
}];
}

Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,22 @@ llvm::SmallVector<Type> ReduceOp::getElementTypes() {
return getElementTypesImpl(this->getOperands());
}

::mlir::Operation *ReduceOp::getSingleCombiner() {
if (getNumOperands() != 1 || getNumResults() != 1)
return nullptr;
Block *block = &(*getCombineOp().begin());
Operation *yield = block->getTerminator();
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
reduceOp->getNumResults() != 1)
return nullptr;
if (reduceOp->getOperand(0) != block->getArgument(0) ||
reduceOp->getOperand(1) != block->getArgument(1))
return nullptr;

return reduceOp;
}

unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); }

//-- ScanOp --
Expand Down
224 changes: 224 additions & 0 deletions test/Conversion/amd/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,227 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f16x2
tt.func @atomic_add_f16x2(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
// CHECK: llvm.cond_br
// CHECK-NOT: rocdl.update.dpp
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
// CHECK-NOT: rocdl.update.dpp
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
tt.return
}
}

// -----

#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_bf16x2
tt.func @atomic_add_bf16x2(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
// CHECK: llvm.cond_br
// CHECK-NOT: rocdl.update.dpp
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
// CHECK-NOT: rocdl.update.dpp
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
tt.return
}
}

// -----

#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f16_dpp
tt.func @atomic_add_f16_dpp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
// CHECK: llvm.cond_br
// CHECK: rocdl.update.dpp
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
// CHECK: rocdl.update.dpp
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
tt.return
}
}

// -----

#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_bf16_dpp
tt.func @atomic_add_bf16_dpp(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
// CHECK: llvm.cond_br
// CHECK: rocdl.update.dpp
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
// CHECK: rocdl.update.dpp
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
tt.return
}
}

// -----

#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: reduce_dpp_max
tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) {
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 280, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 276, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 274, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 273, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 322, 10, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 323, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK: llvm.amdgcn.readlane
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.maxnumf %arg1, %arg2 : f32
tt.reduce.return %1 : f32
}) : (tensor<64xf32, #blocked3>) -> f32
tt.return
}
}

// -----

#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: reduce_xor_max
tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) {
// CHECK: rocdl.ds_swizzle
// CHECK: llvm.intr.maxnum

// CHECK: rocdl.update.dpp
// CHECK-SAME: with 280, 15, 12, false : i32
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 264, 15, 3, false : i32
// CHECK: llvm.intr.maxnum

// CHECK: rocdl.update.dpp
// CHECK-SAME: with 276, 15, 10, false : i32
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 260, 15, 5, false : i32
// CHECK: llvm.intr.maxnum

// CHECK: rocdl.update.dpp
// CHECK-SAME: with 78, 15, 15, false : i32
// CHECK: llvm.intr.maxnum

// CHECK: rocdl.update.dpp
// CHECK-SAME: with 177, 15, 15, false : i32
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.maxnumf %arg1, %arg2 : f32
tt.reduce.return %1 : f32
}) : (tensor<32xf32, #blocked4>) -> f32
tt.return
}
}

// -----

#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: reduce_dpp_max
tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) {
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 280, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 276, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 274, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 273, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 322, 10, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK-NEXT: rocdl.update.dpp
// CHECK-SAME: with 323, 15, 15, true : f32
// CHECK-NEXT: llvm.intr.maxnum

// CHECK: llvm.amdgcn.readlane
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.maxnumf %arg1, %arg2 : f32
tt.reduce.return %1 : f32
}) : (tensor<64xf32, #blocked3>) -> f32
tt.return
}
}

// -----

#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: reduce_xor_max
tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) {
// CHECK: rocdl.ds_swizzle
// CHECK: llvm.intr.maxnum

// CHECK: rocdl.update.dpp
// CHECK-SAME: with 280, 15, 12, false : i32
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 264, 15, 3, false : i32
// CHECK: llvm.intr.maxnum

// CHECK: rocdl.update.dpp
// CHECK-SAME: with 276, 15, 10, false : i32
// CHECK: rocdl.update.dpp
// CHECK-SAME: with 260, 15, 5, false : i32
// CHECK: llvm.intr.maxnum

// CHECK: rocdl.update.dpp
// CHECK-SAME: with 78, 15, 15, false : i32
// CHECK: llvm.intr.maxnum

// CHECK: rocdl.update.dpp
// CHECK-SAME: with 177, 15, 15, false : i32
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.maxnumf %arg1, %arg2 : f32
tt.reduce.return %1 : f32
}) : (tensor<32xf32, #blocked4>) -> f32
tt.return
}
}
Loading

0 comments on commit deee2b1

Please sign in to comment.