From 1cf06c5e1982eba8f17062e1c6c3d3fa458597b2 Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Fri, 15 Nov 2024 02:49:13 +0000 Subject: [PATCH] [IR] Add typing for tensor descriptor types (#5147) Currently tensor descriptors are just typed as `!tt.ptr` which is exposing the assumption it's using a TMA descriptor. This changes it to a custom type `!tt.tensordesc>` which is lowered to a pointer type in the LLVM IR. I also add two new IR Ops which are used to cast between pointers and tensordesc objects. ```mlir tt.reinterpret_tensor_descriptor %ptr : !tt.ptr to !tt.tensordesc<...> triton_nvidia_gpu.tensor_desc_to_tma_ptr %desc : !tt.tensordesc<...> -> !tt.ptr ``` Really both of these should be nvidia-specific but the first is exposed in the triton IR to keep support for the by-value TMA descriptor API around while we figure out if it's possible to update to the new style. --- include/triton/Dialect/Triton/IR/TritonOps.td | 64 ++++++++++------ .../triton/Dialect/Triton/IR/TritonTypes.td | 11 +++ .../Dialect/TritonGPU/IR/TritonGPUOps.td | 2 +- .../TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td | 20 +++++ .../TritonGPUToLLVM/TypeConverter.cpp | 4 + lib/Dialect/Triton/IR/Ops.cpp | 16 ++-- .../Pipeliner/MatmulLoopPipeline.cpp | 10 ++- .../Pipeliner/TMAStoresPipeline.cpp | 4 +- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 12 +++ .../Transforms/TMALowering.cpp | 12 ++- python/src/ir.cc | 20 +++-- python/triton/language/__init__.py | 2 + python/triton/language/core.py | 76 ++++++++++++------- python/triton/language/semantic.py | 18 +++-- test/Triton/ops.mlir | 6 +- test/TritonGPU/global_scratch_alloc.mlir | 8 +- test/TritonGPU/loop-pipeline-cuda.mlir | 6 +- test/TritonGPU/loop-pipeline-hopper.mlir | 13 ++-- test/TritonNvidiaGPU/membar.mlir | 8 +- test/TritonNvidiaGPU/tma_lowering.mlir | 19 +++-- .../lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp | 37 ++++++++- 21 files changed, 263 insertions(+), 105 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 43e4ac027105..31456f23ae41 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -956,9 +956,10 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", // // Make Tensor Descriptor Op // -def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", - [Pure, - SameVariadicOperandSize]> { +def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [ + Pure, + SameVariadicOperandSize, +]> { let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size"; let description = [{ @@ -969,23 +970,38 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", let arguments = (ins TT_Ptr:$base, Variadic:$shape, - Variadic:$strides, - DenseI32ArrayAttr:$tensorShape + Variadic:$strides ); - // TODO(peterbell10): define a custom IR type to represent descriptors - let results = (outs TT_Ptr:$result); + let results = (outs TT_TensorDescType:$result); let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)"; let builders = [ - OpBuilder<(ins - "Value":$base, - "ValueRange":$shape, - "ValueRange":$strides, - "ArrayRef":$tensorShape - )> + OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef":$blockShape)> ]; + + let extraClassDeclaration = [{ + ArrayRef getTensorShape() { + return getType().getBlockType().getShape(); + } + }]; +} + +def ReinterpretTensorDescOp : TT_Op<"reinterpret_tensor_descriptor", [Pure]> { + let summary = "Reinterpret a pointer as a tensor descriptor"; + + let description = [{ + This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects. + Ideally, we can remove this once the APIs are fully fleshed out. + }]; + + let arguments = (ins TT_Ptr:$rawDesc); + let results = (outs TT_TensorDescType:$result); + + let assemblyFormat = [{ + $rawDesc attr-dict `:` qualified(type($rawDesc)) `to` qualified(type($result)) + }]; } // The following ops, including `call`, `func`, and `return` are copied and modified from @@ -1195,12 +1211,11 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable } -def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ - MemoryEffects<[MemRead]>]> { +def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [MemoryEffects<[MemRead]>]> { let summary = "Load from descriptor"; let description = [{ This operation will be lowered to Nvidia TMA load operation on targets supporting it. - `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + `desc` is a tensor descriptor object. The destination tensor type and shape must match the descriptor otherwise the result is undefined. This is an escape hatch and is only there for testing/experimenting. @@ -1208,7 +1223,7 @@ def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ }]; let arguments = ( ins - TT_PtrType:$desc_ptr, + TT_TensorDescType:$desc, Variadic:$indices, DefaultValuedAttr:$cache, DefaultValuedAttr:$evict @@ -1217,21 +1232,22 @@ def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ let results = (outs TT_Tensor:$result); let assemblyFormat = [{ - $desc_ptr `[` $indices `]` + $desc `[` $indices `]` oilist( `cacheModifier` `=` $cache | `evictionPolicy` `=` $evict ) - attr-dict `:` qualified(type($desc_ptr)) `->` type($result) + attr-dict `:` qualified(type($desc)) `->` type($result) }]; } def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ - MemoryEffects<[MemRead, MemWrite]>]> { + MemoryEffects<[MemRead, MemWrite]>, +]> { let summary = "store value based on descriptor"; let description = [{ This operation will be lowered to Nvidia TMA store operation on targets supporting it. - `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + `desc` is a tensor descriptor object. The shape and types of `src` must match the descriptor otherwise the result is undefined. This is an escape hatch and is only there for testing/experimenting. @@ -1239,14 +1255,14 @@ def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ }]; let arguments = ( ins - TT_PtrType:$desc_ptr, + TT_TensorDescType:$desc, TT_Tensor:$src, Variadic:$indices ); let assemblyFormat = [{ - $desc_ptr `[` $indices `]` `,` $src - attr-dict `:` qualified(type($desc_ptr)) `,` type($src) + $desc `[` $indices `]` `,` $src + attr-dict `:` qualified(type($desc)) `,` type($src) }]; } diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index 4c709cd4420b..98f8e570a9d7 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -140,5 +140,16 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> let hasCustomAssemblyFormat = 1; } +// Result type of ExperimentalMakeTensorDescriptor +def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> { + let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system"; + + let description = [{ + A portable abstraction for nvidia-TMA descriptors. + }]; + + let parameters = (ins "RankedTensorType":$blockType); + let assemblyFormat = "`<` $blockType `>`"; +} #endif diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 81e90bab62d9..8983fae24da1 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -295,7 +295,7 @@ def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc", [MemoryEffects<[Me $_builder.getI32IntegerAttr(nbytes), $_builder.getI32IntegerAttr(alignment)); }]> ]; - let assemblyFormat = [{attr-dict `:` type($result)}]; + let assemblyFormat = [{attr-dict `:` qualified(type($result))}]; } #endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 243b934367ad..e257e8feadb7 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -185,6 +185,26 @@ def TTNG_WaitBarrierOp : TTNG_Op<"wait_barrier", [DeclareOpInterfaceMethods { + let summary = "Convert tensor descriptor to pointer to tma descriptor"; + + let arguments = (ins TT_TensorDescType:$desc); + let results = (outs TT_Ptr:$ptr); + + let assemblyFormat = [{ + $desc attr-dict `:` qualified(type($desc)) `to` qualified(type($ptr)) + }]; + + let builders = [ + OpBuilder<(ins "Value":$desc), [{ + auto ptrTy = triton::PointerType::get($_builder.getI8Type(), 1); + build($_builder, $_state, ptrTy, desc); + }]> + ]; + + let hasCanonicalizeMethod = 1; +} + def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [DeclareOpInterfaceMethods]> { let summary = "copy data based on descriptor from global memory to local memory asynchronously"; diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 8cac1efbff8b..fee10296c89e 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -28,6 +28,10 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( addConversion([&](MemDescType type) -> std::optional { return convertMemDescType(type, targetInfo); }); + addConversion([](TensorDescType type) -> std::optional { + auto ctx = type.getContext(); + return LLVM::LLVMPointerType::get(ctx, 1); + }); addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional { return convertAsyncToken(type); }); diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index a5ef8a487e09..1ac2d8cb53f8 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -8,6 +8,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/Support/ErrorHandling.h" namespace mlir { namespace triton { @@ -863,12 +864,17 @@ OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { //-- MakeTensorDescOp -- void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state, Value base, ValueRange shape, ValueRange strides, - ArrayRef tensorShape) { - auto resultTy = getPointerType(builder.getI8Type()); - assert(resultTy.getContext()); + ArrayRef blockShape) { + auto ptrTy = dyn_cast(base.getType()); + if (!ptrTy) { + llvm::report_fatal_error("Expected pointer type"); + } + auto elemTy = ptrTy.getPointeeType(); - return build(builder, state, resultTy, base, shape, strides, - builder.getDenseI32ArrayAttr(tensorShape)); + SmallVector blockShape64(blockShape); + auto blockTy = RankedTensorType::get(blockShape64, elemTy); + auto descTy = TensorDescType::get(builder.getContext(), blockTy); + return build(builder, state, descTy, base, shape, strides); } // The following ops, including `call`, `func`, and `return` are copied and diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 7029ae6afe7d..9f5bec98503c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -34,13 +34,13 @@ namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; namespace ttng = mlir::triton::nvidia_gpu; -// TODO: We can extra some helpers into common utilities once we add more +// TODO: We can extract some helpers into common utilities once we add more // schedules. namespace { struct LoadInfo { - // Layout of the data in the shared memory. + // Layout of the data in shared memory. ttg::SharedEncodingAttr sharedEncoding = nullptr; // Blocked encoding is used for loads not used by the dot. ttg::BlockedEncodingAttr blockedEncoding = nullptr; @@ -239,9 +239,11 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value pred = builder.createWithStage(loc, stage, clusterId, 1, 1); + Value tmaPtr = + builder.createWithStage( + loc, stage, clusterId, loadOp.getDesc()); Operation *copy = builder.createWithStage( - loc, stage, clusterId, loadOp.getDescPtr(), loadOp.getIndices(), barrier, - view, pred); + loc, stage, clusterId, tmaPtr, loadOp.getIndices(), barrier, view, pred); bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp index 7985d25b9097..1cc3df7ec39b 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -63,8 +63,10 @@ static void createTMAAsyncCopy(scf::ForOp &forOp, builder.create(loc, 0); builder.create(loc, storeOp.getSrc(), alloc); builder.create(loc, false); + Value tmaPtr = builder.create( + loc, storeOp.getDesc()); builder.create( - loc, storeOp.getDescPtr(), storeOp.getIndices(), alloc); + loc, tmaPtr, storeOp.getIndices(), alloc); storeOp->erase(); } diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 37c69eef8adb..92d9b589a280 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -160,6 +160,18 @@ void WaitBarrierOp::getEffects( mlir::triton::gpu::SharedMemory::get()); } +// -- TensorDescToTMAPtrOp -- +LogicalResult TensorDescToTMAPtrOp::canonicalize(TensorDescToTMAPtrOp op, + PatternRewriter &rewriter) { + // tensor_desc_to_tma_ptr(reinterpret_tensor_desc(ptr)) -> ptr + if (auto reinterpret = + op.getDesc().getDefiningOp()) { + rewriter.replaceOp(op, reinterpret.getRawDesc()); + return success(); + } + return failure(); +} + // -- AsyncTMACopyGlobalToLocalOp -- LogicalResult AsyncTMACopyGlobalToLocalOp::verify() { if (failed(verifyBarrierType(*this, getBarrier().getType()))) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 4f928dcaf82f..cb9ae9dd0f3c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -60,8 +60,10 @@ class TMALoadLowering : public OpRewritePattern { Value pred = rewriter.create(loc, 1, 1); rewriter.create(loc, barrierAlloc, sizeInBytes, pred); + Value tmaPtr = rewriter.create( + loc, op.getDesc()); rewriter.create( - loc, op.getDescPtr(), op.getIndices(), barrierAlloc, alloc, pred); + loc, tmaPtr, op.getIndices(), barrierAlloc, alloc, pred); Value phase = rewriter.create(loc, 0, 32); rewriter.create(loc, barrierAlloc, phase); rewriter.create(loc, barrierAlloc); @@ -95,8 +97,10 @@ class TMAStoreLowering encoding, sharedMemorySpace, /*mutableMemory=*/true); Value alloc = rewriter.create(loc, memDescType, op.getSrc()); rewriter.create(loc, false); + Value tmaPtr = rewriter.create( + loc, op.getDesc()); rewriter.create( - loc, op.getDescPtr(), op.getIndices(), alloc); + loc, tmaPtr, op.getIndices(), alloc); rewriter.create(loc, 0); rewriter.eraseOp(op); return success(); @@ -194,7 +198,9 @@ class TMACreateDescLowering : public OpRewritePattern { /*fill_mode=*/rewriter.getI32IntegerAttr(0)); rewriter.create( loc, alloc.getResult()); - rewriter.replaceOp(op, alloc); + auto newDesc = rewriter.create( + loc, op.getType(), alloc.getResult()); + rewriter.replaceOp(op, newDesc); return success(); } }; diff --git a/python/src/ir.cc b/python/src/ir.cc index e7322c4fd232..27c35910e97a 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -47,6 +47,7 @@ class TritonOpBuilder { } OpBuilder &getBuilder() { return *builder; } + MLIRContext *getContext() { return builder->getContext(); } bool isLineInfoEnabled() { return lineInfoEnabled; } @@ -1318,19 +1319,26 @@ void init_triton_ir(py::module &&m) { self.create(ptrs, val, mask, cacheModifier, evictionPolicy); }) + .def("create_reinterpret_tensor_descriptor", + [](TritonOpBuilder &self, Value desc_ptr, Type blockTy) -> Value { + auto ctx = self.getContext(); + auto resultTy = triton::TensorDescType::get( + ctx, cast(blockTy)); + return self.create(resultTy, desc_ptr); + }) .def("create_descriptor_load", - [](TritonOpBuilder &self, Value desc_ptr, - std::vector &indices, Type type, + [](TritonOpBuilder &self, Value desc, std::vector &indices, CacheModifier cacheModifier, EvictionPolicy evictionPolicy) -> Value { + auto descTy = cast(desc.getType()); + auto resTy = descTy.getBlockType(); return self.create( - type, desc_ptr, indices, cacheModifier, evictionPolicy); + resTy, desc, indices, cacheModifier, evictionPolicy); }) .def("create_descriptor_store", - [](TritonOpBuilder &self, Value desc_ptr, Value value, + [](TritonOpBuilder &self, Value desc, Value value, std::vector &indices) -> void { - self.create(desc_ptr, value, - indices); + self.create(desc, value, indices); }) .def("create_tensormap_create", [](TritonOpBuilder &self, Value desc_ptr, Value global_address, diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index e3804dcc4a5f..737ff06e6aed 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -29,6 +29,7 @@ _experimental_descriptor_load, _experimental_descriptor_store, _experimental_make_tensor_descriptor, + _experimental_reinterpret_tensor_descriptor, _experimental_tensor_descriptor, add, advance, @@ -129,6 +130,7 @@ "_experimental_descriptor_load", "_experimental_descriptor_store", "_experimental_make_tensor_descriptor", + "_experimental_reinterpret_tensor_descriptor", "_experimental_tensor_descriptor", "abs", "add", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index d3b5269b4461..a2e8e36f60c6 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1144,39 +1144,26 @@ def flip(self, dim=None) -> tensor: ... -class _experimental_tensor_descriptor(_value): - """A descriptor representing a tensor in global memory. +class _experimental_tensor_descriptor_base(_value): + """" + A tensor descriptor with unknown shape and strides """ - def __init__(self, handle, shape: List[tensor], strides: List[tensor], type: block_type): + def __init__(self, handle, type: block_type): """Not called by user code.""" # IR handle super().__init__(handle) - # Global shape - self.shape = shape - self.strides = strides self.type = type # Tensor type (block_type) # Following the practice in pytorch, dtype is scalar type self.dtype = type.scalar def _flatten_ir(self): - handles = [self.handle] - handles.extend(s.handle for s in self.shape) - handles.extend(s.handle for s in self.strides) - return handles + return [self.handle] def _unflatten_ir(self, handles): - ndim = len(self.shape) - assert len(handles) == 2 * ndim + 1 - handle = handles[0] - shape = [tensor(handle, s.type) for handle, s in zip(handles[1:1 + ndim], self.shape)] - strides = [tensor(handle, s.type) for handle, s in zip(handles[1 + ndim:], self.strides)] - return _experimental_tensor_descriptor(handle, shape, strides, self.type) - - @builtin - def _as_ptr(self, _builder): - return tensor(self.handle, pointer_type(int8)) + assert len(handles) == 1 + return _experimental_tensor_descriptor_base(handles[0], self.type) @property def block_shape(self): @@ -1194,7 +1181,7 @@ def load(self, offsets: List[tensor], _builder=None) -> tensor: :note: Offset must be a multiple of 16-bytes """ - return _experimental_descriptor_load(self, offsets, self.type.shape, self.type.element_ty, _builder=_builder) + return semantic.descriptor_load(self, offsets, "", "", _builder) @builtin def store(self, offsets: List[tensor], value: tensor, _builder=None) -> tensor: @@ -1204,8 +1191,34 @@ def store(self, offsets: List[tensor], value: tensor, _builder=None) -> tensor: :note: Offset must be a multiple of 16-bytes """ - value = cast(value, self.type, _builder=_builder) - return _experimental_descriptor_store(self, value, offsets, _builder=_builder) + return semantic.descriptor_store(self, value, offsets, _builder) + + +class _experimental_tensor_descriptor(_experimental_tensor_descriptor_base): + """A descriptor representing a tensor in global memory. + """ + + def __init__(self, handle, shape: List[tensor], strides: List[tensor], type: block_type): + """Not called by user code.""" + # IR handle + super().__init__(handle, type) + # Global shape + self.shape = shape + self.strides = strides + + def _flatten_ir(self): + handles = [self.handle] + handles.extend(s.handle for s in self.shape) + handles.extend(s.handle for s in self.strides) + return handles + + def _unflatten_ir(self, handles): + ndim = len(self.shape) + assert len(handles) == 2 * ndim + 1 + handle = handles[0] + shape = [tensor(handle, s.type) for handle, s in zip(handles[1:1 + ndim], self.shape)] + strides = [tensor(handle, s.type) for handle, s in zip(handles[1 + ndim:], self.strides)] + return _experimental_tensor_descriptor(handle, shape, strides, self.type) def get_bool_env_var(var_name): @@ -1717,6 +1730,16 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c volatile, _builder) +@builtin +def _experimental_reinterpret_tensor_descriptor(desc_ptr, block_shape, dtype, + _builder=None) -> _experimental_tensor_descriptor_base: + """ + Reinterpret a generic pointer as a TMA-backed tensor descriptor object. + """ + block_ty = block_type(_constexpr_to_value(dtype), block_shape) + return semantic.reinterpret_tensor_descriptor(desc_ptr, block_ty, _builder) + + @builtin def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None): """ @@ -1725,8 +1748,8 @@ def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder= This loads a tensor of data based on the descriptor and offsets. """ - type = block_type(_constexpr_to_value(dtype), shape) - return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder) + desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, shape, dtype, _builder=_builder) + return desc.load(offsets, _builder=_builder) @builtin @@ -1737,7 +1760,8 @@ def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None): This stores a tensor of data based on the descriptor and offsets. """ - return semantic.descriptor_store(desc_pointer, value, offsets, _builder) + desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, value.shape, value.dtype, _builder=_builder) + return desc.store(offsets, value, _builder=_builder) @_tensor_member_fn diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 61f6d3948d97..4b27700b00c4 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1138,18 +1138,24 @@ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) -def descriptor_load(desc_ptr: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, type, +def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type, builder: ir.builder): + handle = builder.create_reinterpret_tensor_descriptor(desc_ptr.handle, block_ty.to_ir(builder)) + return tl._experimental_tensor_descriptor_base(handle, block_ty) + + +def descriptor_load(desc: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, builder: ir.builder) -> tl.tensor: + assert isinstance(desc, tl._experimental_tensor_descriptor_base) offsets = _convert_to_ir_values(builder, offsets, require_i64=False) - x = builder.create_descriptor_load(desc_ptr.handle, offsets, type.to_ir(builder), - _str_to_load_cache_modifier(cache_modifier), + x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier), _str_to_eviction_policy(eviction_policy)) - return tl.tensor(x, type) + return tl.tensor(x, desc.type) -def descriptor_store(desc_ptr: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: +def descriptor_store(desc: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + assert isinstance(desc, tl._experimental_tensor_descriptor_base) offsets = _convert_to_ir_values(builder, offsets, require_i64=False) - return tl.tensor(builder.create_descriptor_store(desc_ptr.handle, value.handle, offsets), tl.void) + return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void) def tensormap_create( diff --git a/test/Triton/ops.mlir b/test/Triton/ops.mlir index c3b92b7ee403..9dec1e9c481e 100644 --- a/test/Triton/ops.mlir +++ b/test/Triton/ops.mlir @@ -244,9 +244,9 @@ tt.func @histogram(%0: tensor<512xi32>) { } // CHECK-LABEL: experimental_descriptor_load -tt.func @experimental_descriptor_load(%0: !tt.ptr) { - // CHECK: tt.experimental_descriptor_load %{{.+}}[%{{.+}}] : !tt.ptr -> tensor<128xf32> +tt.func @experimental_descriptor_load(%0: !tt.tensordesc>) { + // CHECK: tt.experimental_descriptor_load %{{.+}}[%{{.+}}] : !tt.tensordesc> -> tensor<128xf32> %c0_i32 = arith.constant 0 : i32 - %1 = tt.experimental_descriptor_load %0[%c0_i32] : !tt.ptr -> tensor<128xf32> + %1 = tt.experimental_descriptor_load %0[%c0_i32] : !tt.tensordesc> -> tensor<128xf32> tt.return } diff --git a/test/TritonGPU/global_scratch_alloc.mlir b/test/TritonGPU/global_scratch_alloc.mlir index 47c580db84c8..a715b30d6131 100644 --- a/test/TritonGPU/global_scratch_alloc.mlir +++ b/test/TritonGPU/global_scratch_alloc.mlir @@ -5,9 +5,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: @test_alloc{{.*}}triton_gpu.global_scratch_memory_alignment = 128 : i32, triton_gpu.global_scratch_memory_size = 256 : i32 tt.func public @test_alloc() -> (!tt.ptr, !tt.ptr) { // CHECK: triton_gpu.global_scratch_memory_offset = 0 - %0 = triton_gpu.global_scratch_alloc {alignment = 8 : i32, nbytes = 100 : i32} : + %0 = triton_gpu.global_scratch_alloc {alignment = 8 : i32, nbytes = 100 : i32} : !tt.ptr // CHECK: triton_gpu.global_scratch_memory_offset = 128 - %1 = triton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : + %1 = triton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr tt.return %0, %1 : !tt.ptr, !tt.ptr } } @@ -19,14 +19,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: @helper1{{.*}}triton_gpu.global_scratch_memory_alignment = 128 : i32, triton_gpu.global_scratch_memory_size = 128 : i32 tt.func private @helper1() -> (!tt.ptr) { // CHECK: triton_gpu.global_scratch_memory_offset = 0 - %0 = triton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : + %0 = triton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr tt.return %0 : !tt.ptr } // CHECK: @test_function{{.*}}triton_gpu.global_scratch_memory_alignment = 128 : i32, triton_gpu.global_scratch_memory_size = 256 : i32 tt.func public @test_function() -> (!tt.ptr, !tt.ptr) { // CHECK: triton_gpu.global_scratch_memory_offset = 0 - %0 = triton_gpu.global_scratch_alloc {alignment = 8 : i32, nbytes = 100 : i32} : + %0 = triton_gpu.global_scratch_alloc {alignment = 8 : i32, nbytes = 100 : i32} : !tt.ptr // CHECK: triton_gpu.global_scratch_memory_offset = 128 %1 = tt.call @helper1() : () -> (!tt.ptr) tt.return %0, %1 : !tt.ptr, !tt.ptr diff --git a/test/TritonGPU/loop-pipeline-cuda.mlir b/test/TritonGPU/loop-pipeline-cuda.mlir index ceadbb7e1453..7b8bed9a18f4 100644 --- a/test/TritonGPU/loop-pipeline-cuda.mlir +++ b/test/TritonGPU/loop-pipeline-cuda.mlir @@ -179,16 +179,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NOT: triton_nvidia_gpu.wait_barrier // CHECK-COUNT-2: triton_nvidia_gpu.async_tma_copy_global_to_local // CHECK: scf.yield - tt.func public @matmul_tma(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x256xf32, #mma> { + tt.func public @matmul_tma(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) -> tensor<128x256xf32, #mma> { %c256_i32 = arith.constant 256 : i32 %c0_i32 = arith.constant 0 : i32 %c64_i32 = arith.constant 64 : i32 %c1_i32 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> %0:2 = scf.for %arg3 = %c0_i32 to %c256_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { - %1 = tt.experimental_descriptor_load %arg0[%c0_i32, %arg5] : !tt.ptr -> tensor<128x64xf16, #blocked> + %1 = tt.experimental_descriptor_load %arg0[%c0_i32, %arg5] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> %2 = triton_gpu.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %3 = tt.experimental_descriptor_load %arg1[%arg5, %c0_i32] : !tt.ptr -> tensor<64x256xf16, #blocked1> + %3 = tt.experimental_descriptor_load %arg1[%arg5, %c0_i32] : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> %4 = triton_gpu.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> %5 = triton_nvidia_gpu.warp_group_dot %2, %4, %arg4 { inputPrecision = 0 : i32 } : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> %6 = arith.addi %arg5, %c64_i32 : i32 diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 524cf69c7cfe..1f0ecaee439c 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -698,15 +698,16 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_store_pipeline - tt.func public @tma_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + tt.func public @tma_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 { %1 = arith.divsi %arg4, %arg2 : i32 // CHECK: triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} // CHECK-NEXT: triton_gpu.local_store // CHECK-NEXT: triton_nvidia_gpu.fence_async_shared + // CHECK-NEXT: triton_nvidia_gpu.tensor_desc_to_tma_ptr // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_local_to_global - tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.ptr, tensor<1xf32, #blocked> + tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> } tt.return } @@ -716,7 +717,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_multiple_store_pipeline - tt.func public @tma_multiple_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + tt.func public @tma_multiple_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 // CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc : () -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> // CHECK: scf.for @@ -726,13 +727,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} // CHECK-NEXT: triton_gpu.local_store %{{.+}}, %[[ALLOC]] // CHECK-NEXT: triton_nvidia_gpu.fence_async_shared + // CHECK-NEXT: triton_nvidia_gpu.tensor_desc_to_tma_ptr // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] // CHECK: triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} // CHECK-NEXT: triton_gpu.local_store %{{.+}}, %[[ALLOC]] // CHECK-NEXT: triton_nvidia_gpu.fence_async_shared + // CHECK-NEXT: triton_nvidia_gpu.tensor_desc_to_tma_ptr // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] - tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.ptr, tensor<1xf32, #blocked> - tt.experimental_descriptor_store %arg1[%2], %arg0 : !tt.ptr, tensor<1xf32, #blocked> + tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> + tt.experimental_descriptor_store %arg1[%2], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> } tt.return } diff --git a/test/TritonNvidiaGPU/membar.mlir b/test/TritonNvidiaGPU/membar.mlir index 358f53fd7cd6..6d9c16650861 100644 --- a/test/TritonNvidiaGPU/membar.mlir +++ b/test/TritonNvidiaGPU/membar.mlir @@ -81,7 +81,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @tma_load(%arg0: !tt.ptr, %arg1: i32) -> tensor<128x64xf16, #blocked0> { + tt.func public @tma_load(%arg0: !tt.tensordesc>, %arg1: i32) -> tensor<128x64xf16, #blocked0> { // CHECK-LABEL: tma_load // CHECK: local_dealloc // CHECK-NEXT: local_alloc @@ -91,7 +91,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> - %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.ptr -> tensor<128x64xf16, #blocked0> + %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc> -> tensor<128x64xf16, #blocked0> tt.return %l : tensor<128x64xf16, #blocked0> } } @@ -106,11 +106,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: triton_gpu.local_dealloc // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - tt.func public @tma_store(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked0>) { + tt.func public @tma_store(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked0>) { %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> - tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.ptr, tensor<128x256xf32, #blocked0> + tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc>, tensor<128x256xf32, #blocked0> tt.return } } diff --git a/test/TritonNvidiaGPU/tma_lowering.mlir b/test/TritonNvidiaGPU/tma_lowering.mlir index 5bc357f1fb94..dc8113ca8a72 100644 --- a/test/TritonNvidiaGPU/tma_lowering.mlir +++ b/test/TritonNvidiaGPU/tma_lowering.mlir @@ -6,12 +6,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: triton_gpu.local_alloc : () // CHECK: triton_gpu.local_alloc : () // CHECK: triton_nvidia_gpu.init_barrier +// CHECK: triton_nvidia_gpu.tensor_desc_to_tma_ptr // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local // CHECK: triton_nvidia_gpu.wait_barrier // CHECK: triton_nvidia_gpu.inval_barrier // CHECK: triton_gpu.local_load - tt.func public @tma_load(%arg0: !tt.ptr, %arg1: i32) -> tensor<128x64xf16, #blocked> { - %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.ptr -> tensor<128x64xf16, #blocked> + tt.func public @tma_load(%arg0: !tt.tensordesc>, %arg1: i32) -> tensor<128x64xf16, #blocked> { + %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> tt.return %l : tensor<128x64xf16, #blocked> } } @@ -23,9 +24,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: tma_store // CHECK: triton_gpu.local_alloc // CHECK: triton_nvidia_gpu.fence_async_shared {bCluster = false} +// CHECK: triton_nvidia_gpu.tensor_desc_to_tma_ptr // CHECK: triton_nvidia_gpu.async_tma_copy_local_to_global - tt.func public @tma_store(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked>) { - tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.ptr, tensor<128x256xf32, #blocked> + tt.func public @tma_store(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked>) { + tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc>, tensor<128x256xf32, #blocked> tt.return } } @@ -35,17 +37,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: make_tensor_descriptor // CHECK: %0 = arith.extsi %arg2 : i32 to i64 - // CHECK: %1 = triton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : + // CHECK: %1 = triton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr // CHECK: %2 = arith.shrsi %0, %c4_i64 : i64 // CHECK: tt.experimental_tensormap_create %1, %arg0, [%c32_i32, %c8_i32], [%arg2, %arg1], [%2], [%c1_i32, %c1_i32] {elem_type = 0 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 1 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () // CHECK: tt.experimental_tensormap_fenceproxy_acquire %1 : !tt.ptr - tt.func public @make_tensor_descriptor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32} ) -> !tt.ptr { + // CHECK: tt.reinterpret_tensor_descriptor %1 : !tt.ptr to !tt.tensordesc> + tt.func public @make_tensor_descriptor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32} ) -> !tt.tensordesc> { %c1_i64 = arith.constant 1 : i64 %cst = arith.constant dense<32> : tensor<8x1xi32> %c64_i32 = arith.constant 64 : i32 %c8_i32 = arith.constant 8 : i32 %0 = arith.extsi %arg2 : i32 to i64 - %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] {tensorShape = array} : , - tt.return %1 : !tt.ptr + %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : !tt.ptr, !tt.tensordesc> + tt.return %1 : !tt.tensordesc> } } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp index c64ba1915ded..459a00c1a142 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp @@ -282,6 +282,37 @@ struct ExperimentalTensormapCreateOpConversion } }; +struct ReinterpretTensorDescOpConversion + : public ConvertOpToLLVMPattern { + + ReinterpretTensorDescOpConversion(LLVMTypeConverter &converter, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::ReinterpretTensorDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getRawDesc()); + return success(); + } +}; + +struct TensorDescToTMAPtrOpConversion + : public ConvertOpToLLVMPattern { + + TensorDescToTMAPtrOpConversion(LLVMTypeConverter &converter, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::TensorDescToTMAPtrOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getDesc()); + return success(); + } +}; + } // namespace void mlir::triton::NVIDIA::populateTMAToLLVMPatterns( @@ -289,6 +320,8 @@ void mlir::triton::NVIDIA::populateTMAToLLVMPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, targetInfo, benefit); - patterns.add( - typeConverter, benefit); + patterns + .add( + typeConverter, benefit); }