Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Add typing for tensor descriptor types #5147

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -952,9 +952,10 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
//
// Make Tensor Descriptor Op
//
def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor",
[Pure,
SameVariadicOperandSize]> {
def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
Pure,
SameVariadicOperandSize,
]> {
let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size";

let description = [{
Expand All @@ -965,23 +966,38 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor",
let arguments = (ins
TT_Ptr:$base,
Variadic<I32>:$shape,
Variadic<I64>:$strides,
DenseI32ArrayAttr:$tensorShape
Variadic<I64>:$strides
);

// TODO(peterbell10): define a custom IR type to represent descriptors
let results = (outs TT_Ptr:$result);
let results = (outs TT_TensorDescType:$result);

let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";

let builders = [
OpBuilder<(ins
"Value":$base,
"ValueRange":$shape,
"ValueRange":$strides,
"ArrayRef<int32_t>":$tensorShape
)>
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape)>
];

let extraClassDeclaration = [{
ArrayRef<int64_t> getTensorShape() {
return getType().getBlockType().getShape();
}
}];
}

def ReinterpretTensorDescOp : TT_Op<"reinterpret_tensor_descriptor", [Pure]> {
let summary = "Reinterpret a pointer as a tensor descriptor";

let description = [{
This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects.
Ideally, we can remove this once the APIs are fully fleshed out.
}];

let arguments = (ins TT_Ptr:$rawDesc);
let results = (outs TT_TensorDescType:$result);

let assemblyFormat = [{
$rawDesc attr-dict `:` qualified(type($rawDesc)) `to` qualified(type($result))
}];
}

// The following ops, including `call`, `func`, and `return` are copied and modified from
Expand Down Expand Up @@ -1191,20 +1207,19 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable
}


def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [
MemoryEffects<[MemRead<GlobalMemory>]>]> {
def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [MemoryEffects<[MemRead<GlobalMemory>]>]> {
let summary = "Load from descriptor";
let description = [{
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
`desc` is a tensor descriptor object.
The destination tensor type and shape must match the descriptor otherwise the result is undefined.

This is an escape hatch and is only there for testing/experimenting.
This op will be removed in the future.
}];
let arguments = (
ins
TT_PtrType:$desc_ptr,
TT_TensorDescType:$desc,
Variadic<I32>:$indices,
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict
Expand All @@ -1213,36 +1228,37 @@ def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
$desc_ptr `[` $indices `]`
$desc `[` $indices `]`
oilist(
`cacheModifier` `=` $cache |
`evictionPolicy` `=` $evict
)
attr-dict `:` qualified(type($desc_ptr)) `->` type($result)
attr-dict `:` qualified(type($desc)) `->` type($result)
}];
}

def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>]> {
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
]> {
let summary = "store value based on descriptor";
let description = [{
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
`desc` is a tensor descriptor object.
The shape and types of `src` must match the descriptor otherwise the result is undefined.

This is an escape hatch and is only there for testing/experimenting.
This op will be removed in the future.
}];
let arguments = (
ins
TT_PtrType:$desc_ptr,
TT_TensorDescType:$desc,
TT_Tensor:$src,
Variadic<I32>:$indices
);

let assemblyFormat = [{
$desc_ptr `[` $indices `]` `,` $src
attr-dict `:` qualified(type($desc_ptr)) `,` type($src)
$desc `[` $indices `]` `,` $src
attr-dict `:` qualified(type($desc)) `,` type($src)
}];
}

Expand Down
11 changes: 11 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,16 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]>
let hasCustomAssemblyFormat = 1;
}

// Result type of ExperimentalMakeTensorDescriptor
def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system";

let description = [{
A portable abstraction for nvidia-TMA descriptors.
}];

let parameters = (ins "RankedTensorType":$blockType);
let assemblyFormat = "`<` $blockType `>`";
}

#endif
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc", [MemoryEffects<[Me
$_builder.getI32IntegerAttr(nbytes), $_builder.getI32IntegerAttr(alignment)); }]>
];

let assemblyFormat = [{attr-dict `:` type($result)}];
let assemblyFormat = [{attr-dict `:` qualified(type($result))}];
}

#endif
20 changes: 20 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,26 @@ def TTNG_WaitBarrierOp : TTNG_Op<"wait_barrier", [DeclareOpInterfaceMethods<Memo
let assemblyFormat = "$alloc `,` $phase attr-dict `:` type($alloc)";
}

def TTNG_TensorDescToTMAPtrOp : TTNG_Op<"tensor_desc_to_tma_ptr", [Pure]> {
let summary = "Convert tensor descriptor to pointer to tma descriptor";

let arguments = (ins TT_TensorDescType:$desc);
let results = (outs TT_Ptr:$ptr);

let assemblyFormat = [{
$desc attr-dict `:` qualified(type($desc)) `to` qualified(type($ptr))
}];

let builders = [
OpBuilder<(ins "Value":$desc), [{
auto ptrTy = triton::PointerType::get($_builder.getI8Type(), 1);
build($_builder, $_state, ptrTy, desc);
}]>
];

let hasCanonicalizeMethod = 1;
}


def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "copy data based on descriptor from global memory to local memory asynchronously";
Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
addConversion([&](MemDescType type) -> std::optional<Type> {
return convertMemDescType(type, targetInfo);
});
addConversion([](TensorDescType type) -> std::optional<Type> {
auto ctx = type.getContext();
return LLVM::LLVMPointerType::get(ctx, 1);
});
addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional<Type> {
return convertAsyncToken(type);
});
Expand Down
16 changes: 11 additions & 5 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir {
namespace triton {
Expand Down Expand Up @@ -839,12 +840,17 @@ OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) {
//-- MakeTensorDescOp --
void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
Value base, ValueRange shape, ValueRange strides,
ArrayRef<int32_t> tensorShape) {
auto resultTy = getPointerType(builder.getI8Type());
assert(resultTy.getContext());
ArrayRef<int32_t> blockShape) {
auto ptrTy = dyn_cast<triton::PointerType>(base.getType());
if (!ptrTy) {
llvm::report_fatal_error("Expected pointer type");
}
auto elemTy = ptrTy.getPointeeType();

return build(builder, state, resultTy, base, shape, strides,
builder.getDenseI32ArrayAttr(tensorShape));
SmallVector<int64_t> blockShape64(blockShape);
auto blockTy = RankedTensorType::get(blockShape64, elemTy);
auto descTy = TensorDescType::get(builder.getContext(), blockTy);
return build(builder, state, descTy, base, shape, strides);
}

// The following ops, including `call`, `func`, and `return` are copied and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttng = mlir::triton::nvidia_gpu;

// TODO: We can extra some helpers into common utilities once we add more
// TODO: We can extract some helpers into common utilities once we add more
// schedules.

namespace {

struct LoadInfo {
// Layout of the data in the shared memory.
// Layout of the data in shared memory.
ttg::SharedEncodingAttr sharedEncoding = nullptr;
// Blocked encoding is used for loads not used by the dot.
ttg::BlockedEncodingAttr blockedEncoding = nullptr;
Expand Down Expand Up @@ -239,9 +239,11 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp,

Value pred = builder.createWithStage<arith::ConstantIntOp>(loc, stage,
clusterId, 1, 1);
Value tmaPtr =
builder.createWithStage<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
loc, stage, clusterId, loadOp.getDesc());
Operation *copy = builder.createWithStage<ttng::AsyncTMACopyGlobalToLocalOp>(
loc, stage, clusterId, loadOp.getDescPtr(), loadOp.getIndices(), barrier,
view, pred);
loc, stage, clusterId, tmaPtr, loadOp.getIndices(), barrier, view, pred);

bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ static void createTMAAsyncCopy(scf::ForOp &forOp,
builder.create<ttng::TMAStoreWait>(loc, 0);
builder.create<ttg::LocalStoreOp>(loc, storeOp.getSrc(), alloc);
builder.create<ttng::FenceAsyncSharedOp>(loc, false);
Value tmaPtr = builder.create<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
loc, storeOp.getDesc());
builder.create<ttng::AsyncTMACopyLocalToGlobalOp>(
loc, storeOp.getDescPtr(), storeOp.getIndices(), alloc);
loc, tmaPtr, storeOp.getIndices(), alloc);

storeOp->erase();
}
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ void WaitBarrierOp::getEffects(
mlir::triton::gpu::SharedMemory::get());
}

// -- TensorDescToTMAPtrOp --
LogicalResult TensorDescToTMAPtrOp::canonicalize(TensorDescToTMAPtrOp op,
PatternRewriter &rewriter) {
// tensor_desc_to_tma_ptr(reinterpret_tensor_desc(ptr)) -> ptr
if (auto reinterpret =
op.getDesc().getDefiningOp<triton::ReinterpretTensorDescOp>()) {
rewriter.replaceOp(op, reinterpret.getRawDesc());
return success();
}
return failure();
}
Comment on lines +164 to +173
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could be a folder (::fold()

Copy link
Contributor Author

@peterbell10 peterbell10 Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Welp... I tried this out and it revealed there's a bit of a bug here. tensor_desc_to_tma_ptr returns a pointer to global memory, but by-val TMA is represented as a pointer to the generic address space.

On the one hand this is a real concern because the by-val descriptor lives in parameter space, not global memory so the pointer type is wrong. On the other hand we only use the result for TMA operations which are inline assembly anyway so it makes no practical difference.

I'll just leave it as a canonicalizer for now, but in future I should figure out how to return the correct type from tensor_desc_to_tma_ptr.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok that works, isn't there a type mistmatch still? I guess it doesn't show up because the user can take either memory space?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, the only users of this op are the AsyncTMA ops, neither of which actually care about the address space.


// -- AsyncTMACopyGlobalToLocalOp --
LogicalResult AsyncTMACopyGlobalToLocalOp::verify() {
if (failed(verifyBarrierType(*this, getBarrier().getType())))
Expand Down
12 changes: 9 additions & 3 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ class TMALoadLowering : public OpRewritePattern<ExperimentalDescriptorLoadOp> {
Value pred = rewriter.create<arith::ConstantIntOp>(loc, 1, 1);
rewriter.create<triton::nvidia_gpu::BarrierExpectOp>(loc, barrierAlloc,
sizeInBytes, pred);
Value tmaPtr = rewriter.create<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
loc, op.getDesc());
rewriter.create<triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp>(
loc, op.getDescPtr(), op.getIndices(), barrierAlloc, alloc, pred);
loc, tmaPtr, op.getIndices(), barrierAlloc, alloc, pred);
Value phase = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
rewriter.create<WaitBarrierOp>(loc, barrierAlloc, phase);
rewriter.create<InvalBarrierOp>(loc, barrierAlloc);
Expand Down Expand Up @@ -95,8 +97,10 @@ class TMAStoreLowering
encoding, sharedMemorySpace, /*mutableMemory=*/true);
Value alloc = rewriter.create<LocalAllocOp>(loc, memDescType, op.getSrc());
rewriter.create<triton::nvidia_gpu::FenceAsyncSharedOp>(loc, false);
Value tmaPtr = rewriter.create<triton::nvidia_gpu::TensorDescToTMAPtrOp>(
loc, op.getDesc());
rewriter.create<triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp>(
loc, op.getDescPtr(), op.getIndices(), alloc);
loc, tmaPtr, op.getIndices(), alloc);
rewriter.create<triton::nvidia_gpu::TMAStoreWait>(loc, 0);
rewriter.eraseOp(op);
return success();
Expand Down Expand Up @@ -194,7 +198,9 @@ class TMACreateDescLowering : public OpRewritePattern<MakeTensorDescOp> {
/*fill_mode=*/rewriter.getI32IntegerAttr(0));
rewriter.create<triton::ExperimentalTensormapFenceproxyAcquireOp>(
loc, alloc.getResult());
rewriter.replaceOp(op, alloc);
auto newDesc = rewriter.create<triton::ReinterpretTensorDescOp>(
loc, op.getType(), alloc.getResult());
rewriter.replaceOp(op, newDesc);
return success();
}
};
Expand Down
20 changes: 14 additions & 6 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class TritonOpBuilder {
}

OpBuilder &getBuilder() { return *builder; }
MLIRContext *getContext() { return builder->getContext(); }

bool isLineInfoEnabled() { return lineInfoEnabled; }

Expand Down Expand Up @@ -1318,19 +1319,26 @@ void init_triton_ir(py::module &&m) {
self.create<StoreOp>(ptrs, val, mask, cacheModifier,
evictionPolicy);
})
.def("create_reinterpret_tensor_descriptor",
[](TritonOpBuilder &self, Value desc_ptr, Type blockTy) -> Value {
auto ctx = self.getContext();
auto resultTy = triton::TensorDescType::get(
ctx, cast<RankedTensorType>(blockTy));
return self.create<ReinterpretTensorDescOp>(resultTy, desc_ptr);
})
.def("create_descriptor_load",
[](TritonOpBuilder &self, Value desc_ptr,
std::vector<Value> &indices, Type type,
[](TritonOpBuilder &self, Value desc, std::vector<Value> &indices,
CacheModifier cacheModifier,
EvictionPolicy evictionPolicy) -> Value {
auto descTy = cast<triton::TensorDescType>(desc.getType());
auto resTy = descTy.getBlockType();
return self.create<ExperimentalDescriptorLoadOp>(
type, desc_ptr, indices, cacheModifier, evictionPolicy);
resTy, desc, indices, cacheModifier, evictionPolicy);
})
.def("create_descriptor_store",
[](TritonOpBuilder &self, Value desc_ptr, Value value,
[](TritonOpBuilder &self, Value desc, Value value,
std::vector<Value> &indices) -> void {
self.create<ExperimentalDescriptorStoreOp>(desc_ptr, value,
indices);
self.create<ExperimentalDescriptorStoreOp>(desc, value, indices);
})
.def("create_tensormap_create",
[](TritonOpBuilder &self, Value desc_ptr, Value global_address,
Expand Down
2 changes: 2 additions & 0 deletions python/triton/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_experimental_descriptor_load,
_experimental_descriptor_store,
_experimental_make_tensor_descriptor,
_experimental_reinterpret_tensor_descriptor,
_experimental_tensor_descriptor,
add,
advance,
Expand Down Expand Up @@ -129,6 +130,7 @@
"_experimental_descriptor_load",
"_experimental_descriptor_store",
"_experimental_make_tensor_descriptor",
"_experimental_reinterpret_tensor_descriptor",
"_experimental_tensor_descriptor",
"abs",
"add",
Expand Down
Loading
Loading