Skip to content

Commit

Permalink
[IR] Add typing for tensor descriptor types (#5147)
Browse files Browse the repository at this point in the history
Currently tensor descriptors are just typed as `!tt.ptr<i8>` which is
exposing the assumption it's using a TMA descriptor. This changes it to
a custom type `!tt.tensordesc<tensor<...>>` which is lowered to a
pointer type in the LLVM IR.
    
I also add two new IR Ops which are used to cast between pointers and
tensordesc objects.
```mlir
tt.reinterpret_tensor_descriptor %ptr : !tt.ptr<i8> to !tt.tensordesc<...>
triton_nvidia_gpu.tensor_desc_to_tma_ptr %desc : !tt.tensordesc<...> -> !tt.ptr<i8>
```
    
Really both of these should be nvidia-specific but the first is exposed
in the triton IR to keep support for the by-value TMA descriptor API
around while we figure out if it's possible to update to the new style.
  • Loading branch information
peterbell10 authored Nov 15, 2024
1 parent 38c6284 commit 1cf06c5
Show file tree
Hide file tree
Showing 21 changed files with 263 additions and 105 deletions.
64 changes: 40 additions & 24 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -956,9 +956,10 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
//
// Make Tensor Descriptor Op
//
def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor",
[Pure,
SameVariadicOperandSize]> {
def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
Pure,
SameVariadicOperandSize,
]> {
let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size";

let description = [{
Expand All @@ -969,23 +970,38 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor",
let arguments = (ins
TT_Ptr:$base,
Variadic<I32>:$shape,
Variadic<I64>:$strides,
DenseI32ArrayAttr:$tensorShape
Variadic<I64>:$strides
);

// TODO(peterbell10): define a custom IR type to represent descriptors
let results = (outs TT_Ptr:$result);
let results = (outs TT_TensorDescType:$result);

let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";

let builders = [
OpBuilder<(ins
"Value":$base,
"ValueRange":$shape,
"ValueRange":$strides,
"ArrayRef<int32_t>":$tensorShape
)>
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape)>
];

let extraClassDeclaration = [{
ArrayRef<int64_t> getTensorShape() {
return getType().getBlockType().getShape();
}
}];
}

def ReinterpretTensorDescOp : TT_Op<"reinterpret_tensor_descriptor", [Pure]> {
let summary = "Reinterpret a pointer as a tensor descriptor";

let description = [{
This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects.
Ideally, we can remove this once the APIs are fully fleshed out.
}];

let arguments = (ins TT_Ptr:$rawDesc);
let results = (outs TT_TensorDescType:$result);

let assemblyFormat = [{
$rawDesc attr-dict `:` qualified(type($rawDesc)) `to` qualified(type($result))
}];
}

// The following ops, including `call`, `func`, and `return` are copied and modified from
Expand Down Expand Up @@ -1195,20 +1211,19 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable
}


def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [
MemoryEffects<[MemRead<GlobalMemory>]>]> {
def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [MemoryEffects<[MemRead<GlobalMemory>]>]> {
let summary = "Load from descriptor";
let description = [{
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
`desc` is a tensor descriptor object.
The destination tensor type and shape must match the descriptor otherwise the result is undefined.

This is an escape hatch and is only there for testing/experimenting.
This op will be removed in the future.
}];
let arguments = (
ins
TT_PtrType:$desc_ptr,
TT_TensorDescType:$desc,
Variadic<I32>:$indices,
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict
Expand All @@ -1217,36 +1232,37 @@ def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
$desc_ptr `[` $indices `]`
$desc `[` $indices `]`
oilist(
`cacheModifier` `=` $cache |
`evictionPolicy` `=` $evict
)
attr-dict `:` qualified(type($desc_ptr)) `->` type($result)
attr-dict `:` qualified(type($desc)) `->` type($result)
}];
}

def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>]> {
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
]> {
let summary = "store value based on descriptor";
let description = [{
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
`desc` is a tensor descriptor object.
The shape and types of `src` must match the descriptor otherwise the result is undefined.

This is an escape hatch and is only there for testing/experimenting.
This op will be removed in the future.
}];
let arguments = (
ins
TT_PtrType:$desc_ptr,
TT_TensorDescType:$desc,
TT_Tensor:$src,
Variadic<I32>:$indices
);

let assemblyFormat = [{
$desc_ptr `[` $indices `]` `,` $src
attr-dict `:` qualified(type($desc_ptr)) `,` type($src)
$desc `[` $indices `]` `,` $src
attr-dict `:` qualified(type($desc)) `,` type($src)
}];
}

Expand Down
11 changes: 11 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,16 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]>
let hasCustomAssemblyFormat = 1;
}

// Result type of ExperimentalMakeTensorDescriptor
def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system";

let description = [{
A portable abstraction for nvidia-TMA descriptors.
}];

let parameters = (ins "RankedTensorType":$blockType);
let assemblyFormat = "`<` $blockType `>`";
}

#endif
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc", [MemoryEffects<[Me
$_builder.getI32IntegerAttr(nbytes), $_builder.getI32IntegerAttr(alignment)); }]>
];

let assemblyFormat = [{attr-dict `:` type($result)}];
let assemblyFormat = [{attr-dict `:` qualified(type($result))}];
}

#endif
20 changes: 20 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,26 @@ def TTNG_WaitBarrierOp : TTNG_Op<"wait_barrier", [DeclareOpInterfaceMethods<Memo
let assemblyFormat = "$alloc `,` $phase attr-dict `:` type($alloc)";
}

def TTNG_TensorDescToTMAPtrOp : TTNG_Op<"tensor_desc_to_tma_ptr", [Pure]> {
let summary = "Convert tensor descriptor to pointer to tma descriptor";

let arguments = (ins TT_TensorDescType:$desc);
let results = (outs TT_Ptr:$ptr);

let assemblyFormat = [{
$desc attr-dict `:` qualified(type($desc)) `to` qualified(type($ptr))
}];

let builders = [
OpBuilder<(ins "Value":$desc), [{
auto ptrTy = triton::PointerType::get($_builder.getI8Type(), 1);
build($_builder, $_state, ptrTy, desc);
}]>
];

let hasCanonicalizeMethod = 1;
}


def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "copy data based on descriptor from global memory to local memory asynchronously";
Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
addConversion([&](MemDescType type) -> std::optional<Type> {
return convertMemDescType(type, targetInfo);
});
addConversion([](TensorDescType type) -> std::optional<Type> {
auto ctx = type.getContext();
return LLVM::LLVMPointerType::get(ctx, 1);
});
addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional<Type> {
return convertAsyncToken(type);
});
Expand Down
16 changes: 11 additions & 5 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir {
namespace triton {
Expand Down Expand Up @@ -863,12 +864,17 @@ OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) {
//-- MakeTensorDescOp --
void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
Value base, ValueRange shape, ValueRange strides,
ArrayRef<int32_t> tensorShape) {
auto resultTy = getPointerType(builder.getI8Type());
assert(resultTy.getContext());
ArrayRef<int32_t> blockShape) {
auto ptrTy = dyn_cast<triton::PointerType>(base.getType());
if (!ptrTy) {
llvm::report_fatal_error("Expected pointer type");
}
auto elemTy = ptrTy.getPointeeType();

return build(builder, state, resultTy, base, shape, strides,
builder.getDenseI32ArrayAttr(tensorShape));
SmallVector<int64_t> blockShape64(blockShape);
auto blockTy = RankedTensorType::get(blockShape64, elemTy);
auto descTy = TensorDescType::get(builder.getContext(), blockTy);
return build(builder, state, descTy, base, shape, strides);
}

// The following ops, including `call`, `func`, and `return` are copied and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttng = mlir::triton::nvidia_gpu;

// TODO: We can extra some helpers into common utilities once we add more
// TODO: We can extract some helpers into common utilities once we add more
// schedules.

namespace {

struct LoadInfo {
// Layout of the data in the shared memory.
// Layout of the data in shared memory.
ttg::SharedEncodingAttr sharedEncoding = nullptr;
// Blocked encoding is used for loads not used by the dot.
ttg::BlockedEncodingAttr blockedEncoding = nullptr;
Expand Down Expand Up @@ -239,9 +239,11 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp,

Value pred = builder.createWithStage<arith::ConstantIntOp>(loc, stage,
clusterId, 1, 1);
Value tmaPtr =
builder.createWithStage<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
loc, stage, clusterId, loadOp.getDesc());
Operation *copy = builder.createWithStage<ttng::AsyncTMACopyGlobalToLocalOp>(
loc, stage, clusterId, loadOp.getDescPtr(), loadOp.getIndices(), barrier,
view, pred);
loc, stage, clusterId, tmaPtr, loadOp.getIndices(), barrier, view, pred);

bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ static void createTMAAsyncCopy(scf::ForOp &forOp,
builder.create<ttng::TMAStoreWait>(loc, 0);
builder.create<ttg::LocalStoreOp>(loc, storeOp.getSrc(), alloc);
builder.create<ttng::FenceAsyncSharedOp>(loc, false);
Value tmaPtr = builder.create<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
loc, storeOp.getDesc());
builder.create<ttng::AsyncTMACopyLocalToGlobalOp>(
loc, storeOp.getDescPtr(), storeOp.getIndices(), alloc);
loc, tmaPtr, storeOp.getIndices(), alloc);

storeOp->erase();
}
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ void WaitBarrierOp::getEffects(
mlir::triton::gpu::SharedMemory::get());
}

// -- TensorDescToTMAPtrOp --
LogicalResult TensorDescToTMAPtrOp::canonicalize(TensorDescToTMAPtrOp op,
PatternRewriter &rewriter) {
// tensor_desc_to_tma_ptr(reinterpret_tensor_desc(ptr)) -> ptr
if (auto reinterpret =
op.getDesc().getDefiningOp<triton::ReinterpretTensorDescOp>()) {
rewriter.replaceOp(op, reinterpret.getRawDesc());
return success();
}
return failure();
}

// -- AsyncTMACopyGlobalToLocalOp --
LogicalResult AsyncTMACopyGlobalToLocalOp::verify() {
if (failed(verifyBarrierType(*this, getBarrier().getType())))
Expand Down
12 changes: 9 additions & 3 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ class TMALoadLowering : public OpRewritePattern<ExperimentalDescriptorLoadOp> {
Value pred = rewriter.create<arith::ConstantIntOp>(loc, 1, 1);
rewriter.create<triton::nvidia_gpu::BarrierExpectOp>(loc, barrierAlloc,
sizeInBytes, pred);
Value tmaPtr = rewriter.create<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
loc, op.getDesc());
rewriter.create<triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp>(
loc, op.getDescPtr(), op.getIndices(), barrierAlloc, alloc, pred);
loc, tmaPtr, op.getIndices(), barrierAlloc, alloc, pred);
Value phase = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
rewriter.create<WaitBarrierOp>(loc, barrierAlloc, phase);
rewriter.create<InvalBarrierOp>(loc, barrierAlloc);
Expand Down Expand Up @@ -95,8 +97,10 @@ class TMAStoreLowering
encoding, sharedMemorySpace, /*mutableMemory=*/true);
Value alloc = rewriter.create<LocalAllocOp>(loc, memDescType, op.getSrc());
rewriter.create<triton::nvidia_gpu::FenceAsyncSharedOp>(loc, false);
Value tmaPtr = rewriter.create<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
loc, op.getDesc());
rewriter.create<triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp>(
loc, op.getDescPtr(), op.getIndices(), alloc);
loc, tmaPtr, op.getIndices(), alloc);
rewriter.create<triton::nvidia_gpu::TMAStoreWait>(loc, 0);
rewriter.eraseOp(op);
return success();
Expand Down Expand Up @@ -194,7 +198,9 @@ class TMACreateDescLowering : public OpRewritePattern<MakeTensorDescOp> {
/*fill_mode=*/rewriter.getI32IntegerAttr(0));
rewriter.create<triton::ExperimentalTensormapFenceproxyAcquireOp>(
loc, alloc.getResult());
rewriter.replaceOp(op, alloc);
auto newDesc = rewriter.create<triton::ReinterpretTensorDescOp>(
loc, op.getType(), alloc.getResult());
rewriter.replaceOp(op, newDesc);
return success();
}
};
Expand Down
20 changes: 14 additions & 6 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class TritonOpBuilder {
}

OpBuilder &getBuilder() { return *builder; }
MLIRContext *getContext() { return builder->getContext(); }

bool isLineInfoEnabled() { return lineInfoEnabled; }

Expand Down Expand Up @@ -1318,19 +1319,26 @@ void init_triton_ir(py::module &&m) {
self.create<StoreOp>(ptrs, val, mask, cacheModifier,
evictionPolicy);
})
.def("create_reinterpret_tensor_descriptor",
[](TritonOpBuilder &self, Value desc_ptr, Type blockTy) -> Value {
auto ctx = self.getContext();
auto resultTy = triton::TensorDescType::get(
ctx, cast<RankedTensorType>(blockTy));
return self.create<ReinterpretTensorDescOp>(resultTy, desc_ptr);
})
.def("create_descriptor_load",
[](TritonOpBuilder &self, Value desc_ptr,
std::vector<Value> &indices, Type type,
[](TritonOpBuilder &self, Value desc, std::vector<Value> &indices,
CacheModifier cacheModifier,
EvictionPolicy evictionPolicy) -> Value {
auto descTy = cast<triton::TensorDescType>(desc.getType());
auto resTy = descTy.getBlockType();
return self.create<ExperimentalDescriptorLoadOp>(
type, desc_ptr, indices, cacheModifier, evictionPolicy);
resTy, desc, indices, cacheModifier, evictionPolicy);
})
.def("create_descriptor_store",
[](TritonOpBuilder &self, Value desc_ptr, Value value,
[](TritonOpBuilder &self, Value desc, Value value,
std::vector<Value> &indices) -> void {
self.create<ExperimentalDescriptorStoreOp>(desc_ptr, value,
indices);
self.create<ExperimentalDescriptorStoreOp>(desc, value, indices);
})
.def("create_tensormap_create",
[](TritonOpBuilder &self, Value desc_ptr, Value global_address,
Expand Down
2 changes: 2 additions & 0 deletions python/triton/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_experimental_descriptor_load,
_experimental_descriptor_store,
_experimental_make_tensor_descriptor,
_experimental_reinterpret_tensor_descriptor,
_experimental_tensor_descriptor,
add,
advance,
Expand Down Expand Up @@ -129,6 +130,7 @@
"_experimental_descriptor_load",
"_experimental_descriptor_store",
"_experimental_make_tensor_descriptor",
"_experimental_reinterpret_tensor_descriptor",
"_experimental_tensor_descriptor",
"abs",
"add",
Expand Down
Loading

0 comments on commit 1cf06c5

Please sign in to comment.