Skip to content

Commit

Permalink
[PROTON] Introduce the Proton dialect as a third-party plugin for int…
Browse files Browse the repository at this point in the history
…ra-kernel perf tooling (#5119)

This PR introduces the `Proton Dialect` to enable intra kernel profiling
and tooling for Triton. As a third-party dialect, it serves as the
building blocks to create 3rd-party perf tools (e.g., profilers,
analysis, modeling) for Triton compiler developers in a compiler-centric
way, such as an intra-kernel latency profiler to understand software
pipelining, warp specialization, and CTA fine-grained orchestration
(e.g., cuda core, tensor core, TMA). Future developments would integrate
this dialect with the existing Proton backend profiling infrastructure
to make it a powerful and general perf tool utility. As a first step,
this PR adds some basic boilerplate code and mechanics, and the
`proton.record` op for the `Proton Dialect`.

---------

Co-authored-by: Yuanwei Fang <[email protected]>
Co-authored-by: Keren Zhou <[email protected]>
  • Loading branch information
3 people authored Nov 21, 2024
1 parent ad28e6c commit e9db186
Show file tree
Hide file tree
Showing 19 changed files with 269 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ if(TRITON_BUILD_PYTHON_MODULE)
if (TRITON_BUILD_PROTON)
add_subdirectory(third_party/proton)
endif()
# We always build proton dialect
list(APPEND TRITON_PLUGIN_NAMES "proton")
add_subdirectory(third_party/proton/dialect)

get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
Expand Down Expand Up @@ -311,6 +314,7 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()
add_subdirectory(third_party/proton/dialect)
endif()

add_subdirectory(third_party/f2reduce)
Expand Down
18 changes: 10 additions & 8 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "amd/include/TritonAMDGPUTransforms/Passes.h"
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
Expand Down Expand Up @@ -68,12 +69,13 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect,
mlir::NVVM::NVVMDialect, mlir::triton::nvgpu::NVGPUDialect,
mlir::triton::amdgpu::TritonAMDGPUDialect,
mlir::ROCDL::ROCDLDialect>();
registry
.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect,
mlir::NVVM::NVVMDialect, mlir::triton::nvgpu::NVGPUDialect,
mlir::triton::amdgpu::TritonAMDGPUDialect,
mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect>();
}
15 changes: 15 additions & 0 deletions test/Proton/ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: triton-opt --split-input-file %s -cse -canonicalize | FileCheck %s

module {
// CHECK-LABEL: proton_record
tt.func @proton_record() {
// CHECK: proton.record() {isStart = true, regionId = 1 : i32}
// CHECK-NEXT: proton.record() {isStart = false, regionId = 1 : i32}
// CHECK-NEXT: tt.return
proton.record() {isStart = true, regionId = 1 : i32}
proton.record() {isStart = false, regionId = 1 : i32}
tt.return
}
} // end module

// -----
7 changes: 7 additions & 0 deletions third_party/proton/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
add_subdirectory(include)
add_subdirectory(lib)
if(TRITON_BUILD_PYTHON_MODULE)
add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc LINK_LIBS ProtonIR)
endif()
1 change: 1 addition & 0 deletions third_party/proton/dialect/include/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Dialect)
1 change: 1 addition & 0 deletions third_party/proton/dialect/include/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Proton)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(IR)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})

set(LLVM_TARGET_DEFINITIONS ProtonOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton)
mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_mlir_doc(ProtonDialect ProtonDialect dialects/ -gen-dialect-doc)
add_mlir_doc(ProtonOps ProtonOps dialects/ -gen-op-doc)
add_public_tablegen_target(ProtonTableGen)

set(LLVM_TARGET_DEFINITIONS ProtonAttrDefs.td)
mlir_tablegen(ProtonAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(ProtonAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(ProtonAttrDefsIncGen)
23 changes: 23 additions & 0 deletions third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef TRITON_DIALECT_PROTON_IR_DIALECT_H_
#define TRITON_DIALECT_PROTON_IR_DIALECT_H_

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/PatternMatch.h"
#include "proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc"
#include "proton/dialect/include/Dialect/Proton/IR/OpsEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc"

#define GET_OP_CLASSES
#include "proton/dialect/include/Dialect/Proton/IR/Ops.h.inc"

namespace mlir {
namespace triton {
namespace proton {} // namespace proton
} // namespace triton
} // namespace mlir

#endif // TRITON_DIALECT_PROTON_IR_DIALECT_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef PROTON_ATTRDEFS
#define PROTON_ATTRDEFS

include "mlir/IR/AttrTypeBase.td"
include "ProtonDialect.td"

class Proton_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<Proton_Dialect, name, traits, baseCppClass> {
}

#endif // PROTON_ATTRDEFS
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef PROTON_DIALECT
#define PROTON_DIALECT

include "mlir/IR/OpBase.td"

def Proton_Dialect : Dialect {
let name = "proton";
let cppNamespace = "::mlir::triton::proton";

let description = [{
Proton Dialect provides core ops for building third-party compiler-based
performance profiling and analysis tools.
}];

let dependentDialects = [];
}

#endif
65 changes: 65 additions & 0 deletions third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#ifndef PROTON_OPS
#define PROTON_OPS

include "mlir/IR/OpBase.td"
include "mlir/IR/EnumAttr.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "ProtonDialect.td"
include "ProtonAttrDefs.td"

class TT_Proton_Op<string mnemonic, list<Trait> traits = []> :
Op<Proton_Dialect, mnemonic, !listconcat(traits, [])> {
}

// Proton profiling metric.
def MetricAttr : I32EnumAttr<
"Metric", "",
[
I32EnumAttrCase<"CYCLE", 0, "cycle">,
]> {
let cppNamespace = "::mlir::triton::proton";
}

// Proton profiling granularity.
def GranularityAttr : I32EnumAttr<
"Granularity", "",
[
I32EnumAttrCase<"WARPGROUP", 0, "warpgroup">,
I32EnumAttrCase<"WARP", 1, "warp">,
]> {
let cppNamespace = "::mlir::triton::proton";
}

def TT_RecordOp : TT_Proton_Op<"record", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Record a GPU hardware event";

let description = [{
The operator records GPU events from performance counters.
Currently only cycle counter is supported.

Example:

```mlir
proton.record() {isStart = true, regionId = 4 : i32}
...
proton.record() {isStart = false, regionId = 4 : i32}
...
proton.record() {isStart = true, regionId = 1 : i32, granularity = 1 : i32}
...
proton.record() {isStart = false, regionId = 1 : i32, granularity = 1 : i32}
```
}];
let arguments = (
ins BoolAttr: $isStart,
ConfinedAttr<I32Attr, [IntNonNegative]>:$regionId,
DefaultValuedAttr<MetricAttr, "Metric::CYCLE">:$metric,
DefaultValuedAttr<GranularityAttr, "Granularity::WARPGROUP">:$granularity
);
let assemblyFormat = " `(` operands `)` attr-dict";
}

#endif // PROTON_OPS
1 change: 1 addition & 0 deletions third_party/proton/dialect/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Dialect)
1 change: 1 addition & 0 deletions third_party/proton/dialect/lib/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Proton)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(IR)
13 changes: 13 additions & 0 deletions third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
add_triton_library(ProtonIR
Dialect.cpp
Ops.cpp

DEPENDS
ProtonTableGen
ProtonAttrDefsIncGen

LINK_LIBS PUBLIC
MLIRLLVMDialect
TritonIR
TritonGPUIR
)
25 changes: 25 additions & 0 deletions third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"

// clang-format off
#include "Dialect/Proton/IR/Dialect.h"
#include "Dialect/Proton/IR/Dialect.cpp.inc"
// clang-format on

using namespace mlir;
using namespace mlir::triton::proton;

void mlir::triton::proton::ProtonDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc"
>();

addOperations<
#define GET_OP_LIST
#include "Dialect/Proton/IR/Ops.cpp.inc"
>();
}

#define GET_ATTRDEF_CLASSES
#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc"
33 changes: 33 additions & 0 deletions third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "Dialect/Proton/IR/Dialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"

#define GET_OP_CLASSES
#include "Dialect/Proton/IR/Ops.cpp.inc"
#include "Dialect/Proton/IR/OpsEnums.cpp.inc"

namespace mlir {
namespace triton {
namespace proton {

// -- RecordOp --
void RecordOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Write::get(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Read::get(),
SideEffects::DefaultResource::get());
}

} // namespace proton
} // namespace triton
} // namespace mlir
20 changes: 20 additions & 0 deletions third_party/proton/dialect/triton_proton.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "Dialect/Proton/IR/Dialect.h"
#include "mlir/Pass/PassManager.h"
#include "passes.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>

namespace py = pybind11;

void init_triton_proton(py::module &&m) {
auto passes = m.def_submodule("passes");

// load dialects
m.def("load_dialects", [](mlir::MLIRContext &context) {
mlir::DialectRegistry registry;
registry.insert<mlir::triton::proton::ProtonDialect>();
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
});
}

0 comments on commit e9db186

Please sign in to comment.