From 095d41957077b6affd9c57c8488b927fbb7d3293 Mon Sep 17 00:00:00 2001 From: Haishan Zhu Date: Tue, 9 Jan 2024 20:21:51 -0800 Subject: [PATCH 1/8] introduce TritonStructured dialect --- include/triton-shared/Dialect/CMakeLists.txt | 1 + .../Dialect/TritonStructured/CMakeLists.txt | 1 + .../TritonStructured/IR/CMakeLists.txt | 8 + .../IR/TritonStructuredDialect.h | 27 ++++ .../IR/TritonStructuredDialect.td | 137 ++++++++++++++++++ lib/Dialect/CMakeLists.txt | 1 + lib/Dialect/TritonStructured/CMakeLists.txt | 1 + .../TritonStructured/IR/CMakeLists.txt | 11 ++ .../IR/TritonStructuredDialect.cpp | 22 +++ .../IR/TritonStructuredOps.cpp | 71 +++++++++ 10 files changed, 280 insertions(+) create mode 100644 include/triton-shared/Dialect/TritonStructured/CMakeLists.txt create mode 100644 include/triton-shared/Dialect/TritonStructured/IR/CMakeLists.txt create mode 100644 include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h create mode 100644 include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td create mode 100644 lib/Dialect/TritonStructured/CMakeLists.txt create mode 100644 lib/Dialect/TritonStructured/IR/CMakeLists.txt create mode 100644 lib/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp create mode 100644 lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp diff --git a/include/triton-shared/Dialect/CMakeLists.txt b/include/triton-shared/Dialect/CMakeLists.txt index 34718811..68066ab6 100644 --- a/include/triton-shared/Dialect/CMakeLists.txt +++ b/include/triton-shared/Dialect/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(TritonTilingExt) +add_subdirectory(TritonStructured) diff --git a/include/triton-shared/Dialect/TritonStructured/CMakeLists.txt b/include/triton-shared/Dialect/TritonStructured/CMakeLists.txt new file mode 100644 index 00000000..f33061b2 --- /dev/null +++ b/include/triton-shared/Dialect/TritonStructured/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/include/triton-shared/Dialect/TritonStructured/IR/CMakeLists.txt b/include/triton-shared/Dialect/TritonStructured/IR/CMakeLists.txt new file mode 100644 index 00000000..9c32c97c --- /dev/null +++ b/include/triton-shared/Dialect/TritonStructured/IR/CMakeLists.txt @@ -0,0 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS TritonStructuredDialect.td) +mlir_tablegen(TritonStructuredDialect.h.inc -gen-dialect-decls -dialect=tts) +mlir_tablegen(TritonStructuredDialect.cpp.inc -gen-dialect-defs -dialect=tts) +mlir_tablegen(TritonStructuredOps.h.inc -gen-op-decls) +mlir_tablegen(TritonStructuredOps.cpp.inc -gen-op-defs) + + +add_public_tablegen_target(TritonStructuredTableGen) diff --git a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h new file mode 100644 index 00000000..bd01afd0 --- /dev/null +++ b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h @@ -0,0 +1,27 @@ +#ifndef MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_ +#define MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/Dialect.h" + +//===----------------------------------------------------------------------===// +// TritonStructured Operations +//===----------------------------------------------------------------------===// +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h.inc" + +// Include the auto-generated header file containing the declarations of the +// TritonStructured operations. +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.h.inc" + +#endif diff --git a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td new file mode 100644 index 00000000..0d8ae99f --- /dev/null +++ b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td @@ -0,0 +1,137 @@ +#ifndef TRITON_STRUCTURED_DIALECT +#define TRITON_STRUCTURED_DIALECT + +include "mlir/IR/OpBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure + +def Triton_Structured_Dialect : Dialect { + let name = "tts"; + + let cppNamespace = "::mlir::tts"; + + let summary = "Structured Triton operations"; + + let description = [{ + Triton Structured Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect" + ]; + + let usePropertiesForAttributes = 1; +} + +// +// Op Base +// +class TTS_Op traits = []> : + Op { +} + +def TTS_MakeTensorPtrOp + : TTS_Op<"make_tptr", [ AttrSizedOperandSegments, Pure]> { + let summary = "create a pointer that points to a tensor in memory"; + + let arguments = (ins TT_Ptr:$base, + DenseI64ArrayAttr:$sizes, + Variadic:$strides, + Variadic:$offsets, + Variadic:$parent_sizes, + DenseI64ArrayAttr:$static_strides, + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_parent_sizes); + + let results = (outs TT_PtrTensor:$result); + + let assemblyFormat = [{ + $base `to` `sizes` `` `:` $sizes + `` `,` `strides` `` `:` + custom($strides, $static_strides) + `` `,` `offsets` `` `:` + custom($offsets, $static_offsets) + `` `,` `parent_sizes` `` `:` + custom($parent_sizes, $static_parent_sizes) + attr-dict `:` type($base) `to` type($result) + }]; + + + let builders = [ + // Build with mixed static and dynamic entries. + OpBuilder<(ins + "Value":$base, + "ArrayRef":$sizes, + "ArrayRef":$strides, + "ArrayRef":$offsets, + "ArrayRef":$parent_sizes)>, + ]; + + let extraClassDeclaration = [{ + /// Return a vector of all the static or dynamic fields + SmallVector getMixedSizes() { + Builder b(getContext()); + SmallVector dynSizes; // sizes are always static + return ::mlir::getMixedValues(getSizes(), dynSizes, b); + } + SmallVector getMixedStrides() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticStrides(), getStrides(), b); + } + SmallVector getMixedOffsets() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticOffsets(), getOffsets(), b); + } + SmallVector getMixedParentSizes() { + Builder b(getContext()); + return ::mlir::getMixedValues(getStaticParentSizes(), getParentSizes(), b); + } + }]; + + // TODO + //let hasVerifier = 1; + //let hasCanonicalizer = 1; +} + +def TTS_LoadOp : TTS_Op<"load", [ + MemoryEffects<[MemRead]>, + AttrSizedOperandSegments +]> { + let summary = "optionally load data from in memory to fill a portion of the tensor"; + + let arguments = (ins TT_PtrTensor:$ptr, + Variadic:$dims, + DenseI64ArrayAttr:$static_dims, + Optional>:$other); + + let results = (outs TT_Tensor:$result); + + let builders = [ + OpBuilder<(ins "Value":$ptr, "ArrayRef":$dims, "Value":$other)>, + ]; + + // TODO + //let hasCustomAssemblyFormat = 1; + //let hasVerifier = 1; +} + +def TTS_StoreOp : TTS_Op<"store", [ + MemoryEffects<[MemWrite]> +]> { + let summary = "optionally load data from in memory to fill a portion of the tensor"; + + let arguments = (ins TT_PtrTensor:$ptr, + TT_Tensor:$value, + Variadic:$dims, + DenseI64ArrayAttr:$static_dims); + + let builders = [ + OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef":$dims)>, + ]; + + // TODO + //let hasCustomAssemblyFormat = 1; + //let hasVerifier = 1; +} + +#endif // TRITON_STRUCTURED_DIALECT diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 34718811..68066ab6 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(TritonTilingExt) +add_subdirectory(TritonStructured) diff --git a/lib/Dialect/TritonStructured/CMakeLists.txt b/lib/Dialect/TritonStructured/CMakeLists.txt new file mode 100644 index 00000000..f33061b2 --- /dev/null +++ b/lib/Dialect/TritonStructured/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/lib/Dialect/TritonStructured/IR/CMakeLists.txt b/lib/Dialect/TritonStructured/IR/CMakeLists.txt new file mode 100644 index 00000000..752363c7 --- /dev/null +++ b/lib/Dialect/TritonStructured/IR/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library(TritonStructuredIR + TritonStructuredOps.cpp + TritonStructuredDialect.cpp + + DEPENDS + TritonStructuredTableGen + + LINK_LIBS PUBLIC + TritonIR + MLIRIR + ) diff --git a/lib/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp b/lib/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp new file mode 100644 index 00000000..2af19b8a --- /dev/null +++ b/lib/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp @@ -0,0 +1,22 @@ +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" + +using namespace mlir; +using namespace mlir::tts; + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +void TritonStructuredDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.cpp.inc" + +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.cpp.inc" diff --git a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp new file mode 100644 index 00000000..0298d49c --- /dev/null +++ b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp @@ -0,0 +1,71 @@ +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "llvm/ADT/STLExtras.h" + +#define GET_OP_CLASSES +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredOps.h.inc" + +using namespace mlir; +using namespace mlir::tts; + +namespace mlir { +namespace tts { + +void MakeTensorPtrOp::build(OpBuilder &b, OperationState &state, Value base, + ArrayRef sizes, + ArrayRef strides, + ArrayRef offsets, + ArrayRef parentSizes) { + SmallVector staticStrides, staticOffsets, staticParentSizes; + SmallVector dynamicStrides, dynamicOffsets, dynamicParentSizes; + + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + dispatchIndexOpFoldResults(parentSizes, dynamicParentSizes, + staticParentSizes); + + auto basePtr = base.getType().cast(); + auto elemType = basePtr.getPointeeType(); + auto resType = RankedTensorType::get(sizes, basePtr); + + build(b, state, resType, base, sizes, dynamicStrides, dynamicOffsets, + dynamicParentSizes, b.getDenseI64ArrayAttr(staticStrides), + b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticParentSizes)); +} + +void LoadOp::build(OpBuilder &b, OperationState &state, Value ptr, + ArrayRef dims, Value other) { + SmallVector staticDims; + SmallVector dynamicDims; + + dispatchIndexOpFoldResults(dims, dynamicDims, staticDims); + + auto ptrTensorType = ptr.getType().cast(); + auto elemType = ptrTensorType.getElementType() + .cast() + .getPointeeType(); + auto resType = RankedTensorType::get(ptrTensorType.getShape(), elemType); + + build(b, state, resType, ptr, dynamicDims, b.getDenseI64ArrayAttr(staticDims), + other); +} + +void StoreOp::build(OpBuilder &b, OperationState &state, Value ptr, Value value, + ArrayRef dims) { + SmallVector staticDims; + SmallVector dynamicDims; + + dispatchIndexOpFoldResults(dims, dynamicDims, staticDims); + + build(b, state, ptr, value, dynamicDims, b.getDenseI64ArrayAttr(staticDims)); +} + +} // namespace tts +} // namespace mlir From c697bfc41d41d5fde84f2b4205963814e89fd84a Mon Sep 17 00:00:00 2001 From: Haishan Zhu Date: Tue, 9 Jan 2024 20:24:58 -0800 Subject: [PATCH 2/8] Updated mask analysis --- .../AnalysisStructured/MaskAnalysis.h | 128 ++++++++ lib/AnalysisStructured/MaskAnalysis.cpp | 303 ++++++++++++++++++ 2 files changed, 431 insertions(+) create mode 100644 include/triton-shared/AnalysisStructured/MaskAnalysis.h create mode 100644 lib/AnalysisStructured/MaskAnalysis.cpp diff --git a/include/triton-shared/AnalysisStructured/MaskAnalysis.h b/include/triton-shared/AnalysisStructured/MaskAnalysis.h new file mode 100644 index 00000000..daf34047 --- /dev/null +++ b/include/triton-shared/AnalysisStructured/MaskAnalysis.h @@ -0,0 +1,128 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSISSTRUCTURED_MASKANALYSIS_H +#define TRITON_ANALYSISSTRUCTURED_MASKANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +namespace mlir { + +class OpBuilder; + +namespace triton { +// Data structure used to decode the pattern in a mask used for load and store. +// start and end field represent the start and end index of a range (produced +// by make_range, addi, etc.). While multi-dimensional data is possible, we +// assume range comparison can only be done on 1 dimension at a time (and +// results of range comparions across dimensions can be combined), hence start +// and end are not vectors. dims represents the real access size for ld/st +// (instead of the tensor/memref size specified by the IR). scalar is a shortcut +// used when the entire state contains a single scalar value. +// +// The general lifetime of this data structure is roughly: +// 1. A range is created by make_range and optionally operated on by addi w/ +// result of splat, expand_dims, etc. During this phase, either (1) both start +// and end are populated, or (2) scalar is populated. Only one of the dimensions +// (that contains the range) can have dim > 1. +// 2. Result from step 1 is compared with a another MaskSState that represents a +// scalar value. The resulting state only has dims populated. +// 3. Optionally, result from step 2 can be broadcasted and anded with other +// results from step 2. The resulting state only has dims populated. +// +// Example of creating 2D mask: +// mask = (rows[:, None] < M) & (cols[None, :] < N) +struct MaskSState { + OpFoldResult start; + OpFoldResult end; + SmallVector dims; + OpFoldResult scalar; + + int64_t getRank() const { return dims.size(); } + + bool isEmpty() const { return getRank() == 0 && !scalar && !start && !end; } + + bool isMask() const { return !start && !end && !scalar && dims.size() != 0; } + + // Recursively parse a Value; call the coresponding function based on the + // defining operation and Value type + LogicalResult parse(Value operand, const Location loc, OpBuilder &builder); + +private: + // ------- + // Utility functions to operate on MaskSState + // ------- + LogicalResult addStateScalar(const MaskSState &state, + const OpFoldResult scalar, Location loc, + OpBuilder &builder); + + LogicalResult addStates(const MaskSState &lhsState, const MaskSState &rhsState, + Location loc, OpBuilder &builder); + + LogicalResult minStates(const MaskSState &lhsState, const MaskSState &rhsState, + Location loc, OpBuilder &builder); + // ------- + // Helper functions to parse values to populate MaskSState + // ------- + + // Operand is the result of a constant + // Get the value of the constant and assign it to scalar. + LogicalResult parseConstant(arith::ConstantOp constOp, const Location loc, + OpBuilder &builder); + + // Operand is an integer scalar + LogicalResult parseIntScalar(Value scalar, const Location loc, + OpBuilder &builder); + + // Operand is the result of addi + // One and only one of the operands should be a scalar. Increment both start + // and end, dims remains unchanged, and scalar is empty. + LogicalResult parseAdd(arith::AddIOp addOp, const Location loc, + OpBuilder &builder); + // Operand is the result of andi + // Each of the result state dims is smaller of the two operands' dims. + // Insert instruction if needed to get new dims. + LogicalResult parseAnd(arith::AndIOp andOp, const Location loc, + OpBuilder &builder); + + // Operand is the result of cmpi + // Assume only of the dimensions have size > 1. Only support slt for now. + // For that dimension, calculate this new dim as: dim = min(end, value) - + // start + LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location loc, + OpBuilder &builder); + // Operand is the result of make_range + // Set start and end accordingly; step size must be 1. + LogicalResult parseMakeRange(triton::MakeRangeOp rangeOp, const Location loc, + OpBuilder &builder); + // Operand is the result of broadcast + // Change dims only; assume only applies to tensors. + LogicalResult parseBroadcast(triton::BroadcastOp broadcastOp, + const Location loc, OpBuilder &builder); + // Operand is the result of splat + // Assume only applies to scalar. start and end are left empty; scalar will + // be assigned, and dims will be updated. + LogicalResult parseSplat(triton::SplatOp splatOp, const Location loc, + OpBuilder &builder); + // Operand is the result of expand_dims + // Insert additional dims; start and end do not change and correspond to the + // dimension that contains the range. + LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp, + const Location loc, OpBuilder &builder); +}; + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/lib/AnalysisStructured/MaskAnalysis.cpp b/lib/AnalysisStructured/MaskAnalysis.cpp new file mode 100644 index 00000000..ed218257 --- /dev/null +++ b/lib/AnalysisStructured/MaskAnalysis.cpp @@ -0,0 +1,303 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/AnalysisStructured/MaskAnalysis.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +namespace triton { + +LogicalResult MaskSState::parse(Value operand, const Location loc, + OpBuilder &builder) { + if (auto op = operand.getDefiningOp()) { + return this->parseConstant(op, loc, builder); + } else if (operand.getType().isa()) { + return this->parseIntScalar(operand, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseAdd(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseAnd(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseCmp(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseMakeRange(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseBroadcast(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseSplat(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseExpandDims(op, loc, builder); + } else { + return failure(); + } +} + +LogicalResult MaskSState::addStateScalar(const MaskSState &state, + const OpFoldResult scalar, Location loc, + OpBuilder &builder) { + start = addOFRs(state.start, scalar, loc, builder); + end = addOFRs(state.end, scalar, loc, builder); + dims = state.dims; + return success(); +} + +LogicalResult MaskSState::addStates(const MaskSState &lhsState, + const MaskSState &rhsState, Location loc, + OpBuilder &builder) { + if (lhsState.scalar && rhsState.scalar) { + InFlightDiagnostic diag = + emitError(loc) << "Unexpected case where both lhs and rhs are scalars"; + return failure(); + } + + if (!lhsState.scalar && !rhsState.scalar) { + InFlightDiagnostic diag = + emitError(loc) + << "Unsupported scenario where neither lhs nor rhs is a scalar"; + return failure(); + } + + if (lhsState.scalar) + return addStateScalar(rhsState, lhsState.scalar, loc, builder); + else + return addStateScalar(lhsState, rhsState.scalar, loc, builder); +} + +LogicalResult MaskSState::minStates(const MaskSState &lhsState, + const MaskSState &rhsState, Location loc, + OpBuilder &builder) { + if (lhsState.getRank() != rhsState.getRank()) { + InFlightDiagnostic diag = + emitError(loc) + << "Unexpected case where lhs and rhs have different ranks"; + return failure(); + } + + for (uint32_t i = 0; i < lhsState.getRank(); i++) { + auto lhsDim = lhsState.dims[i]; + auto rhsDim = rhsState.dims[i]; + dims.push_back(minOFRs(lhsDim, rhsDim, loc, builder)); + } + return success(); +} + +LogicalResult MaskSState::parseConstant(arith::ConstantOp constOp, + const Location loc, OpBuilder &builder) { + assert(this->isEmpty()); + + if (isa(constOp.getValue())) { + auto attr = cast(constOp.getValue()); + auto elementType = attr.getElementType(); + assert(attr.isSplat() && elementType.isa() && + "All elements must share a single integer constant value"); + auto values = attr.getValues(); + auto value = values[0].getValue(); + auto constAttr = builder.getIndexAttr(value.getSExtValue()); + auto op = arith::ConstantOp::materialize(builder, constAttr, + builder.getIndexType(), loc); + this->scalar = op.getValue(); + } else { + auto value = constOp.getValue().cast().getInt(); + this->scalar = builder.getIndexAttr(value); + } + + return success(); +} + +LogicalResult MaskSState::parseIntScalar(Value scalar, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + auto castOp = + builder.create(loc, builder.getIndexType(), scalar); + this->scalar = castOp.getResult(); + return success(); +} + +LogicalResult MaskSState::parseAdd(arith::AddIOp addOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + MaskSState lhsState; + if (failed(lhsState.parse(addOp.getLhs(), loc, builder))) + return failure(); + + MaskSState rhsState; + if (failed(rhsState.parse(addOp.getRhs(), loc, builder))) + return failure(); + + return this->addStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskSState::parseAnd(arith::AndIOp andOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + MaskSState lhsState; + if (failed(lhsState.parse(andOp.getLhs(), loc, builder)) || + !lhsState.isMask()) + return failure(); + + MaskSState rhsState; + if (failed(rhsState.parse(andOp.getRhs(), loc, builder)) || + !rhsState.isMask()) + return failure(); + + return this->minStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskSState::parseCmp(arith::CmpIOp cmpOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (cmpOp.getPredicate() != arith::CmpIPredicate::slt) { + InFlightDiagnostic diag = emitError(loc) << "Unsupported cmpi predicate"; + return failure(); + } + + MaskSState lhsState; + if (failed(lhsState.parse(cmpOp.getLhs(), loc, builder))) + return failure(); + + MaskSState rhsState; + if (failed(rhsState.parse(cmpOp.getRhs(), loc, builder))) + return failure(); + + assert((!lhsState.scalar && rhsState.scalar) && "Unsupported cmpi scenario"); + + int32_t cmpDim = -1; + for (int32_t i = 0; i < lhsState.getRank(); i++) { + auto dimIntAttr = getIntAttr(lhsState.dims[i]); + if (!dimIntAttr || dimIntAttr.value() != 1) { + if (cmpDim != -1) { + InFlightDiagnostic diag = emitError(loc) + << "Unsupported cmpi with more than one " + "dimension with size larger than 1"; + return failure(); + } + cmpDim = i; + } + } + assert(cmpDim != -1 && + "Unexpected case where no dimension has size larger than 1"); + + auto newEnd = minOFRs(lhsState.end, rhsState.scalar, loc, builder); + auto newDim = subOFRs(newEnd, lhsState.start, loc, builder); + + for (int32_t i = 0; i < lhsState.getRank(); i++) { + if (i == cmpDim) + this->dims.push_back(newDim); + else + this->dims.push_back(lhsState.dims[i]); + } + + return success(); +} + +LogicalResult MaskSState::parseMakeRange(triton::MakeRangeOp rangeOp, + const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + auto shape = rangeOp.getType().cast().getShape(); + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + + if (stride != 1) { + InFlightDiagnostic diag = + emitError(loc) + << "stride must be 1 for make_range whose result is used " + "as load or store masks"; + return failure(); + } + + this->start = builder.getIndexAttr(start); + this->end = builder.getIndexAttr(end); + this->dims.push_back(builder.getIndexAttr(shape[0])); + + return success(); +} + +LogicalResult MaskSState::parseBroadcast(triton::BroadcastOp broadcastOp, + const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + auto src = broadcastOp.getSrc(); + auto dst = broadcastOp.getResult(); + assert(src.getType().isa() && + "input to tt.broadcast should be a tensor"); + + auto srcShape = src.getType().cast().getShape(); + auto dstShape = dst.getType().cast().getShape(); + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + if (failed(parse(src, loc, builder))) + return failure(); + + for (size_t i = 0; i < srcShape.size(); i++) { + if (srcShape[i] == dstShape[i]) + continue; + else if (srcShape[i] < dstShape[i]) + this->dims[i] = builder.getIndexAttr(dstShape[i]); + else + llvm_unreachable("unexpected dimensions used in broadcast"); + } + + return success(); +} + +LogicalResult MaskSState::parseSplat(triton::SplatOp splatOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = dst.getType().cast().getShape(); + + if (!src.getType().isa()) { + InFlightDiagnostic diag = + emitError(loc) + << "splat source must be an integer scalar for load/store masks"; + return failure(); + } + + if (failed(this->parse(src, loc, builder))) + return failure(); + + for (auto s : dstShape) + this->dims.push_back(builder.getIndexAttr(s)); + + return success(); +} + +LogicalResult MaskSState::parseExpandDims(triton::ExpandDimsOp expandDimsOp, + const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (failed(this->parse(expandDimsOp.getSrc(), loc, builder))) + return failure(); + + auto dstShape = + expandDimsOp.getResult().getType().cast().getShape(); + auto axis = expandDimsOp.getAxis(); + assert(dstShape[axis] == 1 && + "expect changed dimension to be 1 in expand_dims"); + this->dims.insert(this->dims.begin() + axis, builder.getIndexAttr(1)); + + return success(); +} + +} // namespace triton +} // namespace mlir From db396c619c7cd6701343e2798a5722b30e364a75 Mon Sep 17 00:00:00 2001 From: Haishan Zhu Date: Tue, 9 Jan 2024 20:25:40 -0800 Subject: [PATCH 3/8] Update OpFoldResultUtils --- .../Analysis/OpFoldResultUtils.h | 10 +++++-- lib/Analysis/OpFoldResultUtils.cpp | 29 ++++++++++++++++++- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/include/triton-shared/Analysis/OpFoldResultUtils.h b/include/triton-shared/Analysis/OpFoldResultUtils.h index 3cd82ddf..2a3e485a 100644 --- a/include/triton-shared/Analysis/OpFoldResultUtils.h +++ b/include/triton-shared/Analysis/OpFoldResultUtils.h @@ -1,12 +1,12 @@ //===----------------------------------------------------------------------===// // -// Copyright (c) Microsoft Corporation. +// Copyright (c) Microsoft Corporation, Meta Platforms. // Licensed under the MIT license. // //===----------------------------------------------------------------------===// -#ifndef TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H -#define TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H +#ifndef TRITON_ANALYSISSTRUCTURED_OPFOLDRESULT_UTILS_H +#define TRITON_ANALYSISSTRUCTURED_OPFOLDRESULT_UTILS_H #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" @@ -22,6 +22,10 @@ class OpBuilder; // result of an operation too. std::optional getIntAttr(const OpFoldResult ofr); +// Return if ofr contains a constant zero, either represented by an integer +// attribute or a constant value. +bool hasConstZero(const OpFoldResult ofr); + // Create a value of index type if necessary from an OpFoldResult. Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, OpBuilder &b); diff --git a/lib/Analysis/OpFoldResultUtils.cpp b/lib/Analysis/OpFoldResultUtils.cpp index c98657d3..6d4ea20a 100644 --- a/lib/Analysis/OpFoldResultUtils.cpp +++ b/lib/Analysis/OpFoldResultUtils.cpp @@ -1,6 +1,6 @@ //===----------------------------------------------------------------------===// // -// Copyright (c) Microsoft Corporation. +// Copyright (c) Microsoft Corporation, Meta Platforms. // Licensed under the MIT license. // //===----------------------------------------------------------------------===// @@ -8,6 +8,7 @@ #include "triton-shared/Analysis/OpFoldResultUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -19,6 +20,32 @@ std::optional getIntAttr(const OpFoldResult ofr) { return std::nullopt; } +bool hasConstZero(const OpFoldResult ofr) { + auto intAttr = getIntAttr(ofr); + if (intAttr.has_value()) { + if (intAttr.value() == 0) { + return true; + } + return false; + } + + auto val = ofr.dyn_cast(); + assert(val); + auto constOp = val.getDefiningOp(); + if (!constOp) + return false; + + intAttr = getIntAttr(constOp.getValue()); + if (intAttr.has_value()) { + if (intAttr.value() == 0) { + return true; + } + return false; + } + + return false; +} + Value ofrToIndexValue(const OpFoldResult ofr, const Location loc, OpBuilder &b) { if (Value val = ofr.dyn_cast()) { From 7df79304373022b0676f0e6c44e8c4873132fbc2 Mon Sep 17 00:00:00 2001 From: Haishan Zhu Date: Tue, 9 Jan 2024 20:27:40 -0800 Subject: [PATCH 4/8] triton-to-structured pass --- .../AnalysisStructured/PtrAnalysis.h | 209 +++ .../triton-shared/Conversion/CMakeLists.txt | 1 + .../TritonToStructured/CMakeLists.txt | 3 + .../Conversion/TritonToStructured/Passes.h | 15 + .../Conversion/TritonToStructured/Passes.td | 11 + .../TritonToStructured/TritonToStructured.h | 17 + lib/AnalysisStructured/CMakeLists.txt | 14 + lib/AnalysisStructured/PtrAnalysis.cpp | 1128 +++++++++++++++++ lib/CMakeLists.txt | 1 + lib/Conversion/CMakeLists.txt | 1 + .../TritonToStructured/CMakeLists.txt | 24 + .../TritonToStructuredPass.cpp | 58 + tools/RegisterTritonSharedDialects.h | 5 +- 13 files changed, 1486 insertions(+), 1 deletion(-) create mode 100644 include/triton-shared/AnalysisStructured/PtrAnalysis.h create mode 100644 include/triton-shared/Conversion/TritonToStructured/CMakeLists.txt create mode 100644 include/triton-shared/Conversion/TritonToStructured/Passes.h create mode 100644 include/triton-shared/Conversion/TritonToStructured/Passes.td create mode 100644 include/triton-shared/Conversion/TritonToStructured/TritonToStructured.h create mode 100644 lib/AnalysisStructured/CMakeLists.txt create mode 100644 lib/AnalysisStructured/PtrAnalysis.cpp create mode 100644 lib/Conversion/TritonToStructured/CMakeLists.txt create mode 100644 lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp diff --git a/include/triton-shared/AnalysisStructured/PtrAnalysis.h b/include/triton-shared/AnalysisStructured/PtrAnalysis.h new file mode 100644 index 00000000..55af7a4c --- /dev/null +++ b/include/triton-shared/AnalysisStructured/PtrAnalysis.h @@ -0,0 +1,209 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSISSTRUCTURED_PTRANALYSIS_H +#define TRITON_ANALYSISSTRUCTURED_PTRANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +namespace mlir { + +class OpBuilder; + +namespace triton { + +const extern std::string ptrAnalysisAttr; + +// Data structure used to decode pointer arithmetics. offsets, sizes, and +// strides are in unit of elements in a linearly laid-out memory, which is the +// same as pointer arithmetic operations in Triton language. scalar is a +// shortcut used when the entire state describes a single scalar value. source +// is the base pointer. +class PtrSState { + +public: + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + SmallVector modulos; + + Value source; + Value scalar; + + int32_t getRank() const; + + bool isEmpty() const; + + bool hasModulo() const; + + bool dimHasModulo(uint32_t dim) const; + + // Process addition of two PtrSStates. + LogicalResult addState(const PtrSState &lhsState, const PtrSState &rhsState, + Operation *op, OpBuilder &builder); + + // Process multiplication of two PtrSStates + LogicalResult mulState(const PtrSState &lhsState, const PtrSState &rhsState, + Operation *op, OpBuilder &builder); + + tts::MakeTensorPtrOp createTTSMakeTensorPtrOp(OpBuilder &builder, + Location loc); + + static void swap(PtrSState &&a, PtrSState &&b); +}; + +struct PtrAnalysis { + using IndexMapSet = std::map>; + + IndexMapSet levelToBlockArgIndex; + int level = 0; + + llvm::SmallDenseMap knownPtrs; + + IRMapping map; + + // Recursively parse a Value; call the corresponding + // function based on the defining operation and argument type. + LogicalResult visitOperand(Value operand, PtrSState &state, const Location loc, + OpBuilder &builder); + + // Operand is the result of arith.addi. Process both arguments and insert any + // arith.addi instruction as needed. + // Main assumptions: + // Only one of lhsState and rhsState has source field set + // Current PtrSState should be empty + // Expected result: + // source = lhsState.source ? lhsState.source : rhsState.source + // sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i]) + // offsets[i] = lhsState.offsets[i] + rhsState.offsets[i] + // strides[i] = lhsState.strides[i] + rhsState.strides[i] + LogicalResult visitOperandAdd(arith::AddIOp addOp, PtrSState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of arith.muli. Process both arguments and insert any + // arith.muli instruction as needed. + // Main assumptions: + // Neither lhsState nor rhsState has source field set + // Current PtrSState should be empty + // Currently only support one of the operand is a scalar index + // Expected result (scalar and tensorState represent the two operands): + // source = null + // sizes[i] = tensorState.sizes[i] + // offsets[i] = tensorState.offsets[i] * scalar + // strides[i] = tensorState.strides[i] * scalar + LogicalResult visitOperandMul(arith::MulIOp mulOp, PtrSState &state, + const Location loc, OpBuilder &builder); + + LogicalResult visitOperandRem(arith::RemSIOp mulOp, PtrSState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of make_range. + // Main assumptions: + // start, end, and shape are all statically known + // The output of make_range is 1-dimensional + // Does not check validity of inputs (e.g., stride > 0) + // Expected result: + // source = null + // sizes[0] = shape[0] + // offset[0] = start + // strides[0] = ceiling( (end - start) / shape[0] ) + LogicalResult visitOperandMakeRange(triton::MakeRangeOp rangeOp, + PtrSState &state, Location loc, + OpBuilder &builder); + + // Operand is the result of expand_dims + // Main assumptions: + // Only 1 dimension changes for each invocation of reshape + // The changed dimension must have size of 1 + // Expected result: + // Insert a dimension of size 1, stride 0, and offset 0 + LogicalResult visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, + PtrSState &state, const Location loc, + OpBuilder &builder); + + // Operand is the result of broadcast + // Main assumptions: + // Rank of soure and result is the same + // Expected result: + // Update sizes[i] only, no changes to other fields + LogicalResult visitOperandBroadcast(triton::BroadcastOp broadcastOp, + PtrSState &state, const Location loc, + OpBuilder &builder); + + // Operand is the result of splat + // Main assumptions: + // Source is a scalar value (i.e., an integer or a pointer, not a tensor) + // Expected result: + // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = 0 + // if source is an integer, offset[0] = scalar = source + LogicalResult visitOperandSplat(triton::SplatOp splatOp, PtrSState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of arith.constant that is a splat + // Main assumptions: + // Source is a constant op that produces a constant dense tensor where all + // elements are the same (i.e.: a constant that is splatted) + // Expected result: + // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = + // splat value if i == 0, otherwise 0 + LogicalResult visitOperandConstSplat(arith::ConstantOp op, PtrSState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of addptr. + // Main assumptions: + // The ptr field should populate the source field + // ptr and offset fields should result in same rank + // Expected result: + // The resulting state for ptr and offset wil be added + LogicalResult visitOperandAddptr(triton::AddPtrOp addptrOp, PtrSState &state, + const Location loc, OpBuilder &builder); + + // Operand is the result of make_tptr. + // Main assumptions: + // This function is only called when rewriting a loop + // Expected result: + // Directly grab all corresponding fields from make_tptr. + LogicalResult visitOperandMakeTPtr(tts::MakeTensorPtrOp makeTPtrOp, + PtrSState &state, const Location loc, + OpBuilder &builder); + + // Parse the state of AddPtrOp, insert any instruction needed to + // calculate strides and offsets, build PtrSState for this operand, and record + // PtrSState for knownPtrs. + LogicalResult rewriteAddptrOp(triton::AddPtrOp op); + + // Parse the state of YieldOp, insert any instruction needed to calculate + // strides and offsets, build PtrSState for this operand, and record PtrSState + // in knownPtrs. + LogicalResult + rewriteYieldOp(scf::YieldOp op, + llvm::SmallDenseMap &knownPtrsFor); + + // Rewrite eligible tt.addptr in loop init args so loop can update the such + // pointers over iterations. Insert any instruction needed to calculate + // strides, offsets, and modulos. + LogicalResult rewriteForOp(scf::ForOp op); + + LogicalResult rewriteLoadOp(triton::LoadOp op); + + LogicalResult rewriteStoreOp(triton::StoreOp op); + + LogicalResult rewriteOp(Operation *op); +}; + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/include/triton-shared/Conversion/CMakeLists.txt b/include/triton-shared/Conversion/CMakeLists.txt index 64ac1576..600de5fc 100644 --- a/include/triton-shared/Conversion/CMakeLists.txt +++ b/include/triton-shared/Conversion/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(TritonToLinalg) +add_subdirectory(TritonToStructured) diff --git a/include/triton-shared/Conversion/TritonToStructured/CMakeLists.txt b/include/triton-shared/Conversion/TritonToStructured/CMakeLists.txt new file mode 100644 index 00000000..5762c1f6 --- /dev/null +++ b/include/triton-shared/Conversion/TritonToStructured/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToStructured) +add_public_tablegen_target(TritonToStructuredConversionPassIncGen) diff --git a/include/triton-shared/Conversion/TritonToStructured/Passes.h b/include/triton-shared/Conversion/TritonToStructured/Passes.h new file mode 100644 index 00000000..3c3b81ca --- /dev/null +++ b/include/triton-shared/Conversion/TritonToStructured/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_TO_STRUCTURED_CONVERSION_PASSES_H +#define TRITON_TO_STRUCTURED_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TritonToStructured/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/include/triton-shared/Conversion/TritonToStructured/Passes.td b/include/triton-shared/Conversion/TritonToStructured/Passes.td new file mode 100644 index 00000000..e5cb050f --- /dev/null +++ b/include/triton-shared/Conversion/TritonToStructured/Passes.td @@ -0,0 +1,11 @@ +#ifndef TRITON_TO_STRUCTURED_CONVERSION_PASSES +#define TRITON_TO_STRUCTURED_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToStructured : Pass<"triton-to-structured", "mlir::ModuleOp"> { + let summary = "Convert Triton non-block pointer to TritonStructured dialect"; + let constructor = "triton::createTritonToStructuredPass()"; +} + +#endif diff --git a/include/triton-shared/Conversion/TritonToStructured/TritonToStructured.h b/include/triton-shared/Conversion/TritonToStructured/TritonToStructured.h new file mode 100644 index 00000000..0ee1a6d5 --- /dev/null +++ b/include/triton-shared/Conversion/TritonToStructured/TritonToStructured.h @@ -0,0 +1,17 @@ +#ifndef TRITON_CONVERSION_TRITONTOSTRUCTURED_TRITONTOSTRUCTURED_H +#define TRITON_CONVERSION_TRITONTOSTRUCTURED_TRITONTOSTRUCTURED_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonToStructuredPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONTOSTRUCTURED_TRITONTOSTRUCTURED_H diff --git a/lib/AnalysisStructured/CMakeLists.txt b/lib/AnalysisStructured/CMakeLists.txt new file mode 100644 index 00000000..d0e31aa8 --- /dev/null +++ b/lib/AnalysisStructured/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_library(TritonSharedAnalysisStructured + MaskAnalysis.cpp + PtrAnalysis.cpp + + DEPENDS + TritonAnalysis + TritonTableGen + TritonStructuredTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + TritonStructuredIR + MLIRAnalysis +) diff --git a/lib/AnalysisStructured/PtrAnalysis.cpp b/lib/AnalysisStructured/PtrAnalysis.cpp new file mode 100644 index 00000000..419a88c1 --- /dev/null +++ b/lib/AnalysisStructured/PtrAnalysis.cpp @@ -0,0 +1,1128 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/AnalysisStructured/PtrAnalysis.h" +#include "triton-shared/AnalysisStructured/MaskAnalysis.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" + +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-ptr-analysis" + +namespace mlir { + +// Extract a scalar value from v. +// If v is a scalar, return that directly. Otherwise, parse through operations +// (currently only support splat, sitofp, and truncf) that produce it to +// extract the underlying scalar value. We then reconstruct the chain of +// operations that can produce this constant with the original type. If no +// scalar value can be extracted, a nullptr is returned. +static Value getScalarValue(Value operand, Location loc, OpBuilder &builder) { + SmallVector ops; + + auto reconstructScalarValue = [&](Value src) { + for (auto op = ops.rbegin(); op != ops.rend(); ++op) { + src = TypeSwitch(*op) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return builder.create(loc, resType, src); + }) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return builder.create(loc, resType, src); + }) + .Default([](Operation *op) { + llvm_unreachable("unsupported op in generating "); + return nullptr; + }); + } + return src; + }; + + while (true) { + if (!operand.getType().dyn_cast()) { + return reconstructScalarValue(operand); + } else if (auto op = operand.getDefiningOp()) { + if (auto attr = op.getValue().dyn_cast()) { + if (!attr.isSplat()) { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load " + "produced by unsupported instruction"; + return nullptr; + } + auto elemValue = attr.getSplatValue(); + auto constOp = arith::ConstantOp::materialize( + builder, elemValue, attr.getElementType(), op.getLoc()); + return reconstructScalarValue(constOp.getResult()); + } + } else if (auto op = operand.getDefiningOp()) { + operand = op.getSrc(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; + } + } + return nullptr; +} + +namespace triton { + +int32_t PtrSState::getRank() const { + assert(offsets.size() == sizes.size() && offsets.size() == strides.size() && + modulos.size() == offsets.size()); + return offsets.size(); +} + +bool PtrSState::isEmpty() const { + return (getRank() == 0 && !source && !scalar); +} + +bool PtrSState::hasModulo() const { + for (int32_t i = 0; i < getRank(); i++) { + if (dimHasModulo(i)) { + return true; + } + } + return false; +} + +bool PtrSState::dimHasModulo(uint32_t dim) const { + assert(dim < getRank()); + + auto intAttr = getIntAttr(modulos[dim]); + if (!intAttr.has_value()) { + return true; + } + + return intAttr.value() != 0; +} + +LogicalResult PtrSState::addState(const PtrSState &lhsState, + const PtrSState &rhsState, Operation *op, + OpBuilder &builder) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + auto loc = op->getLoc(); + + if (lhsState.source && rhsState.source) { + op->emitRemark( + "PtrAnalysis: do not support adding two pointer states that both " + "have base pointers"); + return failure(); + } + + source = lhsState.source ? lhsState.source : rhsState.source; + + // AddPtr where both lhs and rhs containing modulo operators not supported + if (lhsState.hasModulo() && rhsState.hasModulo()) { + op->emitRemark("PtrAnalysis: do not support adding two pointer states " + "that both have modulo"); + return failure(); + } + + if (lhsState.hasModulo() || rhsState.hasModulo()) { + // visitOperandSplat and visitOperandExpandDims should enforce below + assert(lhsState.getRank() <= 2); + } + + if (lhsState.scalar && rhsState.scalar) { + auto addOp = + builder.create(loc, lhsState.scalar, rhsState.scalar); + scalar = addOp.getResult(); + } else if (lhsState.getRank() == 0) { // both lhs and rhs are scalars + scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; + } + + for (uint64_t i = 0; i < lhsState.getRank(); i++) { + auto newOffset = + addOFRs(lhsState.offsets[i], rhsState.offsets[i], loc, builder); + offsets.push_back(newOffset); + + auto newStride = + addOFRs(lhsState.strides[i], rhsState.strides[i], loc, builder); + strides.push_back(newStride); + + sizes.push_back(lhsState.sizes[i]); + } + + // dealing with modulo: + // - If lhs has no modulo, skip + // - If rhs has zero offset on dim i, we can just use lhs's modulo + // - If i == 0 and rhs is the result of a splat, we will allow the add. This + // is because the user may be trying to express adding a constant offset to + // increment dim1, but pointer analysis cannot differentiate dim1 vs dim0 in + // this case. + // - Else, the analysis fail + + // Note that this is not bullet-proof. E.g., broken IR can actually increment + // dim0 while dim0 already has modulo, since Triton offsets are element-wise + // and not in unit of lower dimensions. However, this is highly unlikely but + // the analysis will provide wrong result. Hence we provide a warning in this + // case. + PtrSState const *lhs = &lhsState; + PtrSState const *rhs = &rhsState; + + if (rhs->hasModulo()) { + std::swap(lhs, rhs); + } + + for (uint64_t i = 0; i < lhs->getRank(); i++) { + if (!lhs->dimHasModulo(i)) { + modulos.push_back(lhs->modulos[i]); + } else if (hasConstZero(rhs->offsets[i])) { + modulos.push_back(lhs->modulos[i]); + } else if (i == 0 && lhs->getRank() == 2 && rhs->scalar) { + modulos.push_back(lhs->modulos[1]); + modulos.push_back(lhs->modulos[0]); + op->emitWarning( + "PtrAnalysis: allowing adding pointer state with modulo in dim 0 to " + "another pointer state with offset in dim 0.\nPlease verify the " + "operand that contains a scalar is meant to increment pointers in " + "dim1. If that is not the case it WILL LEAD TO WRONG COMPILATION " + "RESULTS.\n\nTo avoid this warning, use expand_dims (instead of " + "splat) to explicitly specify which dimension contains the scalar."); + break; + } else { + op->emitRemark( + "PtrAnalysis: do not support adding to operand with modulo"); + return failure(); + } + } + + return success(); +} + +LogicalResult PtrSState::mulState(const PtrSState &lhsState, + const PtrSState &rhsState, Operation *op, + OpBuilder &builder) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + + auto loc = op->getLoc(); + + // neither lhs nor rhs should have source, since multiplying base pointer + // does not make sense + if (lhsState.source && rhsState.source) { + op->emitRemark("PtrAnalysis: do not support multiplying base pointer"); + return failure(); + } + + assert(!(lhsState.scalar && rhsState.scalar) && + "do not expect to see both lhs and rhs are scalars"); + + // currently do not support both tensors are effectively non-scalar + if (!lhsState.scalar && !rhsState.scalar) { + op->emitRemark( + "PtrAnalysis: only support multiplying pointer states when one of " + "them represent a scalar"); + return failure(); + } + + PtrSState const *lhs = &lhsState; + PtrSState const *rhs = &rhsState; + + if (!rhs->scalar && lhs->scalar) { + std::swap(lhs, rhs); + } + + for (uint64_t i = 0; i < lhs->sizes.size(); i++) { + OpFoldResult newOffset = + mulOFRValue(lhs->offsets[i], rhs->scalar, loc, builder); + OpFoldResult newStride = + mulOFRValue(lhs->strides[i], rhs->scalar, loc, builder); + OpFoldResult newModulo = + mulOFRValue(lhs->modulos[i], rhs->scalar, loc, builder); + offsets.push_back(newOffset); + strides.push_back(newStride); + modulos.push_back(newModulo); + sizes.push_back(lhs->sizes[i]); + } + + if (rhs->hasModulo()) { + op->emitRemark( + "PtrAnalysis: do not support multiplying pointer states that has " + "modulos"); + return failure(); + } + + return success(); +} + +tts::MakeTensorPtrOp PtrSState::createTTSMakeTensorPtrOp(OpBuilder &builder, + Location loc) { + SmallVector staticSizes; + for (size_t i = 0; i < getRank(); i++) { + auto s = getIntAttr(sizes[i]); + assert(s.has_value()); + staticSizes.push_back(s.value()); + } + + auto op = builder.create( + loc, source, staticSizes, strides, offsets, modulos); + LLVM_DEBUG({ + llvm::dbgs() << "creating tts::make_tensor_ptr:\n"; + op->dump(); + }); + + return op; +} + +LogicalResult PtrAnalysis::visitOperandAdd(arith::AddIOp addOp, PtrSState &state, + const Location loc, + OpBuilder &builder) { + PtrSState lhsState; + if (visitOperand(addOp.getLhs(), lhsState, loc, builder).failed()) { + return failure(); + } + + PtrSState rhsState; + if (visitOperand(addOp.getRhs(), rhsState, loc, builder).failed()) + return failure(); + + // Checking for higher dimension is done in addState below + if ((lhsState.getRank() == 1 && lhsState.hasModulo()) || + (rhsState.getRank() == 1 && rhsState.hasModulo())) { + addOp->emitRemark( + "PtrAnalysis: do not support this pattern: a + arange(0, K) % M"); + return failure(); + } + + return state.addState(lhsState, rhsState, addOp, builder); +} + +LogicalResult PtrAnalysis::visitOperandMul(arith::MulIOp mulOp, PtrSState &state, + const Location loc, + OpBuilder &builder) { + PtrSState lhsState; + if (visitOperand(mulOp.getLhs(), lhsState, loc, builder).failed()) { + return failure(); + } + + PtrSState rhsState; + if (visitOperand(mulOp.getRhs(), rhsState, loc, builder).failed()) { + return failure(); + } + + return state.mulState(lhsState, rhsState, mulOp, builder); +} + +LogicalResult PtrAnalysis::visitOperandRem(arith::RemSIOp remOp, + PtrSState &state, const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + PtrSState rhsState; + if (visitOperand(remOp.getRhs(), rhsState, loc, builder).failed()) { + return failure(); + } + + if (!rhsState.scalar) { + remOp->emitRemark("PtrAnalysis: only support cases when rhs of remainder " + "contains scalar"); + return failure(); + } + + if (visitOperand(remOp.getLhs(), state, loc, builder).failed()) { + return failure(); + } + + // If there are multiple modulo ops on an expression (e.g.: (a % b) % c), we + // would have already populated the modulo states after visiting the lhs. + // Assert that all the modulo states are empty. + if (state.hasModulo()) { + remOp->emitRemark( + "PtrAnalysis: do not support multiple modulo within an expression"); + return failure(); + } + + if (state.getRank() == 1) { + // Apply the modulo before expanding shape, the common pattern is + // offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + // a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * + // stride_ak) + state.modulos.back() = rhsState.scalar; + } else if (state.getRank() == 2) { + // torch inductor expands the tensor shape before applying the modulo. + // + // We only support either: + // - (tl.arange(0, end)[:, None] % mod), or + // - (tl.arange(0, end)[None, :] % mod) + // + // In both cases, we apply the modulo to the non-singleton dimension. + auto shape = cast(remOp.getResult().getType()).getShape(); + if (shape[0] == 1) { + state.modulos[1] = rhsState.scalar; + } else if (shape[1] == 1) { + state.modulos[0] = rhsState.scalar; + } else { + remOp->emitRemark( + "PtrAnalysis: taking modulo on a 2D tensor with no singleton " + "dimension not supported"); + return failure(); + } + } else { + remOp->emitRemark("PtrAnalysis: unsupported modulo pattern"); + return failure(); + } + return success(); +} + +LogicalResult PtrAnalysis::visitOperandMakeRange(triton::MakeRangeOp rangeOp, + PtrSState &state, Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + auto shape = rangeOp.getType().cast().getShape(); + + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + assert(stride == 1 && + "Expect make_range op to always return tensor of stride 1"); + + state.offsets.push_back(builder.getIndexAttr(start)); + state.sizes.push_back(builder.getIndexAttr(shape[0])); + state.strides.push_back(builder.getIndexAttr(stride)); + state.modulos.push_back(builder.getIndexAttr(0)); + return success(); +} + +LogicalResult +PtrAnalysis::visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, + PtrSState &state, const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + if (visitOperand(expandDimsOp.getSrc(), state, loc, builder).failed()) { + return failure(); + } + + auto dstShape = + expandDimsOp.getResult().getType().cast().getShape(); + auto axis = expandDimsOp.getAxis(); + + assert(dstShape[axis] == 1 && + "expect changed dimension to be 1 in expand_dims"); + + // insert dimension info + state.offsets.insert(state.offsets.begin() + axis, builder.getIndexAttr(0)); + state.sizes.insert(state.sizes.begin() + axis, builder.getIndexAttr(1)); + state.strides.insert(state.strides.begin() + axis, builder.getIndexAttr(0)); + state.modulos.insert(state.modulos.begin() + axis, builder.getIndexAttr(0)); + + if (state.hasModulo() && state.getRank() > 2) { + expandDimsOp->emitRemark( + "PtrAnalysis: unsupported scenario where expand_dims result " + "has modulo and rank > 2"); + return failure(); + } + + return success(); +} + +LogicalResult +PtrAnalysis::visitOperandBroadcast(triton::BroadcastOp broadcastOp, + PtrSState &state, const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + auto src = broadcastOp.getSrc(); + auto dst = broadcastOp.getResult(); + + SmallVector srcShape; + auto dstShape = dst.getType().cast().getShape(); + + if (src.getType().isa()) { + srcShape = + SmallVector(src.getType().cast().getShape()); + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + if (visitOperand(src, state, loc, builder).failed()) { + return failure(); + } + + for (size_t i = 0; i < dstShape.size(); i++) { + if (srcShape[i] == dstShape[i]) { + continue; + } else if (srcShape[i] < dstShape[i]) { + state.sizes[i] = builder.getIndexAttr(dstShape[i]); + } else { + llvm_unreachable("unexpected dimensions used in broadcast"); + } + } + return success(); + } + + broadcastOp->emitRemark("PtrAnalysis: Unsupported broadcast source type"); + return failure(); +} + +LogicalResult PtrAnalysis::visitOperandSplat(triton::SplatOp splatOp, + PtrSState &state, + const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = dst.getType().cast().getShape(); + + if (visitOperand(src, state, loc, builder).failed()) { + return failure(); + } + + if (src.getType().isa()) { + for (auto s : dstShape) { + state.offsets.push_back(builder.getIndexAttr(0)); + state.sizes.push_back(builder.getIndexAttr(s)); + state.strides.push_back(builder.getIndexAttr(0)); + state.modulos.push_back(builder.getIndexAttr(0)); + } + } else { + splatOp->emitRemark("PtrAnalysis: unsupported splat pattern"); + return failure(); + } + + // If we splat a integer value, scalar should become the offset of the outer + // most dimension + if (state.scalar) + state.offsets[0] = state.scalar; + + if (state.hasModulo() && state.getRank() > 2) { + splatOp->emitRemark("PtrAnalysis: unsupported scenario where splat result " + "has modulo and rank > 2"); + return failure(); + } + + return success(); +} + +LogicalResult PtrAnalysis::visitOperandAddptr(triton::AddPtrOp addptrOp, + PtrSState &state, + const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + + PtrSState ptrState; + if (visitOperand(addptrOp.getPtr(), ptrState, addptrOp.getLoc(), builder) + .failed()) { + return failure(); + } + + PtrSState offsetState; + if (visitOperand(addptrOp.getOffset(), offsetState, addptrOp.getLoc(), + builder) + .failed()) { + return failure(); + } + + assert(ptrState.source && "ptr field should provide source / base pointer"); + + assert(ptrState.getRank() == offsetState.getRank() && + "ptr and offset field should have the same rank"); + + return state.addState(ptrState, offsetState, addptrOp, builder); +} + +LogicalResult PtrAnalysis::visitOperandConstSplat(arith::ConstantOp op, + PtrSState &state, + const Location loc, + OpBuilder &builder) { + assert(state.isEmpty()); + // this condition is to handle cases where tt.broadcast and tt.splat are + // folded + auto attr = cast(op.getValue()); + auto elementType = attr.getElementType(); + assert(attr.isSplat() && elementType.isa()); + auto values = attr.getValues(); + auto value = values[0].getValue(); + auto constAttr = builder.getIndexAttr(value.getSExtValue()); + auto constOp = arith::ConstantOp::materialize(builder, constAttr, + builder.getIndexType(), loc); + + state.scalar = constOp; + + auto resultType = cast(op.getResult().getType()); + for (size_t i = 0; i < resultType.getShape().size(); i++) { + if (i == 0) { + state.offsets.push_back(constOp.getResult()); + } else { + state.offsets.push_back(builder.getIndexAttr(0)); + } + + state.sizes.push_back(builder.getIndexAttr(resultType.getShape()[i])); + state.strides.push_back(builder.getIndexAttr(0)); + state.modulos.push_back(builder.getIndexAttr(0)); + } + + return success(); +} + +LogicalResult PtrAnalysis::visitOperandMakeTPtr(tts::MakeTensorPtrOp makeTPtrOp, + PtrSState &state, + const Location loc, + OpBuilder &builder) { + + assert(state.isEmpty()); + state.source = makeTPtrOp.getBase(); + state.offsets = makeTPtrOp.getMixedOffsets(); + state.sizes = makeTPtrOp.getMixedSizes(); + state.strides = makeTPtrOp.getMixedStrides(); + state.modulos = makeTPtrOp.getMixedParentSizes(); + + return success(); +} + +LogicalResult PtrAnalysis::visitOperand(Value operand, PtrSState &state, + const Location loc, + OpBuilder &builder) { + + if (knownPtrs.find(operand) != knownPtrs.end()) { + state = knownPtrs.lookup(operand); + return success(); + } + + if (operand.getType().isa()) { + OpBuilder::InsertionGuard guard(builder); + if (!operand.isa() && operand.getDefiningOp()) { + builder.setInsertionPointAfter(operand.getDefiningOp()); + } + auto castOp = builder.create( + loc, builder.getIndexType(), operand); + state.scalar = castOp.getResult(); + return success(); + } else if (operand.getType().isa()) { + state.scalar = operand; + return success(); + } + + if (operand.getType().isa()) { + // A scalar pointer can either be produced by AddPtrOp or a block + // argument + if (auto op = operand.getDefiningOp()) { + if (auto addPtrOp = dyn_cast(op)) { + return visitOperandAddptr(cast(op), state, loc, + builder); + } else if (auto makeTensorOp = dyn_cast(op)) { + llvm_unreachable("NYI"); + } else { + llvm_unreachable("Unexpected operand defining operation"); + } + } else { + state.source = operand; + return success(); + } + } + + if (auto op = operand.getDefiningOp()) { + return visitOperandAdd(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandMul(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandMakeRange(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandBroadcast(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandSplat(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandExpandDims(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandAddptr(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandConstSplat(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandRem(op, state, loc, builder); + } else { + llvm::dbgs() << "PtrAnalysis: encountered addptr operand produced by an " + "unsupported operation\n"; + operand.dump(); + return failure(); + } +} + +LogicalResult PtrAnalysis::rewriteAddptrOp(triton::AddPtrOp op) { + OpBuilder builder(op); + + PtrSState state; + if (visitOperandAddptr(op, state, op.getLoc(), builder).failed()) { + return failure(); + } + + knownPtrs[op.getResult()] = state; + + if (op.getPtr().getType().isa()) { + auto maketptrOp = state.createTTSMakeTensorPtrOp(builder, op.getLoc()); + map.map(op.getResult(), maketptrOp.getResult()); + } else { + map.map(op.getResult(), op.getResult()); + } + return success(); +} + +LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { + SmallVector newInitArgs; + + SmallVector, 5> initArgIndexState; + SmallVector, 5> knownPtrsTmp; + + llvm::SmallDenseMap initArgIndexMap; + + OpBuilder builder(op); + + // Create a new list of init args + for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { + auto mappedV = map.lookupOrNull(arg); + PtrSState state; + + if (mappedV) { + if (auto makeTPtrOp = mappedV.getDefiningOp()) { + if (visitOperandMakeTPtr(makeTPtrOp, state, op.getLoc(), builder) + .failed()) { + newInitArgs.push_back(arg); + continue; + } + newInitArgs.push_back(mappedV); + } else if (auto addptrOp = mappedV.getDefiningOp()) { + assert(!addptrOp.getResult().getType().isa()); + if (visitOperandAddptr(addptrOp, state, op.getLoc(), builder) + .failed()) { + newInitArgs.push_back(arg); + continue; + } + newInitArgs.push_back(mappedV); + } + } + // Init arg is not pointer related or prior rewrite has failed. Pass as is + else { + newInitArgs.push_back(arg); + continue; + } + // Record the PtrSState for later processing + initArgIndexState.push_back(std::make_pair(i, state)); + } + + // For each of the PtrSState recorded in the last step, insert new instructions + // to describe offset and stride for each dimension and append them to init + // args + for (auto [i, state] : initArgIndexState) { + // For each dimension, if the corresponding offset and stride is an + // integer attribute, create a constant value and append them at the + // end of init arg list. + for (auto [j, s] : llvm::enumerate(state.offsets)) { + auto sIntAttr = getIntAttr(s); + if (sIntAttr) { + auto constOp = builder.create( + op.getLoc(), builder.getIndexAttr(sIntAttr.value())); + newInitArgs.push_back(constOp.getResult()); + state.offsets[j] = constOp.getResult(); + } else { + newInitArgs.push_back(s.get()); + } + } + + for (auto [j, s] : llvm::enumerate(state.strides)) { + auto sIntAttr = getIntAttr(s); + if (sIntAttr) { + auto constOp = builder.create( + op.getLoc(), builder.getIndexAttr(sIntAttr.value())); + newInitArgs.push_back(constOp.getResult()); + state.strides[j] = constOp.getResult(); + } else { + newInitArgs.push_back(s.get()); + } + } + + if (state.getRank() == 0) { + assert(state.scalar); + newInitArgs.push_back(state.scalar); + } + + // Note that we want the knownPtrs to be indexed by block arg, but we + // only have index for now. Also, the state we record is the init + // arg, but want to to use newly created block arg. These block args + // are not created yet. We will translate this mapping later. + knownPtrsTmp.push_back(std::make_pair(i, state)); + levelToBlockArgIndex[level].insert(i); + } + + // Create a new scf::ForOp that uses updated init args and same loop body + auto newOp = builder.create( + op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), + newInitArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + IRMapping cloneMap; + cloneMap.map(op.getInductionVar(), iv); + cloneMap.map(op.getInitArgs(), newInitArgs); + cloneMap.map(op.getRegionIterArgs(), args); + + for (auto &bodyOp : op.getRegion().getOps()) { + b.clone(bodyOp, cloneMap); + } + }); + + // Convert the book-keeping data structure to use the correct key and value. + // Key is converted from init arg index to newly created block arg, and + // Value's PtrSState fields are converted from init arg to newly created block + // arg + int cnt = op.getRegionIterArgs().size(); + for (auto [i, state] : knownPtrsTmp) { + for (auto it = state.offsets.begin(); it != state.offsets.end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + for (auto it = state.strides.begin(); it != state.strides.end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + if (state.getRank() == 0) { + assert(state.scalar); + state.scalar = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + // Record the PtrSState for this pointer + auto key = newOp.getRegionIterArgs()[i]; + knownPtrs[key] = state; + initArgIndexMap[i] = state; + + // Create a tts.make_tptr at the beginning of the loop body that correspond + // to this region iter arg. In case it is used by tt.load/tt.store in the + // loop body, this will make sure rewriteLoadOp/rewriteStoreOp can use the + // analysis result. + if (state.getRank() != 0) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&newOp.getRegion().front()); + auto maketptrOp = state.createTTSMakeTensorPtrOp(builder, op.getLoc()); + map.map(key, maketptrOp.getResult()); + } else { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&newOp.getRegion().front()); + + auto offset = state.scalar; + if (offset.getType().isa()) { + offset = builder.create( + op.getLoc(), builder.getI32Type(), offset); + } + auto addptrOp = builder.create( + op.getLoc(), state.source.getType(), state.source, offset); + map.map(key, addptrOp.getResult()); + } + } + + for (auto &bodyOp : newOp.getRegion().getOps()) { + if (auto forOp = dyn_cast(bodyOp)) { + assert(0 && "nested loops currently not supported"); + } + } + + // Update the loop body. + if (rewriteOp(newOp).failed()) { + newOp->erase(); + op->emitRemark( + "PtrAnalysis: update loop body failed when rewriting for op"); + return failure(); + } + + if (op.getNumRegionIterArgs()) { + auto yieldOp = cast(newOp.getBody()->getTerminator()); + if (rewriteYieldOp(yieldOp, initArgIndexMap).failed()) { + newOp->erase(); + return failure(); + }; + } + + levelToBlockArgIndex.erase(level); + + // Replace only the results that correspond to the original scf.for + auto resultsToReplaceWith = ResultRange( + newOp.result_begin(), newOp.result_begin() + op.getNumResults()); + op->replaceAllUsesWith(resultsToReplaceWith); + op->erase(); + + LLVM_DEBUG({ + llvm::dbgs() << "new for\n"; + newOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + + llvm::dbgs() << "old for\n"; + op->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + return success(); +} + +LogicalResult +PtrAnalysis::rewriteYieldOp(scf::YieldOp op, + llvm::SmallDenseMap &knownPtrsFor) { + if (levelToBlockArgIndex.find(level) == levelToBlockArgIndex.end()) { + // no need to rewrite this op + return success(); + } + + OpBuilder builder(op); + + // For each of the init arg that we added additional Values in for loop, we + // need to add corresponding Values as yield operands. The loop below gathers + // PtrSState for those values. + SmallVector initArgState; + for (auto [i, v] : llvm::enumerate(op->getOperands())) { + // If this operand is not rewritten by forOp, skip + auto thisSet = levelToBlockArgIndex.find(level)->second; + if (thisSet.find(i) == thisSet.end()) + continue; + + auto mappedV = map.lookupOrNull(v); + if (!mappedV) { + op->emitRemark("Prior rewrite failure lead to yield rewrite failure"); + return failure(); + } + + PtrSState state; + LogicalResult ret = failure(); + if (auto makeTPtrOp = mappedV.getDefiningOp()) { + ret = visitOperandMakeTPtr(makeTPtrOp, state, op.getLoc(), builder); + } else if (auto addptrOp = mappedV.getDefiningOp()) { + ret = visitOperandAddptr(addptrOp, state, op.getLoc(), builder); + } + if (ret.failed()) { + op->emitRemark("Failed to rewrite yield op"); + return failure(); + } + initArgState.push_back(state); + + // Verify that modulo state is not updated during the for loop + auto forState = knownPtrsFor[i]; + for (auto i = 0; i < forState.getRank(); ++i) { + if (forState.modulos[i] != state.modulos[i]) { + // Special case, see comments in addState in dealing with modulos + if (i == 0 && forState.getRank() == 2) { + if (forState.modulos[1] == state.modulos[0] && + forState.modulos[0] == state.modulos[1]) { + break; + } + } + assert(0); + op->emitRemark( + "PtrAnalysis: operand's modulo state changed within loop body"); + return failure(); + } + } + } + + SmallVector operands; + for (auto opnd : op->getOperands()) { + auto mappedV = map.lookupOrNull(opnd); + if (mappedV) { + operands.push_back(mappedV); + } else { + operands.push_back(opnd); + } + } + + // For each of the PtrSState recorded in the last step, extract value + // that correspond to offset and stride for each dimension and append + // them to yield operands. + for (auto state : initArgState) { + for (auto s : state.offsets) { + if (auto sIntAttr = getIntAttr(s)) { + auto constOp = builder.create( + op.getLoc(), builder.getIndexAttr(sIntAttr.value())); + operands.push_back(constOp.getResult()); + } else { + operands.push_back(s.get()); + } + } + + for (auto s : state.strides) { + assert(!getIntAttr(s) && "PtrSState strides for yield within for " + "loop not expected to be attribute."); + operands.push_back(s.get()); + } + + if (state.getRank() == 0) { + operands.push_back(state.scalar); + } + } + + auto newOp = builder.create(op->getLoc(), operands); + + LLVM_DEBUG({ + llvm::dbgs() << "new yield:"; + newOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + op->erase(); + return success(); +} + +LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) { + auto ptr = map.lookupOrNull(op.getPtr()); + auto mask = op.getMask(); + auto other = op.getOther(); + auto loc = op.getLoc(); + + if (!ptr) { + op->emitRemark("PtrAnalysis: pointer is not replace with tts.make_tptr so " + "loadOp cannot be rewritten"); + return failure(); + } + + if (!ptr.getType().isa()) { + op->emitRemark("PtrAnalysis: scalar loadOp will not be rewritten"); + return failure(); + } + + ArrayRef dims; + MaskSState mstate; + Value scalarOther; + + OpBuilder builder(op); + // Analyze the mask operand to determine at runtime the size of the data we + // are moving. + if (mask) { + if (mstate.parse(mask, loc, builder).failed()) { + op->emitRemark("MaskAnalysis failed"); + return failure(); + } + dims = mstate.dims; + } + + if (other) { + assert(mask && "other value used while no masks are specified"); + + scalarOther = getScalarValue(other, loc, builder); + if (!scalarOther) { + op->emitRemark("other value used in masked load produced by " + "unsupported instruction"); + return failure(); + } + } + + auto loadOp = builder.create(loc, ptr, dims, scalarOther); + + LLVM_DEBUG({ + llvm::dbgs() << "creating tts::load:\n"; + loadOp->dump(); + }); + + op.replaceAllUsesWith(loadOp.getResult()); + op->erase(); + return success(); +} + +LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) { + auto ptr = map.lookupOrNull(op.getPtr()); + auto val = op.getValue(); + auto mask = op.getMask(); + auto loc = op.getLoc(); + + if (!ptr) { + op->emitRemark("PtrAnalysis: pointer is not replace with tts.make_tptr so " + "storeOp cannot be rewritten"); + return failure(); + } + + if (!ptr.getType().isa()) { + op->emitRemark("PtrAnalysis: scalar loadOp will not be rewritten"); + return failure(); + } + + ArrayRef dims; + MaskSState mstate; + + OpBuilder builder(op); + + // Analyze the mask operand to determine at runtime the size of the data + // are moving. + if (mask) { + if (mstate.parse(mask, loc, builder).failed()) { + op->emitRemark("MaskAnalysis failed"); + return failure(); + } + dims = mstate.dims; + } + + auto storeOp = builder.create(loc, ptr, val, dims); + + LLVM_DEBUG({ + llvm::dbgs() << "creating tts::store:\n"; + storeOp->dump(); + }); + + op->erase(); + return success(); +} + +LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp) { + LLVM_DEBUG({ + llvm::dbgs() << "rewriting rootOp\n"; + rootOp->dump(); + }); + + rootOp->walk([&](Operation *op) { + if (op == rootOp) { + return WalkResult::advance(); + } + return TypeSwitch(op) + .Case([&](auto addptr) { + if (rewriteAddptrOp(addptr).failed()) { + addptr->emitRemark("PtrAnalysis: Failed to rewrite AddPtrOp"); + } + return WalkResult::advance(); + }) + .Case([&](auto load) { + if (rewriteLoadOp(load).failed()) { + load->emitRemark("PtrAnalysis: Failed to rewrite LoadOp"); + return WalkResult::advance(); + } + return WalkResult::skip(); + }) + .Case([&](auto store) { + if (rewriteStoreOp(store).failed()) { + store->emitRemark("PtrAnalysis: Failed to rewrite StoreOp"); + return WalkResult::advance(); + } + return WalkResult::skip(); + }) + .Case([&](auto forOp) { + if (rewriteForOp(forOp).failed()) { + forOp->emitRemark("PtrAnalysis: Failed to rewrite ForOp"); + return WalkResult::advance(); + } + return WalkResult::skip(); + }) + .Default([&](auto) { return WalkResult::advance(); }); + }); + + return success(); +} + +} // namespace triton +} // namespace mlir diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 433b407c..eff85b20 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Analysis) +add_subdirectory(AnalysisStructured) add_subdirectory(Conversion) add_subdirectory(Dialect) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 64ac1576..600de5fc 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(TritonToLinalg) +add_subdirectory(TritonToStructured) diff --git a/lib/Conversion/TritonToStructured/CMakeLists.txt b/lib/Conversion/TritonToStructured/CMakeLists.txt new file mode 100644 index 00000000..2c9a84eb --- /dev/null +++ b/lib/Conversion/TritonToStructured/CMakeLists.txt @@ -0,0 +1,24 @@ +add_mlir_conversion_library(TritonToStructured + TritonToStructuredPass.cpp + + DEPENDS + TritonStructuredTableGen + TritonToStructuredConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms + TritonSharedAnalysisStructured + TritonStructuredIR +) diff --git a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp new file mode 100644 index 00000000..70e41162 --- /dev/null +++ b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/AnalysisStructured/PtrAnalysis.h" +#include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-to-structured" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonToStructured/Passes.h.inc" + +namespace { + +class TritonToStructuredPass + : public TritonToStructuredBase { + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + PtrAnalysis ptrAnalysis; + if (ptrAnalysis.rewriteOp(moduleOp).failed()) { + moduleOp->emitWarning("PtrAnalysis failed"); + } + } +}; +} // namespace + +std::unique_ptr> +triton::createTritonToStructuredPass() { + return std::make_unique(); +} diff --git a/tools/RegisterTritonSharedDialects.h b/tools/RegisterTritonSharedDialects.h index cf9b006e..47884a47 100644 --- a/tools/RegisterTritonSharedDialects.h +++ b/tools/RegisterTritonSharedDialects.h @@ -9,6 +9,8 @@ #include "triton/Conversion/TritonToTritonGPU/Passes.h" #include "triton-shared/Conversion/TritonToLinalg/Passes.h" +#include "triton-shared/Conversion/TritonToStructured/Passes.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" #include "mlir/InitAllPasses.h" @@ -31,12 +33,13 @@ inline void registerTritonSharedDialects(mlir::DialectRegistry ®istry) { mlir::test::registerTestAllocationPass(); mlir::test::registerTestMembarPass(); mlir::triton::registerTritonToLinalgPass(); + mlir::triton::registerTritonToStructuredPass(); mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::registerConvertTritonGPUToLLVMPass(); // TODO: register Triton & TritonGPU passes registry - .insert(); From e5a8b0cdd53de1538865992a3a92902d66e4ffbb Mon Sep 17 00:00:00 2001 From: Haishan Zhu Date: Tue, 9 Jan 2024 20:27:53 -0800 Subject: [PATCH 5/8] LIT tests --- .../TritonToStructured/addptr_2d_example.mlir | 62 ++++++ .../TritonToStructured/addptr_add_value.mlir | 67 ++++++ .../TritonToStructured/addptr_dim1.mlir | 100 +++++++++ .../addptr_for_accumulation.mlir | 85 +++++++ .../addptr_for_expand_ptr.mlir | 71 ++++++ .../addptr_for_more_init_args.mlir | 69 ++++++ .../addptr_for_used_after_update.mlir | 98 +++++++++ .../addptr_for_used_before_update.mlir | 54 +++++ .../TritonToStructured/addptr_loopback.mlir | 54 +++++ .../addptr_mul_const_const.mlir | 50 +++++ .../addptr_mul_value_const.mlir | 53 +++++ .../TritonToStructured/addptr_nested.mlir | 62 ++++++ .../addptr_reshape_broadcast.mlir | 43 ++++ .../addptr_scalar_broadcast.mlir | 61 +++++ .../TritonToStructured/addptr_scalar_for.mlir | 59 +++++ .../addptr_scalar_for_2d.mlir | 81 +++++++ .../addptr_scalar_loopback.mlir | 23 ++ .../addptr_scalar_nested.mlir | 53 +++++ .../addptr_scalar_splat.mlir | 41 ++++ .../addptr_scalar_splat_2d.mlir | 52 +++++ .../arith_not_ptr_arith.mlir | 33 +++ .../TritonToStructured/bitcast.mlir | 34 +++ .../TritonToStructured/block_ptr_advance.mlir | 62 ++++++ test/Conversion/TritonToStructured/dot.mlir | 66 ++++++ .../kernel-01-vector-add.mlir | 62 ++++++ .../kernel-02-fused-softmax.mlir | 74 +++++++ .../kernel-03-matrix-multiplication.mlir | 190 ++++++++++++++++ .../kernel-05-layer-norm-dwdb.mlir | 145 ++++++++++++ .../kernel-05-layer-norm-fwd.mlir | 208 ++++++++++++++++++ .../TritonToStructured/masked_ldst_1d.mlir | 36 +++ .../TritonToStructured/masked_ldst_2d.mlir | 97 ++++++++ .../masked_ldst_sitofp_other.mlir | 38 ++++ .../TritonToStructured/use_dot_opc.mlir | 68 ++++++ .../TritonToStructured/use_end_chain.mlir | 56 +++++ .../TritonToStructured/use_mid_chain.mlir | 52 +++++ .../wraparound_side_by_side.mlir | 92 ++++++++ .../wraparound_stacked.mlir | 88 ++++++++ .../wraparound_unsupported_add_offset.mlir | 110 +++++++++ 38 files changed, 2749 insertions(+) create mode 100644 test/Conversion/TritonToStructured/addptr_2d_example.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_add_value.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_dim1.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_for_accumulation.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_for_expand_ptr.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_for_more_init_args.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_for_used_after_update.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_for_used_before_update.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_loopback.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_mul_const_const.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_mul_value_const.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_nested.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_reshape_broadcast.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_scalar_broadcast.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_scalar_for.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_scalar_for_2d.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_scalar_loopback.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_scalar_nested.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_scalar_splat.mlir create mode 100644 test/Conversion/TritonToStructured/addptr_scalar_splat_2d.mlir create mode 100644 test/Conversion/TritonToStructured/arith_not_ptr_arith.mlir create mode 100644 test/Conversion/TritonToStructured/bitcast.mlir create mode 100644 test/Conversion/TritonToStructured/block_ptr_advance.mlir create mode 100644 test/Conversion/TritonToStructured/dot.mlir create mode 100644 test/Conversion/TritonToStructured/kernel-01-vector-add.mlir create mode 100644 test/Conversion/TritonToStructured/kernel-02-fused-softmax.mlir create mode 100644 test/Conversion/TritonToStructured/kernel-03-matrix-multiplication.mlir create mode 100644 test/Conversion/TritonToStructured/kernel-05-layer-norm-dwdb.mlir create mode 100644 test/Conversion/TritonToStructured/kernel-05-layer-norm-fwd.mlir create mode 100644 test/Conversion/TritonToStructured/masked_ldst_1d.mlir create mode 100644 test/Conversion/TritonToStructured/masked_ldst_2d.mlir create mode 100644 test/Conversion/TritonToStructured/masked_ldst_sitofp_other.mlir create mode 100644 test/Conversion/TritonToStructured/use_dot_opc.mlir create mode 100644 test/Conversion/TritonToStructured/use_end_chain.mlir create mode 100644 test/Conversion/TritonToStructured/use_mid_chain.mlir create mode 100644 test/Conversion/TritonToStructured/wraparound_side_by_side.mlir create mode 100644 test/Conversion/TritonToStructured/wraparound_stacked.mlir create mode 100644 test/Conversion/TritonToStructured/wraparound_unsupported_add_offset.mlir diff --git a/test/Conversion/TritonToStructured/addptr_2d_example.mlir b/test/Conversion/TritonToStructured/addptr_2d_example.mlir new file mode 100644 index 00000000..ce4fa196 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_2d_example.mlir @@ -0,0 +1,62 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : !tt.ptr, + %arg3 : i32 + ) + { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + // offset = 0, size = 4, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + // offset = [0,0], size = [4,1], stride = [1,0] + %2 = tt.broadcast %1 : (tensor<4x1xi32>) -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [1,0] + %arg3splat = tt.splat %arg3 : (i32) -> tensor<4x256xi32> + %offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32> + // offset = [%arg3,0], size = [4,256], stride = [1,0] + %3 = tt.make_range {end = 256 : i32, start = 0 : i32}: tensor<256xi32> + // offset = 0, size = 256, stride = 1 + %4 = tt.expand_dims %3 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + // offset = [0,0], size = [1,256], stride = [0,1] + %5 = tt.broadcast %4 : (tensor<1x256xi32>) -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,1] + %6 = arith.constant 5 : i32 + %splat6 = tt.splat %6 : (i32) -> tensor<4x256xi32> + %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,5] + %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> + // offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<4x256x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4x256xbf16> + %11 = tt.splat %arg1 : (!tt.ptr) -> tensor<4x256x!tt.ptr> + %12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4x256xbf16> + %14 = arith.addf %10, %13 : tensor<4x256xbf16> + %15 = tt.splat %arg2 : (!tt.ptr) -> tensor<4x256x!tt.ptr> + %16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + tt.store %16, %14 : tensor<4x256xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32) { +// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : index +// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [4, 256], strides: [1, [[CST_5_]]{{.}}, offsets: {{.}}[[VAR_0_]], 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK-DAG: [[VAR_2_:%.+]] = "tts.load"([[VAR_1_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<4x256x!tt.ptr>) -> tensor<4x256xbf16> +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_4_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [4, 256], strides: [1, [[CST_5_]]{{.}}, offsets: {{.}}[[VAR_3_]], 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK: [[VAR_5_:%.+]] = "tts.load"([[VAR_4_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<4x256x!tt.ptr>) -> tensor<4x256xbf16> +// CHECK-DAG: [[VAR_6_:%.+]] = arith.addf [[VAR_2_]], [[VAR_5_]] : tensor<4x256xbf16> +// CHECK-DAG: [[VAR_7_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_8_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [4, 256], strides: [1, [[CST_5_]]{{.}}, offsets: {{.}}[[VAR_7_]], 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK: "tts.store"([[VAR_8_]], [[VAR_6_]]) <{static_dims = array}> : (tensor<4x256x!tt.ptr>, tensor<4x256xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_add_value.mlir b/test/Conversion/TritonToStructured/addptr_add_value.mlir new file mode 100644 index 00000000..9242fe10 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_add_value.mlir @@ -0,0 +1,67 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32, + %arg3 : i32 + ) + { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + // offset = 0, size = 4, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + // offset = [0,0], size = [4,1], stride = [1,0] + %2 = tt.broadcast %1 : (tensor<4x1xi32>) -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [1,0] + %arg2splat = tt.splat %arg2 : (i32) -> tensor<4x256xi32> + %offset2 = arith.addi %2, %arg2splat : tensor<4x256xi32> + // offset = [%arg2,0], size = [4,256], stride = [1,0] + %arg3splat = tt.splat %arg3 : (i32) -> tensor<4x256xi32> + %offset3 = arith.addi %offset2, %arg3splat : tensor<4x256xi32> + // offset = [%arg2+%arg3,0], size = [4,256], stride = [1,0] + %c10 = arith.constant 10 : i32 + %c10splat = tt.splat %c10 : (i32) -> tensor<4x256xi32> + %offset4 = arith.addi %offset3, %c10splat : tensor<4x256xi32> + // offset = [%arg2+%arg3+10,0], size = [4,256], stride = [1,0] + %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + // offset = 0, size = 256, stride = 1 + %4 = tt.expand_dims %3 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + // offset = [0,0], size = [1,256], stride = [0,1] + %5 = tt.broadcast %4 : (tensor<1x256xi32>) -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,1] + %c6 = arith.constant 6 : i32 + %splat6 = tt.splat %c6 : (i32) -> tensor<4x256xi32> + %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,6] + %7 = arith.addi %offset4, %scale5: tensor<4x256xi32> + // offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6] + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<4x256x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>,tensor<4x256xi32> + // source = %arg0, offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6] + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<4x256x!tt.ptr> + %11 = tt.addptr %10, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source = %arg1, offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6] + %12 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4x256xbf16> + tt.store %11, %12 : tensor<4x256xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32) { +// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index +// CHECK-DAG: [[CST_10_:%.+]] = arith.constant 10 : index +// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_2_:%.+]] = arith.addi [[VAR_0_]], [[VAR_1_]] : index +// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_2_]], [[CST_10_]] : index +// CHECK-DAG: [[VAR_4_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [4, 256], strides: [1, [[CST_6_]]{{.}}, offsets: {{.}}[[VAR_3_]], 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_5_]], [[VAR_6_]] : index +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_7_]], [[CST_10_]] : index +// CHECK-DAG: [[VAR_9_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [4, 256], strides: [1, [[CST_6_]]{{.}}, offsets: {{.}}[[VAR_8_]], 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK-DAG: [[VAR_10_:%.+]] = "tts.load"([[VAR_4_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<4x256x!tt.ptr>) -> tensor<4x256xbf16> +// CHECK: "tts.store"([[VAR_9_]], [[VAR_10_]]) <{static_dims = array}> : (tensor<4x256x!tt.ptr>, tensor<4x256xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_dim1.mlir b/test/Conversion/TritonToStructured/addptr_dim1.mlir new file mode 100644 index 00000000..fd865932 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_dim1.mlir @@ -0,0 +1,100 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : i32 + ) + { + %0 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + %1 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + + %splat_arg0 = tt.splat %arg0 : (!tt.ptr) -> tensor<1x256x!tt.ptr> + %2 = tt.addptr %splat_arg0, %1 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + // source = %arg0, offset = [0, 0], size = [1, 256], stride = [0, 1] + + // 1x256 pointer should have meaningful stride in outer dimension + %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<1x256xbf16> + + %4 = tt.splat %arg1 : (i32) -> tensor<1x256xi32> + // 1x256 pointer should have meaningful stride in outer dimension + %5 = tt.addptr %2, %4 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + // source = %arg0, offset = [%arg1, 0], size = [1, 256], stride = [0, 1] + + tt.store %5, %3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x256x!tt.ptr>, tensor<1x256xbf16> + + %10 = arith.constant 0.0 : bf16 + %11 = tt.splat %10 : (bf16) -> tensor<4x256xbf16> + + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %c256 = arith.constant 256 : i32 + %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %11, %ptr = %2) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr>) { + %bptr = tt.broadcast %ptr : (tensor<1x256x!tt.ptr>) -> tensor<4x256x!tt.ptr> + // source = %arg0, offset = [0, 0], size = [4, 256], stride = [0, 1] + + %20 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + %i_i32 = arith.index_cast %i : index to i32 + %21 = arith.muli %c256, %i_i32 : i32 + %22 = tt.splat %21 : (i32) -> tensor<4xi32> + %23 = arith.muli %20, %22 : tensor<4xi32> + %24 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %25 = tt.broadcast %24 : (tensor<4x1xi32>) -> tensor<4x256xi32> + // offset = [0, 0], size = [4, 256], stride = [i*256, 1] + + // %bptr should have zero stride and %30 should have correct stride + %30 = tt.addptr %bptr, %25 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source = %arg0, offset = [0, 0], size = [4, 256], stride = [i*256, 1] + + %31 = tt.load %30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256xbf16> + %32 = arith.addf %sum_iter, %31 : tensor<4x256xbf16> + + %40 = tt.splat %c256 : (i32) -> tensor<1x256xi32> + %41 = tt.addptr %ptr, %40 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + // source = %arg0, offset = [i*256, 0], size = [4, 256], stride = [i*256, 1] + + scf.yield %32, %41 : tensor<4x256xbf16>, tensor<1x256x!tt.ptr> + } + + %31 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + %splat_c256 = tt.splat %c256 : (i32) -> tensor<4xi32> + %32 = arith.muli %31, %splat_c256 : tensor<4xi32> + %33 = tt.expand_dims %32 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %34 = tt.broadcast %33 : (tensor<4x1xi32>) -> tensor<4x256xi32> + %35 = tt.broadcast %2 : (tensor<1x256x!tt.ptr>) -> tensor<4x256x!tt.ptr> + %36 = tt.addptr %35, %34 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + tt.store %36, %sum_out {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4x256x!tt.ptr>, tensor<4x256xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: i32) { +// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : tensor<4x256xbf16> +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [1, 256], strides: [0, 1], offsets: [0, 0], parent_sizes: [0, 0] : to tensor<1x256x!tt.ptr> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = "tts.load"([[VAR_0_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<1x256x!tt.ptr>) -> tensor<1x256xbf16> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_1_]] : i32 to index +// CHECK: [[VAR_3_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [1, 256], strides: [0, 1], offsets: {{.}}[[VAR_2_]], 0], parent_sizes: [0, 0] : to tensor<1x256x!tt.ptr> +// CHECK: "tts.store"([[VAR_3_]], [[VAR_1_]]) <{static_dims = array}> : (tensor<1x256x!tt.ptr>, tensor<1x256xbf16>) -> () +// CHECK-DAG: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg3_:%.+]] = [[VAR_cst_]], [[VAR_arg4_:%.+]] = [[CST_0_]]) -> (tensor<4x256xbf16>, index) { +// CHECK-DAG: [[VAR_6_:%.+]] = arith.index_cast [[VAR_arg2_]] : index to i32 +// CHECK: [[VAR_7_:%.+]] = arith.muli [[VAR_6_]], [[CST_256_1_]] : i32 +// CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_7_]] : i32 to index +// CHECK: [[VAR_9_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [4, 256], strides: {{.}}[[VAR_8_]], [[CST_1_]]{{.}}, offsets: {{.}}[[VAR_arg4_]], [[CST_0_]]{{.}}, parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK: [[VAR_10_:%.+]] = "tts.load"([[VAR_9_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<4x256x!tt.ptr>) -> tensor<4x256xbf16> +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addf [[VAR_arg3_]], [[VAR_10_]] : tensor<4x256xbf16> +// CHECK-DAG: [[VAR_12_:%.+]] = arith.addi [[VAR_arg4_]], [[CST_256_]] : index +// CHECK: scf.yield [[VAR_11_]], [[VAR_12_]] : tensor<4x256xbf16>, index +// CHECK: } +// CHECK: [[VAR_5_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [4, 256], strides: {{.}}[[CST_256_]], 1], offsets: [0, 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK: "tts.store"([[VAR_5_]], [[VAR_4_]]#0) <{static_dims = array}> : (tensor<4x256x!tt.ptr>, tensor<4x256xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_for_accumulation.mlir b/test/Conversion/TritonToStructured/addptr_for_accumulation.mlir new file mode 100644 index 00000000..25ca8aa2 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_for_accumulation.mlir @@ -0,0 +1,85 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : !tt.ptr, + %arg3 : i32, + %arg4 : i32 + ) + { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + // offset = 0, size = 4, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + // offset = [0,0], size = [4,1], stride = [1,0] + %2 = tt.broadcast %1 : (tensor<4x1xi32>) -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [1,0] + %arg3splat = tt.splat %arg3 : (i32) -> tensor<4x256xi32> + %offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32> + // offset = [%arg3,0], size = [4,256], stride = [1,0] + %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + // offset = 0, size = 256, stride = 1 + %4 = tt.expand_dims %3 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + // offset = [0,0], size = [1,256], stride = [0,1] + %5 = tt.broadcast %4 : (tensor<1x256xi32>) -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,1] + %c5 = arith.constant 5 : i32 + %splat6 = tt.splat %c5 : (i32) -> tensor<4x256xi32> + // scalar = 5 + %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> // Why we never called the conversion function for the inputs here? + // offset = [0,0], size = [4,256], stride = [0,5] + %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> // Why we never called the conversion function for the inputs here? + // offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<4x256x!tt.ptr> // Why is the input unknown + %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %19 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256xbf16> // this will be replaced with a memref.copy + %11 = tt.splat %arg1 : (!tt.ptr) -> tensor<4x256x!tt.ptr> + %12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %19, %ptr_iter = %12) -> (tensor<4x256xbf16>, tensor<4x256x!tt.ptr>) { + %20 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256xbf16> + %sum = arith.addf %sum_iter, %20 : tensor<4x256xbf16> + // pointer updates + %17 = tt.splat %i_c3 : (i32) -> tensor<4x256xi32> + // offset: [3, 0], size = [4, 256], stride [0, 0] + %ptr = tt.addptr %ptr_iter, %17 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg1, offset = [%arg3+%i, 0], size = [4, 256], stride = [1, 5] + scf.yield %sum, %ptr : tensor<4x256xbf16>, tensor<4x256x!tt.ptr> + } + %15 = tt.splat %arg2 : (!tt.ptr) -> tensor<4x256x!tt.ptr> + %16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + tt.store %16, %sum_out : tensor<4x256xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [4, 256], strides: [1, [[CST_5_]]{{.}}, offsets: {{.}}[[VAR_0_]], 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK-DAG: [[VAR_2_:%.+]] = "tts.load"([[VAR_1_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<4x256x!tt.ptr>) -> tensor<4x256xbf16> +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[VAR_3_]]) -> (tensor<4x256xbf16>, index) { +// CHECK-DAG: [[VAR_7_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [4, 256], strides: {{.}}[[CST_1_]], [[CST_5_]]{{.}}, offsets: {{.}}[[VAR_arg7_]], [[CST_0_]]{{.}}, parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK: [[VAR_8_:%.+]] = "tts.load"([[VAR_7_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<4x256x!tt.ptr>) -> tensor<4x256xbf16> +// CHECK-DAG: [[VAR_9_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_8_]] : tensor<4x256xbf16> +// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_3_]] : index +// CHECK: scf.yield [[VAR_9_]], [[VAR_10_]] : tensor<4x256xbf16>, index +// CHECK: } +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_6_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [4, 256], strides: [1, [[CST_5_]]{{.}}, offsets: {{.}}[[VAR_5_]], 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK: "tts.store"([[VAR_6_]], [[VAR_4_]]#0) <{static_dims = array}> : (tensor<4x256x!tt.ptr>, tensor<4x256xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_for_expand_ptr.mlir b/test/Conversion/TritonToStructured/addptr_for_expand_ptr.mlir new file mode 100644 index 00000000..43918a4a --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_for_expand_ptr.mlir @@ -0,0 +1,71 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr + ) + { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr> + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + // source: null, sizes: 256, offsets: 1024, strides: 1 + + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024, strides: 1 + + // gep operand is another gep' output, which is passed into the loop as varible, used after update + %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { + %6 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : (tensor<256xi32>) -> tensor<256x1xi32> + + %8 = tt.broadcast %7 : (tensor<256x1xi32>) -> tensor<256x256xi32> + // sizes: [256, 256], offsets: [0, 0], strides: [1, 0] + + %9 = tt.make_range {end = 512 : i32, start = 256 : i32} : tensor<256xi32> + %10 = tt.expand_dims %9 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + + %11 = tt.broadcast %10 : (tensor<1x256xi32>) -> tensor<256x256xi32> + // sizes: [256, 256], offsets: [0, 256], strides: [0, 1] + + %12 = arith.addi %8, %11 : tensor<256x256xi32> + // sizes: [256, 256], offsets: [0, 256], strides: [1, 1] + + %13 = tt.expand_dims %ptr {axis = 1 : i32} : (tensor<256x!tt.ptr>) -> tensor<256x1x!tt.ptr> + %14 = tt.broadcast %13 : (tensor<256x1x!tt.ptr>) -> tensor<256x256x!tt.ptr> + + %15 = tt.addptr %14, %12 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> + // source: arg0, sizes: [256, 256], offsets: [1024 + i*3, 256], strides: [2, 1] + + // perform load + %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x256xbf16> + tt.store %15, %16 : tensor<256x256xbf16> + // pointer updates + %17 = tt.splat %i_c3 : (i32) -> tensor<256xi32> + // sizes: 256, offsets: 3, strides: 0 + %ptr_iter = tt.addptr %ptr, %17 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024 + i*3, strides: 4 + scf.yield %ptr_iter : tensor<256x!tt.ptr> + } + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr) { +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[CST_1024_]]) -> (index) { +// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256, 256], strides: {{.}}[[CST_2_]], 1], offsets: {{.}}[[VAR_arg2_]], 256], parent_sizes: [0, 0] : to tensor<256x256x!tt.ptr> +// CHECK: [[VAR_2_:%.+]] = "tts.load"([[VAR_1_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x256x!tt.ptr>) -> tensor<256x256xbf16> +// CHECK: "tts.store"([[VAR_1_]], [[VAR_2_]]) <{static_dims = array}> : (tensor<256x256x!tt.ptr>, tensor<256x256xbf16>) -> () +// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_3_]] : index +// CHECK: scf.yield [[VAR_3_]] : index +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_for_more_init_args.mlir b/test/Conversion/TritonToStructured/addptr_for_more_init_args.mlir new file mode 100644 index 00000000..2fce23d6 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_for_more_init_args.mlir @@ -0,0 +1,69 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr + ) + { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c12 = arith.constant 12 : index + %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr> + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + // source: null, sizes: 256, offsets: 1024, strides: 1 + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024, strides: 1 + %3 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr> + %4 = tt.addptr %3, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg1, sizes: 256, offsets: 1024, strides: 1 + %_arg2, %_ptr_ld, %_arg3, %_ptr_st, %_arg4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%arg2 = %c1, %ptr_ld = %2, %arg3 = %c2, %ptr_st = %4, %arg4 = %c3) -> (index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index) { + // perform load + %5 = tt.load %ptr_ld {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256xbf16> + tt.store %ptr_st, %5 : tensor<256xbf16> + // pointer updates + %cast3 = arith.index_cast %c3 : index to i32 + %6 = tt.splat %cast3 : (i32) -> tensor<256xi32> + %ptr_ld_iter = tt.addptr %ptr_ld, %6 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024 + i*3, strides: 1 + %arg2_iter = arith.addi %arg2, %c3 : index + %arg3_iter = arith.addi %arg3, %c3 : index + %arg4_iter = arith.addi %arg4, %c3 : index + %7 = arith.addi %arg2_iter, %arg3_iter : index + %8 = arith.addi %7, %arg4_iter : index + %cast8 = arith.index_cast %8 : index to i32 + %9 = tt.splat %cast8 : (i32) -> tensor<256xi32> + %ptr_st_iter = tt.addptr %ptr_st, %9 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg1, sizes: 256, offsets: 1024 + loop-carry variable*i, strides: 1 + scf.yield %arg2_iter, %ptr_ld_iter, %arg3_iter, %ptr_st_iter, %arg4_iter : index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index + } + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]]:5 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg3_:%.+]] = [[CST_1_]], [[VAR_arg4_:%.+]] = [[CST_2_]], [[VAR_arg5_:%.+]] = [[CST_3_]], [[VAR_arg6_:%.+]] = [[CST_1_]]024, [[VAR_arg7_:%.+]] = [[CST_1_]]024) -> (index, index, index, index, index) { +// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [256], strides: {{.}}[[CST_1_]]{{.}}, offsets: {{.}}[[VAR_arg7_]]{{.}}, parent_sizes: [0] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_2_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256], strides: {{.}}[[CST_1_]]{{.}}, offsets: {{.}}[[VAR_arg6_]]{{.}}, parent_sizes: [0] : to tensor<256x!tt.ptr> +// CHECK: [[VAR_3_:%.+]] = "tts.load"([[VAR_2_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x!tt.ptr>) -> tensor<256xbf16> +// CHECK: "tts.store"([[VAR_1_]], [[VAR_3_]]) <{static_dims = array}> : (tensor<256x!tt.ptr>, tensor<256xbf16>) -> () +// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_arg6_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_arg3_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[VAR_arg4_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_arg5_]], [[CST_3_]] : index +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_5_]], [[VAR_6_]] : index +// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_8_]], [[VAR_7_]] : index +// CHECK: [[VAR_10_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_9_]] : index +// CHECK: scf.yield [[VAR_5_]], [[VAR_6_]], [[VAR_7_]], [[VAR_4_]], [[VAR_10_]] : index, index, index, index, index +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_for_used_after_update.mlir b/test/Conversion/TritonToStructured/addptr_for_used_after_update.mlir new file mode 100644 index 00000000..f3b2f8b3 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_for_used_after_update.mlir @@ -0,0 +1,98 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr + ) + { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr> + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + // source: null, sizes: 256, offsets: 1024, strides: 1 + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024, strides: 1 + // gep operand is another gep' output, which is passed into the loop as varible, used after update + %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { + // pointer updates + %4 = tt.splat %i_c3 : (i32) -> tensor<256xi32> + // sizes: 256, offsets: 3, strides: 0 + %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024 + i, strides: 1 + // perform load + %3 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256xbf16> + tt.store %ptr_iter, %3 : tensor<256xbf16> + scf.yield %ptr_iter : tensor<256x!tt.ptr> + } + // Expected output + // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) + // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) + // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) + // %subview = memref.subview %arg0, [%4][256][4] : memref<> -> memref<> <- generate subview on getelementptr (already done) + // ... + // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) + // } + // TODO: examples below are not supported since scf.for does not support returning a tensor type + // Example 3, gep operand is a vector of i32, which is passed into the loop as variable, pointer updated using step, used after update + //%_ptr3 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %1) -> (tensor<256xi32>) { + // // offset update + // %3 = tt.splat %c3 : (i32) -> tensor<256xi32> + // %ptr_iter = arith.addi %3, %ptr : tensor<256xi32> + // // generate pointer + // %gep_ptr = tt.addptr %0, %ptr_iter : tensor<256x!tt.ptr> + // // perform load + // %4 = tt.load %gep_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256xbf16> + // tt.store %gep_ptr, %4 : tensor<256xbf16> + // scf.yield %ptr_iter : tensor<256xi32> + //} + // Expected output + // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) + // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) + // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) + // %subview = memref.subview %arg0, [%offset_dim0_iter][256][4] : memref<> -> memref<> <- generate subview on load (new) + // ... + // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) + // } + //// Example 4, gep operand is a vector of i32, which is passed into the loop as variable, pointer updated using step, used before update + //%_ptr4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %1) -> (tensor<256xi32>) { + // // generate pointer + // %gep_ptr = tt.addptr %0, %ptr : tensor<256x!tt.ptr> + // + // // perform load + // %4 = tt.load %gep_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256xbf16> + // tt.store %gep_ptr, %4 : tensor<256xbf16> + // // offset update + // %3 = tt.splat %c3 : (i32) -> tensor<256xi32> + // %ptr_iter = arith.addi %3, %ptr : tensor<256xi32> + // scf.yield %ptr_iter : tensor<256xi32> + //} + // Expected output + // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) + // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) + // %subview = memref.subview %arg0, [%offset_dim0_iter][256][4] : memref<> -> memref<> <- generate subview on load (new) + // ... + // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) + // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) + // } + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr) { +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[CST_1024_]]) -> (index) { +// CHECK-DAG: [[VAR_1_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_3_]] : index +// CHECK: [[VAR_2_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256], strides: {{.}}[[CST_1_]]{{.}}, offsets: {{.}}[[VAR_1_]]{{.}}, parent_sizes: [0] : to tensor<256x!tt.ptr> +// CHECK: [[VAR_3_:%.+]] = "tts.load"([[VAR_2_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x!tt.ptr>) -> tensor<256xbf16> +// CHECK: "tts.store"([[VAR_2_]], [[VAR_3_]]) <{static_dims = array}> : (tensor<256x!tt.ptr>, tensor<256xbf16>) -> () +// CHECK: scf.yield [[VAR_1_]] : index +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_for_used_before_update.mlir b/test/Conversion/TritonToStructured/addptr_for_used_before_update.mlir new file mode 100644 index 00000000..cec2b69c --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_for_used_before_update.mlir @@ -0,0 +1,54 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr + ) + { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr> + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + // source: null, sizes: 256, offsets: 1024, strides: 1 + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024, strides: 1 + // Example 2, gep operand is another gep's output, which is passed into the loop as varible, used before update + %_ptr2 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { + // perform load + %3 = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256xbf16> + tt.store %ptr, %3 : tensor<256xbf16> + // pointer updates + %4 = tt.splat %i_c3 : (i32) -> tensor<256xi32> + %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + scf.yield %ptr_iter : tensor<256x!tt.ptr> + } + // Expected output + // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) + // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) + // %subview = memref.subview %arg0, [%offset_dim0_iter][256][4] : memref<> -> memref<> <- generate subview on load (new) + // ... + // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) + // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) + // } + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr) { +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[CST_1024_]]) -> (index) { +// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256], strides: {{.}}[[CST_1_]]{{.}}, offsets: {{.}}[[VAR_arg2_]]{{.}}, parent_sizes: [0] : to tensor<256x!tt.ptr> +// CHECK: [[VAR_2_:%.+]] = "tts.load"([[VAR_1_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x!tt.ptr>) -> tensor<256xbf16> +// CHECK: "tts.store"([[VAR_1_]], [[VAR_2_]]) <{static_dims = array}> : (tensor<256x!tt.ptr>, tensor<256xbf16>) -> () +// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_3_]] : index +// CHECK: scf.yield [[VAR_3_]] : index +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_loopback.mlir b/test/Conversion/TritonToStructured/addptr_loopback.mlir new file mode 100644 index 00000000..82f7be28 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_loopback.mlir @@ -0,0 +1,54 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) + { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + // offset = 0, size = 4, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + // offset = [0,0], size = [4,1], stride = [1,0] + %2 = tt.broadcast %1 : (tensor<4x1xi32>) -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [1,0] + %arg2splat = tt.splat %arg2 : (i32) -> tensor<4x256xi32> + %offset2 = arith.addi %2, %arg2splat : tensor<4x256xi32> + // offset = [%arg2,0], size = [4,256], stride = [1,0] + %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + // offset = 0, size = 256, stride = 1 + %4 = tt.expand_dims %3 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + // offset = [0,0], size = [1,256], stride = [0,1] + %5 = tt.broadcast %4 : (tensor<1x256xi32>) -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,1] + %c6 = arith.constant 6 : i32 + %splat6 = tt.splat %c6 : (i32) -> tensor<4x256xi32> + %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,6] + %7 = arith.addi %offset2, %scale5: tensor<4x256xi32> + // offset = [%arg2, 0], size = [4, 256], stride = [1, 6] + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<4x256x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: arg0, offset = [%arg2, 0], size = [4, 256], stride = [1, 6] + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<4x256x!tt.ptr> + %11 = tt.addptr %10, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: arg1, offset = [%arg2, 0], size = [4, 256], stride = [1, 6] + %12 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256xbf16> + tt.store %11, %12 : tensor<4x256xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32) { +// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index +// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [4, 256], strides: [1, [[CST_6_]]{{.}}, offsets: {{.}}[[VAR_0_]], 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [4, 256], strides: [1, [[CST_6_]]{{.}}, offsets: {{.}}[[VAR_2_]], 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK-DAG: [[VAR_4_:%.+]] = "tts.load"([[VAR_1_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<4x256x!tt.ptr>) -> tensor<4x256xbf16> +// CHECK: "tts.store"([[VAR_3_]], [[VAR_4_]]) <{static_dims = array}> : (tensor<4x256x!tt.ptr>, tensor<4x256xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_mul_const_const.mlir b/test/Conversion/TritonToStructured/addptr_mul_const_const.mlir new file mode 100644 index 00000000..f9320289 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_mul_const_const.mlir @@ -0,0 +1,50 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) + { + %0 = tt.get_program_id x : i32 + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32}:tensor<1024xi32> + %2 = tt.splat %0 : (i32) -> tensor<1024xi32> + %3 = arith.addi %2, %1 : tensor<1024xi32> + //%3: splat(%0) + range(0, 1024) + //%3: offset = %0, size = 1024, stride = 1 + // vector and scalar are both constant + %4 = tt.make_range {end = 3072 : i32, start = 2048 : i32}:tensor<1024xi32> + %c10 = arith.constant 10 : i32 + %5 = tt.splat %c10 : (i32) -> tensor<1024xi32> + %6 = arith.muli %5, %4 : tensor<1024xi32> + //%6: splat(%c10)*range(2048, 4096); + //%6: offset = %c10*2048, size = 1024, stride = %c10*1 + %7 = arith.addi %3, %6 : tensor<1024xi32> + //%7: offset = %c10*2048 + %0, size = 1024, stride = %c10*1+1 + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + //source=%arg0 offset = %c10*2048 + pid0, size = 1024, stride = %c10*1+1 + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + //source=%arg1, offset = pid0, size = 1024, stride = 1 + %16 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xbf16> + tt.store %11, %16 : tensor<1024xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32) { +// CHECK-DAG: [[CST_11_:%.+]] = arith.constant 11 : index +// CHECK-DAG: [[CST_20480_:%.+]] = arith.constant 20480 : index +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index +// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_2_]], [[CST_20480_]] : index +// CHECK-DAG: [[VAR_4_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [1024], strides: {{.}}[[CST_11_]]{{.}}, offsets: {{.}}[[VAR_3_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK-DAG: [[VAR_5_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_1_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK: [[VAR_6_:%.+]] = "tts.load"([[VAR_4_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<1024x!tt.ptr>) -> tensor<1024xbf16> +// CHECK: "tts.store"([[VAR_5_]], [[VAR_6_]]) <{static_dims = array}> : (tensor<1024x!tt.ptr>, tensor<1024xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_mul_value_const.mlir b/test/Conversion/TritonToStructured/addptr_mul_value_const.mlir new file mode 100644 index 00000000..1cdce25a --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_mul_value_const.mlir @@ -0,0 +1,53 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) + { + %0 = tt.get_program_id x : i32 + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32}:tensor<1024xi32> + %2 = tt.splat %0 : (i32) -> tensor<1024xi32> + %3 = arith.addi %2, %1 : tensor<1024xi32> + //%3: splat(%0) + range(0, 1024) + //%3: offset = %0, size = 1024, stride = 1 + // vector is constant, scalar is value + %4 = tt.make_range {end = 3072 : i32, start = 2048 : i32}:tensor<1024xi32> + %5 = tt.splat %arg2 : (i32) -> tensor<1024xi32> + %6 = arith.muli %5, %4 : tensor<1024xi32> + //%6: splat(%arg2)*range(2048, 3072); + //%6: offset = %arg2*2048, size = 1024, stride = %arg2*1 + %7 = arith.addi %3, %6 : tensor<1024xi32> + //%7: offset = %arg2*2048 + %0, size = 1024, stride = %arg2*1+1 + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + //source=%arg0: offset = %arg2*2048 + pid0, size = 1024, stride = %arg2*1+1 + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + //source=%arg1: offset = pid0, size = 1024, stride = 1 + %16 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xbf16> + tt.store %11, %16 : tensor<1024xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32) { +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_2048_:%.+]] = arith.constant 2048 : index +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_4_:%.+]] = arith.muli [[VAR_3_]], [[CST_2048_]] : index +// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_2_]], [[VAR_4_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[VAR_3_]], [[CST_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [1024], strides: {{.}}[[VAR_6_]]{{.}}, offsets: {{.}}[[VAR_5_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK-DAG: [[VAR_8_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_1_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK: [[VAR_9_:%.+]] = "tts.load"([[VAR_7_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<1024x!tt.ptr>) -> tensor<1024xbf16> +// CHECK: "tts.store"([[VAR_8_]], [[VAR_9_]]) <{static_dims = array}> : (tensor<1024x!tt.ptr>, tensor<1024xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_nested.mlir b/test/Conversion/TritonToStructured/addptr_nested.mlir new file mode 100644 index 00000000..fecdbe28 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_nested.mlir @@ -0,0 +1,62 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : i32 + ) + { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + // offset = 0, size = 4, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + // offset = [0,0], size = [4,1], stride = [1,0] + %2 = tt.broadcast %1 : (tensor<4x1xi32>) -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [1,0] + %arg1splat = tt.splat %arg1 : (i32) -> tensor<4x256xi32> + %offset3 = arith.addi %2, %arg1splat : tensor<4x256xi32> + // offset = [%arg1,0], size = [4,256], stride = [1,0] + %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + // offset = 0, size = 256, stride = 1 + %4 = tt.expand_dims %3 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + // offset = [0,0], size = [1,256], stride = [0,1] + %5 = tt.broadcast %4 : (tensor<1x256xi32>) -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,1] + %6 = arith.constant 5 : i32 + %splat6 = tt.splat %6 : (i32) -> tensor<4x256xi32> + %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,5] + %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> + // offset = [%arg1, 0], size = [4, 256], stride = [1, 5] + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<4x256x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg0, offset = [%arg1, 0], size = [4, 256], stride = [1, 5] + %10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256xbf16> + %12 = tt.addptr %9, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg0, offset = [%arg1+%arg1, 0], size = [4, 256], stride = [2, 10] + %13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256xbf16> + %14 = arith.addf %10, %13 : tensor<4x256xbf16> + %16 = tt.addptr %12, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg0, offset = [%arg1+%arg1+%arg1, 0], size = [4, 256], stride = [3, 15] + tt.store %16, %14 : tensor<4x256xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: i32) { +// CHECK-DAG: [[CST_15_:%.+]] = arith.constant 15 : index +// CHECK-DAG: [[CST_10_:%.+]] = arith.constant 10 : index +// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : index +// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_1_]] : i32 to index +// CHECK: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [4, 256], strides: [1, [[CST_5_]]{{.}}, offsets: {{.}}[[VAR_0_]], 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK-DAG: [[VAR_2_:%.+]] = "tts.load"([[VAR_1_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<4x256x!tt.ptr>) -> tensor<4x256xbf16> +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_1_]] : i32 to index +// CHECK: [[VAR_4_:%.+]] = arith.addi [[VAR_0_]], [[VAR_3_]] : index +// CHECK: [[VAR_5_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [4, 256], strides: [2, [[CST_10_]]{{.}}, offsets: {{.}}[[VAR_4_]], 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK: [[VAR_6_:%.+]] = "tts.load"([[VAR_5_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<4x256x!tt.ptr>) -> tensor<4x256xbf16> +// CHECK-DAG: [[VAR_7_:%.+]] = arith.addf [[VAR_2_]], [[VAR_6_]] : tensor<4x256xbf16> +// CHECK-DAG: [[VAR_8_:%.+]] = arith.index_cast [[PARAM_1_]] : i32 to index +// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_4_]], [[VAR_8_]] : index +// CHECK: [[VAR_10_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [4, 256], strides: [3, [[CST_15_]]{{.}}, offsets: {{.}}[[VAR_9_]], 0], parent_sizes: [0, 0] : to tensor<4x256x!tt.ptr> +// CHECK: "tts.store"([[VAR_10_]], [[VAR_7_]]) <{static_dims = array}> : (tensor<4x256x!tt.ptr>, tensor<4x256xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_reshape_broadcast.mlir b/test/Conversion/TritonToStructured/addptr_reshape_broadcast.mlir new file mode 100644 index 00000000..a5cbd85b --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_reshape_broadcast.mlir @@ -0,0 +1,43 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +// TODO: expand this example to 3D +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr + ) + { + %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> + // offset = [512] size = 256, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<256xi32>) -> tensor<256x1xi32> + // offset = [512,0], size = [256,1], stride = [1,0] + %2 = tt.broadcast %1 : (tensor<256x1xi32>) -> tensor<256x128xi32> + // offset = [512,0], size = [256,128], stride = [1,0] + %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> + // offset = 1024, size = 128, stride = 1 + %6 = tt.expand_dims %5 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> + // offset = [0,1024], size = [1,128], stride = [0,1] + %7 = tt.broadcast %6 : (tensor<1x128xi32>) -> tensor<256x128xi32> + // offset = [0,1024], size = [256,128], stride = [0,1] + %c6 = arith.constant 6 : i32 + %splat6 = tt.splat %c6 : (i32) -> tensor<256x128xi32> + %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> + // offset = [0,6144], size = [256,128], stride = [0,6] + %14 = arith.addi %2, %scale7 : tensor<256x128xi32> + // offset = [512,6144], size = [256,128], stride = [1,6] + %17 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x128x!tt.ptr> + %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> + %19 = tt.load %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x128xbf16> + tt.store %18, %19 : tensor<256x128xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) { +// CHECK-DAG: [[CST_6144_:%.+]] = arith.constant 6144 : index +// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index +// CHECK: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [256, 128], strides: [1, [[CST_6_]]{{.}}, offsets: [512, [[CST_6144_]]{{.}}, parent_sizes: [0, 0] : to tensor<256x128x!tt.ptr> +// CHECK: [[VAR_1_:%.+]] = "tts.load"([[VAR_0_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x128x!tt.ptr>) -> tensor<256x128xbf16> +// CHECK: "tts.store"([[VAR_0_]], [[VAR_1_]]) <{static_dims = array}> : (tensor<256x128x!tt.ptr>, tensor<256x128xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_scalar_broadcast.mlir b/test/Conversion/TritonToStructured/addptr_scalar_broadcast.mlir new file mode 100644 index 00000000..113faefb --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_scalar_broadcast.mlir @@ -0,0 +1,61 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + // source = arg1, offset = %1, size = 1, strides = 0 + %3 = tt.splat %2 : (!tt.ptr) -> tensor<1024x!tt.ptr> + // source = arg1, offset = %1, size = 1024, strides = 0 + %4 = tt.expand_dims %3 {axis = 1 : i32} : (tensor<1024x!tt.ptr>) -> tensor<1024x1x!tt.ptr> + // source = arg1, offset = [%1, 0], size = [1024, 1], strides = [0, 0] + %5 = tt.broadcast %4 : (tensor<1024x1x!tt.ptr>) -> tensor<1024x1024x!tt.ptr> + // source = arg1, offset = [%1, 0], size = [1024, 1024], strides = [0, 0] + %6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // offset = 0, size = 1024, strides = 1 + %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<1024xi32>) -> tensor<1x1024xi32> + // offset = [0, 0], size = [1, 1024], strides = [0, 1] + %8 = tt.broadcast %7 : (tensor<1x1024xi32>) -> tensor<1024x1024xi32> + // offset = [0, 0], size = [1024, 1024], strides = [0, 1] + %9 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // offset = 0, size = 1024, strides = 1 + %10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<1024xi32>) -> tensor<1024x1xi32> + // offset = [0, 0], size = [1024, 1], strides = [1, 0] + %11 = tt.broadcast %10 : (tensor<1024x1xi32>) -> tensor<1024x1024xi32> + // offset = [0, 0], size = [1024, 1024], strides = [1, 0] + %12 = arith.addi %8, %11 : tensor<1024x1024xi32> + // offset = [0, 0], size = [1024, 1024], strides = [1, 1] + %13 = tt.addptr %5, %12 : tensor<1024x1024x!tt.ptr>, tensor<1024x1024xi32> + // source = arg1, offset = [pid * %arg2, 0], size = [1024, 1024], strides = [1, 1] + %14 = tt.load %13 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024x1024xf32> + %17 = math.exp %14 : tensor<1024x1024xf32> + %18 = arith.muli %0, %arg3 : i32 + %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 + // source = arg0, offset = pid+arg3, size = 1, strides = 0 + %20 = tt.splat %19 : (!tt.ptr) -> tensor<1024x!tt.ptr> + // source = arg0, offset = pid+arg3, size = 1024, strides = 0 + %21 = tt.expand_dims %20 {axis = 1 : i32} : (tensor<1024x!tt.ptr>) -> tensor<1024x1x!tt.ptr> + // source = arg0, offset = [pid+arg3, 0], size = [1024, 1], strides = [0, 0] + %22 = tt.broadcast %21 : (tensor<1024x1x!tt.ptr>) -> tensor<1024x1024x!tt.ptr> + // source = arg0, offset = [pid+arg3, 0], size = [1024, 1024], strides = [0, 0] + %23 = tt.addptr %22, %12 : tensor<1024x1024x!tt.ptr>, tensor<1024x1024xi32> + // source = arg0, offset = [pid+arg3, 0], size = [1024, 1024], strides = [1, 1] + tt.store %23, %17 : tensor<1024x1024xf32> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { +// CHECK: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_2_]] : i32 +// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK: [[VAR_3_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [1024, 1024], strides: [1, 1], offsets: {{.}}[[VAR_2_]], 0], parent_sizes: [0, 0] : to tensor<1024x1024x!tt.ptr> +// CHECK: [[VAR_4_:%.+]] = "tts.load"([[VAR_3_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<1024x1024x!tt.ptr>) -> tensor<1024x1024xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = math.exp [[VAR_4_]] : tensor<1024x1024xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : i32 to index +// CHECK: [[VAR_8_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [1024, 1024], strides: [1, 1], offsets: {{.}}[[VAR_7_]], 0], parent_sizes: [0, 0] : to tensor<1024x1024x!tt.ptr> +// CHECK: "tts.store"([[VAR_8_]], [[VAR_5_]]) <{static_dims = array}> : (tensor<1024x1024x!tt.ptr>, tensor<1024x1024xf32>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_scalar_for.mlir b/test/Conversion/TritonToStructured/addptr_scalar_for.mlir new file mode 100644 index 00000000..e88ada11 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_scalar_for.mlir @@ -0,0 +1,59 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + // source = %arg1, offset = %1, size = 1, strides = 0 + %cf0 = arith.constant 0.000000e+00 : f32 + %tensor_cf0 = tt.splat %cf0 : (f32) -> tensor<1024xf32> + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %_ptr, %sum_out = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr_iter = %2, %sum_iter = %tensor_cf0) -> (!tt.ptr, tensor<1024xf32>) { + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // offset = 0, size = 1024, strides = 1 + %4 = tt.splat %ptr_iter : (!tt.ptr) -> tensor<1024x!tt.ptr> + // source = %arg1, offset = %1, size = 1024, strides = 0 + %5 = tt.addptr %4, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // source = %arg1, offset = %1, size = 1024, strides = 1 + %8 = tt.load %5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32> + %9 = math.exp %8 : tensor<1024xf32> + %sum_next = arith.addf %sum_iter, %9 : tensor<1024xf32> + %cast_i = arith.index_cast %i : index to i32 + %ptr_next = tt.addptr %ptr_iter, %cast_i : !tt.ptr, i32 + // source = %arg1, offset = %1 + %i, size = 1, strides = 0 + scf.yield %ptr_next, %sum_next : !tt.ptr, tensor<1024xf32> + } + %10 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %18 = arith.muli %0, %arg3 : i32 + %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 + %20 = tt.splat %19 : (!tt.ptr) -> tensor<1024x!tt.ptr> + %21 = tt.addptr %20, %10 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %21, %sum_out : tensor<1024xf32> + tt.return + } +} +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : tensor<1024xf32> +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_2_]] : i32 +// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_cst_]], [[VAR_arg7_:%.+]] = [[VAR_2_]]) -> (tensor<1024xf32>, index) { +// CHECK-DAG: [[VAR_7_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_arg7_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK: [[VAR_8_:%.+]] = "tts.load"([[VAR_7_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<1024x!tt.ptr>) -> tensor<1024xf32> +// CHECK: [[VAR_9_:%.+]] = math.exp [[VAR_8_]] : tensor<1024xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_9_]] : tensor<1024xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_arg5_]] : index +// CHECK: scf.yield [[VAR_10_]], [[VAR_11_]] : tensor<1024xf32>, index +// CHECK: } +// CHECK: [[VAR_4_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[VAR_4_]] : i32 to index +// CHECK: [[VAR_6_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_5_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK: "tts.store"([[VAR_6_]], [[VAR_3_]]#0) <{static_dims = array}> : (tensor<1024x!tt.ptr>, tensor<1024xf32>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_scalar_for_2d.mlir b/test/Conversion/TritonToStructured/addptr_scalar_for_2d.mlir new file mode 100644 index 00000000..a44494b3 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_scalar_for_2d.mlir @@ -0,0 +1,81 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + %cf0 = arith.constant 0.000000e+00 : f32 + %tensor_cf0 = tt.splat %cf0 : (f32) -> tensor<128x128xf32> + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %tensor_cf0, %ptr_iter = %2) -> (tensor<128x128xf32>, !tt.ptr ) { + %3 = tt.splat %ptr_iter : (!tt.ptr) -> tensor<128x128x!tt.ptr> + // source = %arg1, offset = [%1, 0], size = [128, 128], strides = [0, 0] + %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> + %6 = tt.broadcast %5 : (tensor<1x128xi32>) -> tensor<128x128xi32> + // offset = [0, 0], size = [128, 128], strides = [0, 1] + %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> + %8 = tt.expand_dims %7 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> + %9 = tt.broadcast %8 : (tensor<128x1xi32>) -> tensor<128x128xi32> + // offset = [128, 0], size = [128, 128], strides = [1, 0] + %10 = arith.addi %6, %9 : tensor<128x128xi32> + // offset = [128, 0], size = [128, 128], strides = [1, 1] + %11 = tt.addptr %3, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // source = %arg1, offset = [%1 + 128, 0], size = [128, 128], strides = [1, 1] + %12 = tt.load %11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> + %17 = math.exp %12 : tensor<128x128xf32> + %sum_next = arith.addf %sum_iter, %17 : tensor<128x128xf32> + %cast_i = arith.index_cast %i : index to i32 + %ptr_next = tt.addptr %ptr_iter, %cast_i : !tt.ptr, i32 + // source = %arg1, offset = %1 + %i, size = 1, strides = 0 + scf.yield %sum_next, %ptr_next : tensor<128x128xf32>, !tt.ptr + } + %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> + %6 = tt.broadcast %5 : (tensor<1x128xi32>) -> tensor<128x128xi32> + // offset = [0, 0], size = [128, 128], strides = [0, 1] + %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> + %8 = tt.expand_dims %7 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> + %9 = tt.broadcast %8 : (tensor<128x1xi32>) -> tensor<128x128xi32> + // offset = [128, 0], size = [128, 128], strides = [1, 0] + %10 = arith.addi %6, %9 : tensor<128x128xi32> + // offset = [128, 0], size = [128, 128], strides = [1, 1] + %18 = arith.muli %0, %arg3 : i32 + %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 + // source = arg0, offset = %18, size = 1, strides = 0 + %20 = tt.splat %19 : (!tt.ptr) -> tensor<128x128x!tt.ptr> + // source = arg0, offset = [%18, 0], size = [128, 128], strides = [0, 0] + %21 = tt.addptr %20, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // source = %arg0, offset = [%18 + 128, 0], size = [128, 128], strides = [1, 1] + tt.store %21, %sum_out : tensor<128x128xf32> + tt.return + } +} +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32> +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_2_]] : i32 +// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_cst_]], [[VAR_arg7_:%.+]] = [[VAR_2_]]) -> (tensor<128x128xf32>, index) { +// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_128_]] : index +// CHECK: [[VAR_9_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [128, 128], strides: [1, 1], offsets: {{.}}[[VAR_8_]], 0], parent_sizes: [0, 0] : to tensor<128x128x!tt.ptr> +// CHECK: [[VAR_10_:%.+]] = "tts.load"([[VAR_9_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<128x128x!tt.ptr>) -> tensor<128x128xf32> +// CHECK: [[VAR_11_:%.+]] = math.exp [[VAR_10_]] : tensor<128x128xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_11_]] : tensor<128x128xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_arg5_]] : index +// CHECK: scf.yield [[VAR_12_]], [[VAR_13_]] : tensor<128x128xf32>, index +// CHECK: } +// CHECK: [[VAR_4_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[VAR_4_]] : i32 to index +// CHECK: [[VAR_6_:%.+]] = arith.addi [[VAR_5_]], [[CST_128_]] : index +// CHECK: [[VAR_7_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128, 128], strides: [1, 1], offsets: {{.}}[[VAR_6_]], 0], parent_sizes: [0, 0] : to tensor<128x128x!tt.ptr> +// CHECK: "tts.store"([[VAR_7_]], [[VAR_3_]]#0) <{static_dims = array}> : (tensor<128x128x!tt.ptr>, tensor<128x128xf32>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_scalar_loopback.mlir b/test/Conversion/TritonToStructured/addptr_scalar_loopback.mlir new file mode 100644 index 00000000..de5248aa --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_scalar_loopback.mlir @@ -0,0 +1,23 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) { + %0 = tt.addptr %arg0, %arg2 : !tt.ptr, i32 + %1 = tt.addptr %arg1, %arg2 : !tt.ptr, i32 + %10 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: bf16 + tt.store %1, %10 : bf16 + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32) { +// CHECK-DAG: [[VAR_0_:%.+]] = tt.addptr [[PARAM_0_]], [[PARAM_2_]] : !tt.ptr, i32 +// CHECK-DAG: [[VAR_1_:%.+]] = tt.addptr [[PARAM_1_]], [[PARAM_2_]] : !tt.ptr, i32 +// CHECK: [[LOAD_VAR_0_MEM_:%.+]] = tt.load [[VAR_0_]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : bf16 +// CHECK: tt.store [[VAR_1_]], [[LOAD_VAR_0_MEM_]] {cache = 1 : i32, evict = 1 : i32} : bf16 +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_scalar_nested.mlir b/test/Conversion/TritonToStructured/addptr_scalar_nested.mlir new file mode 100644 index 00000000..3d166be7 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_scalar_nested.mlir @@ -0,0 +1,53 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + // source = arg1, offset = %1, size = 1, strides = 0 + %3 = arith.muli %0, %arg3 : i32 + %4 = tt.addptr %2, %3 : !tt.ptr, i32 + // source = arg1, offset = %1+%3, size = 1, strides = 0 + %5 = arith.muli %0, %arg4 : i32 + %6 = tt.addptr %4, %5 : !tt.ptr, i32 + // source = arg1, offset = %1+%3+%5, size = 1, strides = 0 + %7 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // offset = 0, size = 1024, strides = 1 + %8 = tt.splat %6 : (!tt.ptr) -> tensor<1024x!tt.ptr> + // source = arg1, offset = %1, size = 1024, strides = 0 + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // source = arg1, offset = %1+%3+%5, size = 1024, strides = 1 + %10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32> + %17 = math.exp %10 : tensor<1024xf32> + %18 = arith.muli %0, %arg3 : i32 + %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 + // source = arg0, offset = %18, size = 1, strides = 0 + %20 = tt.splat %19 : (!tt.ptr) -> tensor<1024x!tt.ptr> + // source = arg0, offset = %18, size = 1024, strides = 0 + %21 = tt.addptr %20, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // source = arg0, offset = %18, size = 1024, strides = 1 + tt.store %21, %17 : tensor<1024xf32> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { +// CHECK: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_2_]] : i32 +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_3_]] : i32 to index +// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_2_]], [[VAR_4_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_4_]] : i32 +// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : i32 to index +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_5_]], [[VAR_7_]] : index +// CHECK: [[VAR_9_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_8_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK: [[VAR_10_:%.+]] = "tts.load"([[VAR_9_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<1024x!tt.ptr>) -> tensor<1024xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = math.exp [[VAR_10_]] : tensor<1024xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[VAR_12_]] : i32 to index +// CHECK: [[VAR_14_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_13_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK: "tts.store"([[VAR_14_]], [[VAR_11_]]) <{static_dims = array}> : (tensor<1024x!tt.ptr>, tensor<1024xf32>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_scalar_splat.mlir b/test/Conversion/TritonToStructured/addptr_scalar_splat.mlir new file mode 100644 index 00000000..a405b8d0 --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_scalar_splat.mlir @@ -0,0 +1,41 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + // source = %arg1, offset = %1, size = 1, strides = 0 + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // offset = 0, size = 1024, strides = 1 + %4 = tt.splat %2 : (!tt.ptr) -> tensor<1024x!tt.ptr> + // source = %arg1, offset = %1, size = 1024, strides = 0 + %5 = tt.addptr %4, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // source = %arg1, offset = %1, size = 1024, strides = 1 + %8 = tt.load %5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32> + %17 = math.exp %8 : tensor<1024xf32> + %18 = arith.muli %0, %arg3 : i32 + %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 + // source = %arg0, offset = %18, size = 1, strides = 0 + %20 = tt.splat %19 : (!tt.ptr) -> tensor<1024x!tt.ptr> + // source = %arg0, offset = %18, size = 1024, strides = 0 + %21 = tt.addptr %20, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // source = %arg0, offset = %18, size = 1024, strides = 1 + tt.store %21, %17 : tensor<1024xf32> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { +// CHECK: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_2_]] : i32 +// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK: [[VAR_3_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_2_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK: [[VAR_4_:%.+]] = "tts.load"([[VAR_3_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<1024x!tt.ptr>) -> tensor<1024xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = math.exp [[VAR_4_]] : tensor<1024xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : i32 to index +// CHECK: [[VAR_8_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_7_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK: "tts.store"([[VAR_8_]], [[VAR_5_]]) <{static_dims = array}> : (tensor<1024x!tt.ptr>, tensor<1024xf32>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/addptr_scalar_splat_2d.mlir b/test/Conversion/TritonToStructured/addptr_scalar_splat_2d.mlir new file mode 100644 index 00000000..8eedeb1d --- /dev/null +++ b/test/Conversion/TritonToStructured/addptr_scalar_splat_2d.mlir @@ -0,0 +1,52 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + %3 = tt.splat %2 : (!tt.ptr) -> tensor<128x128x!tt.ptr> + // source = %arg1, offset = [%1, 0], size = [128, 128], strides = [0, 0] + %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> + %6 = tt.broadcast %5 : (tensor<1x128xi32>) -> tensor<128x128xi32> + // offset = [0, 0], size = [128, 128], strides = [0, 1] + %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> + // offset = 128, size = 128, strides = 1 + %8 = tt.expand_dims %7 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> + %9 = tt.broadcast %8 : (tensor<128x1xi32>) -> tensor<128x128xi32> + // offset = [128, 0], size = [128, 128], strides = [1, 0] + %10 = arith.addi %6, %9 : tensor<128x128xi32> + // offset = [128, 0], size = [128, 128], strides = [1, 1] + %11 = tt.addptr %3, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // source = %arg1, offset = [%1 + 128, 0], size = [128, 128], strides = [1, 1] + %12 = tt.load %11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> + %17 = math.exp %12 : tensor<128x128xf32> + %18 = arith.muli %0, %arg3 : i32 + %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 + // source = arg0, offset = %18, size = 1, strides = 0 + %20 = tt.splat %19 : (!tt.ptr) -> tensor<128x128x!tt.ptr> + // source = arg0, offset = [%18, 0], size = [128, 128], strides = [0, 0] + %21 = tt.addptr %20, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // source = %arg0, offset = [%18 + 128, 0], size = [128, 128], strides = [1, 1] + tt.store %21, %17 : tensor<128x128xf32> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_2_]] : i32 +// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_2_]], [[CST_128_]] : index +// CHECK: [[VAR_4_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [128, 128], strides: [1, 1], offsets: {{.}}[[VAR_3_]], 0], parent_sizes: [0, 0] : to tensor<128x128x!tt.ptr> +// CHECK: [[VAR_5_:%.+]] = "tts.load"([[VAR_4_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<128x128x!tt.ptr>) -> tensor<128x128xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = math.exp [[VAR_5_]] : tensor<128x128xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_7_]] : i32 to index +// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_8_]], [[CST_128_]] : index +// CHECK: [[VAR_10_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128, 128], strides: [1, 1], offsets: {{.}}[[VAR_9_]], 0], parent_sizes: [0, 0] : to tensor<128x128x!tt.ptr> +// CHECK: "tts.store"([[VAR_10_]], [[VAR_6_]]) <{static_dims = array}> : (tensor<128x128x!tt.ptr>, tensor<128x128xf32>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/arith_not_ptr_arith.mlir b/test/Conversion/TritonToStructured/arith_not_ptr_arith.mlir new file mode 100644 index 00000000..5ae126e3 --- /dev/null +++ b/test/Conversion/TritonToStructured/arith_not_ptr_arith.mlir @@ -0,0 +1,33 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %a : !tt.ptr, + %b : !tt.ptr + ) -> () { + // offset calculations + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // a pointer + %8 = tt.splat %a : (!tt.ptr) -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // b pointer + %18 = tt.splat %b : (!tt.ptr) -> tensor<1024x!tt.ptr> + %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %am = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xi32> + %bm = tt.load %19 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xi32> + %5 = arith.addi %am, %bm : tensor<1024xi32> + tt.store %19, %5 : tensor<1024xi32> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) { +// CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [1024], strides: [1], offsets: [0], parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [1024], strides: [1], offsets: [0], parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = "tts.load"([[VAR_0_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<1024x!tt.ptr>) -> tensor<1024xi32> +// CHECK-DAG: [[VAR_3_:%.+]] = "tts.load"([[VAR_1_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<1024x!tt.ptr>) -> tensor<1024xi32> +// CHECK: [[VAR_4_:%.+]] = arith.addi [[VAR_2_]], [[VAR_3_]] : tensor<1024xi32> +// CHECK: "tts.store"([[VAR_1_]], [[VAR_4_]]) <{static_dims = array}> : (tensor<1024x!tt.ptr>, tensor<1024xi32>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/bitcast.mlir b/test/Conversion/TritonToStructured/bitcast.mlir new file mode 100644 index 00000000..cfec5d33 --- /dev/null +++ b/test/Conversion/TritonToStructured/bitcast.mlir @@ -0,0 +1,34 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel(%a : !tt.ptr, %b : !tt.ptr) -> () { + // offset calculations + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + + // a pointer + %8 = tt.splat %a : (!tt.ptr) -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + + // b pointer + %18 = tt.splat %b : (!tt.ptr) -> tensor<1024x!tt.ptr> + %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + + %am = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xi32> + + // cast result before doing float add + %am_bitcast = tt.bitcast %am : tensor<1024xi32> -> tensor<1024xf32> + + + tt.store %19, %am_bitcast : tensor<1024xf32> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) { +// CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [1024], strides: [1], offsets: [0], parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [1024], strides: [1], offsets: [0], parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK: [[VAR_2_:%.+]] = "tts.load"([[VAR_0_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<1024x!tt.ptr>) -> tensor<1024xi32> +// CHECK: [[VAR_3_:%.+]] = tt.bitcast [[VAR_2_]] : tensor<1024xi32> -> tensor<1024xf32> +// CHECK: "tts.store"([[VAR_1_]], [[VAR_3_]]) <{static_dims = array}> : (tensor<1024x!tt.ptr>, tensor<1024xf32>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/block_ptr_advance.mlir b/test/Conversion/TritonToStructured/block_ptr_advance.mlir new file mode 100644 index 00000000..b2be3cd9 --- /dev/null +++ b/test/Conversion/TritonToStructured/block_ptr_advance.mlir @@ -0,0 +1,62 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func public @matmul_kernel_with_block_pointers_01234567891011(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32) { + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 0.000000e+00 : bf16 + %c256_i32 = arith.constant 256 : i32 + %0 = arith.extsi %arg3 : i32 to i64 + %1 = arith.extsi %arg5 : i32 to i64 + %2 = arith.extsi %arg6 : i32 to i64 + %3 = arith.extsi %arg7 : i32 to i64 + %4 = tt.make_tensor_ptr %arg0, [%0, %1], [%2, %3], [%arg12, %c0_i32] {order = array} : > + %5 = tt.advance %4, [%c0_i32, %c64_i32] : > + %6 = tt.splat %cst : (bf16) -> tensor<128x64xbf16> + %7:3 = scf.for %arg14 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg15 = %6, %arg16 = %5, %arg17 = %4) -> (tensor<128x64xbf16>, !tt.ptr>, !tt.ptr>) : i32 { + %13 = tt.load %arg16 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> -> tensor<128x64xbf16> + %14 = tt.load %arg17 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> -> tensor<128x64xbf16> + %15 = arith.addf %13, %14 : tensor<128x64xbf16> + %16 = arith.addf %arg15, %15 : tensor<128x64xbf16> + %17 = tt.advance %arg16, [%c0_i32, %c64_i32] : > + %18 = tt.advance %arg17, [%c64_i32, %c0_i32] : > + scf.yield %16, %17, %18 : tensor<128x64xbf16>, !tt.ptr>, !tt.ptr> + } + %8 = arith.extsi %arg10 : i32 to i64 + %9 = arith.extsi %arg11 : i32 to i64 + %10 = arith.extsi %arg4 : i32 to i64 + %11 = arith.muli %arg13, %c256_i32 : i32 + %12 = tt.make_tensor_ptr %arg2, [%0, %10], [%8, %9], [%arg12, %11] {order = array} : > + tt.store %12, %7#0 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr>, tensor<128x64xbf16> + tt.return + } +} + +// CHECK: tt.func public @matmul_kernel_with_block_pointers_01234567891011([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> +// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = arith.extsi [[PARAM_3_]] : i32 to i64 +// CHECK-DAG: [[VAR_1_:%.+]] = arith.extsi [[PARAM_5_]] : i32 to i64 +// CHECK-DAG: [[VAR_2_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64 +// CHECK-DAG: [[VAR_3_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64 +// CHECK: [[VAR_4_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{.}}[[VAR_0_]], [[VAR_1_]]{{.}}, {{.}}[[VAR_2_]], [[VAR_3_]]{{.}}, {{.}}[[PARAM_12_]], [[CST_0_]]{{.}} {order = array} : , 1> +// CHECK: [[VAR_5_:%.+]] = tt.advance [[VAR_4_]], {{.}}[[CST_0_]], [[CST_64_]]{{.}} : , 1> +// CHECK-DAG: [[VAR_6_:%.+]]:3 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_]] to [[PARAM_5_]] step [[CST_64_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_cst_]], [[VAR_arg16_:%.+]] = [[VAR_5_]], [[VAR_arg17_:%.+]] = [[VAR_4_]]) -> (tensor<128x64xbf16>, !tt.ptr, 1>, !tt.ptr, 1>) : i32 { +// CHECK-DAG: [[LOAD_VAR_arg16_MEM_:%.+]] = tt.load [[VAR_arg16_]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<128x64xbf16> +// CHECK-DAG: [[LOAD_VAR_arg17_MEM_:%.+]] = tt.load [[VAR_arg17_]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<128x64xbf16> +// CHECK: [[VAR_14_:%.+]] = arith.addf [[LOAD_VAR_arg16_MEM_]], [[LOAD_VAR_arg17_MEM_]] : tensor<128x64xbf16> +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addf [[VAR_arg15_]], [[VAR_14_]] : tensor<128x64xbf16> +// CHECK-DAG: [[VAR_16_:%.+]] = tt.advance [[VAR_arg16_]], {{.}}[[CST_0_]], [[CST_64_]]{{.}} : , 1> +// CHECK-DAG: [[VAR_17_:%.+]] = tt.advance [[VAR_arg17_]], {{.}}[[CST_64_]], [[CST_0_]]{{.}} : , 1> +// CHECK: scf.yield [[VAR_15_]], [[VAR_16_]], [[VAR_17_]] : tensor<128x64xbf16>, !tt.ptr, 1>, !tt.ptr, 1> +// CHECK: } +// CHECK-DAG: [[VAR_7_:%.+]] = arith.extsi [[PARAM_10_]] : i32 to i64 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.extsi [[PARAM_11_]] : i32 to i64 +// CHECK-DAG: [[VAR_9_:%.+]] = arith.extsi [[PARAM_4_]] : i32 to i64 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.muli [[PARAM_13_]], [[CST_256_]] : i32 +// CHECK: [[VAR_11_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{.}}[[VAR_0_]], [[VAR_9_]]{{.}}, {{.}}[[VAR_7_]], [[VAR_8_]]{{.}}, {{.}}[[PARAM_12_]], [[VAR_10_]]{{.}} {order = array} : , 1> +// CHECK: tt.store [[VAR_11_]], [[VAR_6_]]#0 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128x64xbf16> +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/dot.mlir b/test/Conversion/TritonToStructured/dot.mlir new file mode 100644 index 00000000..79b96da1 --- /dev/null +++ b/test/Conversion/TritonToStructured/dot.mlir @@ -0,0 +1,66 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : !tt.ptr + ) + { + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %c64 = arith.constant 128 : i32 + %1 = tt.splat %c64 : (i32) -> tensor<128xi32> + %2 = arith.muli %0, %1 : tensor<128xi32> + %3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> + %4 = tt.broadcast %3 : (tensor<128x1xi32>) -> tensor<128x64xi32> + %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %6 = tt.expand_dims %5 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32> + %7 = tt.broadcast %6 : (tensor<1x64xi32>) -> tensor<128x64xi32> + %8 = arith.addi %4, %7 : tensor<128x64xi32> + %10 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %11 = tt.expand_dims %10 {axis = 1 : i32} : (tensor<256xi32>) -> tensor<256x1xi32> + %12 = tt.broadcast %11 : (tensor<256x1xi32>) -> tensor<256x64xi32> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %c256 = arith.constant 256 : i32 + %14 = tt.splat %c256 : (i32) -> tensor<64xi32> + %15 = arith.muli %13, %14 : tensor<64xi32> + %16 = tt.expand_dims %15 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32> + %17 = tt.broadcast %16 : (tensor<1x64xi32>) -> tensor<256x64xi32> + %18 = arith.addi %12, %17 : tensor<256x64xi32> + %20 = tt.splat %c256 : (i32) -> tensor<128xi32> + %21 = arith.muli %0, %20 : tensor<128xi32> + %22 = tt.expand_dims %21 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> + %23 = tt.broadcast %22 : (tensor<128x1xi32>) -> tensor<128x256xi32> + %24 = tt.expand_dims %10 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + %25 = tt.broadcast %24 {axis = 0 : i32} : (tensor<1x256xi32>) -> tensor<128x256xi32> + %26 = arith.addi %23, %25 : tensor<128x256xi32> + %30 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x64x!tt.ptr> + %31 = tt.addptr %30, %8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %32 = tt.load %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<128x64xbf16> + %40 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x64x!tt.ptr> + %41 = tt.addptr %40, %18 : tensor<256x64x!tt.ptr>, tensor<256x64xi32> + %42 = tt.load %41 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x64xbf16> + %43 = tt.trans %42 : (tensor<256x64xbf16>) -> tensor<64x256xbf16> + %50 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x256x!tt.ptr> + %51 = tt.addptr %50, %26 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %52 = tt.load %51 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<128x256xbf16> + %60 = tt.dot %32, %43, %52 {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xbf16> + tt.store %51, %60 : tensor<128x256xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr) { +// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128, 64], strides: {{.}}[[CST_128_]], 1], offsets: [0, 0], parent_sizes: [0, 0] : to tensor<128x64x!tt.ptr> +// CHECK-DAG: [[VAR_1_:%.+]] = "tts.load"([[VAR_0_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<128x64x!tt.ptr>) -> tensor<128x64xbf16> +// CHECK-DAG: [[VAR_2_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [256, 64], strides: [1, [[CST_256_]]{{.}}, offsets: [0, 0], parent_sizes: [0, 0] : to tensor<256x64x!tt.ptr> +// CHECK: [[VAR_3_:%.+]] = "tts.load"([[VAR_2_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x64x!tt.ptr>) -> tensor<256x64xbf16> +// CHECK-DAG: [[VAR_4_:%.+]] = tt.trans [[VAR_3_]] : (tensor<256x64xbf16>) -> tensor<64x256xbf16> +// CHECK-DAG: [[VAR_5_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [128, 256], strides: {{.}}[[CST_256_]], 1], offsets: [0, 0], parent_sizes: [0, 0] : to tensor<128x256x!tt.ptr> +// CHECK: [[VAR_6_:%.+]] = "tts.load"([[VAR_5_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<128x256x!tt.ptr>) -> tensor<128x256xbf16> +// CHECK: [[VAR_7_:%.+]] = tt.dot [[VAR_1_]], [[VAR_4_]], [[VAR_6_]] {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xbf16> +// CHECK: "tts.store"([[VAR_5_]], [[VAR_7_]]) <{static_dims = array}> : (tensor<128x256x!tt.ptr>, tensor<128x256xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/kernel-01-vector-add.mlir b/test/Conversion/TritonToStructured/kernel-01-vector-add.mlir new file mode 100644 index 00000000..4aa5cf31 --- /dev/null +++ b/test/Conversion/TritonToStructured/kernel-01-vector-add.mlir @@ -0,0 +1,62 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func public @add_kernel_01234(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : (i32) -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> + %7 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32> + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32> + %13 = arith.addf %9, %12 : tensor<1024xf32> + %14 = tt.splat %arg2 : (!tt.ptr) -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %15, %13, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf32> + tt.return + } +} + +// CHECK: tt.func public @add_kernel_01234([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32) { +// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index +// CHECK-DAG: [[CST_1024_1_:%.+]] = arith.constant 1024 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[CST_1024_1_]] : i32 +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_5_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_4_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK-DAG: [[VAR_6_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_6_]], [[CST_1024_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_9_:%.+]] = arith.minsi [[VAR_7_]], [[VAR_8_]] : index +// CHECK: [[VAR_10_:%.+]] = arith.subi [[VAR_9_]], [[VAR_6_]] : index +// CHECK-DAG: [[VAR_11_:%.+]] = "tts.load"([[VAR_5_]], [[VAR_10_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<1024x!tt.ptr>, index) -> tensor<1024xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_3_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK-DAG: [[VAR_13_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_14_:%.+]] = arith.addi [[VAR_13_]], [[CST_1024_]] : index +// CHECK-DAG: [[VAR_15_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_16_:%.+]] = arith.minsi [[VAR_14_]], [[VAR_15_]] : index +// CHECK: [[VAR_17_:%.+]] = arith.subi [[VAR_16_]], [[VAR_13_]] : index +// CHECK: [[VAR_18_:%.+]] = "tts.load"([[VAR_12_]], [[VAR_17_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<1024x!tt.ptr>, index) -> tensor<1024xf32> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_11_]], [[VAR_18_]] : tensor<1024xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_2_]]{{.}}, parent_sizes: [0] : to tensor<1024x!tt.ptr> +// CHECK-DAG: [[VAR_21_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_22_:%.+]] = arith.addi [[VAR_21_]], [[CST_1024_]] : index +// CHECK-DAG: [[VAR_23_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_24_:%.+]] = arith.minsi [[VAR_22_]], [[VAR_23_]] : index +// CHECK: [[VAR_25_:%.+]] = arith.subi [[VAR_24_]], [[VAR_21_]] : index +// CHECK: "tts.store"([[VAR_20_]], [[VAR_19_]], [[VAR_25_]]) <{static_dims = array}> : (tensor<1024x!tt.ptr>, tensor<1024xf32>, index) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/kernel-02-fused-softmax.mlir b/test/Conversion/TritonToStructured/kernel-02-fused-softmax.mlir new file mode 100644 index 00000000..c58f4f3b --- /dev/null +++ b/test/Conversion/TritonToStructured/kernel-02-fused-softmax.mlir @@ -0,0 +1,74 @@ +// RUN: triton-shared-opt --triton-to-structured --canonicalize %s | FileCheck %s + +module { + tt.func public @softmax_kernel_012345(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32) { + %cst = arith.constant 0xFF800000 : f32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %4 = tt.splat %2 : (!tt.ptr) -> tensor<128x!tt.ptr> + %5 = tt.addptr %4, %3 : tensor<128x!tt.ptr>, tensor<128xi32> + %6 = tt.splat %arg4 : (i32) -> tensor<128xi32> + %7 = arith.cmpi slt, %3, %6 : tensor<128xi32> + %8 = tt.splat %cst : (f32) -> tensor<128xf32> + %9 = tt.load %5, %7, %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32> + %10 = "tt.reduce"(%9) ({ + ^bb0(%arg5: f32, %arg6: f32): + %21 = arith.cmpf ogt, %arg5, %arg6 : f32 + %22 = arith.select %21, %arg5, %arg6 : f32 + tt.reduce.return %22 : f32 + }) {axis = 0 : i32} : (tensor<128xf32>) -> f32 + %11 = tt.splat %10 : (f32) -> tensor<128xf32> + %12 = arith.subf %9, %11 : tensor<128xf32> + %13 = math.exp %12 : tensor<128xf32> + %14 = "tt.reduce"(%13) ({ + ^bb0(%arg5: f32, %arg6: f32): + %21 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %21 : f32 + }) {axis = 0 : i32} : (tensor<128xf32>) -> f32 + %15 = tt.splat %14 : (f32) -> tensor<128xf32> + %16 = arith.divf %13, %15 : tensor<128xf32> + %17 = arith.muli %0, %arg3 : i32 + %18 = tt.addptr %arg0, %17 : !tt.ptr, i32 + %19 = tt.splat %18 : (!tt.ptr) -> tensor<128x!tt.ptr> + %20 = tt.addptr %19, %3 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %20, %16, %7 {cache = 1 : i32, evict = 1 : i32} : tensor<128xf32> + tt.return + } +} + +// CHECK: tt.func public @softmax_kernel_012345([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_2_]] : i32 +// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [128], strides: [1], offsets: {{.}}[[VAR_2_]]{{.}}, parent_sizes: [0] : to tensor<128x!tt.ptr> +// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_5_:%.+]] = arith.minsi [[VAR_4_]], [[CST_128_]] : index +// CHECK: [[VAR_6_:%.+]] = "tts.load"([[VAR_3_]], [[VAR_5_]], [[CST_0_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<128x!tt.ptr>, index, f32) -> tensor<128xf32> +// CHECK: [[VAR_7_:%.+]] = "tt.reduce"([[VAR_6_]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0([[arg5_:%.+]]: f32, [[arg6_:%.+]]: f32): +// CHECK: [[VAR_19_:%.+]] = arith.cmpf ogt, [[arg5_]], [[arg6_]] : f32 +// CHECK: [[VAR_20_:%.+]] = arith.select [[VAR_19_]], [[arg5_]], [[arg6_]] : f32 +// CHECK: tt.reduce.return [[VAR_20_]] : f32 +// CHECK: }) : (tensor<128xf32>) -> f32 +// CHECK: [[VAR_8_:%.+]] = tt.splat [[VAR_7_]] : (f32) -> tensor<128xf32> +// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : tensor<128xf32> +// CHECK: [[VAR_10_:%.+]] = math.exp [[VAR_9_]] : tensor<128xf32> +// CHECK: [[VAR_11_:%.+]] = "tt.reduce"([[VAR_10_]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0([[arg5_]]: f32, [[arg6_]]: f32): +// CHECK: [[VAR_19_1_:%.+]] = arith.addf [[arg5_]], [[arg6_]] : f32 +// CHECK: tt.reduce.return [[VAR_19_1_]] : f32 +// CHECK: }) : (tensor<128xf32>) -> f32 +// CHECK: [[VAR_12_:%.+]] = tt.splat [[VAR_11_]] : (f32) -> tensor<128xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = arith.divf [[VAR_10_]], [[VAR_12_]] : tensor<128xf32> +// CHECK-DAG: [[VAR_14_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index +// CHECK-DAG: [[VAR_16_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128], strides: [1], offsets: {{.}}[[VAR_15_]]{{.}}, parent_sizes: [0] : to tensor<128x!tt.ptr> +// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_18_:%.+]] = arith.minsi [[VAR_17_]], [[CST_128_]] : index +// CHECK: "tts.store"([[VAR_16_]], [[VAR_13_]], [[VAR_18_]]) <{static_dims = array}> : (tensor<128x!tt.ptr>, tensor<128xf32>, index) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/kernel-03-matrix-multiplication.mlir b/test/Conversion/TritonToStructured/kernel-03-matrix-multiplication.mlir new file mode 100644 index 00000000..512c8383 --- /dev/null +++ b/test/Conversion/TritonToStructured/kernel-03-matrix-multiplication.mlir @@ -0,0 +1,190 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func public @matmul_kernel_0123456789101112131415(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) { + %c63_i32 = arith.constant 63 : i32 + %c255_i32 = arith.constant 255 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %4, %c8_i32 : i32 + %8 = arith.divsi %0, %7 : i32 + %9 = arith.muli %8, %c8_i32 : i32 + %10 = arith.subi %2, %9 : i32 + %11 = arith.cmpi slt, %10, %c8_i32 : i32 + %12 = arith.select %11, %10, %c8_i32 : i32 + %13 = arith.remsi %0, %12 : i32 + %14 = arith.addi %9, %13 : i32 + %15 = arith.remsi %0, %7 : i32 + %16 = arith.divsi %15, %12 : i32 + %17 = arith.muli %14, %c128_i32 : i32 + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %19 = tt.splat %17 : (i32) -> tensor<128xi32> + %20 = arith.addi %19, %18 : tensor<128xi32> + %21 = arith.muli %16, %c256_i32 : i32 + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %23 = tt.splat %21 : (i32) -> tensor<256xi32> + %24 = arith.addi %23, %22 : tensor<256xi32> + %25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %26 = tt.expand_dims %20 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> + %27 = tt.splat %arg6 : (i32) -> tensor<128x1xi32> + %28 = arith.muli %26, %27 : tensor<128x1xi32> + %29 = tt.expand_dims %25 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32> + %30 = tt.splat %arg7 : (i32) -> tensor<1x64xi32> + %31 = arith.muli %29, %30 : tensor<1x64xi32> + %32 = tt.broadcast %28 : (tensor<128x1xi32>) -> tensor<128x64xi32> + %33 = tt.broadcast %31 : (tensor<1x64xi32>) -> tensor<128x64xi32> + %34 = arith.addi %32, %33 : tensor<128x64xi32> + %35 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x64x!tt.ptr> + %36 = tt.addptr %35, %34 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %37 = tt.expand_dims %25 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32> + %38 = tt.splat %arg8 : (i32) -> tensor<64x1xi32> + %39 = arith.muli %37, %38 : tensor<64x1xi32> + %40 = tt.expand_dims %24 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + %41 = tt.splat %arg9 : (i32) -> tensor<1x256xi32> + %42 = arith.muli %40, %41 : tensor<1x256xi32> + %43 = tt.broadcast %39 : (tensor<64x1xi32>) -> tensor<64x256xi32> + %44 = tt.broadcast %42 : (tensor<1x256xi32>) -> tensor<64x256xi32> + %45 = arith.addi %43, %44 : tensor<64x256xi32> + %46 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x256x!tt.ptr> + %47 = tt.addptr %46, %45 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + %48 = tt.splat %cst : (f32) -> tensor<128x256xf32> + %49 = arith.muli %arg7, %c64_i32 : i32 + %50 = tt.splat %49 : (i32) -> tensor<128x64xi32> + %51 = arith.muli %arg8, %c64_i32 : i32 + %52 = tt.splat %51 : (i32) -> tensor<64x256xi32> + %53:3 = scf.for %arg12 = %c0_i32 to %6 step %c1_i32 iter_args(%arg13 = %48, %arg14 = %36, %arg15 = %47) -> (tensor<128x256xf32>, tensor<128x64x!tt.ptr>, tensor<64x256x!tt.ptr>) : i32 { + %71 = tt.load %arg14 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xbf16> + %72 = tt.load %arg15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x256xbf16> + %73 = tt.dot %71, %72, %48 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32> + %74 = arith.addf %arg13, %73 : tensor<128x256xf32> + %75 = tt.addptr %arg14, %50 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %76 = tt.addptr %arg15, %52 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + scf.yield %74, %75, %76 : tensor<128x256xf32>, tensor<128x64x!tt.ptr>, tensor<64x256x!tt.ptr> + } + %54 = arith.truncf %53#0 : tensor<128x256xf32> to tensor<128x256xbf16> + %55 = tt.splat %arg10 : (i32) -> tensor<128x1xi32> + %56 = arith.muli %55, %26 : tensor<128x1xi32> + %57 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr> + %58 = tt.addptr %57, %56 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + %59 = tt.splat %arg11 : (i32) -> tensor<1x256xi32> + %60 = arith.muli %59, %40 : tensor<1x256xi32> + %61 = tt.broadcast %58 : (tensor<128x1x!tt.ptr>) -> tensor<128x256x!tt.ptr> + %62 = tt.broadcast %60 : (tensor<1x256xi32>) -> tensor<128x256xi32> + %63 = tt.addptr %61, %62 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %64 = tt.splat %arg3 : (i32) -> tensor<128x1xi32> + %65 = arith.cmpi slt, %26, %64 : tensor<128x1xi32> + %66 = tt.splat %arg4 : (i32) -> tensor<1x256xi32> + %67 = arith.cmpi slt, %40, %66 : tensor<1x256xi32> + %68 = tt.broadcast %65 : (tensor<128x1xi1>) -> tensor<128x256xi1> + %69 = tt.broadcast %67 : (tensor<1x256xi1>) -> tensor<128x256xi1> + %70 = arith.andi %68, %69 : tensor<128x256xi1> + tt.store %63, %54, %70 {cache = 1 : i32, evict = 1 : i32} : tensor<128x256xbf16> + tt.return + } +} + +// CHECK: tt.func public @matmul_kernel_0123456789101112131415([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32> +// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_63_:%.+]] = arith.constant 63 : i32 +// CHECK-DAG: [[CST_255_:%.+]] = arith.constant 255 : i32 +// CHECK-DAG: [[CST_127_:%.+]] = arith.constant 127 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : i32 +// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 +// CHECK-DAG: [[CST_128_1_:%.+]] = arith.constant 128 : i32 +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.addi [[PARAM_3_]], [[CST_127_]] : i32 +// CHECK-DAG: [[VAR_2_:%.+]] = arith.divsi [[VAR_1_]], [[CST_128_1_]] : i32 +// CHECK-DAG: [[VAR_3_:%.+]] = arith.addi [[PARAM_4_]], [[CST_255_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = arith.divsi [[VAR_3_]], [[CST_256_1_]] : i32 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[PARAM_5_]], [[CST_63_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = arith.divsi [[VAR_5_]], [[CST_64_]] : i32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.muli [[VAR_4_]], [[CST_8_]] : i32 +// CHECK: [[VAR_8_:%.+]] = arith.divsi [[VAR_0_]], [[VAR_7_]] : i32 +// CHECK: [[VAR_9_:%.+]] = arith.muli [[VAR_8_]], [[CST_8_]] : i32 +// CHECK: [[VAR_10_:%.+]] = arith.subi [[VAR_2_]], [[VAR_9_]] : i32 +// CHECK: [[VAR_11_:%.+]] = arith.cmpi slt, [[VAR_10_]], [[CST_8_]] : i32 +// CHECK: [[VAR_12_:%.+]] = arith.select [[VAR_11_]], [[VAR_10_]], [[CST_8_]] : i32 +// CHECK: [[VAR_13_:%.+]] = arith.remsi [[VAR_0_]], [[VAR_12_]] : i32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.addi [[VAR_9_]], [[VAR_13_]] : i32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.remsi [[VAR_0_]], [[VAR_7_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = arith.divsi [[VAR_15_]], [[VAR_12_]] : i32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.muli [[VAR_14_]], [[CST_128_1_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[VAR_17_]] : i32 to index +// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[VAR_17_]] : i32 to index +// CHECK-DAG: [[VAR_20_:%.+]] = arith.muli [[VAR_16_]], [[CST_256_1_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_21_:%.+]] = arith.index_cast [[VAR_20_]] : i32 to index +// CHECK-DAG: [[VAR_22_:%.+]] = arith.index_cast [[VAR_20_]] : i32 to index +// CHECK-DAG: [[VAR_23_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[VAR_19_]], [[VAR_23_]] : index +// CHECK-DAG: [[VAR_25_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK-DAG: [[VAR_26_:%.+]] = arith.index_cast [[PARAM_8_]] : i32 to index +// CHECK-DAG: [[VAR_27_:%.+]] = arith.index_cast [[PARAM_9_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[VAR_22_]], [[VAR_27_]] : index +// CHECK-DAG: [[VAR_29_:%.+]] = arith.muli [[PARAM_7_]], [[CST_64_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_30_:%.+]] = arith.index_cast [[VAR_29_]] : i32 to index +// CHECK-DAG: [[VAR_31_:%.+]] = arith.muli [[PARAM_8_]], [[CST_64_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_32_:%.+]] = arith.index_cast [[VAR_31_]] : i32 to index +// CHECK-DAG: [[VAR_33_:%.+]]:3 = scf.for [[VAR_arg12_:%.+]] = [[CST_0_1_]] to [[VAR_6_]] step [[CST_1_]] iter_args([[VAR_arg13_:%.+]] = [[VAR_cst_]], [[VAR_arg14_:%.+]] = [[VAR_24_]], [[VAR_arg15_:%.+]] = [[CST_0_]]) -> (tensor<128x256xf32>, index, index) : i32 { +// CHECK-DAG: [[VAR_52_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [64, 256], strides: {{.}}[[VAR_26_]], [[VAR_27_]]{{.}}, offsets: {{.}}[[PARAM_1_]]5, [[VAR_28_]]{{.}}, parent_sizes: [0, 0] : to tensor<64x256x!tt.ptr> +// CHECK-DAG: [[VAR_53_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128, 64], strides: {{.}}[[VAR_23_]], [[VAR_25_]]{{.}}, offsets: {{.}}[[VAR_arg14_]], [[CST_0_]]{{.}}, parent_sizes: [0, 0] : to tensor<128x64x!tt.ptr> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_54_:%.+]] = "tts.load"([[VAR_53_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<128x64x!tt.ptr>) -> tensor<128x64xbf16> +// CHECK-DAG: [[VAR_55_:%.+]] = "tts.load"([[VAR_52_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<64x256x!tt.ptr>) -> tensor<64x256xbf16> +// CHECK: [[VAR_56_:%.+]] = tt.dot [[VAR_54_]], [[VAR_55_]], [[VAR_cst_]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32> +// CHECK-DAG: [[VAR_57_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_56_]] : tensor<128x256xf32> +// CHECK-DAG: [[VAR_58_:%.+]] = arith.addi [[VAR_arg14_]], [[VAR_30_]] : index +// CHECK-DAG: [[VAR_59_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_32_]] : index +// CHECK: scf.yield [[VAR_57_]], [[VAR_58_]], [[VAR_59_]] : tensor<128x256xf32>, index, index +// CHECK: } +// CHECK-DAG: [[VAR_34_:%.+]] = arith.truncf [[VAR_33_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16> +// CHECK-DAG: [[VAR_35_:%.+]] = arith.index_cast [[PARAM_10_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_36_:%.+]] = arith.muli [[VAR_18_]], [[VAR_35_]] : index +// CHECK-DAG: [[VAR_37_:%.+]] = arith.index_cast [[PARAM_11_]] : i32 to index +// CHECK: [[VAR_38_:%.+]] = arith.muli [[VAR_21_]], [[VAR_37_]] : index +// CHECK-DAG: [[VAR_39_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [128, 256], strides: {{.}}[[VAR_35_]], [[VAR_37_]]{{.}}, offsets: {{.}}[[VAR_36_]], [[VAR_38_]]{{.}}, parent_sizes: [0, 0] : to tensor<128x256x!tt.ptr> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.index_cast [[VAR_17_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_41_:%.+]] = arith.addi [[VAR_40_]], [[CST_128_]] : index +// CHECK-DAG: [[VAR_42_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_43_:%.+]] = arith.minsi [[VAR_41_]], [[VAR_42_]] : index +// CHECK-DAG: [[VAR_44_:%.+]] = arith.subi [[VAR_43_]], [[VAR_40_]] : index +// CHECK-DAG: [[VAR_45_:%.+]] = arith.index_cast [[VAR_20_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_46_:%.+]] = arith.addi [[VAR_45_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_47_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_48_:%.+]] = arith.minsi [[VAR_46_]], [[VAR_47_]] : index +// CHECK-DAG: [[VAR_49_:%.+]] = arith.subi [[VAR_48_]], [[VAR_45_]] : index +// CHECK-DAG: [[VAR_50_:%.+]] = arith.minsi [[VAR_44_]], [[CST_128_]] : index +// CHECK: [[VAR_51_:%.+]] = arith.minsi [[VAR_49_]], [[CST_256_]] : index +// CHECK: "tts.store"([[VAR_39_]], [[VAR_34_]], [[VAR_50_]], [[VAR_51_]]) <{static_dims = array}> : (tensor<128x256x!tt.ptr>, tensor<128x256xbf16>, index, index) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/kernel-05-layer-norm-dwdb.mlir b/test/Conversion/TritonToStructured/kernel-05-layer-norm-dwdb.mlir new file mode 100644 index 00000000..b28bcb4c --- /dev/null +++ b/test/Conversion/TritonToStructured/kernel-05-layer-norm-dwdb.mlir @@ -0,0 +1,145 @@ +// RUN: triton-shared-opt --triton-to-structured --canonicalize %s | FileCheck %s + +module { + tt.func public @_layer_norm_bwd_dwdb_0123456(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: !tt.ptr, %arg4: i32, %arg5: i32) { + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %3 = tt.splat %1 : (i32) -> tensor<256xi32> + %4 = arith.addi %3, %2 : tensor<256xi32> + %5 = tt.splat %cst : (f32) -> tensor<256x256xf32> + %6 = tt.splat %arg4 : (i32) -> tensor<256x1xi32> + %7 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + %8 = tt.splat %arg5 : (i32) -> tensor<1x256xi32> + %9 = arith.cmpi slt, %7, %8 : tensor<1x256xi32> + %10 = tt.broadcast %9 : (tensor<1x256xi1>) -> tensor<256x256xi1> + %11 = tt.splat %arg5 : (i32) -> tensor<256x1xi32> + %12 = tt.broadcast %7 : (tensor<1x256xi32>) -> tensor<256x256xi32> + %13 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x256x!tt.ptr> + %14 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x256x!tt.ptr> + %15:2 = scf.for %arg6 = %c0_i32 to %arg4 step %c256_i32 iter_args(%arg7 = %5, %arg8 = %5) -> (tensor<256x256xf32>, tensor<256x256xf32>) : i32 { + %24 = tt.splat %arg6 : (i32) -> tensor<256xi32> + %25 = arith.addi %24, %2 : tensor<256xi32> + %26 = tt.expand_dims %25 {axis = 1 : i32} : (tensor<256xi32>) -> tensor<256x1xi32> + %27 = arith.cmpi slt, %26, %6 : tensor<256x1xi32> + %28 = tt.broadcast %27 : (tensor<256x1xi1>) -> tensor<256x256xi1> + %29 = arith.andi %28, %10 : tensor<256x256xi1> + %30 = arith.muli %26, %11 : tensor<256x1xi32> + %31 = tt.broadcast %30 : (tensor<256x1xi32>) -> tensor<256x256xi32> + %32 = arith.addi %31, %12 : tensor<256x256xi32> + %33 = tt.addptr %13, %32 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> + %34 = tt.load %33, %29, %5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x256xf32> + %35 = arith.addf %arg7, %34 : tensor<256x256xf32> + %36 = tt.addptr %14, %32 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> + %37 = tt.load %36, %29, %5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x256xf32> + %38 = arith.addf %arg8, %37 : tensor<256x256xf32> + scf.yield %35, %38 : tensor<256x256xf32>, tensor<256x256xf32> + } + %16 = "tt.reduce"(%15#0) ({ + ^bb0(%arg6: f32, %arg7: f32): + %24 = arith.addf %arg6, %arg7 : f32 + tt.reduce.return %24 : f32 + }) {axis = 0 : i32} : (tensor<256x256xf32>) -> tensor<256xf32> + %17 = "tt.reduce"(%15#1) ({ + ^bb0(%arg6: f32, %arg7: f32): + %24 = arith.addf %arg6, %arg7 : f32 + tt.reduce.return %24 : f32 + }) {axis = 0 : i32} : (tensor<256x256xf32>) -> tensor<256xf32> + %18 = tt.splat %arg5 : (i32) -> tensor<256xi32> + %19 = arith.cmpi slt, %4, %18 : tensor<256xi32> + %20 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr> + %21 = tt.addptr %20, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %21, %16, %19 {cache = 1 : i32, evict = 1 : i32} : tensor<256xf32> + %22 = tt.splat %arg3 : (!tt.ptr) -> tensor<256x!tt.ptr> + %23 = tt.addptr %22, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %23, %17, %19 {cache = 1 : i32, evict = 1 : i32} : tensor<256xf32> + tt.return + } +} + +// CHECK: tt.func public @_layer_norm_bwd_dwdb_0123456([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: !tt.ptr, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : tensor<256x256xf32> +// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[CST_256_1_]] : i32 +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_6_:%.+]]:2 = scf.for [[VAR_arg6_:%.+]] = [[CST_0_]] to [[PARAM_4_]] step [[CST_256_1_]] iter_args([[VAR_arg7_:%.+]] = [[VAR_cst_]], [[VAR_arg8_:%.+]] = [[VAR_cst_]]) -> (tensor<256x256xf32>, tensor<256x256xf32>) : i32 { +// CHECK-DAG: [[VAR_21_:%.+]] = arith.index_cast [[VAR_arg6_]] : i32 to index +// CHECK-DAG: [[VAR_22_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_23_:%.+]] = arith.muli [[VAR_21_]], [[VAR_22_]] : index +// CHECK-DAG: [[VAR_24_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256, 256], strides: {{.}}[[VAR_22_]], 1], offsets: {{.}}[[VAR_23_]], [[VAR_5_]]{{.}}, parent_sizes: [0, 0] : to tensor<256x256x!tt.ptr> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.index_cast [[VAR_arg6_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_26_:%.+]] = arith.addi [[VAR_25_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_27_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_28_:%.+]] = arith.minsi [[VAR_26_]], [[VAR_27_]] : index +// CHECK-DAG: [[VAR_29_:%.+]] = arith.subi [[VAR_28_]], [[VAR_25_]] : index +// CHECK-DAG: [[VAR_30_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_31_:%.+]] = arith.addi [[VAR_30_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_32_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_33_:%.+]] = arith.minsi [[VAR_31_]], [[VAR_32_]] : index +// CHECK-DAG: [[VAR_34_:%.+]] = arith.subi [[VAR_33_]], [[VAR_30_]] : index +// CHECK-DAG: [[VAR_35_:%.+]] = arith.minsi [[VAR_29_]], [[CST_256_]] : index +// CHECK: [[VAR_36_:%.+]] = arith.minsi [[VAR_34_]], [[CST_256_]] : index +// CHECK: [[VAR_37_:%.+]] = "tts.load"([[VAR_24_]], [[VAR_35_]], [[VAR_36_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x256x!tt.ptr>, index, index, f32) -> tensor<256x256xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[VAR_arg7_]], [[VAR_37_]] : tensor<256x256xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.index_cast [[VAR_arg6_]] : i32 to index +// CHECK-DAG: [[VAR_40_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_41_:%.+]] = arith.muli [[VAR_39_]], [[VAR_40_]] : index +// CHECK-DAG: [[VAR_42_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [256, 256], strides: {{.}}[[VAR_40_]], 1], offsets: {{.}}[[VAR_41_]], [[VAR_4_]]{{.}}, parent_sizes: [0, 0] : to tensor<256x256x!tt.ptr> +// CHECK-DAG: [[VAR_43_:%.+]] = arith.index_cast [[VAR_arg6_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_44_:%.+]] = arith.addi [[VAR_43_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_45_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_46_:%.+]] = arith.minsi [[VAR_44_]], [[VAR_45_]] : index +// CHECK-DAG: [[VAR_47_:%.+]] = arith.subi [[VAR_46_]], [[VAR_43_]] : index +// CHECK-DAG: [[VAR_48_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_49_:%.+]] = arith.addi [[VAR_48_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_50_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_51_:%.+]] = arith.minsi [[VAR_49_]], [[VAR_50_]] : index +// CHECK-DAG: [[VAR_52_:%.+]] = arith.subi [[VAR_51_]], [[VAR_48_]] : index +// CHECK-DAG: [[VAR_53_:%.+]] = arith.minsi [[VAR_47_]], [[CST_256_]] : index +// CHECK: [[VAR_54_:%.+]] = arith.minsi [[VAR_52_]], [[CST_256_]] : index +// CHECK: [[VAR_55_:%.+]] = "tts.load"([[VAR_42_]], [[VAR_53_]], [[VAR_54_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x256x!tt.ptr>, index, index, f32) -> tensor<256x256xf32> +// CHECK: [[VAR_56_:%.+]] = arith.addf [[VAR_arg8_]], [[VAR_55_]] : tensor<256x256xf32> +// CHECK: scf.yield [[VAR_38_]], [[VAR_56_]] : tensor<256x256xf32>, tensor<256x256xf32> +// CHECK: } +// CHECK: [[VAR_7_:%.+]] = "tt.reduce"([[VAR_6_]]#0) <{axis = 0 : i32}> ({ +// CHECK: ^bb0([[VAR_arg6_]]: f32, [[VAR_arg7_]]: f32): +// CHECK: [[VAR_21_1_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_arg7_]] : f32 +// CHECK: tt.reduce.return [[VAR_21_1_]] : f32 +// CHECK: }) : (tensor<256x256xf32>) -> tensor<256xf32> +// CHECK: [[VAR_8_:%.+]] = "tt.reduce"([[VAR_6_]]#1) <{axis = 0 : i32}> ({ +// CHECK: ^bb0([[VAR_arg6_]]: f32, [[VAR_arg7_]]: f32): +// CHECK: [[VAR_21_2_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_arg7_]] : f32 +// CHECK: tt.reduce.return [[VAR_21_2_]] : f32 +// CHECK: }) : (tensor<256x256xf32>) -> tensor<256xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_3_]]{{.}}, parent_sizes: [0] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_10_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_10_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_13_:%.+]] = arith.minsi [[VAR_11_]], [[VAR_12_]] : index +// CHECK: [[VAR_14_:%.+]] = arith.subi [[VAR_13_]], [[VAR_10_]] : index +// CHECK: "tts.store"([[VAR_9_]], [[VAR_7_]], [[VAR_14_]]) <{static_dims = array}> : (tensor<256x!tt.ptr>, tensor<256xf32>, index) -> () +// CHECK-DAG: [[VAR_15_:%.+]] = tts.make_tptr [[PARAM_3_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_2_]]{{.}}, parent_sizes: [0] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_17_:%.+]] = arith.addi [[VAR_16_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_19_:%.+]] = arith.minsi [[VAR_17_]], [[VAR_18_]] : index +// CHECK: [[VAR_20_:%.+]] = arith.subi [[VAR_19_]], [[VAR_16_]] : index +// CHECK: "tts.store"([[VAR_15_]], [[VAR_8_]], [[VAR_20_]]) <{static_dims = array}> : (tensor<256x!tt.ptr>, tensor<256xf32>, index) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/kernel-05-layer-norm-fwd.mlir b/test/Conversion/TritonToStructured/kernel-05-layer-norm-fwd.mlir new file mode 100644 index 00000000..5412aefa --- /dev/null +++ b/test/Conversion/TritonToStructured/kernel-05-layer-norm-fwd.mlir @@ -0,0 +1,208 @@ +// RUN: triton-shared-opt --triton-to-structured --canonicalize %s | FileCheck %s + +module { + tt.func public @_layer_norm_fwd_fused_0123456789(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: !tt.ptr, %arg4: !tt.ptr, %arg5: !tt.ptr, %arg6: i32, %arg7: i32, %arg8: f32) { + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg6 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = tt.splat %cst_0 : (f32) -> tensor<256xf32> + %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %6 = tt.splat %arg7 : (i32) -> tensor<256xi32> + %7 = tt.splat %3 : (!tt.ptr) -> tensor<256x!tt.ptr> + %8 = scf.for %arg9 = %c0_i32 to %arg7 step %c256_i32 iter_args(%arg10 = %4) -> (tensor<256xf32>) : i32 { + %32 = tt.splat %arg9 : (i32) -> tensor<256xi32> + %33 = arith.addi %32, %5 : tensor<256xi32> + %34 = arith.cmpi slt, %33, %6 : tensor<256xi32> + %35 = tt.addptr %7, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %36 = tt.load %35, %34, %4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32> + %37 = arith.addf %arg10, %36 : tensor<256xf32> + scf.yield %37 : tensor<256xf32> + } + %9 = "tt.reduce"(%8) ({ + ^bb0(%arg9: f32, %arg10: f32): + %32 = arith.addf %arg9, %arg10 : f32 + tt.reduce.return %32 : f32 + }) {axis = 0 : i32} : (tensor<256xf32>) -> f32 + %10 = arith.sitofp %arg7 : i32 to f32 + %11 = arith.divf %9, %10 : f32 + %12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %13 = tt.splat %arg7 : (i32) -> tensor<256xi32> + %14 = tt.splat %3 : (!tt.ptr) -> tensor<256x!tt.ptr> + %15 = tt.splat %11 : (f32) -> tensor<256xf32> + %16 = scf.for %arg9 = %c0_i32 to %arg7 step %c256_i32 iter_args(%arg10 = %4) -> (tensor<256xf32>) : i32 { + %32 = tt.splat %arg9 : (i32) -> tensor<256xi32> + %33 = arith.addi %32, %12 : tensor<256xi32> + %34 = arith.cmpi slt, %33, %13 : tensor<256xi32> + %35 = tt.addptr %14, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %36 = tt.load %35, %34, %4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32> + %37 = arith.subf %36, %15 : tensor<256xf32> + %38 = arith.select %34, %37, %4 : tensor<256xi1>, tensor<256xf32> + %39 = arith.mulf %38, %38 : tensor<256xf32> + %40 = arith.addf %arg10, %39 : tensor<256xf32> + scf.yield %40 : tensor<256xf32> + } + %17 = "tt.reduce"(%16) ({ + ^bb0(%arg9: f32, %arg10: f32): + %32 = arith.addf %arg9, %arg10 : f32 + tt.reduce.return %32 : f32 + }) {axis = 0 : i32} : (tensor<256xf32>) -> f32 + %18 = arith.divf %17, %10 : f32 + %19 = arith.addf %18, %arg8 : f32 + %20 = math.sqrt %19 : f32 + %21 = arith.divf %cst, %20 : f32 + %22 = tt.addptr %arg4, %0 : !tt.ptr, i32 + tt.store %22, %11 {cache = 1 : i32, evict = 1 : i32} : f32 + %23 = tt.addptr %arg5, %0 : !tt.ptr, i32 + tt.store %23, %21 {cache = 1 : i32, evict = 1 : i32} : f32 + %24 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %25 = tt.splat %arg7 : (i32) -> tensor<256xi32> + %26 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr> + %27 = tt.splat %arg3 : (!tt.ptr) -> tensor<256x!tt.ptr> + %28 = tt.splat %3 : (!tt.ptr) -> tensor<256x!tt.ptr> + %29 = tt.splat %11 : (f32) -> tensor<256xf32> + %30 = tt.splat %21 : (f32) -> tensor<256xf32> + %31 = tt.splat %2 : (!tt.ptr) -> tensor<256x!tt.ptr> + scf.for %arg9 = %c0_i32 to %arg7 step %c256_i32 : i32 { + %32 = tt.splat %arg9 : (i32) -> tensor<256xi32> + %33 = arith.addi %32, %24 : tensor<256xi32> + %34 = arith.cmpi slt, %33, %25 : tensor<256xi32> + %35 = tt.addptr %26, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %36 = tt.load %35, %34 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32> + %37 = tt.addptr %27, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %38 = tt.load %37, %34 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32> + %39 = tt.addptr %28, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %40 = tt.load %39, %34, %4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32> + %41 = arith.subf %40, %29 : tensor<256xf32> + %42 = arith.mulf %41, %30 : tensor<256xf32> + %43 = arith.mulf %42, %36 : tensor<256xf32> + %44 = arith.addf %43, %38 : tensor<256xf32> + %45 = tt.addptr %31, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %45, %44, %34 {cache = 1 : i32, evict = 1 : i32} : tensor<256xf32> + } + tt.return + } +} + +// CHECK: tt.func public @_layer_norm_fwd_fused_0123456789([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: !tt.ptr, [[PARAM_4_:%.+]]: !tt.ptr, [[PARAM_5_:%.+]]: !tt.ptr, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: f32) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : tensor<256xf32> +// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_6_]] : i32 +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_4_:%.+]] = scf.for [[VAR_arg9_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] iter_args([[VAR_arg10_:%.+]] = [[VAR_cst_]]) -> (tensor<256xf32>) : i32 { +// CHECK-DAG: [[VAR_21_:%.+]] = arith.index_cast [[VAR_arg9_]] : i32 to index +// CHECK: [[VAR_22_:%.+]] = arith.addi [[VAR_2_]], [[VAR_2_]]1 : index +// CHECK-DAG: [[VAR_23_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_22_]]{{.}}, parent_sizes: [0] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_24_:%.+]] = arith.index_cast [[VAR_arg9_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addi [[VAR_24_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_26_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: [[VAR_27_:%.+]] = arith.minsi [[VAR_25_]], [[VAR_26_]] : index +// CHECK: [[VAR_28_:%.+]] = arith.subi [[VAR_27_]], [[VAR_24_]] : index +// CHECK: [[VAR_29_:%.+]] = "tts.load"([[VAR_23_]], [[VAR_28_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x!tt.ptr>, index, f32) -> tensor<256xf32> +// CHECK: [[VAR_30_:%.+]] = arith.addf [[VAR_arg10_]], [[VAR_29_]] : tensor<256xf32> +// CHECK: scf.yield [[VAR_30_]] : tensor<256xf32> +// CHECK: } +// CHECK: [[VAR_5_:%.+]] = "tt.reduce"([[VAR_4_]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0([[VAR_arg9_]]: f32, [[VAR_arg10_]]: f32): +// CHECK: [[VAR_21_1_:%.+]] = arith.addf [[VAR_arg9_]], [[VAR_arg10_]] : f32 +// CHECK: tt.reduce.return [[VAR_21_1_]] : f32 +// CHECK: }) : (tensor<256xf32>) -> f32 +// CHECK: [[VAR_6_:%.+]] = arith.sitofp [[PARAM_7_]] : i32 to f32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.divf [[VAR_5_]], [[VAR_6_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> +// CHECK-DAG: [[VAR_9_:%.+]] = tt.splat [[PARAM_7_]] : (i32) -> tensor<256xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_10_:%.+]] = tt.splat [[VAR_7_]] : (f32) -> tensor<256xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = scf.for [[VAR_arg9_1_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] iter_args([[VAR_arg10_1_:%.+]] = [[VAR_cst_]]) -> (tensor<256xf32>) : i32 { +// CHECK-DAG: [[VAR_21_2_:%.+]] = tt.splat [[VAR_arg9_1_]] : (i32) -> tensor<256xi32> +// CHECK: [[VAR_22_1_:%.+]] = arith.addi [[VAR_21_2_]], [[VAR_8_]] : tensor<256xi32> +// CHECK-DAG: [[VAR_23_1_:%.+]] = arith.cmpi slt, [[VAR_22_1_]], [[VAR_9_]] : tensor<256xi32> +// CHECK-DAG: [[VAR_24_1_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK: [[VAR_25_1_:%.+]] = arith.addi [[VAR_2_]], [[VAR_2_]]4 : index +// CHECK-DAG: [[VAR_26_1_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_25_1_]]{{.}}, parent_sizes: [0] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_27_1_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_28_1_:%.+]] = arith.addi [[VAR_27_1_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_29_1_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: [[VAR_30_1_:%.+]] = arith.minsi [[VAR_28_1_]], [[VAR_29_1_]] : index +// CHECK: [[VAR_31_:%.+]] = arith.subi [[VAR_30_1_]], [[VAR_27_1_]] : index +// CHECK: [[VAR_32_:%.+]] = "tts.load"([[VAR_26_1_]], [[VAR_31_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x!tt.ptr>, index, f32) -> tensor<256xf32> +// CHECK: [[VAR_33_:%.+]] = arith.subf [[VAR_32_]], [[VAR_10_]] : tensor<256xf32> +// CHECK: [[VAR_34_:%.+]] = arith.select [[VAR_23_1_]], [[VAR_33_]], [[VAR_cst_]] : tensor<256xi1>, tensor<256xf32> +// CHECK: [[VAR_35_:%.+]] = arith.mulf [[VAR_34_]], [[VAR_34_]] : tensor<256xf32> +// CHECK: [[VAR_36_:%.+]] = arith.addf [[VAR_arg10_1_]], [[VAR_35_]] : tensor<256xf32> +// CHECK: scf.yield [[VAR_36_]] : tensor<256xf32> +// CHECK: } +// CHECK: [[VAR_12_:%.+]] = "tt.reduce"([[VAR_11_]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0([[VAR_arg9_1_]]: f32, [[VAR_arg10_1_]]: f32): +// CHECK: [[VAR_21_3_:%.+]] = arith.addf [[VAR_arg9_1_]], [[VAR_arg10_1_]] : f32 +// CHECK: tt.reduce.return [[VAR_21_3_]] : f32 +// CHECK: }) : (tensor<256xf32>) -> f32 +// CHECK: [[VAR_13_:%.+]] = arith.divf [[VAR_12_]], [[VAR_6_]] : f32 +// CHECK: [[VAR_14_:%.+]] = arith.addf [[VAR_13_]], [[PARAM_8_]] : f32 +// CHECK: [[VAR_15_:%.+]] = math.sqrt [[VAR_14_]] : f32 +// CHECK-DAG: [[VAR_16_:%.+]] = arith.divf [[CST_1_dot_000000_]], [[VAR_15_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = tt.addptr [[PARAM_4_]], [[VAR_0_]] : !tt.ptr, i32 +// CHECK: tt.store [[VAR_17_]], [[VAR_7_]] {cache = 1 : i32, evict = 1 : i32} : f32 +// CHECK: [[VAR_18_:%.+]] = tt.addptr [[PARAM_5_]], [[VAR_0_]] : !tt.ptr, i32 +// CHECK: tt.store [[VAR_18_]], [[VAR_16_]] {cache = 1 : i32, evict = 1 : i32} : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = tt.splat [[VAR_7_]] : (f32) -> tensor<256xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = tt.splat [[VAR_16_]] : (f32) -> tensor<256xf32> +// CHECK: scf.for [[VAR_arg9_1_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] : i32 { +// CHECK: [[VAR_21_4_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK-DAG: [[VAR_22_2_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_21_4_]]{{.}}, parent_sizes: [0] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_23_2_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_24_2_:%.+]] = arith.addi [[VAR_23_2_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_25_2_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: [[VAR_26_2_:%.+]] = arith.minsi [[VAR_24_2_]], [[VAR_25_2_]] : index +// CHECK: [[VAR_27_2_:%.+]] = arith.subi [[VAR_26_2_]], [[VAR_23_2_]] : index +// CHECK-DAG: [[VAR_28_2_:%.+]] = "tts.load"([[VAR_22_2_]], [[VAR_27_2_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x!tt.ptr>, index) -> tensor<256xf32> +// CHECK-DAG: [[VAR_29_2_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_30_2_:%.+]] = tts.make_tptr [[PARAM_3_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_29_2_]]{{.}}, parent_sizes: [0] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_31_1_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_32_1_:%.+]] = arith.addi [[VAR_31_1_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_33_1_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: [[VAR_34_1_:%.+]] = arith.minsi [[VAR_32_1_]], [[VAR_33_1_]] : index +// CHECK: [[VAR_35_1_:%.+]] = arith.subi [[VAR_34_1_]], [[VAR_31_1_]] : index +// CHECK-DAG: [[VAR_36_1_:%.+]] = "tts.load"([[VAR_30_2_]], [[VAR_35_1_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x!tt.ptr>, index) -> tensor<256xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK: [[VAR_38_:%.+]] = arith.addi [[VAR_2_]], [[VAR_37_]] : index +// CHECK-DAG: [[VAR_39_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_38_]]{{.}}, parent_sizes: [0] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_41_:%.+]] = arith.addi [[VAR_40_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_42_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: [[VAR_43_:%.+]] = arith.minsi [[VAR_41_]], [[VAR_42_]] : index +// CHECK: [[VAR_44_:%.+]] = arith.subi [[VAR_43_]], [[VAR_40_]] : index +// CHECK: [[VAR_45_:%.+]] = "tts.load"([[VAR_39_]], [[VAR_44_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x!tt.ptr>, index, f32) -> tensor<256xf32> +// CHECK: [[VAR_46_:%.+]] = arith.subf [[VAR_45_]], [[VAR_19_]] : tensor<256xf32> +// CHECK: [[VAR_47_:%.+]] = arith.mulf [[VAR_46_]], [[VAR_20_]] : tensor<256xf32> +// CHECK: [[VAR_48_:%.+]] = arith.mulf [[VAR_47_]], [[VAR_28_2_]] : tensor<256xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.addf [[VAR_48_]], [[VAR_36_1_]] : tensor<256xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK: [[VAR_51_:%.+]] = arith.addi [[VAR_3_]], [[VAR_50_]] : index +// CHECK-DAG: [[VAR_52_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_51_]]{{.}}, parent_sizes: [0] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_53_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_54_:%.+]] = arith.addi [[VAR_53_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_55_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: [[VAR_56_:%.+]] = arith.minsi [[VAR_54_]], [[VAR_55_]] : index +// CHECK: [[VAR_57_:%.+]] = arith.subi [[VAR_56_]], [[VAR_53_]] : index +// CHECK: "tts.store"([[VAR_52_]], [[VAR_49_]], [[VAR_57_]]) <{static_dims = array}> : (tensor<256x!tt.ptr>, tensor<256xf32>, index) -> () +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/masked_ldst_1d.mlir b/test/Conversion/TritonToStructured/masked_ldst_1d.mlir new file mode 100644 index 00000000..ff9275eb --- /dev/null +++ b/test/Conversion/TritonToStructured/masked_ldst_1d.mlir @@ -0,0 +1,36 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) + { + %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x!tt.ptr> + %1 = tt.splat %arg1 : (!tt.ptr) -> tensor<128x!tt.ptr> + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %ldptr = tt.addptr %0, %2 : tensor<128x!tt.ptr>, tensor<128xi32> + %stptr = tt.addptr %1, %2 : tensor<128x!tt.ptr>, tensor<128xi32> + %nans = arith.constant dense<0xFF80> : tensor<128xbf16> + %5 = tt.splat %arg2 : (i32) -> tensor<128xi32> + %mask = arith.cmpi slt, %2, %5 : tensor<128xi32> + %buff = tt.load %ldptr, %mask, %nans {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xbf16> + tt.store %stptr, %buff, %mask : tensor<128xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF80 : bf16 +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128], strides: [1], offsets: [0], parent_sizes: [0] : to tensor<128x!tt.ptr> +// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [128], strides: [1], offsets: [0], parent_sizes: [0] : to tensor<128x!tt.ptr> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_3_:%.+]] = arith.minsi [[VAR_2_]], [[CST_128_]] : index +// CHECK-DAG: [[VAR_4_:%.+]] = "tts.load"([[VAR_0_]], [[VAR_3_]], [[CST_0_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<128x!tt.ptr>, index, bf16) -> tensor<128xbf16> +// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_6_:%.+]] = arith.minsi [[VAR_5_]], [[CST_128_]] : index +// CHECK: "tts.store"([[VAR_1_]], [[VAR_4_]], [[VAR_6_]]) <{static_dims = array}> : (tensor<128x!tt.ptr>, tensor<128xbf16>, index) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/masked_ldst_2d.mlir b/test/Conversion/TritonToStructured/masked_ldst_2d.mlir new file mode 100644 index 00000000..ea4696dc --- /dev/null +++ b/test/Conversion/TritonToStructured/masked_ldst_2d.mlir @@ -0,0 +1,97 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32, + %arg3 : i32 + ) + { + // Mimic a scenario where the raw pointer points to a buffer with dimension (1024, 1024) + // in row-major, but the actual tensor size is (arg2, arg3). + // We are trying to load a 128x256 sub-buffer starting at (2, 3). + // The resulting memref: + // offset = 3074 + // size[1] = 128 + // size[0] = 256 + // stride[0] = 1024 + // stride[1] = 1 + %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x256x!tt.ptr> + %1 = tt.splat %arg1 : (!tt.ptr) -> tensor<128x256x!tt.ptr> + // horizontal index + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %c2 = arith.constant 2 : i32 + %c2tensor = tt.splat %c2 : (i32) -> tensor<128xi32> + %offset2 = arith.addi %2, %c2tensor : tensor<128xi32> + %3 = tt.expand_dims %offset2 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> + %4 = tt.broadcast %3 : (tensor<128x1xi32>) -> tensor<128x256xi32> + // vertical index + %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %c3 = arith.constant 3 : i32 + %c3tensor = tt.splat %c3 : (i32) -> tensor<256xi32> + %offset5 = arith.addi %5, %c3tensor : tensor<256xi32> + %c1024 = arith.constant 1024 : i32 + %c1024tensor = tt.splat %c1024 : (i32) -> tensor<256xi32> + %scale5 = arith.muli %offset5, %c1024tensor : tensor<256xi32> + %6 = tt.expand_dims %scale5 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + %7 = tt.broadcast %6 : (tensor<1x256xi32>) -> tensor<128x256xi32> + // combined index + %index = arith.addi %4, %7 : tensor<128x256xi32> + %ldptr = tt.addptr %0, %index : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %stptr = tt.addptr %1, %index : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + // other value for masked load + %cnan = arith.constant 0xFF80 : bf16 + %nans = tt.splat %cnan : (bf16) -> tensor<128x256xbf16> + // horizontal mask + %8 = tt.splat %arg2 : (i32) -> tensor<128xi32> + %9 = arith.cmpi slt, %offset2, %8 : tensor<128xi32> + %10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<128xi1>) -> tensor<128x1xi1> + %11 = tt.broadcast %10 : (tensor<128x1xi1>) -> tensor<128x256xi1> + // vertical mask + %12 = tt.splat %arg3 : (i32) -> tensor<256xi32> + %13 = arith.cmpi slt, %offset5, %12 : tensor<256xi32> + %14 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<256xi1>) -> tensor<1x256xi1> + %15 = tt.broadcast %14 : (tensor<1x256xi1>) -> tensor<128x256xi1> + // combined mask + %mask = arith.andi %11, %15 : tensor<128x256xi1> + // dim0 = min(%arg2, 128), dim1 = min(%arg3, 256) + %buff = tt.load %ldptr, %mask, %nans {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x256xbf16> + tt.store %stptr, %buff, %mask : tensor<128x256xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32) { +// CHECK-DAG: [[CST_3072_:%.+]] = arith.constant 3072 : index +// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[CST_259_:%.+]] = arith.constant 259 : index +// CHECK-DAG: [[CST_130_:%.+]] = arith.constant 130 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF80 : bf16 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128, 256], strides: [1, [[CST_1024_]]{{.}}, offsets: {{.}}[[CST_2_]], [[CST_3072_]]{{.}}, parent_sizes: [0, 0] : to tensor<128x256x!tt.ptr> +// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [128, 256], strides: [1, [[CST_1024_]]{{.}}, offsets: {{.}}[[CST_2_]], [[CST_3072_]]{{.}}, parent_sizes: [0, 0] : to tensor<128x256x!tt.ptr> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_3_:%.+]] = arith.minsi [[VAR_2_]], [[CST_130_]] : index +// CHECK-DAG: [[VAR_4_:%.+]] = arith.subi [[VAR_3_]], [[CST_2_]] : index +// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_6_:%.+]] = arith.minsi [[VAR_5_]], [[CST_259_]] : index +// CHECK-DAG: [[VAR_7_:%.+]] = arith.subi [[VAR_6_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.minsi [[VAR_4_]], [[CST_128_]] : index +// CHECK: [[VAR_9_:%.+]] = arith.minsi [[VAR_7_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_10_:%.+]] = "tts.load"([[VAR_0_]], [[VAR_8_]], [[VAR_9_]], [[CST_0_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<128x256x!tt.ptr>, index, index, bf16) -> tensor<128x256xbf16> +// CHECK-DAG: [[VAR_11_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_12_:%.+]] = arith.minsi [[VAR_11_]], [[CST_130_]] : index +// CHECK-DAG: [[VAR_13_:%.+]] = arith.subi [[VAR_12_]], [[CST_2_]] : index +// CHECK-DAG: [[VAR_14_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_15_:%.+]] = arith.minsi [[VAR_14_]], [[CST_259_]] : index +// CHECK-DAG: [[VAR_16_:%.+]] = arith.subi [[VAR_15_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_17_:%.+]] = arith.minsi [[VAR_13_]], [[CST_128_]] : index +// CHECK: [[VAR_18_:%.+]] = arith.minsi [[VAR_16_]], [[CST_256_]] : index +// CHECK: "tts.store"([[VAR_1_]], [[VAR_1_]]0, [[VAR_1_]]7, [[VAR_1_]]8) <{static_dims = array}> : (tensor<128x256x!tt.ptr>, tensor<128x256xbf16>, index, index) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/masked_ldst_sitofp_other.mlir b/test/Conversion/TritonToStructured/masked_ldst_sitofp_other.mlir new file mode 100644 index 00000000..605b7db4 --- /dev/null +++ b/test/Conversion/TritonToStructured/masked_ldst_sitofp_other.mlir @@ -0,0 +1,38 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) + { + %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x!tt.ptr> + %1 = tt.splat %arg1 : (!tt.ptr) -> tensor<128x!tt.ptr> + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %ldptr = tt.addptr %0, %2 : tensor<128x!tt.ptr>, tensor<128xi32> + %stptr = tt.addptr %1, %2 : tensor<128x!tt.ptr>, tensor<128xi32> + %c7_i32 = arith.constant 7 : i32 + %splat_c7_i32 = tt.splat %c7_i32 : (i32) -> tensor<128xi32> + %splat_c7_bf16 = arith.sitofp %splat_c7_i32 : tensor<128xi32> to tensor<128xbf16> + %5 = tt.splat %arg2 : (i32) -> tensor<128xi32> + %mask = arith.cmpi slt, %2, %5 : tensor<128xi32> + %buff = tt.load %ldptr, %mask, %splat_c7_bf16 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xbf16> + tt.store %stptr, %buff, %mask : tensor<128xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32) { +// CHECK-DAG: [[CST_7_dot_000000_:%.+]] = arith.constant 7.000000e+00 : bf16 +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128], strides: [1], offsets: [0], parent_sizes: [0] : to tensor<128x!tt.ptr> +// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [128], strides: [1], offsets: [0], parent_sizes: [0] : to tensor<128x!tt.ptr> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_3_:%.+]] = arith.minsi [[VAR_2_]], [[CST_128_]] : index +// CHECK-DAG: [[VAR_4_:%.+]] = "tts.load"([[VAR_0_]], [[VAR_3_]], [[CST_7_dot_000000_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<128x!tt.ptr>, index, bf16) -> tensor<128xbf16> +// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_6_:%.+]] = arith.minsi [[VAR_5_]], [[CST_128_]] : index +// CHECK: "tts.store"([[VAR_1_]], [[VAR_4_]], [[VAR_6_]]) <{static_dims = array}> : (tensor<128x!tt.ptr>, tensor<128xbf16>, index) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/use_dot_opc.mlir b/test/Conversion/TritonToStructured/use_dot_opc.mlir new file mode 100644 index 00000000..c543957f --- /dev/null +++ b/test/Conversion/TritonToStructured/use_dot_opc.mlir @@ -0,0 +1,68 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : !tt.ptr + ) + { + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %c64 = arith.constant 128 : i32 + %1 = tt.splat %c64 : (i32) -> tensor<128xi32> + %2 = arith.muli %0, %1 : tensor<128xi32> + %3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> + %4 = tt.broadcast %3 : (tensor<128x1xi32>) -> tensor<128x64xi32> + %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %6 = tt.expand_dims %5 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32> + %7 = tt.broadcast %6 : (tensor<1x64xi32>) -> tensor<128x64xi32> + %8 = arith.addi %4, %7 : tensor<128x64xi32> + %10 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %11 = tt.expand_dims %10 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + %12 = tt.broadcast %11 : (tensor<1x256xi32>) -> tensor<64x256xi32> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %c256 = arith.constant 256 : i32 + %14 = tt.splat %c256 : (i32) -> tensor<64xi32> + %15 = arith.muli %13, %14 : tensor<64xi32> + %16 = tt.expand_dims %15 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32> + %17 = tt.broadcast %16 : (tensor<64x1xi32>) -> tensor<64x256xi32> + %18 = arith.addi %12, %17 : tensor<64x256xi32> + %20 = tt.splat %c256 : (i32) -> tensor<128xi32> + %21 = arith.muli %0, %20 : tensor<128xi32> + %22 = tt.expand_dims %21 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> + %23 = tt.broadcast %22 : (tensor<128x1xi32>) -> tensor<128x256xi32> + %24 = tt.expand_dims %10 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + %25 = tt.broadcast %24 {axis = 0 : i32} : (tensor<1x256xi32>) -> tensor<128x256xi32> + %26 = arith.addi %23, %25 : tensor<128x256xi32> + %30 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x64x!tt.ptr> + %31 = tt.addptr %30, %8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %32 = tt.load %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<128x64xbf16> + %40 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x256x!tt.ptr> + %41 = tt.addptr %40, %18 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + %42 = tt.load %41 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<64x256xbf16> + %50 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x256x!tt.ptr> + %51 = tt.addptr %50, %26 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %cf0 = arith.constant 0.0 : bf16 + %71 = tt.splat %cf0 : (bf16) -> (tensor<128x256xbf16>) + %60 = tt.dot %32, %42, %71 {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xbf16> + tt.store %51, %60 : tensor<128x256xbf16> + tt.store %51, %71 : tensor<128x256xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> +// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128, 64], strides: {{.}}[[CST_128_]], 1], offsets: [0, 0], parent_sizes: [0, 0] : to tensor<128x64x!tt.ptr> +// CHECK-DAG: [[VAR_1_:%.+]] = "tts.load"([[VAR_0_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<128x64x!tt.ptr>) -> tensor<128x64xbf16> +// CHECK-DAG: [[VAR_2_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [64, 256], strides: {{.}}[[CST_256_]], 1], offsets: [0, 0], parent_sizes: [0, 0] : to tensor<64x256x!tt.ptr> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = "tts.load"([[VAR_2_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<64x256x!tt.ptr>) -> tensor<64x256xbf16> +// CHECK-DAG: [[VAR_4_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [128, 256], strides: {{.}}[[CST_256_]], 1], offsets: [0, 0], parent_sizes: [0, 0] : to tensor<128x256x!tt.ptr> +// CHECK: [[VAR_5_:%.+]] = tt.dot [[VAR_1_]], [[VAR_3_]], [[VAR_cst_]] {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xbf16> +// CHECK: "tts.store"([[VAR_4_]], [[VAR_5_]]) <{static_dims = array}> : (tensor<128x256x!tt.ptr>, tensor<128x256xbf16>) -> () +// CHECK: "tts.store"([[VAR_4_]], [[VAR_cst_]]) <{static_dims = array}> : (tensor<128x256x!tt.ptr>, tensor<128x256xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/use_end_chain.mlir b/test/Conversion/TritonToStructured/use_end_chain.mlir new file mode 100644 index 00000000..20d74737 --- /dev/null +++ b/test/Conversion/TritonToStructured/use_end_chain.mlir @@ -0,0 +1,56 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr + ) + { + %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> + // offset = [512] size = 256, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<256xi32>) -> tensor<256x1xi32> + // offset = [512,0], size = [256,1], stride = [1,0] + %2 = tt.broadcast %1 : (tensor<256x1xi32>) -> tensor<256x128xi32> + // offset = [512,0], size = [256,128], stride = [1,0] + %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> + // offset = 1024, size = 128, stride = 1 + %6 = tt.expand_dims %5 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> + // offset = [0,1024], size = [1,128], stride = [0,1] + %7 = tt.broadcast %6 : (tensor<1x128xi32>) -> tensor<256x128xi32> + // offset = [0,1024], size = [256,128], stride = [0,1] + %c6 = arith.constant 6 : i32 + %splat6 = tt.splat %c6 : (i32) -> tensor<256x128xi32> + %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> + // offset = [0,6144], size = [256,128], stride = [0,6] + %14 = arith.addi %2, %scale7 : tensor<256x128xi32> + // offset = [512,6144], size = [256,128], stride = [1,6] + // mixed use + %17 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x128x!tt.ptr> + %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> + %19 = tt.load %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x128xbf16> + tt.store %18, %19 : tensor<256x128xbf16> + %20 = arith.sitofp %14 : tensor<256x128xi32> to tensor<256x128xbf16> + tt.store %18, %20 : tensor<256x128xbf16> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) { +// CHECK-DAG: [[CST_6144_:%.+]] = arith.constant 6144 : index +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<6> : tensor<256x128xi32> +// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 768 : i32, start = 512 : i32} : tensor<256xi32> +// CHECK: [[VAR_1_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : (tensor<256xi32>) -> tensor<256x1xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.broadcast [[VAR_1_]] : (tensor<256x1xi32>) -> tensor<256x128xi32> +// CHECK-DAG: [[VAR_3_:%.+]] = tt.make_range {end = 1152 : i32, start = 1024 : i32} : tensor<128xi32> +// CHECK: [[VAR_4_:%.+]] = tt.expand_dims [[VAR_3_]] {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> +// CHECK: [[VAR_5_:%.+]] = tt.broadcast [[VAR_4_]] : (tensor<1x128xi32>) -> tensor<256x128xi32> +// CHECK: [[VAR_6_:%.+]] = arith.muli [[VAR_5_]], [[VAR_cst_]] : tensor<256x128xi32> +// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_2_]], [[VAR_6_]] : tensor<256x128xi32> +// CHECK-DAG: [[VAR_8_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [256, 128], strides: [1, [[CST_6_]]{{.}}, offsets: [512, [[CST_6144_]]{{.}}, parent_sizes: [0, 0] : to tensor<256x128x!tt.ptr> +// CHECK: [[VAR_9_:%.+]] = "tts.load"([[VAR_8_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x128x!tt.ptr>) -> tensor<256x128xbf16> +// CHECK: "tts.store"([[VAR_8_]], [[VAR_9_]]) <{static_dims = array}> : (tensor<256x128x!tt.ptr>, tensor<256x128xbf16>) -> () +// CHECK: [[VAR_10_:%.+]] = arith.sitofp [[VAR_7_]] : tensor<256x128xi32> to tensor<256x128xbf16> +// CHECK: "tts.store"([[VAR_8_]], [[VAR_10_]]) <{static_dims = array}> : (tensor<256x128x!tt.ptr>, tensor<256x128xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/use_mid_chain.mlir b/test/Conversion/TritonToStructured/use_mid_chain.mlir new file mode 100644 index 00000000..d90fc098 --- /dev/null +++ b/test/Conversion/TritonToStructured/use_mid_chain.mlir @@ -0,0 +1,52 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : !tt.ptr + ) + { + %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> + // offset = [512] size = 256, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<256xi32>) -> tensor<256x1xi32> + // offset = [512,0], size = [256,1], stride = [1,0] + %2 = tt.broadcast %1 : (tensor<256x1xi32>) -> tensor<256x128xi32> + // offset = [512,0], size = [256,128], stride = [1,0] + // mixed use + %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> + // offset = 1024, size = 128, stride = 1 + %6 = tt.expand_dims %5 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> + // offset = [0,1024], size = [1,128], stride = [0,1] + %7 = tt.broadcast %6 : (tensor<1x128xi32>) -> tensor<256x128xi32> + // offset = [0,1024], size = [256,128], stride = [0,1] + %c6 = arith.constant 6 : i32 + %splat6 = tt.splat %c6 : (i32) -> tensor<256x128xi32> + %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> + // offset = [0,6144], size = [256,128], stride = [0,6] + %14 = arith.addi %2, %scale7 : tensor<256x128xi32> + // offset = [512,6144], size = [256,128], stride = [1,6] + %17 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x128x!tt.ptr> + %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> + %19 = tt.load %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x128xbf16> + tt.store %18, %19 : tensor<256x128xbf16> + %20 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x128x!tt.ptr> + %21 = tt.addptr %20, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> + tt.store %21, %2 : tensor<256x128xi32> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr) { +// CHECK-DAG: [[CST_6144_:%.+]] = arith.constant 6144 : index +// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 768 : i32, start = 512 : i32} : tensor<256xi32> +// CHECK: [[VAR_1_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : (tensor<256xi32>) -> tensor<256x1xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.broadcast [[VAR_1_]] : (tensor<256x1xi32>) -> tensor<256x128xi32> +// CHECK-DAG: [[VAR_3_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [256, 128], strides: [1, [[CST_6_]]{{.}}, offsets: [512, [[CST_6_]]144], parent_sizes: [0, 0] : to tensor<256x128x!tt.ptr> +// CHECK: [[VAR_4_:%.+]] = "tts.load"([[VAR_3_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<256x128x!tt.ptr>) -> tensor<256x128xbf16> +// CHECK: "tts.store"([[VAR_3_]], [[VAR_4_]]) <{static_dims = array}> : (tensor<256x128x!tt.ptr>, tensor<256x128xbf16>) -> () +// CHECK: [[VAR_5_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [256, 128], strides: [1, [[CST_6_]]{{.}}, offsets: [512, [[CST_6_]]144], parent_sizes: [0, 0] : to tensor<256x128x!tt.ptr> +// CHECK: "tts.store"([[VAR_5_]], [[VAR_2_]]) <{static_dims = array}> : (tensor<256x128x!tt.ptr>, tensor<256x128xi32>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/wraparound_side_by_side.mlir b/test/Conversion/TritonToStructured/wraparound_side_by_side.mlir new file mode 100644 index 00000000..b48f69ca --- /dev/null +++ b/test/Conversion/TritonToStructured/wraparound_side_by_side.mlir @@ -0,0 +1,92 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func public @wrap_side_by_side_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { + %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %cst_0 = arith.constant dense<2> : tensor<4x1xi32> + %cst_1 = arith.constant dense<6> : tensor<4xi32> + %cst_2 = arith.constant dense<2> : tensor<4xi32> + %c4_i32 = arith.constant 4 : i32 + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = arith.addi %0, %cst_2 : tensor<4xi32> + %2 = arith.addi %0, %cst_1 : tensor<4xi32> + %3 = tt.splat %arg3 : (i32) -> tensor<4xi32> + %4 = arith.remsi %2, %3 : tensor<4xi32> + %5 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %6 = tt.splat %arg4 : (i32) -> tensor<4x1xi32> + %7 = arith.muli %5, %6 : tensor<4x1xi32> + %8 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> + %9 = tt.splat %arg5 : (i32) -> tensor<1x4xi32> + %10 = arith.muli %8, %9 : tensor<1x4xi32> + %11 = tt.broadcast %7 : (tensor<4x1xi32>) -> tensor<4x4xi32> + %12 = tt.broadcast %10 : (tensor<1x4xi32>) -> tensor<4x4xi32> + %13 = arith.addi %11, %12 : tensor<4x4xi32> + %14 = tt.splat %arg0 : (!tt.ptr) -> tensor<4x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %16 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %17 = tt.splat %arg6 : (i32) -> tensor<4x1xi32> + %18 = arith.muli %17, %16 : tensor<4x1xi32> + %19 = tt.splat %arg1 : (!tt.ptr) -> tensor<4x1x!tt.ptr> + %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> + %21 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> + %22 = tt.splat %arg7 : (i32) -> tensor<1x4xi32> + %23 = arith.muli %22, %21 : tensor<1x4xi32> + %24 = tt.broadcast %20 : (tensor<4x1x!tt.ptr>) -> tensor<4x4x!tt.ptr> + %25 = tt.broadcast %23 : (tensor<1x4xi32>) -> tensor<4x4xi32> + %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32> + %28 = tt.broadcast %27 : (tensor<4x1xi1>) -> tensor<4x4xi1> + %29 = arith.muli %arg4, %c4_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<4x4xi32> + %31 = arith.muli %arg5, %c4_i32 : i32 + %32 = tt.splat %31 : (i32) -> tensor<4x4xi32> + %33:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { + %34 = tt.load %arg9, %28, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4x4xf32> + tt.store %arg10, %34 {cache = 1 : i32, evict = 1 : i32} : tensor<4x4xf32> + %35 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %36 = tt.addptr %arg10, %32 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + scf.yield %35, %36 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> + } + tt.return + } +} + +// CHECK: tt.func public @wrap_side_by_side_masked_loop_01234567([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_minus_9_dot_900000_:%.+]] = arith.constant -9.900000e+01 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32 +// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[CST_2_1_]] : index +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = arith.muli [[VAR_3_]], [[CST_6_]] : index +// CHECK-DAG: [[VAR_5_:%.+]] = arith.muli [[VAR_2_]], [[VAR_3_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index +// CHECK-DAG: [[VAR_7_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = arith.index_cast [[VAR_8_]] : i32 to index +// CHECK-DAG: [[VAR_10_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : i32 to index +// CHECK-DAG: [[VAR_12_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg9_:%.+]] = [[VAR_1_]], [[VAR_arg10_:%.+]] = [[CST_0_]]) -> (index, index) : i32 { +// CHECK-DAG: [[VAR_13_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [4, 4], strides: {{.}}[[VAR_6_]], [[VAR_7_]]{{.}}, offsets: {{.}}[[PARAM_1_]]0, [[CST_0_]]{{.}}, parent_sizes: [0, 0] : to tensor<4x4x!tt.ptr> +// CHECK-DAG: [[VAR_14_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [4, 4], strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}}, offsets: {{.}}[[VAR_arg9_]], [[VAR_4_]]{{.}}, parent_sizes: [0, [[VAR_5_]]{{.}} : to tensor<4x4x!tt.ptr> +// CHECK: [[VAR_15_:%.+]] = "tts.load"([[VAR_14_]], [[CST_minus_9_dot_900000_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<4x4x!tt.ptr>, f32) -> tensor<4x4xf32> +// CHECK: "tts.store"([[VAR_13_]], [[VAR_15_]]) <{static_dims = array}> : (tensor<4x4x!tt.ptr>, tensor<4x4xf32>) -> () +// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_arg9_]], [[VAR_9_]] : index +// CHECK-DAG: [[VAR_17_:%.+]] = arith.addi [[VAR_arg10_]], [[VAR_11_]] : index +// CHECK: scf.yield [[VAR_16_]], [[VAR_17_]] : index, index +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/wraparound_stacked.mlir b/test/Conversion/TritonToStructured/wraparound_stacked.mlir new file mode 100644 index 00000000..6f3bde51 --- /dev/null +++ b/test/Conversion/TritonToStructured/wraparound_stacked.mlir @@ -0,0 +1,88 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +module { + tt.func public @wrap_stacked_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { + %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %cst_0 = arith.constant dense<3> : tensor<1x4xi32> + %cst_1 = arith.constant dense<3> : tensor<4xi32> + %cst_2 = arith.constant dense<2> : tensor<4xi32> + %c4_i32 = arith.constant 4 : i32 + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = arith.addi %0, %cst_2 : tensor<4xi32> + %2 = tt.splat %arg2 : (i32) -> tensor<4xi32> + %3 = arith.remsi %1, %2 : tensor<4xi32> + %4 = arith.addi %0, %cst_1 : tensor<4xi32> + %5 = tt.expand_dims %3 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %6 = tt.splat %arg4 : (i32) -> tensor<4x1xi32> + %7 = arith.muli %5, %6 : tensor<4x1xi32> + %8 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> + %9 = tt.splat %arg5 : (i32) -> tensor<1x4xi32> + %10 = arith.muli %8, %9 : tensor<1x4xi32> + %11 = tt.broadcast %7 : (tensor<4x1xi32>) -> tensor<4x4xi32> + %12 = tt.broadcast %10 : (tensor<1x4xi32>) -> tensor<4x4xi32> + %13 = arith.addi %11, %12 : tensor<4x4xi32> + %14 = tt.splat %arg0 : (!tt.ptr) -> tensor<4x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %16 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %17 = tt.splat %arg6 : (i32) -> tensor<4x1xi32> + %18 = arith.muli %17, %16 : tensor<4x1xi32> + %19 = tt.splat %arg1 : (!tt.ptr) -> tensor<4x1x!tt.ptr> + %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> + %21 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> + %22 = tt.splat %arg7 : (i32) -> tensor<1x4xi32> + %23 = arith.muli %22, %21 : tensor<1x4xi32> + %24 = tt.broadcast %20 : (tensor<4x1x!tt.ptr>) -> tensor<4x4x!tt.ptr> + %25 = tt.broadcast %23 : (tensor<1x4xi32>) -> tensor<4x4xi32> + %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %27 = arith.cmpi slt, %21, %cst_0 : tensor<1x4xi32> + %28 = tt.broadcast %27 : (tensor<1x4xi1>) -> tensor<4x4xi1> + %29 = arith.muli %arg5, %c4_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<4x4xi32> + %31:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { + %32 = tt.load %arg9, %28, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4x4xf32> + tt.store %arg10, %32 {cache = 1 : i32, evict = 1 : i32} : tensor<4x4xf32> + %33 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %34 = tt.addptr %arg10, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + scf.yield %33, %34 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> + } + tt.return + } +} + +// CHECK: tt.func public @wrap_stacked_masked_loop_01234567([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_minus_9_dot_900000_:%.+]] = arith.constant -9.900000e+01 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[VAR_1_]], [[CST_2_]] : index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.muli [[VAR_0_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_5_:%.+]] = arith.muli [[VAR_4_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index +// CHECK-DAG: [[VAR_7_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = arith.index_cast [[VAR_8_]] : i32 to index +// CHECK-DAG: [[VAR_10_:%.+]] = arith.index_cast [[VAR_8_]] : i32 to index +// CHECK-DAG: [[VAR_11_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_1_]] to [[CST_2_1_]] step [[CST_1_]] iter_args([[VAR_arg9_:%.+]] = [[VAR_2_]], [[VAR_arg10_:%.+]] = [[CST_0_]]) -> (index, index) : i32 { +// CHECK-DAG: [[VAR_12_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [4, 4], strides: {{.}}[[VAR_6_]], [[VAR_7_]]{{.}}, offsets: {{.}}[[PARAM_1_]]0, [[CST_0_]]{{.}}, parent_sizes: [0, 0] : to tensor<4x4x!tt.ptr> +// CHECK-DAG: [[VAR_13_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [4, 4], strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}}, offsets: {{.}}[[VAR_arg9_]], [[VAR_5_]]{{.}}, parent_sizes: {{.}}[[VAR_3_]], 0] : to tensor<4x4x!tt.ptr> +// CHECK: [[VAR_14_:%.+]] = "tts.load"([[VAR_13_]], [[CST_minus_9_dot_900000_]]) <{operandSegmentSizes = array, static_dims = array}> : (tensor<4x4x!tt.ptr>, f32) -> tensor<4x4xf32> +// CHECK: "tts.store"([[VAR_12_]], [[VAR_14_]]) <{static_dims = array}> : (tensor<4x4x!tt.ptr>, tensor<4x4xf32>) -> () +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addi [[VAR_arg9_]], [[VAR_10_]] : index +// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_arg10_]], [[VAR_9_]] : index +// CHECK: scf.yield [[VAR_15_]], [[VAR_16_]] : index, index +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/wraparound_unsupported_add_offset.mlir b/test/Conversion/TritonToStructured/wraparound_unsupported_add_offset.mlir new file mode 100644 index 00000000..79be9196 --- /dev/null +++ b/test/Conversion/TritonToStructured/wraparound_unsupported_add_offset.mlir @@ -0,0 +1,110 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize %s | FileCheck %s + +// We currently do not support this kind of modulo pattern: +// (a + arrange(0, K)) % M +// Check verifies that we fail gracefully and keep the original code +module { + tt.func public @wrap_side_by_side_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { + %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %cst_0 = arith.constant dense<2> : tensor<4x1xi32> + %cst_1 = arith.constant dense<6> : tensor<4xi32> + %cst_2 = arith.constant dense<2> : tensor<4xi32> + %c4_i32 = arith.constant 4 : i32 + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = arith.addi %0, %cst_2 : tensor<4xi32> + %2 = tt.splat %arg3 : (i32) -> tensor<4xi32> + %3 = arith.remsi %0, %2 : tensor<4xi32> + %4 = arith.addi %3, %cst_1 : tensor<4xi32> + %5 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %6 = tt.splat %arg4 : (i32) -> tensor<4x1xi32> + %7 = arith.muli %5, %6 : tensor<4x1xi32> + %8 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> + %9 = tt.splat %arg5 : (i32) -> tensor<1x4xi32> + %10 = arith.muli %8, %9 : tensor<1x4xi32> + %11 = tt.broadcast %7 : (tensor<4x1xi32>) -> tensor<4x4xi32> + %12 = tt.broadcast %10 : (tensor<1x4xi32>) -> tensor<4x4xi32> + %13 = arith.addi %11, %12 : tensor<4x4xi32> + %14 = tt.splat %arg0 : (!tt.ptr) -> tensor<4x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %16 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> + %17 = tt.splat %arg6 : (i32) -> tensor<4x1xi32> + %18 = arith.muli %17, %16 : tensor<4x1xi32> + %19 = tt.splat %arg1 : (!tt.ptr) -> tensor<4x1x!tt.ptr> + %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> + %21 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> + %22 = tt.splat %arg7 : (i32) -> tensor<1x4xi32> + %23 = arith.muli %22, %21 : tensor<1x4xi32> + %24 = tt.broadcast %20 : (tensor<4x1x!tt.ptr>) -> tensor<4x4x!tt.ptr> + %25 = tt.broadcast %23 : (tensor<1x4xi32>) -> tensor<4x4xi32> + %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32> + %28 = tt.broadcast %27 : (tensor<4x1xi1>) -> tensor<4x4xi1> + %29 = arith.muli %arg4, %c4_i32 : i32 + %30 = tt.splat %29 : (i32) -> tensor<4x4xi32> + %31 = arith.muli %arg5, %c4_i32 : i32 + %32 = tt.splat %31 : (i32) -> tensor<4x4xi32> + %33:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { + %34 = tt.load %arg9, %28, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4x4xf32> + tt.store %arg10, %34 {cache = 1 : i32, evict = 1 : i32} : tensor<4x4xf32> + %35 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %36 = tt.addptr %arg10, %32 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + scf.yield %35, %36 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> + } + tt.return + } +} + +// CHECK: tt.func public @wrap_side_by_side_masked_loop_01234567([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<2> : tensor<4x1xi32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<6> : tensor<4xi32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2> : tensor<4xi32> +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = arith.addi [[VAR_0_]], [[VAR_cst_2_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.splat [[PARAM_3_]] : (i32) -> tensor<4xi32> +// CHECK: [[VAR_3_:%.+]] = arith.remsi [[VAR_0_]], [[VAR_2_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_3_]], [[VAR_cst_1_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_5_:%.+]] = tt.expand_dims [[VAR_1_]] {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> +// CHECK-DAG: [[VAR_6_:%.+]] = tt.splat [[PARAM_4_]] : (i32) -> tensor<4x1xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = arith.muli [[VAR_5_]], [[VAR_6_]] : tensor<4x1xi32> +// CHECK-DAG: [[VAR_8_:%.+]] = tt.expand_dims [[VAR_4_]] {axis = 0 : i32} : (tensor<4xi32>) -> tensor<1x4xi32> +// CHECK-DAG: [[VAR_9_:%.+]] = tt.splat [[PARAM_5_]] : (i32) -> tensor<1x4xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_10_:%.+]] = arith.muli [[VAR_8_]], [[VAR_9_]] : tensor<1x4xi32> +// CHECK-DAG: [[VAR_11_:%.+]] = tt.broadcast [[VAR_7_]] : (tensor<4x1xi32>) -> tensor<4x4xi32> +// CHECK: [[VAR_12_:%.+]] = tt.broadcast [[VAR_10_]] : (tensor<1x4xi32>) -> tensor<4x4xi32> +// CHECK-DAG: [[VAR_13_:%.+]] = arith.addi [[VAR_11_]], [[VAR_12_]] : tensor<4x4xi32> +// CHECK-DAG: [[VAR_14_:%.+]] = tt.splat [[PARAM_0_]] : (!tt.ptr) -> tensor<4x4x!tt.ptr> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_15_:%.+]] = tt.addptr [[VAR_14_]], [[VAR_13_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> +// CHECK-DAG: [[VAR_16_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : (tensor<4xi32>) -> tensor<4x1xi32> +// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index +// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: [[VAR_19_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32> +// CHECK-DAG: [[VAR_20_:%.+]] = tt.broadcast [[VAR_19_]] : (tensor<4x1xi1>) -> tensor<4x4xi1> +// CHECK-DAG: [[VAR_21_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_22_:%.+]] = tt.splat [[VAR_21_]] : (i32) -> tensor<4x4xi32> +// CHECK-DAG: [[VAR_23_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_24_:%.+]] = arith.index_cast [[VAR_23_]] : i32 to index +// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[CST_0_]]) -> (tensor<4x4x!tt.ptr>, index) : i32 { +// CHECK-DAG: [[VAR_26_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [4, 4], strides: {{.}}[[VAR_17_]], [[VAR_18_]]{{.}}, offsets: {{.}}[[PARAM_1_]]0, [[CST_0_]]{{.}}, parent_sizes: [0, 0] : to tensor<4x4x!tt.ptr> +// CHECK-DAG: [[LOAD_VAR_arg9_MEM_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_20_]], [[VAR_cst_]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<4x4xf32> +// CHECK: "tts.store"([[VAR_26_]], [[LOAD_VAR_arg9_MEM_]]) <{static_dims = array}> : (tensor<4x4x!tt.ptr>, tensor<4x4xf32>) -> () +// CHECK-DAG: [[VAR_28_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_22_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> +// CHECK-DAG: [[VAR_29_:%.+]] = arith.addi [[VAR_arg10_]], [[VAR_24_]] : index +// CHECK: scf.yield [[VAR_28_]], [[VAR_29_]] : tensor<4x4x!tt.ptr>, index +// CHECK: } +// CHECK: tt.return +// CHECK: } From 60c6dab02b5fa9a409459091b41ecc25ae435520 Mon Sep 17 00:00:00 2001 From: Haishan Zhu Date: Wed, 17 Jan 2024 14:49:58 -0800 Subject: [PATCH 6/8] Address review comments --- .../AnalysisStructured/MaskAnalysis.h | 16 +- .../AnalysisStructured/PtrAnalysis.h | 55 ++-- .../IR/TritonStructuredDialect.td | 9 +- lib/AnalysisStructured/MaskAnalysis.cpp | 44 +-- lib/AnalysisStructured/PtrAnalysis.cpp | 267 ++++++++++-------- .../TritonToStructuredPass.cpp | 2 +- .../IR/TritonStructuredOps.cpp | 7 +- 7 files changed, 212 insertions(+), 188 deletions(-) diff --git a/include/triton-shared/AnalysisStructured/MaskAnalysis.h b/include/triton-shared/AnalysisStructured/MaskAnalysis.h index daf34047..3f03c686 100644 --- a/include/triton-shared/AnalysisStructured/MaskAnalysis.h +++ b/include/triton-shared/AnalysisStructured/MaskAnalysis.h @@ -20,7 +20,7 @@ namespace mlir { class OpBuilder; -namespace triton { +namespace tts { // Data structure used to decode the pattern in a mask used for load and store. // start and end field represent the start and end index of a range (produced // by make_range, addi, etc.). While multi-dimensional data is possible, we @@ -35,14 +35,14 @@ namespace triton { // result of splat, expand_dims, etc. During this phase, either (1) both start // and end are populated, or (2) scalar is populated. Only one of the dimensions // (that contains the range) can have dim > 1. -// 2. Result from step 1 is compared with a another MaskSState that represents a +// 2. Result from step 1 is compared with a another MaskState that represents a // scalar value. The resulting state only has dims populated. // 3. Optionally, result from step 2 can be broadcasted and anded with other // results from step 2. The resulting state only has dims populated. // // Example of creating 2D mask: // mask = (rows[:, None] < M) & (cols[None, :] < N) -struct MaskSState { +struct MaskState { OpFoldResult start; OpFoldResult end; SmallVector dims; @@ -60,19 +60,19 @@ struct MaskSState { private: // ------- - // Utility functions to operate on MaskSState + // Utility functions to operate on MaskState // ------- - LogicalResult addStateScalar(const MaskSState &state, + LogicalResult addStateScalar(const MaskState &state, const OpFoldResult scalar, Location loc, OpBuilder &builder); - LogicalResult addStates(const MaskSState &lhsState, const MaskSState &rhsState, + LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState, Location loc, OpBuilder &builder); - LogicalResult minStates(const MaskSState &lhsState, const MaskSState &rhsState, + LogicalResult minStates(const MaskState &lhsState, const MaskState &rhsState, Location loc, OpBuilder &builder); // ------- - // Helper functions to parse values to populate MaskSState + // Helper functions to parse values to populate MaskState // ------- // Operand is the result of a constant diff --git a/include/triton-shared/AnalysisStructured/PtrAnalysis.h b/include/triton-shared/AnalysisStructured/PtrAnalysis.h index 55af7a4c..50e79b71 100644 --- a/include/triton-shared/AnalysisStructured/PtrAnalysis.h +++ b/include/triton-shared/AnalysisStructured/PtrAnalysis.h @@ -21,7 +21,7 @@ namespace mlir { class OpBuilder; -namespace triton { +namespace tts { const extern std::string ptrAnalysisAttr; @@ -29,8 +29,9 @@ const extern std::string ptrAnalysisAttr; // strides are in unit of elements in a linearly laid-out memory, which is the // same as pointer arithmetic operations in Triton language. scalar is a // shortcut used when the entire state describes a single scalar value. source -// is the base pointer. -class PtrSState { +// is the base pointer. modulos describes how address wraps around; a constant 0 +// indicates no modulo for the dimension. +class PtrState { public: SmallVector offsets; @@ -49,18 +50,16 @@ class PtrSState { bool dimHasModulo(uint32_t dim) const; - // Process addition of two PtrSStates. - LogicalResult addState(const PtrSState &lhsState, const PtrSState &rhsState, + // Process addition of two PtrStates. + LogicalResult addState(const PtrState &lhsState, const PtrState &rhsState, Operation *op, OpBuilder &builder); - // Process multiplication of two PtrSStates - LogicalResult mulState(const PtrSState &lhsState, const PtrSState &rhsState, + // Process multiplication of two PtrStates + LogicalResult mulState(const PtrState &lhsState, const PtrState &rhsState, Operation *op, OpBuilder &builder); tts::MakeTensorPtrOp createTTSMakeTensorPtrOp(OpBuilder &builder, Location loc); - - static void swap(PtrSState &&a, PtrSState &&b); }; struct PtrAnalysis { @@ -69,43 +68,43 @@ struct PtrAnalysis { IndexMapSet levelToBlockArgIndex; int level = 0; - llvm::SmallDenseMap knownPtrs; + llvm::SmallDenseMap knownPtrs; - IRMapping map; + IRMapping ptrMap; // Recursively parse a Value; call the corresponding // function based on the defining operation and argument type. - LogicalResult visitOperand(Value operand, PtrSState &state, const Location loc, + LogicalResult visitOperand(Value operand, PtrState &state, const Location loc, OpBuilder &builder); // Operand is the result of arith.addi. Process both arguments and insert any // arith.addi instruction as needed. // Main assumptions: // Only one of lhsState and rhsState has source field set - // Current PtrSState should be empty + // Current PtrState should be empty // Expected result: // source = lhsState.source ? lhsState.source : rhsState.source // sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i]) // offsets[i] = lhsState.offsets[i] + rhsState.offsets[i] // strides[i] = lhsState.strides[i] + rhsState.strides[i] - LogicalResult visitOperandAdd(arith::AddIOp addOp, PtrSState &state, + LogicalResult visitOperandAdd(arith::AddIOp addOp, PtrState &state, const Location loc, OpBuilder &builder); // Operand is the result of arith.muli. Process both arguments and insert any // arith.muli instruction as needed. // Main assumptions: // Neither lhsState nor rhsState has source field set - // Current PtrSState should be empty + // Current PtrState should be empty // Currently only support one of the operand is a scalar index // Expected result (scalar and tensorState represent the two operands): // source = null // sizes[i] = tensorState.sizes[i] // offsets[i] = tensorState.offsets[i] * scalar // strides[i] = tensorState.strides[i] * scalar - LogicalResult visitOperandMul(arith::MulIOp mulOp, PtrSState &state, + LogicalResult visitOperandMul(arith::MulIOp mulOp, PtrState &state, const Location loc, OpBuilder &builder); - LogicalResult visitOperandRem(arith::RemSIOp mulOp, PtrSState &state, + LogicalResult visitOperandRem(arith::RemSIOp mulOp, PtrState &state, const Location loc, OpBuilder &builder); // Operand is the result of make_range. @@ -119,7 +118,7 @@ struct PtrAnalysis { // offset[0] = start // strides[0] = ceiling( (end - start) / shape[0] ) LogicalResult visitOperandMakeRange(triton::MakeRangeOp rangeOp, - PtrSState &state, Location loc, + PtrState &state, Location loc, OpBuilder &builder); // Operand is the result of expand_dims @@ -129,7 +128,7 @@ struct PtrAnalysis { // Expected result: // Insert a dimension of size 1, stride 0, and offset 0 LogicalResult visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, - PtrSState &state, const Location loc, + PtrState &state, const Location loc, OpBuilder &builder); // Operand is the result of broadcast @@ -138,7 +137,7 @@ struct PtrAnalysis { // Expected result: // Update sizes[i] only, no changes to other fields LogicalResult visitOperandBroadcast(triton::BroadcastOp broadcastOp, - PtrSState &state, const Location loc, + PtrState &state, const Location loc, OpBuilder &builder); // Operand is the result of splat @@ -147,7 +146,7 @@ struct PtrAnalysis { // Expected result: // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = 0 // if source is an integer, offset[0] = scalar = source - LogicalResult visitOperandSplat(triton::SplatOp splatOp, PtrSState &state, + LogicalResult visitOperandSplat(triton::SplatOp splatOp, PtrState &state, const Location loc, OpBuilder &builder); // Operand is the result of arith.constant that is a splat @@ -157,7 +156,7 @@ struct PtrAnalysis { // Expected result: // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = // splat value if i == 0, otherwise 0 - LogicalResult visitOperandConstSplat(arith::ConstantOp op, PtrSState &state, + LogicalResult visitOperandConstSplat(arith::ConstantOp op, PtrState &state, const Location loc, OpBuilder &builder); // Operand is the result of addptr. @@ -166,7 +165,7 @@ struct PtrAnalysis { // ptr and offset fields should result in same rank // Expected result: // The resulting state for ptr and offset wil be added - LogicalResult visitOperandAddptr(triton::AddPtrOp addptrOp, PtrSState &state, + LogicalResult visitOperandAddptr(triton::AddPtrOp addptrOp, PtrState &state, const Location loc, OpBuilder &builder); // Operand is the result of make_tptr. @@ -175,20 +174,20 @@ struct PtrAnalysis { // Expected result: // Directly grab all corresponding fields from make_tptr. LogicalResult visitOperandMakeTPtr(tts::MakeTensorPtrOp makeTPtrOp, - PtrSState &state, const Location loc, + PtrState &state, const Location loc, OpBuilder &builder); // Parse the state of AddPtrOp, insert any instruction needed to - // calculate strides and offsets, build PtrSState for this operand, and record - // PtrSState for knownPtrs. + // calculate strides and offsets, build PtrState for this operand, and record + // PtrState for knownPtrs. LogicalResult rewriteAddptrOp(triton::AddPtrOp op); // Parse the state of YieldOp, insert any instruction needed to calculate - // strides and offsets, build PtrSState for this operand, and record PtrSState + // strides and offsets, build PtrState for this operand, and record PtrState // in knownPtrs. LogicalResult rewriteYieldOp(scf::YieldOp op, - llvm::SmallDenseMap &knownPtrsFor); + llvm::SmallDenseMap &knownPtrsFor); // Rewrite eligible tt.addptr in loop init args so loop can update the such // pointers over iterations. Insert any instruction needed to calculate diff --git a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td index 0d8ae99f..f11b25dc 100644 --- a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td +++ b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td @@ -31,9 +31,16 @@ class TTS_Op traits = []> : } def TTS_MakeTensorPtrOp - : TTS_Op<"make_tptr", [ AttrSizedOperandSegments, Pure]> { + : TTS_Op<"make_tptr", [AttrSizedOperandSegments, Pure]> { let summary = "create a pointer that points to a tensor in memory"; + + // base: base pointer used to contruct the tensor of pointers + // sizes: size of tensor of pointers + // strides: address increment from one element to the next in the + // corresponding dimension + // parent_sizes: used to represent wrap-around behavior; constant zero + // indicate no wrap-around in the corresponding dimension let arguments = (ins TT_Ptr:$base, DenseI64ArrayAttr:$sizes, Variadic:$strides, diff --git a/lib/AnalysisStructured/MaskAnalysis.cpp b/lib/AnalysisStructured/MaskAnalysis.cpp index ed218257..c513a83b 100644 --- a/lib/AnalysisStructured/MaskAnalysis.cpp +++ b/lib/AnalysisStructured/MaskAnalysis.cpp @@ -14,9 +14,9 @@ namespace mlir { -namespace triton { +namespace tts { -LogicalResult MaskSState::parse(Value operand, const Location loc, +LogicalResult MaskState::parse(Value operand, const Location loc, OpBuilder &builder) { if (auto op = operand.getDefiningOp()) { return this->parseConstant(op, loc, builder); @@ -41,7 +41,7 @@ LogicalResult MaskSState::parse(Value operand, const Location loc, } } -LogicalResult MaskSState::addStateScalar(const MaskSState &state, +LogicalResult MaskState::addStateScalar(const MaskState &state, const OpFoldResult scalar, Location loc, OpBuilder &builder) { start = addOFRs(state.start, scalar, loc, builder); @@ -50,8 +50,8 @@ LogicalResult MaskSState::addStateScalar(const MaskSState &state, return success(); } -LogicalResult MaskSState::addStates(const MaskSState &lhsState, - const MaskSState &rhsState, Location loc, +LogicalResult MaskState::addStates(const MaskState &lhsState, + const MaskState &rhsState, Location loc, OpBuilder &builder) { if (lhsState.scalar && rhsState.scalar) { InFlightDiagnostic diag = @@ -72,8 +72,8 @@ LogicalResult MaskSState::addStates(const MaskSState &lhsState, return addStateScalar(lhsState, rhsState.scalar, loc, builder); } -LogicalResult MaskSState::minStates(const MaskSState &lhsState, - const MaskSState &rhsState, Location loc, +LogicalResult MaskState::minStates(const MaskState &lhsState, + const MaskState &rhsState, Location loc, OpBuilder &builder) { if (lhsState.getRank() != rhsState.getRank()) { InFlightDiagnostic diag = @@ -90,7 +90,7 @@ LogicalResult MaskSState::minStates(const MaskSState &lhsState, return success(); } -LogicalResult MaskSState::parseConstant(arith::ConstantOp constOp, +LogicalResult MaskState::parseConstant(arith::ConstantOp constOp, const Location loc, OpBuilder &builder) { assert(this->isEmpty()); @@ -113,7 +113,7 @@ LogicalResult MaskSState::parseConstant(arith::ConstantOp constOp, return success(); } -LogicalResult MaskSState::parseIntScalar(Value scalar, const Location loc, +LogicalResult MaskState::parseIntScalar(Value scalar, const Location loc, OpBuilder &builder) { assert(this->isEmpty()); auto castOp = @@ -122,31 +122,31 @@ LogicalResult MaskSState::parseIntScalar(Value scalar, const Location loc, return success(); } -LogicalResult MaskSState::parseAdd(arith::AddIOp addOp, const Location loc, +LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location loc, OpBuilder &builder) { assert(this->isEmpty()); - MaskSState lhsState; + MaskState lhsState; if (failed(lhsState.parse(addOp.getLhs(), loc, builder))) return failure(); - MaskSState rhsState; + MaskState rhsState; if (failed(rhsState.parse(addOp.getRhs(), loc, builder))) return failure(); return this->addStates(lhsState, rhsState, loc, builder); } -LogicalResult MaskSState::parseAnd(arith::AndIOp andOp, const Location loc, +LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc, OpBuilder &builder) { assert(this->isEmpty()); - MaskSState lhsState; + MaskState lhsState; if (failed(lhsState.parse(andOp.getLhs(), loc, builder)) || !lhsState.isMask()) return failure(); - MaskSState rhsState; + MaskState rhsState; if (failed(rhsState.parse(andOp.getRhs(), loc, builder)) || !rhsState.isMask()) return failure(); @@ -154,7 +154,7 @@ LogicalResult MaskSState::parseAnd(arith::AndIOp andOp, const Location loc, return this->minStates(lhsState, rhsState, loc, builder); } -LogicalResult MaskSState::parseCmp(arith::CmpIOp cmpOp, const Location loc, +LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, OpBuilder &builder) { assert(this->isEmpty()); @@ -163,11 +163,11 @@ LogicalResult MaskSState::parseCmp(arith::CmpIOp cmpOp, const Location loc, return failure(); } - MaskSState lhsState; + MaskState lhsState; if (failed(lhsState.parse(cmpOp.getLhs(), loc, builder))) return failure(); - MaskSState rhsState; + MaskState rhsState; if (failed(rhsState.parse(cmpOp.getRhs(), loc, builder))) return failure(); @@ -202,7 +202,7 @@ LogicalResult MaskSState::parseCmp(arith::CmpIOp cmpOp, const Location loc, return success(); } -LogicalResult MaskSState::parseMakeRange(triton::MakeRangeOp rangeOp, +LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp, const Location loc, OpBuilder &builder) { assert(this->isEmpty()); @@ -227,7 +227,7 @@ LogicalResult MaskSState::parseMakeRange(triton::MakeRangeOp rangeOp, return success(); } -LogicalResult MaskSState::parseBroadcast(triton::BroadcastOp broadcastOp, +LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp, const Location loc, OpBuilder &builder) { assert(this->isEmpty()); @@ -257,7 +257,7 @@ LogicalResult MaskSState::parseBroadcast(triton::BroadcastOp broadcastOp, return success(); } -LogicalResult MaskSState::parseSplat(triton::SplatOp splatOp, const Location loc, +LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc, OpBuilder &builder) { assert(this->isEmpty()); @@ -281,7 +281,7 @@ LogicalResult MaskSState::parseSplat(triton::SplatOp splatOp, const Location loc return success(); } -LogicalResult MaskSState::parseExpandDims(triton::ExpandDimsOp expandDimsOp, +LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp, const Location loc, OpBuilder &builder) { assert(this->isEmpty()); diff --git a/lib/AnalysisStructured/PtrAnalysis.cpp b/lib/AnalysisStructured/PtrAnalysis.cpp index 419a88c1..e050c433 100644 --- a/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/lib/AnalysisStructured/PtrAnalysis.cpp @@ -89,19 +89,19 @@ static Value getScalarValue(Value operand, Location loc, OpBuilder &builder) { return nullptr; } -namespace triton { +namespace tts { -int32_t PtrSState::getRank() const { +int32_t PtrState::getRank() const { assert(offsets.size() == sizes.size() && offsets.size() == strides.size() && modulos.size() == offsets.size()); return offsets.size(); } -bool PtrSState::isEmpty() const { +bool PtrState::isEmpty() const { return (getRank() == 0 && !source && !scalar); } -bool PtrSState::hasModulo() const { +bool PtrState::hasModulo() const { for (int32_t i = 0; i < getRank(); i++) { if (dimHasModulo(i)) { return true; @@ -110,7 +110,7 @@ bool PtrSState::hasModulo() const { return false; } -bool PtrSState::dimHasModulo(uint32_t dim) const { +bool PtrState::dimHasModulo(uint32_t dim) const { assert(dim < getRank()); auto intAttr = getIntAttr(modulos[dim]); @@ -121,8 +121,8 @@ bool PtrSState::dimHasModulo(uint32_t dim) const { return intAttr.value() != 0; } -LogicalResult PtrSState::addState(const PtrSState &lhsState, - const PtrSState &rhsState, Operation *op, +LogicalResult PtrState::addState(const PtrState &lhsState, + const PtrState &rhsState, Operation *op, OpBuilder &builder) { assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); auto loc = op->getLoc(); @@ -136,18 +136,6 @@ LogicalResult PtrSState::addState(const PtrSState &lhsState, source = lhsState.source ? lhsState.source : rhsState.source; - // AddPtr where both lhs and rhs containing modulo operators not supported - if (lhsState.hasModulo() && rhsState.hasModulo()) { - op->emitRemark("PtrAnalysis: do not support adding two pointer states " - "that both have modulo"); - return failure(); - } - - if (lhsState.hasModulo() || rhsState.hasModulo()) { - // visitOperandSplat and visitOperandExpandDims should enforce below - assert(lhsState.getRank() <= 2); - } - if (lhsState.scalar && rhsState.scalar) { auto addOp = builder.create(loc, lhsState.scalar, rhsState.scalar); @@ -168,6 +156,18 @@ LogicalResult PtrSState::addState(const PtrSState &lhsState, sizes.push_back(lhsState.sizes[i]); } + // AddPtr where both lhs and rhs containing modulo operators not supported + if (lhsState.hasModulo() && rhsState.hasModulo()) { + op->emitRemark("PtrAnalysis: do not support adding two pointer states " + "that both have modulo"); + return failure(); + } + + if (lhsState.hasModulo() || rhsState.hasModulo()) { + // visitOperandSplat and visitOperandExpandDims should enforce below + assert(lhsState.getRank() <= 2); + } + // dealing with modulo: // - If lhs has no modulo, skip // - If rhs has zero offset on dim i, we can just use lhs's modulo @@ -175,15 +175,24 @@ LogicalResult PtrSState::addState(const PtrSState &lhsState, // is because the user may be trying to express adding a constant offset to // increment dim1, but pointer analysis cannot differentiate dim1 vs dim0 in // this case. - // - Else, the analysis fail + // - Else, the analysis fails + + // An example for the 3rd condition above can look like: + // %0 = tt.splat %scalar + // %1 = tt.splat %ptr + // %2 = tt.arange + // %3 = arith.remsi %2, %size + // %4 = tt.addptr %1, %3 + // %5 = tt.addptr %4, %0 + // %5 may also occur in a loop to increment %4 every iteration. // Note that this is not bullet-proof. E.g., broken IR can actually increment // dim0 while dim0 already has modulo, since Triton offsets are element-wise // and not in unit of lower dimensions. However, this is highly unlikely but // the analysis will provide wrong result. Hence we provide a warning in this // case. - PtrSState const *lhs = &lhsState; - PtrSState const *rhs = &rhsState; + PtrState const *lhs = &lhsState; + PtrState const *rhs = &rhsState; if (rhs->hasModulo()) { std::swap(lhs, rhs); @@ -215,8 +224,8 @@ LogicalResult PtrSState::addState(const PtrSState &lhsState, return success(); } -LogicalResult PtrSState::mulState(const PtrSState &lhsState, - const PtrSState &rhsState, Operation *op, +LogicalResult PtrState::mulState(const PtrState &lhsState, + const PtrState &rhsState, Operation *op, OpBuilder &builder) { assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); @@ -225,7 +234,7 @@ LogicalResult PtrSState::mulState(const PtrSState &lhsState, // neither lhs nor rhs should have source, since multiplying base pointer // does not make sense if (lhsState.source && rhsState.source) { - op->emitRemark("PtrAnalysis: do not support multiplying base pointer"); + op->emitRemark("PtrAnalysis: do not support multiplying base pointers"); return failure(); } @@ -240,8 +249,8 @@ LogicalResult PtrSState::mulState(const PtrSState &lhsState, return failure(); } - PtrSState const *lhs = &lhsState; - PtrSState const *rhs = &rhsState; + PtrState const *lhs = &lhsState; + PtrState const *rhs = &rhsState; if (!rhs->scalar && lhs->scalar) { std::swap(lhs, rhs); @@ -270,7 +279,7 @@ LogicalResult PtrSState::mulState(const PtrSState &lhsState, return success(); } -tts::MakeTensorPtrOp PtrSState::createTTSMakeTensorPtrOp(OpBuilder &builder, +tts::MakeTensorPtrOp PtrState::createTTSMakeTensorPtrOp(OpBuilder &builder, Location loc) { SmallVector staticSizes; for (size_t i = 0; i < getRank(); i++) { @@ -289,15 +298,15 @@ tts::MakeTensorPtrOp PtrSState::createTTSMakeTensorPtrOp(OpBuilder &builder, return op; } -LogicalResult PtrAnalysis::visitOperandAdd(arith::AddIOp addOp, PtrSState &state, +LogicalResult PtrAnalysis::visitOperandAdd(arith::AddIOp addOp, PtrState &state, const Location loc, OpBuilder &builder) { - PtrSState lhsState; + PtrState lhsState; if (visitOperand(addOp.getLhs(), lhsState, loc, builder).failed()) { return failure(); } - PtrSState rhsState; + PtrState rhsState; if (visitOperand(addOp.getRhs(), rhsState, loc, builder).failed()) return failure(); @@ -312,15 +321,15 @@ LogicalResult PtrAnalysis::visitOperandAdd(arith::AddIOp addOp, PtrSState &state return state.addState(lhsState, rhsState, addOp, builder); } -LogicalResult PtrAnalysis::visitOperandMul(arith::MulIOp mulOp, PtrSState &state, +LogicalResult PtrAnalysis::visitOperandMul(arith::MulIOp mulOp, PtrState &state, const Location loc, OpBuilder &builder) { - PtrSState lhsState; + PtrState lhsState; if (visitOperand(mulOp.getLhs(), lhsState, loc, builder).failed()) { return failure(); } - PtrSState rhsState; + PtrState rhsState; if (visitOperand(mulOp.getRhs(), rhsState, loc, builder).failed()) { return failure(); } @@ -329,11 +338,11 @@ LogicalResult PtrAnalysis::visitOperandMul(arith::MulIOp mulOp, PtrSState &state } LogicalResult PtrAnalysis::visitOperandRem(arith::RemSIOp remOp, - PtrSState &state, const Location loc, + PtrState &state, const Location loc, OpBuilder &builder) { assert(state.isEmpty()); - PtrSState rhsState; + PtrState rhsState; if (visitOperand(remOp.getRhs(), rhsState, loc, builder).failed()) { return failure(); } @@ -390,7 +399,7 @@ LogicalResult PtrAnalysis::visitOperandRem(arith::RemSIOp remOp, } LogicalResult PtrAnalysis::visitOperandMakeRange(triton::MakeRangeOp rangeOp, - PtrSState &state, Location loc, + PtrState &state, Location loc, OpBuilder &builder) { assert(state.isEmpty()); @@ -411,7 +420,7 @@ LogicalResult PtrAnalysis::visitOperandMakeRange(triton::MakeRangeOp rangeOp, LogicalResult PtrAnalysis::visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, - PtrSState &state, const Location loc, + PtrState &state, const Location loc, OpBuilder &builder) { assert(state.isEmpty()); @@ -444,44 +453,42 @@ PtrAnalysis::visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, LogicalResult PtrAnalysis::visitOperandBroadcast(triton::BroadcastOp broadcastOp, - PtrSState &state, const Location loc, + PtrState &state, const Location loc, OpBuilder &builder) { assert(state.isEmpty()); auto src = broadcastOp.getSrc(); auto dst = broadcastOp.getResult(); - SmallVector srcShape; - auto dstShape = dst.getType().cast().getShape(); + if (!src.getType().isa()) { + broadcastOp->emitRemark("PtrAnalysis: Unsupported broadcast source type"); + return failure(); + } - if (src.getType().isa()) { - srcShape = - SmallVector(src.getType().cast().getShape()); - assert(srcShape.size() == dstShape.size() && - "rank of source and destination should match"); + auto srcShape = cast(src.getType()).getShape(); + auto dstShape = cast(dst.getType()).getShape(); - if (visitOperand(src, state, loc, builder).failed()) { - return failure(); - } + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); - for (size_t i = 0; i < dstShape.size(); i++) { - if (srcShape[i] == dstShape[i]) { - continue; - } else if (srcShape[i] < dstShape[i]) { - state.sizes[i] = builder.getIndexAttr(dstShape[i]); - } else { - llvm_unreachable("unexpected dimensions used in broadcast"); - } - } - return success(); + if (visitOperand(src, state, loc, builder).failed()) { + return failure(); } - broadcastOp->emitRemark("PtrAnalysis: Unsupported broadcast source type"); - return failure(); + for (size_t i = 0; i < dstShape.size(); i++) { + if (srcShape[i] == dstShape[i]) { + continue; + } else if (srcShape[i] < dstShape[i]) { + state.sizes[i] = builder.getIndexAttr(dstShape[i]); + } else { + llvm_unreachable("unexpected dimensions used in broadcast"); + } + } + return success(); } LogicalResult PtrAnalysis::visitOperandSplat(triton::SplatOp splatOp, - PtrSState &state, + PtrState &state, const Location loc, OpBuilder &builder) { assert(state.isEmpty()); @@ -521,18 +528,18 @@ LogicalResult PtrAnalysis::visitOperandSplat(triton::SplatOp splatOp, } LogicalResult PtrAnalysis::visitOperandAddptr(triton::AddPtrOp addptrOp, - PtrSState &state, + PtrState &state, const Location loc, OpBuilder &builder) { assert(state.isEmpty()); - PtrSState ptrState; + PtrState ptrState; if (visitOperand(addptrOp.getPtr(), ptrState, addptrOp.getLoc(), builder) .failed()) { return failure(); } - PtrSState offsetState; + PtrState offsetState; if (visitOperand(addptrOp.getOffset(), offsetState, addptrOp.getLoc(), builder) .failed()) { @@ -548,7 +555,7 @@ LogicalResult PtrAnalysis::visitOperandAddptr(triton::AddPtrOp addptrOp, } LogicalResult PtrAnalysis::visitOperandConstSplat(arith::ConstantOp op, - PtrSState &state, + PtrState &state, const Location loc, OpBuilder &builder) { assert(state.isEmpty()); @@ -582,7 +589,7 @@ LogicalResult PtrAnalysis::visitOperandConstSplat(arith::ConstantOp op, } LogicalResult PtrAnalysis::visitOperandMakeTPtr(tts::MakeTensorPtrOp makeTPtrOp, - PtrSState &state, + PtrState &state, const Location loc, OpBuilder &builder) { @@ -596,7 +603,7 @@ LogicalResult PtrAnalysis::visitOperandMakeTPtr(tts::MakeTensorPtrOp makeTPtrOp, return success(); } -LogicalResult PtrAnalysis::visitOperand(Value operand, PtrSState &state, +LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state, const Location loc, OpBuilder &builder) { @@ -627,7 +634,7 @@ LogicalResult PtrAnalysis::visitOperand(Value operand, PtrSState &state, return visitOperandAddptr(cast(op), state, loc, builder); } else if (auto makeTensorOp = dyn_cast(op)) { - llvm_unreachable("NYI"); + llvm_unreachable("Unexpected operand defining operation tts.make_tptr"); } else { llvm_unreachable("Unexpected operand defining operation"); } @@ -666,7 +673,7 @@ LogicalResult PtrAnalysis::visitOperand(Value operand, PtrSState &state, LogicalResult PtrAnalysis::rewriteAddptrOp(triton::AddPtrOp op) { OpBuilder builder(op); - PtrSState state; + PtrState state; if (visitOperandAddptr(op, state, op.getLoc(), builder).failed()) { return failure(); } @@ -675,9 +682,11 @@ LogicalResult PtrAnalysis::rewriteAddptrOp(triton::AddPtrOp op) { if (op.getPtr().getType().isa()) { auto maketptrOp = state.createTTSMakeTensorPtrOp(builder, op.getLoc()); - map.map(op.getResult(), maketptrOp.getResult()); + ptrMap.map(op.getResult(), maketptrOp.getResult()); } else { - map.map(op.getResult(), op.getResult()); + // record the ptr as we have visited and built up the state for this scalar + // pointer, which may be used by rewriteForOp later. + ptrMap.map(op.getResult(), op.getResult()); } return success(); } @@ -685,46 +694,47 @@ LogicalResult PtrAnalysis::rewriteAddptrOp(triton::AddPtrOp op) { LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { SmallVector newInitArgs; - SmallVector, 5> initArgIndexState; - SmallVector, 5> knownPtrsTmp; + SmallVector, 5> initArgIndexState; + SmallVector, 5> knownPtrsTmp; - llvm::SmallDenseMap initArgIndexMap; + llvm::SmallDenseMap initArgIndexMap; OpBuilder builder(op); // Create a new list of init args for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { - auto mappedV = map.lookupOrNull(arg); - PtrSState state; + auto mappedV = ptrMap.lookupOrNull(arg); + PtrState state; if (mappedV) { if (auto makeTPtrOp = mappedV.getDefiningOp()) { if (visitOperandMakeTPtr(makeTPtrOp, state, op.getLoc(), builder) - .failed()) { - newInitArgs.push_back(arg); + .succeeded()) { + newInitArgs.push_back(mappedV); + // Record the PtrState for later processing + initArgIndexState.push_back(std::make_pair(i, state)); continue; } - newInitArgs.push_back(mappedV); } else if (auto addptrOp = mappedV.getDefiningOp()) { + // We always use tt.addptr for scalar pointers. If the defininig op is + // tt.addptr and we have a non-scalar pointer, something must have gone + // wrong with the pass. assert(!addptrOp.getResult().getType().isa()); if (visitOperandAddptr(addptrOp, state, op.getLoc(), builder) - .failed()) { - newInitArgs.push_back(arg); + .succeeded()) { + newInitArgs.push_back(mappedV); + // Record the PtrState for later processing + initArgIndexState.push_back(std::make_pair(i, state)); continue; } - newInitArgs.push_back(mappedV); } } - // Init arg is not pointer related or prior rewrite has failed. Pass as is - else { - newInitArgs.push_back(arg); - continue; - } - // Record the PtrSState for later processing - initArgIndexState.push_back(std::make_pair(i, state)); + // If any of the analysis failed, or init arg is not pointer related or + // prior rewrite has failed. Pass as is + newInitArgs.push_back(arg); } - // For each of the PtrSState recorded in the last step, insert new instructions + // For each of the PtrState recorded in the last step, insert new instructions // to describe offset and stride for each dimension and append them to init // args for (auto [i, state] : initArgIndexState) { @@ -757,6 +767,8 @@ LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { if (state.getRank() == 0) { assert(state.scalar); + // for scalar pointers, the scalar contains the offset and is the only + // relevant state that could be updated by the loop. newInitArgs.push_back(state.scalar); } @@ -784,7 +796,7 @@ LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { // Convert the book-keeping data structure to use the correct key and value. // Key is converted from init arg index to newly created block arg, and - // Value's PtrSState fields are converted from init arg to newly created block + // Value's PtrState fields are converted from init arg to newly created block // arg int cnt = op.getRegionIterArgs().size(); for (auto [i, state] : knownPtrsTmp) { @@ -804,38 +816,45 @@ LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { cnt++; } - // Record the PtrSState for this pointer + // Record the PtrState for this pointer auto key = newOp.getRegionIterArgs()[i]; knownPtrs[key] = state; initArgIndexMap[i] = state; - // Create a tts.make_tptr at the beginning of the loop body that correspond - // to this region iter arg. In case it is used by tt.load/tt.store in the - // loop body, this will make sure rewriteLoadOp/rewriteStoreOp can use the - // analysis result. + // For tensors of pointers, create a tts.make_tptr at the beginning of the + // loop body that correspond to this region iter arg. In case it is used + // by tt.load/tt.store in the loop body before pointer updates, this will + // make sure rewriteLoadOp/rewriteStoreOp can use the analysis result. + // E.g., given the following input (%tensor_of_ptr is a block arg): + // scf.for (%tensor_of_ptr) { + // %data = tt.load %tensor_of_ptr + // // more operations to update %tensor_of_ptr + // } + // We may produce the following output: + // scf.for (%base_ptr, %stride, %offset) { + // %tensor_of_ptr = tts.make_tptr(%base_ptr, %stride, %offset) + // %data = tts.load %tensor_of_ptr + // // more operations to update %offset + // } + // If %tensor_of_ptr is not used (i.e., %tensor_of_ptr is updated before + // used in the original IR), it will simply be removed by canonicalization. + + // For scalar pointers, there is no need to create a tts.addptr at the + // beginning of the loop body. We don't lower tt.load and tt.store on + // scalars in this pass; pointer arithmetics can also just use the + // original pointer. if (state.getRank() != 0) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(&newOp.getRegion().front()); auto maketptrOp = state.createTTSMakeTensorPtrOp(builder, op.getLoc()); - map.map(key, maketptrOp.getResult()); - } else { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&newOp.getRegion().front()); - - auto offset = state.scalar; - if (offset.getType().isa()) { - offset = builder.create( - op.getLoc(), builder.getI32Type(), offset); - } - auto addptrOp = builder.create( - op.getLoc(), state.source.getType(), state.source, offset); - map.map(key, addptrOp.getResult()); + ptrMap.map(key, maketptrOp.getResult()); } } for (auto &bodyOp : newOp.getRegion().getOps()) { if (auto forOp = dyn_cast(bodyOp)) { - assert(0 && "nested loops currently not supported"); + forOp->emitRemark("PtrAnalysis: nested loops currently not supported"); + return failure(); } } @@ -878,7 +897,7 @@ LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { LogicalResult PtrAnalysis::rewriteYieldOp(scf::YieldOp op, - llvm::SmallDenseMap &knownPtrsFor) { + llvm::SmallDenseMap &knownPtrsFor) { if (levelToBlockArgIndex.find(level) == levelToBlockArgIndex.end()) { // no need to rewrite this op return success(); @@ -888,21 +907,21 @@ PtrAnalysis::rewriteYieldOp(scf::YieldOp op, // For each of the init arg that we added additional Values in for loop, we // need to add corresponding Values as yield operands. The loop below gathers - // PtrSState for those values. - SmallVector initArgState; + // PtrState for those values. + SmallVector initArgState; for (auto [i, v] : llvm::enumerate(op->getOperands())) { // If this operand is not rewritten by forOp, skip auto thisSet = levelToBlockArgIndex.find(level)->second; if (thisSet.find(i) == thisSet.end()) continue; - auto mappedV = map.lookupOrNull(v); + auto mappedV = ptrMap.lookupOrNull(v); if (!mappedV) { op->emitRemark("Prior rewrite failure lead to yield rewrite failure"); return failure(); } - PtrSState state; + PtrState state; LogicalResult ret = failure(); if (auto makeTPtrOp = mappedV.getDefiningOp()) { ret = visitOperandMakeTPtr(makeTPtrOp, state, op.getLoc(), builder); @@ -936,7 +955,7 @@ PtrAnalysis::rewriteYieldOp(scf::YieldOp op, SmallVector operands; for (auto opnd : op->getOperands()) { - auto mappedV = map.lookupOrNull(opnd); + auto mappedV = ptrMap.lookupOrNull(opnd); if (mappedV) { operands.push_back(mappedV); } else { @@ -944,7 +963,7 @@ PtrAnalysis::rewriteYieldOp(scf::YieldOp op, } } - // For each of the PtrSState recorded in the last step, extract value + // For each of the PtrState recorded in the last step, extract value // that correspond to offset and stride for each dimension and append // them to yield operands. for (auto state : initArgState) { @@ -959,7 +978,7 @@ PtrAnalysis::rewriteYieldOp(scf::YieldOp op, } for (auto s : state.strides) { - assert(!getIntAttr(s) && "PtrSState strides for yield within for " + assert(!getIntAttr(s) && "PtrState strides for yield within for " "loop not expected to be attribute."); operands.push_back(s.get()); } @@ -983,7 +1002,7 @@ PtrAnalysis::rewriteYieldOp(scf::YieldOp op, } LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) { - auto ptr = map.lookupOrNull(op.getPtr()); + auto ptr = ptrMap.lookupOrNull(op.getPtr()); auto mask = op.getMask(); auto other = op.getOther(); auto loc = op.getLoc(); @@ -1000,7 +1019,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) { } ArrayRef dims; - MaskSState mstate; + MaskState mstate; Value scalarOther; OpBuilder builder(op); @@ -1038,7 +1057,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) { } LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) { - auto ptr = map.lookupOrNull(op.getPtr()); + auto ptr = ptrMap.lookupOrNull(op.getPtr()); auto val = op.getValue(); auto mask = op.getMask(); auto loc = op.getLoc(); @@ -1055,7 +1074,7 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) { } ArrayRef dims; - MaskSState mstate; + MaskState mstate; OpBuilder builder(op); @@ -1124,5 +1143,5 @@ LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp) { return success(); } -} // namespace triton +} // namespace tts } // namespace mlir diff --git a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp index 70e41162..89668ecd 100644 --- a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp +++ b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -44,7 +44,7 @@ class TritonToStructuredPass void runOnOperation() override { auto moduleOp = getOperation(); - PtrAnalysis ptrAnalysis; + mlir::tts::PtrAnalysis ptrAnalysis; if (ptrAnalysis.rewriteOp(moduleOp).failed()) { moduleOp->emitWarning("PtrAnalysis failed"); } diff --git a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp index 0298d49c..dbede2f7 100644 --- a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp +++ b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp @@ -30,7 +30,7 @@ void MakeTensorPtrOp::build(OpBuilder &b, OperationState &state, Value base, dispatchIndexOpFoldResults(parentSizes, dynamicParentSizes, staticParentSizes); - auto basePtr = base.getType().cast(); + auto basePtr = cast(base.getType()); auto elemType = basePtr.getPointeeType(); auto resType = RankedTensorType::get(sizes, basePtr); @@ -47,9 +47,8 @@ void LoadOp::build(OpBuilder &b, OperationState &state, Value ptr, dispatchIndexOpFoldResults(dims, dynamicDims, staticDims); - auto ptrTensorType = ptr.getType().cast(); - auto elemType = ptrTensorType.getElementType() - .cast() + auto ptrTensorType = cast(ptr.getType()); + auto elemType = cast(ptrTensorType.getElementType()) .getPointeeType(); auto resType = RankedTensorType::get(ptrTensorType.getShape(), elemType); From c4ac4252500a033af2f1a129f33d459f1deae3f0 Mon Sep 17 00:00:00 2001 From: Haishan Zhu Date: Wed, 17 Jan 2024 15:04:32 -0800 Subject: [PATCH 7/8] revert header name change --- include/triton-shared/Analysis/OpFoldResultUtils.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/triton-shared/Analysis/OpFoldResultUtils.h b/include/triton-shared/Analysis/OpFoldResultUtils.h index 2a3e485a..148c52c4 100644 --- a/include/triton-shared/Analysis/OpFoldResultUtils.h +++ b/include/triton-shared/Analysis/OpFoldResultUtils.h @@ -5,8 +5,8 @@ // //===----------------------------------------------------------------------===// -#ifndef TRITON_ANALYSISSTRUCTURED_OPFOLDRESULT_UTILS_H -#define TRITON_ANALYSISSTRUCTURED_OPFOLDRESULT_UTILS_H +#ifndef TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H +#define TRITON_ANALYSIS_OPFOLDRESULT_UTILS_H #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" From 99f9412dfe2aaab7b8f845571f78dbe98e2326cb Mon Sep 17 00:00:00 2001 From: Haishan Zhu Date: Thu, 18 Jan 2024 14:34:19 -0800 Subject: [PATCH 8/8] remove copied version of mask analysis --- include/triton-shared/Analysis/MaskAnalysis.h | 41 ++- .../AnalysisStructured/MaskAnalysis.h | 128 -------- lib/Analysis/MaskAnalysis.cpp | 151 +++++---- lib/AnalysisStructured/CMakeLists.txt | 1 - lib/AnalysisStructured/MaskAnalysis.cpp | 303 ------------------ lib/AnalysisStructured/PtrAnalysis.cpp | 6 +- 6 files changed, 97 insertions(+), 533 deletions(-) delete mode 100644 include/triton-shared/AnalysisStructured/MaskAnalysis.h delete mode 100644 lib/AnalysisStructured/MaskAnalysis.cpp diff --git a/include/triton-shared/Analysis/MaskAnalysis.h b/include/triton-shared/Analysis/MaskAnalysis.h index 531598e6..c0381f69 100644 --- a/include/triton-shared/Analysis/MaskAnalysis.h +++ b/include/triton-shared/Analysis/MaskAnalysis.h @@ -1,6 +1,6 @@ //===----------------------------------------------------------------------===// // -// Copyright (c) Microsoft Corporation. +// Copyright (c) Microsoft Corporation, Meta Platforms. // Licensed under the MIT license. // //===----------------------------------------------------------------------===// @@ -18,7 +18,7 @@ namespace mlir { -class ConversionPatternRewriter; +class OpBuilder; namespace triton { // Data structure used to decode the pattern in a mask used for load and store. @@ -56,23 +56,22 @@ struct MaskState { // Recursively parse a Value; call the coresponding function based on the // defining operation and Value type - LogicalResult parse(Value operand, const Location loc, - ConversionPatternRewriter &rewriter); + LogicalResult parse(Value operand, const Location loc, OpBuilder &builder); tensor::ExtractSliceOp getExtractSlice(Value source, const Location loc, - ConversionPatternRewriter &rewriter) const; + OpBuilder &builder) const; memref::SubViewOp getSubview(Value source, const Location loc, - ConversionPatternRewriter &rewriter) const; + OpBuilder &builder) const; std::pair getSideBySideSubviews(Value block1, Value block2, const Location loc, - ConversionPatternRewriter &rewriter) const; + OpBuilder &builder) const; std::pair getStackedSubviews(Value block1, Value block2, const Location loc, - ConversionPatternRewriter &rewriter) const; + OpBuilder &builder) const; private: // ------- @@ -80,13 +79,13 @@ struct MaskState { // ------- LogicalResult addStateScalar(const MaskState &state, const OpFoldResult scalar, Location loc, - ConversionPatternRewriter &rewriter); + OpBuilder &builder); LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState, - Location loc, ConversionPatternRewriter &rewriter); + Location loc, OpBuilder &builder); LogicalResult minStates(const MaskState &lhsState, const MaskState &rhsState, - Location loc, ConversionPatternRewriter &rewriter); + Location loc, OpBuilder &builder); // ------- // Helper functions to parse values to populate MaskState // ------- @@ -94,49 +93,47 @@ struct MaskState { // Operand is the result of a constant // Get the value of the constant and assign it to scalar. LogicalResult parseConstant(arith::ConstantOp constOp, const Location loc, - ConversionPatternRewriter &rewriter); + OpBuilder &builder); // Operand is an integer scalar LogicalResult parseIntScalar(Value scalar, const Location loc, - ConversionPatternRewriter &rewriter); + OpBuilder &builder); // Operand is the result of addi // One and only one of the operands should be a scalar. Increment both start // and end, dims remains unchanged, and scalar is empty. LogicalResult parseAdd(arith::AddIOp addOp, const Location loc, - ConversionPatternRewriter &rewriter); + OpBuilder &builder); // Operand is the result of andi // Each of the result state dims is smaller of the two operands' dims. // Insert instruction if needed to get new dims. LogicalResult parseAnd(arith::AndIOp andOp, const Location loc, - ConversionPatternRewriter &rewriter); + OpBuilder &builder); // Operand is the result of cmpi // Assume only of the dimensions have size > 1. Only support slt for now. // For that dimension, calculate this new dim as: dim = min(end, value) - // start LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location loc, - ConversionPatternRewriter &rewriter); + OpBuilder &builder); // Operand is the result of make_range // Set start and end accordingly; step size must be 1. LogicalResult parseMakeRange(triton::MakeRangeOp rangeOp, const Location loc, - ConversionPatternRewriter &rewriter); + OpBuilder &builder); // Operand is the result of broadcast // Change dims only; assume only applies to tensors. LogicalResult parseBroadcast(triton::BroadcastOp broadcastOp, - const Location loc, - ConversionPatternRewriter &rewriter); + const Location loc, OpBuilder &builder); // Operand is the result of splat // Assume only applies to scalar. start and end are left empty; scalar will // be assigned, and dims will be updated. LogicalResult parseSplat(triton::SplatOp splatOp, const Location loc, - ConversionPatternRewriter &rewriter); + OpBuilder &builder); // Operand is the result of expand_dims // Insert additional dims; start and end do not change and correspond to the // dimension that contains the range. LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp, - const Location loc, - ConversionPatternRewriter &rewriter); + const Location loc, OpBuilder &builder); }; } // namespace triton diff --git a/include/triton-shared/AnalysisStructured/MaskAnalysis.h b/include/triton-shared/AnalysisStructured/MaskAnalysis.h deleted file mode 100644 index 3f03c686..00000000 --- a/include/triton-shared/AnalysisStructured/MaskAnalysis.h +++ /dev/null @@ -1,128 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation, Meta Platforms. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_ANALYSISSTRUCTURED_MASKANALYSIS_H -#define TRITON_ANALYSISSTRUCTURED_MASKANALYSIS_H - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include - -namespace mlir { - -class OpBuilder; - -namespace tts { -// Data structure used to decode the pattern in a mask used for load and store. -// start and end field represent the start and end index of a range (produced -// by make_range, addi, etc.). While multi-dimensional data is possible, we -// assume range comparison can only be done on 1 dimension at a time (and -// results of range comparions across dimensions can be combined), hence start -// and end are not vectors. dims represents the real access size for ld/st -// (instead of the tensor/memref size specified by the IR). scalar is a shortcut -// used when the entire state contains a single scalar value. -// -// The general lifetime of this data structure is roughly: -// 1. A range is created by make_range and optionally operated on by addi w/ -// result of splat, expand_dims, etc. During this phase, either (1) both start -// and end are populated, or (2) scalar is populated. Only one of the dimensions -// (that contains the range) can have dim > 1. -// 2. Result from step 1 is compared with a another MaskState that represents a -// scalar value. The resulting state only has dims populated. -// 3. Optionally, result from step 2 can be broadcasted and anded with other -// results from step 2. The resulting state only has dims populated. -// -// Example of creating 2D mask: -// mask = (rows[:, None] < M) & (cols[None, :] < N) -struct MaskState { - OpFoldResult start; - OpFoldResult end; - SmallVector dims; - OpFoldResult scalar; - - int64_t getRank() const { return dims.size(); } - - bool isEmpty() const { return getRank() == 0 && !scalar && !start && !end; } - - bool isMask() const { return !start && !end && !scalar && dims.size() != 0; } - - // Recursively parse a Value; call the coresponding function based on the - // defining operation and Value type - LogicalResult parse(Value operand, const Location loc, OpBuilder &builder); - -private: - // ------- - // Utility functions to operate on MaskState - // ------- - LogicalResult addStateScalar(const MaskState &state, - const OpFoldResult scalar, Location loc, - OpBuilder &builder); - - LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState, - Location loc, OpBuilder &builder); - - LogicalResult minStates(const MaskState &lhsState, const MaskState &rhsState, - Location loc, OpBuilder &builder); - // ------- - // Helper functions to parse values to populate MaskState - // ------- - - // Operand is the result of a constant - // Get the value of the constant and assign it to scalar. - LogicalResult parseConstant(arith::ConstantOp constOp, const Location loc, - OpBuilder &builder); - - // Operand is an integer scalar - LogicalResult parseIntScalar(Value scalar, const Location loc, - OpBuilder &builder); - - // Operand is the result of addi - // One and only one of the operands should be a scalar. Increment both start - // and end, dims remains unchanged, and scalar is empty. - LogicalResult parseAdd(arith::AddIOp addOp, const Location loc, - OpBuilder &builder); - // Operand is the result of andi - // Each of the result state dims is smaller of the two operands' dims. - // Insert instruction if needed to get new dims. - LogicalResult parseAnd(arith::AndIOp andOp, const Location loc, - OpBuilder &builder); - - // Operand is the result of cmpi - // Assume only of the dimensions have size > 1. Only support slt for now. - // For that dimension, calculate this new dim as: dim = min(end, value) - - // start - LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location loc, - OpBuilder &builder); - // Operand is the result of make_range - // Set start and end accordingly; step size must be 1. - LogicalResult parseMakeRange(triton::MakeRangeOp rangeOp, const Location loc, - OpBuilder &builder); - // Operand is the result of broadcast - // Change dims only; assume only applies to tensors. - LogicalResult parseBroadcast(triton::BroadcastOp broadcastOp, - const Location loc, OpBuilder &builder); - // Operand is the result of splat - // Assume only applies to scalar. start and end are left empty; scalar will - // be assigned, and dims will be updated. - LogicalResult parseSplat(triton::SplatOp splatOp, const Location loc, - OpBuilder &builder); - // Operand is the result of expand_dims - // Insert additional dims; start and end do not change and correspond to the - // dimension that contains the range. - LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp, - const Location loc, OpBuilder &builder); -}; - -} // namespace triton - -} // namespace mlir - -#endif diff --git a/lib/Analysis/MaskAnalysis.cpp b/lib/Analysis/MaskAnalysis.cpp index 18af487a..c09d0c68 100644 --- a/lib/Analysis/MaskAnalysis.cpp +++ b/lib/Analysis/MaskAnalysis.cpp @@ -17,25 +17,25 @@ namespace mlir { namespace triton { LogicalResult MaskState::parse(Value operand, const Location loc, - ConversionPatternRewriter &rewriter) { + OpBuilder &builder) { if (auto op = operand.getDefiningOp()) { - return this->parseConstant(op, loc, rewriter); + return this->parseConstant(op, loc, builder); } else if (operand.getType().isa()) { - return this->parseIntScalar(operand, loc, rewriter); + return this->parseIntScalar(operand, loc, builder); } else if (auto op = operand.getDefiningOp()) { - return this->parseAdd(op, loc, rewriter); + return this->parseAdd(op, loc, builder); } else if (auto op = operand.getDefiningOp()) { - return this->parseAnd(op, loc, rewriter); + return this->parseAnd(op, loc, builder); } else if (auto op = operand.getDefiningOp()) { - return this->parseCmp(op, loc, rewriter); + return this->parseCmp(op, loc, builder); } else if (auto op = operand.getDefiningOp()) { - return this->parseMakeRange(op, loc, rewriter); + return this->parseMakeRange(op, loc, builder); } else if (auto op = operand.getDefiningOp()) { - return this->parseBroadcast(op, loc, rewriter); + return this->parseBroadcast(op, loc, builder); } else if (auto op = operand.getDefiningOp()) { - return this->parseSplat(op, loc, rewriter); + return this->parseSplat(op, loc, builder); } else if (auto op = operand.getDefiningOp()) { - return this->parseExpandDims(op, loc, rewriter); + return this->parseExpandDims(op, loc, builder); } else { return failure(); } @@ -43,28 +43,28 @@ LogicalResult MaskState::parse(Value operand, const Location loc, tensor::ExtractSliceOp MaskState::getExtractSlice(Value source, const Location loc, - ConversionPatternRewriter &rewriter) const { + OpBuilder &builder) const { auto sourceType = source.getType().cast(); - SmallVector offsets(getRank(), rewriter.getIndexAttr(0)); - SmallVector strides(getRank(), rewriter.getIndexAttr(1)); + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); auto dstType = tensor::ExtractSliceOp::inferResultType(sourceType, offsets, dims, strides); - return rewriter.create(loc, dstType, source, offsets, + return builder.create(loc, dstType, source, offsets, dims, strides); } memref::SubViewOp MaskState::getSubview(Value source, const Location loc, - ConversionPatternRewriter &rewriter) const { + OpBuilder &builder) const { auto sourceType = source.getType().cast(); - SmallVector offsets(getRank(), rewriter.getIndexAttr(0)); - SmallVector strides(getRank(), rewriter.getIndexAttr(1)); + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); auto dstType = memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); - return rewriter.create(loc, dstType.cast(), + return builder.create(loc, dstType.cast(), source, offsets, dims, strides); } @@ -136,20 +136,20 @@ static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, // + rowView1 = rowView2 = row = rowFull std::pair MaskState::getSideBySideSubviews(Value block1, Value block2, const Location loc, - ConversionPatternRewriter &rewriter) const { + OpBuilder &builder) const { OpFoldResult subviewRowFull = dims[0]; OpFoldResult subviewColFull = dims[1]; OpFoldResult col1 = - rewriter.create(loc, block1, 1).getResult(); - OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, rewriter); + builder.create(loc, block1, 1).getResult(); + OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, builder); OpFoldResult subviewCol2 = - subOFRs(subviewColFull, subviewCol1, loc, rewriter); + subOFRs(subviewColFull, subviewCol1, loc, builder); - SmallVector offsets(getRank(), rewriter.getIndexAttr(0)); - SmallVector strides(getRank(), rewriter.getIndexAttr(1)); - auto sv1 = createSubview(block1, loc, rewriter, offsets, + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto sv1 = createSubview(block1, loc, builder, offsets, {subviewRowFull, subviewCol1}, strides); - auto sv2 = createSubview(block2, loc, rewriter, offsets, + auto sv2 = createSubview(block2, loc, builder, offsets, {subviewRowFull, subviewCol2}, strides); return {sv1, sv2}; @@ -157,36 +157,36 @@ MaskState::getSideBySideSubviews(Value block1, Value block2, const Location loc, std::pair MaskState::getStackedSubviews(Value block1, Value block2, const Location loc, - ConversionPatternRewriter &rewriter) const { + OpBuilder &builder) const { OpFoldResult subviewRowFull = dims[0]; OpFoldResult subviewColFull = dims[1]; OpFoldResult row1 = - rewriter.create(loc, block1, 0).getResult(); - OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, rewriter); + builder.create(loc, block1, 0).getResult(); + OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, builder); OpFoldResult subviewRow2 = - subOFRs(subviewRowFull, subviewRow1, loc, rewriter); + subOFRs(subviewRowFull, subviewRow1, loc, builder); - SmallVector offsets(getRank(), rewriter.getIndexAttr(0)); - SmallVector strides(getRank(), rewriter.getIndexAttr(1)); - auto sv1 = createSubview(block1, loc, rewriter, offsets, + SmallVector offsets(getRank(), builder.getIndexAttr(0)); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto sv1 = createSubview(block1, loc, builder, offsets, {subviewRow1, subviewColFull}, strides); - auto sv2 = createSubview(block2, loc, rewriter, offsets, + auto sv2 = createSubview(block2, loc, builder, offsets, {subviewRow2, subviewColFull}, strides); return {sv1, sv2}; } LogicalResult MaskState::addStateScalar(const MaskState &state, const OpFoldResult scalar, Location loc, - ConversionPatternRewriter &rewriter) { - start = addOFRs(state.start, scalar, loc, rewriter); - end = addOFRs(state.end, scalar, loc, rewriter); + OpBuilder &builder) { + start = addOFRs(state.start, scalar, loc, builder); + end = addOFRs(state.end, scalar, loc, builder); dims = state.dims; return success(); } LogicalResult MaskState::addStates(const MaskState &lhsState, const MaskState &rhsState, Location loc, - ConversionPatternRewriter &rewriter) { + OpBuilder &builder) { if (lhsState.scalar && rhsState.scalar) { InFlightDiagnostic diag = emitError(loc) << "Unexpected case where both lhs and rhs are scalars"; @@ -201,14 +201,14 @@ LogicalResult MaskState::addStates(const MaskState &lhsState, } if (lhsState.scalar) - return addStateScalar(rhsState, lhsState.scalar, loc, rewriter); + return addStateScalar(rhsState, lhsState.scalar, loc, builder); else - return addStateScalar(lhsState, rhsState.scalar, loc, rewriter); + return addStateScalar(lhsState, rhsState.scalar, loc, builder); } LogicalResult MaskState::minStates(const MaskState &lhsState, const MaskState &rhsState, Location loc, - ConversionPatternRewriter &rewriter) { + OpBuilder &builder) { if (lhsState.getRank() != rhsState.getRank()) { InFlightDiagnostic diag = emitError(loc) @@ -219,14 +219,13 @@ LogicalResult MaskState::minStates(const MaskState &lhsState, for (uint32_t i = 0; i < lhsState.getRank(); i++) { auto lhsDim = lhsState.dims[i]; auto rhsDim = rhsState.dims[i]; - dims.push_back(minOFRs(lhsDim, rhsDim, loc, rewriter)); + dims.push_back(minOFRs(lhsDim, rhsDim, loc, builder)); } return success(); } LogicalResult MaskState::parseConstant(arith::ConstantOp constOp, - const Location loc, - ConversionPatternRewriter &rewriter) { + const Location loc, OpBuilder &builder) { assert(this->isEmpty()); if (isa(constOp.getValue())) { @@ -236,61 +235,61 @@ LogicalResult MaskState::parseConstant(arith::ConstantOp constOp, "All elements must share a single integer constant value"); auto values = attr.getValues(); auto value = values[0].getValue(); - auto constAttr = rewriter.getIndexAttr(value.getSExtValue()); - auto op = arith::ConstantOp::materialize(rewriter, constAttr, - rewriter.getIndexType(), loc); + auto constAttr = builder.getIndexAttr(value.getSExtValue()); + auto op = arith::ConstantOp::materialize(builder, constAttr, + builder.getIndexType(), loc); this->scalar = op.getValue(); } else { auto value = constOp.getValue().cast().getInt(); - this->scalar = rewriter.getIndexAttr(value); + this->scalar = builder.getIndexAttr(value); } return success(); } LogicalResult MaskState::parseIntScalar(Value scalar, const Location loc, - ConversionPatternRewriter &rewriter) { + OpBuilder &builder) { assert(this->isEmpty()); auto castOp = - rewriter.create(loc, rewriter.getIndexType(), scalar); + builder.create(loc, builder.getIndexType(), scalar); this->scalar = castOp.getResult(); return success(); } LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location loc, - ConversionPatternRewriter &rewriter) { + OpBuilder &builder) { assert(this->isEmpty()); MaskState lhsState; - if (failed(lhsState.parse(addOp.getLhs(), loc, rewriter))) + if (failed(lhsState.parse(addOp.getLhs(), loc, builder))) return failure(); MaskState rhsState; - if (failed(rhsState.parse(addOp.getRhs(), loc, rewriter))) + if (failed(rhsState.parse(addOp.getRhs(), loc, builder))) return failure(); - return this->addStates(lhsState, rhsState, loc, rewriter); + return this->addStates(lhsState, rhsState, loc, builder); } LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc, - ConversionPatternRewriter &rewriter) { + OpBuilder &builder) { assert(this->isEmpty()); MaskState lhsState; - if (failed(lhsState.parse(andOp.getLhs(), loc, rewriter)) || + if (failed(lhsState.parse(andOp.getLhs(), loc, builder)) || !lhsState.isMask()) return failure(); MaskState rhsState; - if (failed(rhsState.parse(andOp.getRhs(), loc, rewriter)) || + if (failed(rhsState.parse(andOp.getRhs(), loc, builder)) || !rhsState.isMask()) return failure(); - return this->minStates(lhsState, rhsState, loc, rewriter); + return this->minStates(lhsState, rhsState, loc, builder); } LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, - ConversionPatternRewriter &rewriter) { + OpBuilder &builder) { assert(this->isEmpty()); if (cmpOp.getPredicate() != arith::CmpIPredicate::slt) { @@ -299,11 +298,11 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, } MaskState lhsState; - if (failed(lhsState.parse(cmpOp.getLhs(), loc, rewriter))) + if (failed(lhsState.parse(cmpOp.getLhs(), loc, builder))) return failure(); MaskState rhsState; - if (failed(rhsState.parse(cmpOp.getRhs(), loc, rewriter))) + if (failed(rhsState.parse(cmpOp.getRhs(), loc, builder))) return failure(); assert((!lhsState.scalar && rhsState.scalar) && "Unsupported cmpi scenario"); @@ -324,8 +323,8 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, assert(cmpDim != -1 && "Unexpected case where no dimension has size larger than 1"); - auto newEnd = minOFRs(lhsState.end, rhsState.scalar, loc, rewriter); - auto newDim = subOFRs(newEnd, lhsState.start, loc, rewriter); + auto newEnd = minOFRs(lhsState.end, rhsState.scalar, loc, builder); + auto newDim = subOFRs(newEnd, lhsState.start, loc, builder); for (int32_t i = 0; i < lhsState.getRank(); i++) { if (i == cmpDim) @@ -339,7 +338,7 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp, const Location loc, - ConversionPatternRewriter &rewriter) { + OpBuilder &builder) { assert(this->isEmpty()); auto shape = rangeOp.getType().cast().getShape(); @@ -355,16 +354,16 @@ LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp, return failure(); } - this->start = rewriter.getIndexAttr(start); - this->end = rewriter.getIndexAttr(end); - this->dims.push_back(rewriter.getIndexAttr(shape[0])); + this->start = builder.getIndexAttr(start); + this->end = builder.getIndexAttr(end); + this->dims.push_back(builder.getIndexAttr(shape[0])); return success(); } LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp, const Location loc, - ConversionPatternRewriter &rewriter) { + OpBuilder &builder) { assert(this->isEmpty()); auto src = broadcastOp.getSrc(); @@ -377,14 +376,14 @@ LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp, assert(srcShape.size() == dstShape.size() && "rank of source and destination should match"); - if (failed(parse(src, loc, rewriter))) + if (failed(parse(src, loc, builder))) return failure(); for (size_t i = 0; i < srcShape.size(); i++) { if (srcShape[i] == dstShape[i]) continue; else if (srcShape[i] < dstShape[i]) - this->dims[i] = rewriter.getIndexAttr(dstShape[i]); + this->dims[i] = builder.getIndexAttr(dstShape[i]); else llvm_unreachable("unexpected dimensions used in broadcast"); } @@ -393,7 +392,7 @@ LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp, } LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc, - ConversionPatternRewriter &rewriter) { + OpBuilder &builder) { assert(this->isEmpty()); auto src = splatOp.getSrc(); @@ -407,21 +406,21 @@ LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc, return failure(); } - if (failed(this->parse(src, loc, rewriter))) + if (failed(this->parse(src, loc, builder))) return failure(); for (auto s : dstShape) - this->dims.push_back(rewriter.getIndexAttr(s)); + this->dims.push_back(builder.getIndexAttr(s)); return success(); } LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp, const Location loc, - ConversionPatternRewriter &rewriter) { + OpBuilder &builder) { assert(this->isEmpty()); - if (failed(this->parse(expandDimsOp.getSrc(), loc, rewriter))) + if (failed(this->parse(expandDimsOp.getSrc(), loc, builder))) return failure(); auto dstShape = @@ -429,7 +428,7 @@ LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp, auto axis = expandDimsOp.getAxis(); assert(dstShape[axis] == 1 && "expect changed dimension to be 1 in expand_dims"); - this->dims.insert(this->dims.begin() + axis, rewriter.getIndexAttr(1)); + this->dims.insert(this->dims.begin() + axis, builder.getIndexAttr(1)); return success(); } diff --git a/lib/AnalysisStructured/CMakeLists.txt b/lib/AnalysisStructured/CMakeLists.txt index d0e31aa8..54c462a4 100644 --- a/lib/AnalysisStructured/CMakeLists.txt +++ b/lib/AnalysisStructured/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_library(TritonSharedAnalysisStructured - MaskAnalysis.cpp PtrAnalysis.cpp DEPENDS diff --git a/lib/AnalysisStructured/MaskAnalysis.cpp b/lib/AnalysisStructured/MaskAnalysis.cpp deleted file mode 100644 index c513a83b..00000000 --- a/lib/AnalysisStructured/MaskAnalysis.cpp +++ /dev/null @@ -1,303 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation, Meta Platforms. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#include "triton-shared/AnalysisStructured/MaskAnalysis.h" -#include "triton-shared/Analysis/OpFoldResultUtils.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { - -namespace tts { - -LogicalResult MaskState::parse(Value operand, const Location loc, - OpBuilder &builder) { - if (auto op = operand.getDefiningOp()) { - return this->parseConstant(op, loc, builder); - } else if (operand.getType().isa()) { - return this->parseIntScalar(operand, loc, builder); - } else if (auto op = operand.getDefiningOp()) { - return this->parseAdd(op, loc, builder); - } else if (auto op = operand.getDefiningOp()) { - return this->parseAnd(op, loc, builder); - } else if (auto op = operand.getDefiningOp()) { - return this->parseCmp(op, loc, builder); - } else if (auto op = operand.getDefiningOp()) { - return this->parseMakeRange(op, loc, builder); - } else if (auto op = operand.getDefiningOp()) { - return this->parseBroadcast(op, loc, builder); - } else if (auto op = operand.getDefiningOp()) { - return this->parseSplat(op, loc, builder); - } else if (auto op = operand.getDefiningOp()) { - return this->parseExpandDims(op, loc, builder); - } else { - return failure(); - } -} - -LogicalResult MaskState::addStateScalar(const MaskState &state, - const OpFoldResult scalar, Location loc, - OpBuilder &builder) { - start = addOFRs(state.start, scalar, loc, builder); - end = addOFRs(state.end, scalar, loc, builder); - dims = state.dims; - return success(); -} - -LogicalResult MaskState::addStates(const MaskState &lhsState, - const MaskState &rhsState, Location loc, - OpBuilder &builder) { - if (lhsState.scalar && rhsState.scalar) { - InFlightDiagnostic diag = - emitError(loc) << "Unexpected case where both lhs and rhs are scalars"; - return failure(); - } - - if (!lhsState.scalar && !rhsState.scalar) { - InFlightDiagnostic diag = - emitError(loc) - << "Unsupported scenario where neither lhs nor rhs is a scalar"; - return failure(); - } - - if (lhsState.scalar) - return addStateScalar(rhsState, lhsState.scalar, loc, builder); - else - return addStateScalar(lhsState, rhsState.scalar, loc, builder); -} - -LogicalResult MaskState::minStates(const MaskState &lhsState, - const MaskState &rhsState, Location loc, - OpBuilder &builder) { - if (lhsState.getRank() != rhsState.getRank()) { - InFlightDiagnostic diag = - emitError(loc) - << "Unexpected case where lhs and rhs have different ranks"; - return failure(); - } - - for (uint32_t i = 0; i < lhsState.getRank(); i++) { - auto lhsDim = lhsState.dims[i]; - auto rhsDim = rhsState.dims[i]; - dims.push_back(minOFRs(lhsDim, rhsDim, loc, builder)); - } - return success(); -} - -LogicalResult MaskState::parseConstant(arith::ConstantOp constOp, - const Location loc, OpBuilder &builder) { - assert(this->isEmpty()); - - if (isa(constOp.getValue())) { - auto attr = cast(constOp.getValue()); - auto elementType = attr.getElementType(); - assert(attr.isSplat() && elementType.isa() && - "All elements must share a single integer constant value"); - auto values = attr.getValues(); - auto value = values[0].getValue(); - auto constAttr = builder.getIndexAttr(value.getSExtValue()); - auto op = arith::ConstantOp::materialize(builder, constAttr, - builder.getIndexType(), loc); - this->scalar = op.getValue(); - } else { - auto value = constOp.getValue().cast().getInt(); - this->scalar = builder.getIndexAttr(value); - } - - return success(); -} - -LogicalResult MaskState::parseIntScalar(Value scalar, const Location loc, - OpBuilder &builder) { - assert(this->isEmpty()); - auto castOp = - builder.create(loc, builder.getIndexType(), scalar); - this->scalar = castOp.getResult(); - return success(); -} - -LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location loc, - OpBuilder &builder) { - assert(this->isEmpty()); - - MaskState lhsState; - if (failed(lhsState.parse(addOp.getLhs(), loc, builder))) - return failure(); - - MaskState rhsState; - if (failed(rhsState.parse(addOp.getRhs(), loc, builder))) - return failure(); - - return this->addStates(lhsState, rhsState, loc, builder); -} - -LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc, - OpBuilder &builder) { - assert(this->isEmpty()); - - MaskState lhsState; - if (failed(lhsState.parse(andOp.getLhs(), loc, builder)) || - !lhsState.isMask()) - return failure(); - - MaskState rhsState; - if (failed(rhsState.parse(andOp.getRhs(), loc, builder)) || - !rhsState.isMask()) - return failure(); - - return this->minStates(lhsState, rhsState, loc, builder); -} - -LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, - OpBuilder &builder) { - assert(this->isEmpty()); - - if (cmpOp.getPredicate() != arith::CmpIPredicate::slt) { - InFlightDiagnostic diag = emitError(loc) << "Unsupported cmpi predicate"; - return failure(); - } - - MaskState lhsState; - if (failed(lhsState.parse(cmpOp.getLhs(), loc, builder))) - return failure(); - - MaskState rhsState; - if (failed(rhsState.parse(cmpOp.getRhs(), loc, builder))) - return failure(); - - assert((!lhsState.scalar && rhsState.scalar) && "Unsupported cmpi scenario"); - - int32_t cmpDim = -1; - for (int32_t i = 0; i < lhsState.getRank(); i++) { - auto dimIntAttr = getIntAttr(lhsState.dims[i]); - if (!dimIntAttr || dimIntAttr.value() != 1) { - if (cmpDim != -1) { - InFlightDiagnostic diag = emitError(loc) - << "Unsupported cmpi with more than one " - "dimension with size larger than 1"; - return failure(); - } - cmpDim = i; - } - } - assert(cmpDim != -1 && - "Unexpected case where no dimension has size larger than 1"); - - auto newEnd = minOFRs(lhsState.end, rhsState.scalar, loc, builder); - auto newDim = subOFRs(newEnd, lhsState.start, loc, builder); - - for (int32_t i = 0; i < lhsState.getRank(); i++) { - if (i == cmpDim) - this->dims.push_back(newDim); - else - this->dims.push_back(lhsState.dims[i]); - } - - return success(); -} - -LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp, - const Location loc, - OpBuilder &builder) { - assert(this->isEmpty()); - - auto shape = rangeOp.getType().cast().getShape(); - auto start = rangeOp.getStart(); - auto end = rangeOp.getEnd(); - auto stride = (end - start + shape[0] - 1) / shape[0]; - - if (stride != 1) { - InFlightDiagnostic diag = - emitError(loc) - << "stride must be 1 for make_range whose result is used " - "as load or store masks"; - return failure(); - } - - this->start = builder.getIndexAttr(start); - this->end = builder.getIndexAttr(end); - this->dims.push_back(builder.getIndexAttr(shape[0])); - - return success(); -} - -LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp, - const Location loc, - OpBuilder &builder) { - assert(this->isEmpty()); - - auto src = broadcastOp.getSrc(); - auto dst = broadcastOp.getResult(); - assert(src.getType().isa() && - "input to tt.broadcast should be a tensor"); - - auto srcShape = src.getType().cast().getShape(); - auto dstShape = dst.getType().cast().getShape(); - assert(srcShape.size() == dstShape.size() && - "rank of source and destination should match"); - - if (failed(parse(src, loc, builder))) - return failure(); - - for (size_t i = 0; i < srcShape.size(); i++) { - if (srcShape[i] == dstShape[i]) - continue; - else if (srcShape[i] < dstShape[i]) - this->dims[i] = builder.getIndexAttr(dstShape[i]); - else - llvm_unreachable("unexpected dimensions used in broadcast"); - } - - return success(); -} - -LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc, - OpBuilder &builder) { - assert(this->isEmpty()); - - auto src = splatOp.getSrc(); - auto dst = splatOp.getResult(); - auto dstShape = dst.getType().cast().getShape(); - - if (!src.getType().isa()) { - InFlightDiagnostic diag = - emitError(loc) - << "splat source must be an integer scalar for load/store masks"; - return failure(); - } - - if (failed(this->parse(src, loc, builder))) - return failure(); - - for (auto s : dstShape) - this->dims.push_back(builder.getIndexAttr(s)); - - return success(); -} - -LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp, - const Location loc, - OpBuilder &builder) { - assert(this->isEmpty()); - - if (failed(this->parse(expandDimsOp.getSrc(), loc, builder))) - return failure(); - - auto dstShape = - expandDimsOp.getResult().getType().cast().getShape(); - auto axis = expandDimsOp.getAxis(); - assert(dstShape[axis] == 1 && - "expect changed dimension to be 1 in expand_dims"); - this->dims.insert(this->dims.begin() + axis, builder.getIndexAttr(1)); - - return success(); -} - -} // namespace triton -} // namespace mlir diff --git a/lib/AnalysisStructured/PtrAnalysis.cpp b/lib/AnalysisStructured/PtrAnalysis.cpp index e050c433..02c41977 100644 --- a/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/lib/AnalysisStructured/PtrAnalysis.cpp @@ -6,7 +6,7 @@ //===----------------------------------------------------------------------===// #include "triton-shared/AnalysisStructured/PtrAnalysis.h" -#include "triton-shared/AnalysisStructured/MaskAnalysis.h" +#include "triton-shared/Analysis/MaskAnalysis.h" #include "triton-shared/Analysis/OpFoldResultUtils.h" #include "mlir/IR/IRMapping.h" @@ -1019,7 +1019,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) { } ArrayRef dims; - MaskState mstate; + mlir::triton::MaskState mstate; Value scalarOther; OpBuilder builder(op); @@ -1074,7 +1074,7 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) { } ArrayRef dims; - MaskState mstate; + mlir::triton::MaskState mstate; OpBuilder builder(op);