Skip to content

Commit

Permalink
Introduce TritonStructured dialect and triton-to-structured pass (#82)
Browse files Browse the repository at this point in the history
* Introduce TritonStructured dialect

* Updated mask analysis

* Update OpFoldResultUtils

* triton-to-structured pass

* LIT tests

* Address review comments

* Revert header name change

* Remove copied version of mask analysis
  • Loading branch information
haishanzzzz authored and nhat-nguyen committed Feb 1, 2024
1 parent 9317ee0 commit 7a1657e
Show file tree
Hide file tree
Showing 65 changed files with 4,665 additions and 101 deletions.
41 changes: 19 additions & 22 deletions include/triton-shared/Analysis/MaskAnalysis.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Copyright (c) Microsoft Corporation, Meta Platforms.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//
Expand All @@ -18,7 +18,7 @@

namespace mlir {

class ConversionPatternRewriter;
class OpBuilder;

namespace triton {
// Data structure used to decode the pattern in a mask used for load and store.
Expand Down Expand Up @@ -56,87 +56,84 @@ struct MaskState {

// Recursively parse a Value; call the coresponding function based on the
// defining operation and Value type
LogicalResult parse(Value operand, const Location loc,
ConversionPatternRewriter &rewriter);
LogicalResult parse(Value operand, const Location loc, OpBuilder &builder);

tensor::ExtractSliceOp
getExtractSlice(Value source, const Location loc,
ConversionPatternRewriter &rewriter) const;
OpBuilder &builder) const;

memref::SubViewOp getSubview(Value source, const Location loc,
ConversionPatternRewriter &rewriter) const;
OpBuilder &builder) const;

std::pair<memref::SubViewOp, memref::SubViewOp>
getSideBySideSubviews(Value block1, Value block2, const Location loc,
ConversionPatternRewriter &rewriter) const;
OpBuilder &builder) const;

std::pair<memref::SubViewOp, memref::SubViewOp>
getStackedSubviews(Value block1, Value block2, const Location loc,
ConversionPatternRewriter &rewriter) const;
OpBuilder &builder) const;

private:
// -------
// Utility functions to operate on MaskState
// -------
LogicalResult addStateScalar(const MaskState &state,
const OpFoldResult scalar, Location loc,
ConversionPatternRewriter &rewriter);
OpBuilder &builder);

LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState,
Location loc, ConversionPatternRewriter &rewriter);
Location loc, OpBuilder &builder);

LogicalResult minStates(const MaskState &lhsState, const MaskState &rhsState,
Location loc, ConversionPatternRewriter &rewriter);
Location loc, OpBuilder &builder);
// -------
// Helper functions to parse values to populate MaskState
// -------

// Operand is the result of a constant
// Get the value of the constant and assign it to scalar.
LogicalResult parseConstant(arith::ConstantOp constOp, const Location loc,
ConversionPatternRewriter &rewriter);
OpBuilder &builder);

// Operand is an integer scalar
LogicalResult parseIntScalar(Value scalar, const Location loc,
ConversionPatternRewriter &rewriter);
OpBuilder &builder);

// Operand is the result of addi
// One and only one of the operands should be a scalar. Increment both start
// and end, dims remains unchanged, and scalar is empty.
LogicalResult parseAdd(arith::AddIOp addOp, const Location loc,
ConversionPatternRewriter &rewriter);
OpBuilder &builder);
// Operand is the result of andi
// Each of the result state dims is smaller of the two operands' dims.
// Insert instruction if needed to get new dims.
LogicalResult parseAnd(arith::AndIOp andOp, const Location loc,
ConversionPatternRewriter &rewriter);
OpBuilder &builder);

// Operand is the result of cmpi
// Assume only of the dimensions have size > 1. Only support slt for now.
// For that dimension, calculate this new dim as: dim = min(end, value) -
// start
LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location loc,
ConversionPatternRewriter &rewriter);
OpBuilder &builder);
// Operand is the result of make_range
// Set start and end accordingly; step size must be 1.
LogicalResult parseMakeRange(triton::MakeRangeOp rangeOp, const Location loc,
ConversionPatternRewriter &rewriter);
OpBuilder &builder);
// Operand is the result of broadcast
// Change dims only; assume only applies to tensors.
LogicalResult parseBroadcast(triton::BroadcastOp broadcastOp,
const Location loc,
ConversionPatternRewriter &rewriter);
const Location loc, OpBuilder &builder);
// Operand is the result of splat
// Assume only applies to scalar. start and end are left empty; scalar will
// be assigned, and dims will be updated.
LogicalResult parseSplat(triton::SplatOp splatOp, const Location loc,
ConversionPatternRewriter &rewriter);
OpBuilder &builder);
// Operand is the result of expand_dims
// Insert additional dims; start and end do not change and correspond to the
// dimension that contains the range.
LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp,
const Location loc,
ConversionPatternRewriter &rewriter);
const Location loc, OpBuilder &builder);
};

} // namespace triton
Expand Down
6 changes: 5 additions & 1 deletion include/triton-shared/Analysis/OpFoldResultUtils.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Copyright (c) Microsoft Corporation, Meta Platforms.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//
Expand All @@ -22,6 +22,10 @@ class OpBuilder;
// result of an operation too.
std::optional<int64_t> getIntAttr(const OpFoldResult ofr);

// Return if ofr contains a constant zero, either represented by an integer
// attribute or a constant value.
bool hasConstZero(const OpFoldResult ofr);

// Create a value of index type if necessary from an OpFoldResult.
Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, OpBuilder &b);

Expand Down
208 changes: 208 additions & 0 deletions include/triton-shared/AnalysisStructured/PtrAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation, Meta Platforms.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//

#ifndef TRITON_ANALYSISSTRUCTURED_PTRANALYSIS_H
#define TRITON_ANALYSISSTRUCTURED_PTRANALYSIS_H

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include <set>

namespace mlir {

class OpBuilder;

namespace tts {

const extern std::string ptrAnalysisAttr;

// Data structure used to decode pointer arithmetics. offsets, sizes, and
// strides are in unit of elements in a linearly laid-out memory, which is the
// same as pointer arithmetic operations in Triton language. scalar is a
// shortcut used when the entire state describes a single scalar value. source
// is the base pointer. modulos describes how address wraps around; a constant 0
// indicates no modulo for the dimension.
class PtrState {

public:
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
SmallVector<OpFoldResult> modulos;

Value source;
Value scalar;

int32_t getRank() const;

bool isEmpty() const;

bool hasModulo() const;

bool dimHasModulo(uint32_t dim) const;

// Process addition of two PtrStates.
LogicalResult addState(const PtrState &lhsState, const PtrState &rhsState,
Operation *op, OpBuilder &builder);

// Process multiplication of two PtrStates
LogicalResult mulState(const PtrState &lhsState, const PtrState &rhsState,
Operation *op, OpBuilder &builder);

tts::MakeTensorPtrOp createTTSMakeTensorPtrOp(OpBuilder &builder,
Location loc);
};

struct PtrAnalysis {
using IndexMapSet = std::map<int, std::set<int>>;

IndexMapSet levelToBlockArgIndex;
int level = 0;

llvm::SmallDenseMap<Value, PtrState> knownPtrs;

IRMapping ptrMap;

// Recursively parse a Value; call the corresponding
// function based on the defining operation and argument type.
LogicalResult visitOperand(Value operand, PtrState &state, const Location loc,
OpBuilder &builder);

// Operand is the result of arith.addi. Process both arguments and insert any
// arith.addi instruction as needed.
// Main assumptions:
// Only one of lhsState and rhsState has source field set
// Current PtrState should be empty
// Expected result:
// source = lhsState.source ? lhsState.source : rhsState.source
// sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i])
// offsets[i] = lhsState.offsets[i] + rhsState.offsets[i]
// strides[i] = lhsState.strides[i] + rhsState.strides[i]
LogicalResult visitOperandAdd(arith::AddIOp addOp, PtrState &state,
const Location loc, OpBuilder &builder);

// Operand is the result of arith.muli. Process both arguments and insert any
// arith.muli instruction as needed.
// Main assumptions:
// Neither lhsState nor rhsState has source field set
// Current PtrState should be empty
// Currently only support one of the operand is a scalar index
// Expected result (scalar and tensorState represent the two operands):
// source = null
// sizes[i] = tensorState.sizes[i]
// offsets[i] = tensorState.offsets[i] * scalar
// strides[i] = tensorState.strides[i] * scalar
LogicalResult visitOperandMul(arith::MulIOp mulOp, PtrState &state,
const Location loc, OpBuilder &builder);

LogicalResult visitOperandRem(arith::RemSIOp mulOp, PtrState &state,
const Location loc, OpBuilder &builder);

// Operand is the result of make_range.
// Main assumptions:
// start, end, and shape are all statically known
// The output of make_range is 1-dimensional
// Does not check validity of inputs (e.g., stride > 0)
// Expected result:
// source = null
// sizes[0] = shape[0]
// offset[0] = start
// strides[0] = ceiling( (end - start) / shape[0] )
LogicalResult visitOperandMakeRange(triton::MakeRangeOp rangeOp,
PtrState &state, Location loc,
OpBuilder &builder);

// Operand is the result of expand_dims
// Main assumptions:
// Only 1 dimension changes for each invocation of reshape
// The changed dimension must have size of 1
// Expected result:
// Insert a dimension of size 1, stride 0, and offset 0
LogicalResult visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp,
PtrState &state, const Location loc,
OpBuilder &builder);

// Operand is the result of broadcast
// Main assumptions:
// Rank of soure and result is the same
// Expected result:
// Update sizes[i] only, no changes to other fields
LogicalResult visitOperandBroadcast(triton::BroadcastOp broadcastOp,
PtrState &state, const Location loc,
OpBuilder &builder);

// Operand is the result of splat
// Main assumptions:
// Source is a scalar value (i.e., an integer or a pointer, not a tensor)
// Expected result:
// sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = 0
// if source is an integer, offset[0] = scalar = source
LogicalResult visitOperandSplat(triton::SplatOp splatOp, PtrState &state,
const Location loc, OpBuilder &builder);

// Operand is the result of arith.constant that is a splat
// Main assumptions:
// Source is a constant op that produces a constant dense tensor where all
// elements are the same (i.e.: a constant that is splatted)
// Expected result:
// sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] =
// splat value if i == 0, otherwise 0
LogicalResult visitOperandConstSplat(arith::ConstantOp op, PtrState &state,
const Location loc, OpBuilder &builder);

// Operand is the result of addptr.
// Main assumptions:
// The ptr field should populate the source field
// ptr and offset fields should result in same rank
// Expected result:
// The resulting state for ptr and offset wil be added
LogicalResult visitOperandAddptr(triton::AddPtrOp addptrOp, PtrState &state,
const Location loc, OpBuilder &builder);

// Operand is the result of make_tptr.
// Main assumptions:
// This function is only called when rewriting a loop
// Expected result:
// Directly grab all corresponding fields from make_tptr.
LogicalResult visitOperandMakeTPtr(tts::MakeTensorPtrOp makeTPtrOp,
PtrState &state, const Location loc,
OpBuilder &builder);

// Parse the state of AddPtrOp, insert any instruction needed to
// calculate strides and offsets, build PtrState for this operand, and record
// PtrState for knownPtrs.
LogicalResult rewriteAddptrOp(triton::AddPtrOp op);

// Parse the state of YieldOp, insert any instruction needed to calculate
// strides and offsets, build PtrState for this operand, and record PtrState
// in knownPtrs.
LogicalResult
rewriteYieldOp(scf::YieldOp op,
llvm::SmallDenseMap<int, PtrState> &knownPtrsFor);

// Rewrite eligible tt.addptr in loop init args so loop can update the such
// pointers over iterations. Insert any instruction needed to calculate
// strides, offsets, and modulos.
LogicalResult rewriteForOp(scf::ForOp op);

LogicalResult rewriteLoadOp(triton::LoadOp op);

LogicalResult rewriteStoreOp(triton::StoreOp op);

LogicalResult rewriteOp(Operation *op);
};

} // namespace triton

} // namespace mlir

#endif
1 change: 1 addition & 0 deletions include/triton-shared/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(TritonToLinalg)
add_subdirectory(TritonToStructured)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToStructured)
add_public_tablegen_target(TritonToStructuredConversionPassIncGen)
15 changes: 15 additions & 0 deletions include/triton-shared/Conversion/TritonToStructured/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef TRITON_TO_STRUCTURED_CONVERSION_PASSES_H
#define TRITON_TO_STRUCTURED_CONVERSION_PASSES_H

#include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h"

namespace mlir {
namespace triton {

#define GEN_PASS_REGISTRATION
#include "triton-shared/Conversion/TritonToStructured/Passes.h.inc"

} // namespace triton
} // namespace mlir

#endif
11 changes: 11 additions & 0 deletions include/triton-shared/Conversion/TritonToStructured/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#ifndef TRITON_TO_STRUCTURED_CONVERSION_PASSES
#define TRITON_TO_STRUCTURED_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def TritonToStructured : Pass<"triton-to-structured", "mlir::ModuleOp"> {
let summary = "Convert Triton non-block pointer to TritonStructured dialect";
let constructor = "triton::createTritonToStructuredPass()";
}

#endif
Loading

0 comments on commit 7a1657e

Please sign in to comment.