Skip to content

Commit

Permalink
[AMD] NFC: Drop v2 Suffix from Stream Pipeline (#5251)
Browse files Browse the repository at this point in the history
Since StreamPipelineV2 has been the default for a while, this
commit promoted StreamPipelineV2 to the general
StreamPipeline by removing 'v2' suffix.
  • Loading branch information
knwng authored Nov 25, 2024
1 parent e1ebeed commit 4210274
Show file tree
Hide file tree
Showing 10 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUAccelerateMatmul();
mlir::registerTritonAMDGPUOptimizeEpilogue();
mlir::registerTritonAMDGPUReorderInstructions();
mlir::registerTritonAMDGPUStreamPipelineV2();
mlir::registerTritonAMDGPUStreamPipeline();
mlir::registerTritonAMDGPUCanonicalizePointers();
mlir::registerTritonAMDGPUConvertToBufferOps();
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
Expand Down
10 changes: 5 additions & 5 deletions test/TritonGPU/amd/amd-instruction-sched.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2

module {
// INSERT_IGLP0-LABEL: @test_dot_op
Expand Down
2 changes: 1 addition & 1 deletion test/TritonGPU/loop-pipeline-hip.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s

#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
Expand Down
4 changes: 2 additions & 2 deletions test/TritonGPU/loop-pipeline.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=3 -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s --check-prefixes=COMMON,CHECK
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2="num_stages=2 prefetch=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD_PREFETCH
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=2 prefetch=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD_PREFETCH

// 4 warps
// matmul: 128x32 @ 32x128 -> 128x128
Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def make_ttgir(mod, metadata, options):
"num_stages == 0. Now it will not happen anymore; "
"please update to use num_stages == 2 for "
"equivalent behavior in the past.")
amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages, stream_prefetch)
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, stream_prefetch)
passes.common.add_canonicalizer(pm)
amd.passes.ttgpuir.insert_instruction_sched_hints(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
Expand Down
4 changes: 2 additions & 2 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

namespace mlir {

std::unique_ptr<Pass> createTritonAMDGPUStreamPipelineV2Pass(int numStages = 2,
int prefetch = 0);
std::unique_ptr<Pass> createTritonAMDGPUStreamPipelinePass(int numStages = 2,
int prefetch = 0);

std::unique_ptr<Pass>
createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(),
Expand Down
4 changes: 2 additions & 2 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

include "mlir/Pass/PassBase.td"

def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir::ModuleOp"> {
def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::ModuleOp"> {
let summary = "pipeline";

let description = [{
Pipeline global loads through registers to shared memory while computing on previous
tile
}];

let constructor = "mlir::createTritonAMDGPUStreamPipelineV2Pass()";
let constructor = "mlir::createTritonAMDGPUStreamPipelinePass()";

let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"];

Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ add_triton_library(TritonAMDGPUTransforms
ConvertToBufferOps.cpp
OptimizeEpilogue.cpp
ReorderInstructions.cpp
StreamPipelineV2.cpp
StreamPipeline.cpp
MfmaGroup.cpp

DEPENDS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#define GEN_PASS_CLASSES
#include "TritonAMDGPUTransforms/Passes.h.inc"

#define DEBUG_TYPE "tritonamdgpu-stream-pipeline-v2"
#define DEBUG_TYPE "tritonamdgpu-stream-pipeline"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

Expand Down Expand Up @@ -857,7 +857,7 @@ void labelLoadOpsForTritonDot(scf::ForOp forOp) {
}
}

struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base<PipelinePass> {
struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
PipelinePass() = default;
PipelinePass(int32_t numStages, int32_t prefetch) {
this->numStages = numStages;
Expand Down Expand Up @@ -893,7 +893,7 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base<PipelinePass> {
};
} // anonymous namespace

std::unique_ptr<Pass>
mlir::createTritonAMDGPUStreamPipelineV2Pass(int numStages, int prefetch) {
std::unique_ptr<Pass> mlir::createTritonAMDGPUStreamPipelinePass(int numStages,
int prefetch) {
return std::make_unique<PipelinePass>(numStages, prefetch);
}
4 changes: 2 additions & 2 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) {
mlir::createTritonAMDGPUConvertToBufferOpsPass);
ADD_PASS_WRAPPER_0("add_reorder_instructions",
mlir::createTritonAMDGPUReorderInstructionsPass);
ADD_PASS_WRAPPER_2("add_stream_pipelinev2",
mlir::createTritonAMDGPUStreamPipelineV2Pass, int, int);
ADD_PASS_WRAPPER_2("add_stream_pipeline",
mlir::createTritonAMDGPUStreamPipelinePass, int, int);
}

void addControlConstant(llvm::Module *module, const char *name,
Expand Down

0 comments on commit 4210274

Please sign in to comment.