Skip to content

Commit

Permalink
[AMD] Extended local-prefetch to global_load ops
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Dec 9, 2024
1 parent 9743ec0 commit 5914795
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 41 deletions.
84 changes: 79 additions & 5 deletions test/TritonGPU/amd/amd-instruction-sched.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=16 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch arch=gfx942 num_stages=2' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2

Expand All @@ -11,6 +11,7 @@ module {
// INSERT_IGLP1-LABEL: @test_dot_op
// INSTR_COUNT_NS1-LABEL: @test_dot_op
// INSTR_COUNT_NS2-LABEL: @test_dot_op
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: @test_dot_op
// LABELING_PS_1-LABEL: @test_dot_op
// LABELING_PS_2-LABEL: @test_dot_op
tt.func @test_dot_op(%lb : index, %ub : index, %step : index,
Expand Down Expand Up @@ -68,8 +69,81 @@ module {
// INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>>

// USE_LOCAL_PREFETCH_GLOBAL_LOAD: [lower-insert-instruction-sched-hints]
// USE_LOCAL_PREFETCH_GLOBAL_LOAD-SAME: skipping `local-prefetch` scheduling given it needs `buffer_load` instructions
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.barrier [[SCHED_GUARD:.+]]
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE:512]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA:8]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VALU:2]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ:32]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VALU]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VALU]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VALU]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VALU]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VALU]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VALU]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VALU]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ:256]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.barrier [[SCHED_GUARD]]


// LABELING_PS_1: scf.for
// LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>}
Expand Down
4 changes: 2 additions & 2 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,10 @@ def make_ttgir(mod, metadata, options):
passes.ttgpuir.add_optimize_dot_operands(pm, True)

stream_prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1"
use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"

# The `local-prefetch` scheduling variant requires turning on buffer ops.
if options.instruction_sched_variant == "local-prefetch":
stream_prefetch = use_buffer_ops = True
stream_prefetch = True

if amd.has_matrix_core_feature(options.arch):
assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. "
Expand All @@ -255,6 +254,7 @@ def make_ttgir(mod, metadata, options):
if amd.has_matrix_core_feature(options.arch):
amd.passes.ttgpuir.add_reorder_instructions(pm)

use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"
if use_buffer_ops:
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
passes.common.add_canonicalizer(pm)
Expand Down
Loading

0 comments on commit 5914795

Please sign in to comment.