-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce TritonStructured dialect and triton-to-structured pass (#82)
* Introduce TritonStructured dialect * Updated mask analysis * Update OpFoldResultUtils * triton-to-structured pass * LIT tests * Address review comments * Revert header name change * Remove copied version of mask analysis
- Loading branch information
1 parent
9317ee0
commit 7a1657e
Showing
65 changed files
with
4,665 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Copyright (c) Microsoft Corporation, Meta Platforms. | ||
// Licensed under the MIT license. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef TRITON_ANALYSISSTRUCTURED_PTRANALYSIS_H | ||
#define TRITON_ANALYSISSTRUCTURED_PTRANALYSIS_H | ||
|
||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
|
||
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" | ||
#include "triton/Dialect/Triton/IR/Dialect.h" | ||
|
||
#include <set> | ||
|
||
namespace mlir { | ||
|
||
class OpBuilder; | ||
|
||
namespace tts { | ||
|
||
const extern std::string ptrAnalysisAttr; | ||
|
||
// Data structure used to decode pointer arithmetics. offsets, sizes, and | ||
// strides are in unit of elements in a linearly laid-out memory, which is the | ||
// same as pointer arithmetic operations in Triton language. scalar is a | ||
// shortcut used when the entire state describes a single scalar value. source | ||
// is the base pointer. modulos describes how address wraps around; a constant 0 | ||
// indicates no modulo for the dimension. | ||
class PtrState { | ||
|
||
public: | ||
SmallVector<OpFoldResult> offsets; | ||
SmallVector<OpFoldResult> sizes; | ||
SmallVector<OpFoldResult> strides; | ||
SmallVector<OpFoldResult> modulos; | ||
|
||
Value source; | ||
Value scalar; | ||
|
||
int32_t getRank() const; | ||
|
||
bool isEmpty() const; | ||
|
||
bool hasModulo() const; | ||
|
||
bool dimHasModulo(uint32_t dim) const; | ||
|
||
// Process addition of two PtrStates. | ||
LogicalResult addState(const PtrState &lhsState, const PtrState &rhsState, | ||
Operation *op, OpBuilder &builder); | ||
|
||
// Process multiplication of two PtrStates | ||
LogicalResult mulState(const PtrState &lhsState, const PtrState &rhsState, | ||
Operation *op, OpBuilder &builder); | ||
|
||
tts::MakeTensorPtrOp createTTSMakeTensorPtrOp(OpBuilder &builder, | ||
Location loc); | ||
}; | ||
|
||
struct PtrAnalysis { | ||
using IndexMapSet = std::map<int, std::set<int>>; | ||
|
||
IndexMapSet levelToBlockArgIndex; | ||
int level = 0; | ||
|
||
llvm::SmallDenseMap<Value, PtrState> knownPtrs; | ||
|
||
IRMapping ptrMap; | ||
|
||
// Recursively parse a Value; call the corresponding | ||
// function based on the defining operation and argument type. | ||
LogicalResult visitOperand(Value operand, PtrState &state, const Location loc, | ||
OpBuilder &builder); | ||
|
||
// Operand is the result of arith.addi. Process both arguments and insert any | ||
// arith.addi instruction as needed. | ||
// Main assumptions: | ||
// Only one of lhsState and rhsState has source field set | ||
// Current PtrState should be empty | ||
// Expected result: | ||
// source = lhsState.source ? lhsState.source : rhsState.source | ||
// sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i]) | ||
// offsets[i] = lhsState.offsets[i] + rhsState.offsets[i] | ||
// strides[i] = lhsState.strides[i] + rhsState.strides[i] | ||
LogicalResult visitOperandAdd(arith::AddIOp addOp, PtrState &state, | ||
const Location loc, OpBuilder &builder); | ||
|
||
// Operand is the result of arith.muli. Process both arguments and insert any | ||
// arith.muli instruction as needed. | ||
// Main assumptions: | ||
// Neither lhsState nor rhsState has source field set | ||
// Current PtrState should be empty | ||
// Currently only support one of the operand is a scalar index | ||
// Expected result (scalar and tensorState represent the two operands): | ||
// source = null | ||
// sizes[i] = tensorState.sizes[i] | ||
// offsets[i] = tensorState.offsets[i] * scalar | ||
// strides[i] = tensorState.strides[i] * scalar | ||
LogicalResult visitOperandMul(arith::MulIOp mulOp, PtrState &state, | ||
const Location loc, OpBuilder &builder); | ||
|
||
LogicalResult visitOperandRem(arith::RemSIOp mulOp, PtrState &state, | ||
const Location loc, OpBuilder &builder); | ||
|
||
// Operand is the result of make_range. | ||
// Main assumptions: | ||
// start, end, and shape are all statically known | ||
// The output of make_range is 1-dimensional | ||
// Does not check validity of inputs (e.g., stride > 0) | ||
// Expected result: | ||
// source = null | ||
// sizes[0] = shape[0] | ||
// offset[0] = start | ||
// strides[0] = ceiling( (end - start) / shape[0] ) | ||
LogicalResult visitOperandMakeRange(triton::MakeRangeOp rangeOp, | ||
PtrState &state, Location loc, | ||
OpBuilder &builder); | ||
|
||
// Operand is the result of expand_dims | ||
// Main assumptions: | ||
// Only 1 dimension changes for each invocation of reshape | ||
// The changed dimension must have size of 1 | ||
// Expected result: | ||
// Insert a dimension of size 1, stride 0, and offset 0 | ||
LogicalResult visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, | ||
PtrState &state, const Location loc, | ||
OpBuilder &builder); | ||
|
||
// Operand is the result of broadcast | ||
// Main assumptions: | ||
// Rank of soure and result is the same | ||
// Expected result: | ||
// Update sizes[i] only, no changes to other fields | ||
LogicalResult visitOperandBroadcast(triton::BroadcastOp broadcastOp, | ||
PtrState &state, const Location loc, | ||
OpBuilder &builder); | ||
|
||
// Operand is the result of splat | ||
// Main assumptions: | ||
// Source is a scalar value (i.e., an integer or a pointer, not a tensor) | ||
// Expected result: | ||
// sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = 0 | ||
// if source is an integer, offset[0] = scalar = source | ||
LogicalResult visitOperandSplat(triton::SplatOp splatOp, PtrState &state, | ||
const Location loc, OpBuilder &builder); | ||
|
||
// Operand is the result of arith.constant that is a splat | ||
// Main assumptions: | ||
// Source is a constant op that produces a constant dense tensor where all | ||
// elements are the same (i.e.: a constant that is splatted) | ||
// Expected result: | ||
// sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = | ||
// splat value if i == 0, otherwise 0 | ||
LogicalResult visitOperandConstSplat(arith::ConstantOp op, PtrState &state, | ||
const Location loc, OpBuilder &builder); | ||
|
||
// Operand is the result of addptr. | ||
// Main assumptions: | ||
// The ptr field should populate the source field | ||
// ptr and offset fields should result in same rank | ||
// Expected result: | ||
// The resulting state for ptr and offset wil be added | ||
LogicalResult visitOperandAddptr(triton::AddPtrOp addptrOp, PtrState &state, | ||
const Location loc, OpBuilder &builder); | ||
|
||
// Operand is the result of make_tptr. | ||
// Main assumptions: | ||
// This function is only called when rewriting a loop | ||
// Expected result: | ||
// Directly grab all corresponding fields from make_tptr. | ||
LogicalResult visitOperandMakeTPtr(tts::MakeTensorPtrOp makeTPtrOp, | ||
PtrState &state, const Location loc, | ||
OpBuilder &builder); | ||
|
||
// Parse the state of AddPtrOp, insert any instruction needed to | ||
// calculate strides and offsets, build PtrState for this operand, and record | ||
// PtrState for knownPtrs. | ||
LogicalResult rewriteAddptrOp(triton::AddPtrOp op); | ||
|
||
// Parse the state of YieldOp, insert any instruction needed to calculate | ||
// strides and offsets, build PtrState for this operand, and record PtrState | ||
// in knownPtrs. | ||
LogicalResult | ||
rewriteYieldOp(scf::YieldOp op, | ||
llvm::SmallDenseMap<int, PtrState> &knownPtrsFor); | ||
|
||
// Rewrite eligible tt.addptr in loop init args so loop can update the such | ||
// pointers over iterations. Insert any instruction needed to calculate | ||
// strides, offsets, and modulos. | ||
LogicalResult rewriteForOp(scf::ForOp op); | ||
|
||
LogicalResult rewriteLoadOp(triton::LoadOp op); | ||
|
||
LogicalResult rewriteStoreOp(triton::StoreOp op); | ||
|
||
LogicalResult rewriteOp(Operation *op); | ||
}; | ||
|
||
} // namespace triton | ||
|
||
} // namespace mlir | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
add_subdirectory(TritonToLinalg) | ||
add_subdirectory(TritonToStructured) |
3 changes: 3 additions & 0 deletions
3
include/triton-shared/Conversion/TritonToStructured/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
set(LLVM_TARGET_DEFINITIONS Passes.td) | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToStructured) | ||
add_public_tablegen_target(TritonToStructuredConversionPassIncGen) |
15 changes: 15 additions & 0 deletions
15
include/triton-shared/Conversion/TritonToStructured/Passes.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#ifndef TRITON_TO_STRUCTURED_CONVERSION_PASSES_H | ||
#define TRITON_TO_STRUCTURED_CONVERSION_PASSES_H | ||
|
||
#include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h" | ||
|
||
namespace mlir { | ||
namespace triton { | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "triton-shared/Conversion/TritonToStructured/Passes.h.inc" | ||
|
||
} // namespace triton | ||
} // namespace mlir | ||
|
||
#endif |
11 changes: 11 additions & 0 deletions
11
include/triton-shared/Conversion/TritonToStructured/Passes.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#ifndef TRITON_TO_STRUCTURED_CONVERSION_PASSES | ||
#define TRITON_TO_STRUCTURED_CONVERSION_PASSES | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def TritonToStructured : Pass<"triton-to-structured", "mlir::ModuleOp"> { | ||
let summary = "Convert Triton non-block pointer to TritonStructured dialect"; | ||
let constructor = "triton::createTritonToStructuredPass()"; | ||
} | ||
|
||
#endif |
Oops, something went wrong.