Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce TritonStructured dialect and triton-to-structured pass #82

Merged
merged 10 commits into from
Jan 19, 2024
10 changes: 7 additions & 3 deletions include/triton-shared/Analysis/OpFoldResultUtils.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Copyright (c) Microsoft Corporation, Meta Platforms.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//

#ifndef TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H
#define TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H
#ifndef TRITON_ANALYSISSTRUCTURED_OPFOLDRESULT_UTILS_H
#define TRITON_ANALYSISSTRUCTURED_OPFOLDRESULT_UTILS_H

#include "mlir/IR/Location.h"
#include "mlir/IR/OpDefinition.h"
Expand All @@ -22,6 +22,10 @@ class OpBuilder;
// result of an operation too.
std::optional<int64_t> getIntAttr(const OpFoldResult ofr);

// Return if ofr contains a constant zero, either represented by an integer
// attribute or a constant value.
bool hasConstZero(const OpFoldResult ofr);

// Create a value of index type if necessary from an OpFoldResult.
Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, OpBuilder &b);

Expand Down
128 changes: 128 additions & 0 deletions include/triton-shared/AnalysisStructured/MaskAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation, Meta Platforms.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//

#ifndef TRITON_ANALYSISSTRUCTURED_MASKANALYSIS_H
#define TRITON_ANALYSISSTRUCTURED_MASKANALYSIS_H

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

#include <utility>

namespace mlir {

class OpBuilder;

namespace triton {
// Data structure used to decode the pattern in a mask used for load and store.
// start and end field represent the start and end index of a range (produced
// by make_range, addi, etc.). While multi-dimensional data is possible, we
// assume range comparison can only be done on 1 dimension at a time (and
// results of range comparions across dimensions can be combined), hence start
// and end are not vectors. dims represents the real access size for ld/st
// (instead of the tensor/memref size specified by the IR). scalar is a shortcut
// used when the entire state contains a single scalar value.
//
// The general lifetime of this data structure is roughly:
// 1. A range is created by make_range and optionally operated on by addi w/
// result of splat, expand_dims, etc. During this phase, either (1) both start
// and end are populated, or (2) scalar is populated. Only one of the dimensions
// (that contains the range) can have dim > 1.
// 2. Result from step 1 is compared with a another MaskSState that represents a
// scalar value. The resulting state only has dims populated.
// 3. Optionally, result from step 2 can be broadcasted and anded with other
// results from step 2. The resulting state only has dims populated.
//
// Example of creating 2D mask:
// mask = (rows[:, None] < M) & (cols[None, :] < N)
struct MaskSState {
OpFoldResult start;
OpFoldResult end;
SmallVector<OpFoldResult> dims;
OpFoldResult scalar;

int64_t getRank() const { return dims.size(); }

bool isEmpty() const { return getRank() == 0 && !scalar && !start && !end; }

bool isMask() const { return !start && !end && !scalar && dims.size() != 0; }

// Recursively parse a Value; call the coresponding function based on the
// defining operation and Value type
LogicalResult parse(Value operand, const Location loc, OpBuilder &builder);

private:
// -------
// Utility functions to operate on MaskSState
// -------
LogicalResult addStateScalar(const MaskSState &state,
const OpFoldResult scalar, Location loc,
OpBuilder &builder);

LogicalResult addStates(const MaskSState &lhsState, const MaskSState &rhsState,
Location loc, OpBuilder &builder);

LogicalResult minStates(const MaskSState &lhsState, const MaskSState &rhsState,
Location loc, OpBuilder &builder);
// -------
// Helper functions to parse values to populate MaskSState
// -------

// Operand is the result of a constant
// Get the value of the constant and assign it to scalar.
LogicalResult parseConstant(arith::ConstantOp constOp, const Location loc,
OpBuilder &builder);

// Operand is an integer scalar
LogicalResult parseIntScalar(Value scalar, const Location loc,
OpBuilder &builder);

// Operand is the result of addi
// One and only one of the operands should be a scalar. Increment both start
// and end, dims remains unchanged, and scalar is empty.
LogicalResult parseAdd(arith::AddIOp addOp, const Location loc,
OpBuilder &builder);
// Operand is the result of andi
// Each of the result state dims is smaller of the two operands' dims.
// Insert instruction if needed to get new dims.
LogicalResult parseAnd(arith::AndIOp andOp, const Location loc,
OpBuilder &builder);

// Operand is the result of cmpi
// Assume only of the dimensions have size > 1. Only support slt for now.
// For that dimension, calculate this new dim as: dim = min(end, value) -
// start
LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location loc,
OpBuilder &builder);
// Operand is the result of make_range
// Set start and end accordingly; step size must be 1.
LogicalResult parseMakeRange(triton::MakeRangeOp rangeOp, const Location loc,
OpBuilder &builder);
// Operand is the result of broadcast
// Change dims only; assume only applies to tensors.
LogicalResult parseBroadcast(triton::BroadcastOp broadcastOp,
const Location loc, OpBuilder &builder);
// Operand is the result of splat
// Assume only applies to scalar. start and end are left empty; scalar will
// be assigned, and dims will be updated.
LogicalResult parseSplat(triton::SplatOp splatOp, const Location loc,
OpBuilder &builder);
// Operand is the result of expand_dims
// Insert additional dims; start and end do not change and correspond to the
// dimension that contains the range.
LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp,
const Location loc, OpBuilder &builder);
};

} // namespace triton

} // namespace mlir

#endif
1 change: 1 addition & 0 deletions include/triton-shared/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(TritonTilingExt)
add_subdirectory(TritonStructured)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(IR)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
set(LLVM_TARGET_DEFINITIONS TritonStructuredDialect.td)
mlir_tablegen(TritonStructuredDialect.h.inc -gen-dialect-decls -dialect=tts)
mlir_tablegen(TritonStructuredDialect.cpp.inc -gen-dialect-defs -dialect=tts)
mlir_tablegen(TritonStructuredOps.h.inc -gen-op-decls)
mlir_tablegen(TritonStructuredOps.cpp.inc -gen-op-defs)


add_public_tablegen_target(TritonStructuredTableGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_
#define MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/IR/Dialect.h"

//===----------------------------------------------------------------------===//
// TritonStructured Operations
//===----------------------------------------------------------------------===//
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h.inc"

// Include the auto-generated header file containing the declarations of the
// TritonStructured operations.
#define GET_OP_CLASSES
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.h.inc"

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#ifndef TRITON_STRUCTURED_DIALECT
#define TRITON_STRUCTURED_DIALECT

include "mlir/IR/OpBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure

def Triton_Structured_Dialect : Dialect {
let name = "tts";

let cppNamespace = "::mlir::tts";

let summary = "Structured Triton operations";

let description = [{
Triton Structured Dialect.
}];

let dependentDialects = [
"triton::TritonDialect"
];

let usePropertiesForAttributes = 1;
}

//
// Op Base
//
class TTS_Op<string mnemonic, list<Trait> traits = []> :
Op<Triton_Structured_Dialect, mnemonic, traits> {
}

def TTS_MakeTensorPtrOp
: TTS_Op<"make_tptr", [ AttrSizedOperandSegments, Pure]> {
let summary = "create a pointer that points to a tensor in memory";

let arguments = (ins TT_Ptr:$base,
DenseI64ArrayAttr:$sizes,
Variadic<Index>:$strides,
Variadic<Index>:$offsets,
Variadic<Index>:$parent_sizes,
DenseI64ArrayAttr:$static_strides,
DenseI64ArrayAttr:$static_offsets,
DenseI64ArrayAttr:$static_parent_sizes);

let results = (outs TT_PtrTensor:$result);

let assemblyFormat = [{
$base `to` `sizes` `` `:` $sizes
`` `,` `strides` `` `:`
custom<DynamicIndexList>($strides, $static_strides)
`` `,` `offsets` `` `:`
custom<DynamicIndexList>($offsets, $static_offsets)
`` `,` `parent_sizes` `` `:`
custom<DynamicIndexList>($parent_sizes, $static_parent_sizes)
attr-dict `:` type($base) `to` type($result)
}];


let builders = [
// Build with mixed static and dynamic entries.
OpBuilder<(ins
"Value":$base,
"ArrayRef<int64_t>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
"ArrayRef<OpFoldResult>":$offsets,
"ArrayRef<OpFoldResult>":$parent_sizes)>,
];

let extraClassDeclaration = [{
/// Return a vector of all the static or dynamic fields
SmallVector<OpFoldResult> getMixedSizes() {
Builder b(getContext());
SmallVector<Value> dynSizes; // sizes are always static
return ::mlir::getMixedValues(getSizes(), dynSizes, b);
}
SmallVector<OpFoldResult> getMixedStrides() {
Builder b(getContext());
return ::mlir::getMixedValues(getStaticStrides(), getStrides(), b);
}
SmallVector<OpFoldResult> getMixedOffsets() {
Builder b(getContext());
return ::mlir::getMixedValues(getStaticOffsets(), getOffsets(), b);
}
SmallVector<OpFoldResult> getMixedParentSizes() {
Builder b(getContext());
return ::mlir::getMixedValues(getStaticParentSizes(), getParentSizes(), b);
}
}];

// TODO
//let hasVerifier = 1;
//let hasCanonicalizer = 1;
}

def TTS_LoadOp : TTS_Op<"load", [
MemoryEffects<[MemRead]>,
AttrSizedOperandSegments
]> {
let summary = "optionally load data from in memory to fill a portion of the tensor";

let arguments = (ins TT_PtrTensor:$ptr,
Variadic<Index>:$dims,
DenseI64ArrayAttr:$static_dims,
Optional<AnyTypeOf<[TT_Float, TT_Int, TT_Ptr]>>:$other);

let results = (outs TT_Tensor:$result);

let builders = [
OpBuilder<(ins "Value":$ptr, "ArrayRef<OpFoldResult>":$dims, "Value":$other)>,
];

// TODO
//let hasCustomAssemblyFormat = 1;
//let hasVerifier = 1;
}

def TTS_StoreOp : TTS_Op<"store", [
MemoryEffects<[MemWrite]>
]> {
let summary = "optionally load data from in memory to fill a portion of the tensor";

let arguments = (ins TT_PtrTensor:$ptr,
TT_Tensor:$value,
Variadic<Index>:$dims,
DenseI64ArrayAttr:$static_dims);

let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef<OpFoldResult>":$dims)>,
];

// TODO
//let hasCustomAssemblyFormat = 1;
//let hasVerifier = 1;
}

#endif // TRITON_STRUCTURED_DIALECT
29 changes: 28 additions & 1 deletion lib/Analysis/OpFoldResultUtils.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Copyright (c) Microsoft Corporation, Meta Platforms.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//

#include "triton-shared/Analysis/OpFoldResultUtils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
Expand All @@ -19,6 +20,32 @@ std::optional<int64_t> getIntAttr(const OpFoldResult ofr) {
return std::nullopt;
}

bool hasConstZero(const OpFoldResult ofr) {
auto intAttr = getIntAttr(ofr);
if (intAttr.has_value()) {
if (intAttr.value() == 0) {
return true;
}
return false;
}

auto val = ofr.dyn_cast<Value>();
assert(val);
auto constOp = val.getDefiningOp<arith::ConstantOp>();
if (!constOp)
return false;

intAttr = getIntAttr(constOp.getValue());
if (intAttr.has_value()) {
if (intAttr.value() == 0) {
return true;
}
return false;
}

return false;
}

Value ofrToIndexValue(const OpFoldResult ofr, const Location loc,
OpBuilder &b) {
if (Value val = ofr.dyn_cast<Value>()) {
Expand Down
Loading