Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add blocking and lowering pattern for xetile.atomic_rmw op #1014

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions include/imex/Dialect/XeTile/IR/XeTileAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -91,31 +91,6 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> {
];
}

// RMW kind attribute
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>;
def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>;
def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;

def XeTile_AtomicRMWKindAttr : I64EnumAttr<
"AtomicRMWKind", "",
[ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
ATOMIC_RMW_KIND_ANDI]> {
let cppNamespace = "::imex::xetile";
}

//TODO: !!!This is target specific information, cache attributes have to be passed transparently
// as custom arguments and handled properly on XeGPU side
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion include/imex/Dialect/XeTile/IR/XeTileOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include "imex/Dialect/XeTile/IR/XeTileDialect.td"
include "imex/Dialect/XeTile/IR/XeTileTypes.td"
include "imex/Dialect/XeTile/IR/XeTileAttrs.td"

include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Interfaces/ViewLikeInterface.td"

Expand Down Expand Up @@ -501,7 +502,7 @@ def XeTile_AtomicRMWOp : XeTile_Op<"atomic_rmw", []> {
vector<8x16xbf16>, tile<8x16xbf16> -> vector<8x16xbf16>
```
}];
let arguments = (ins XeTile_AtomicRMWKindAttr:$kind,
let arguments = (ins AtomicRMWKindAttr:$kind,
XeTile_2DVector:$value,
XeTile:$tile);
let results = (outs XeTile_2DVector:$result);
Expand Down
9 changes: 4 additions & 5 deletions lib/Conversion/XeGPUToVC/LSCPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"

#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
#include <mlir/Dialect/SPIRV/IR/SPIRVDialect.h>
#include <mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h>

#include "LscIntrinsicEnums.h"
#include "imex/Utils/VCUtils.h"
Expand Down Expand Up @@ -1194,6 +1195,7 @@ class AtomicPattern : public OpConversionPattern<AtomicRMWOp> {
} else {
lscVecSize = log2(numDstVal) + 2;
}

auto vecSize = createIntConstant(i8Type, lscVecSize);
auto transposed = createIntConstant(i8Type, 1);
auto mask = adaptor.getMask();
Expand All @@ -1202,9 +1204,7 @@ class AtomicPattern : public OpConversionPattern<AtomicRMWOp> {
Value payLoad = adaptor.getTensorDesc();
// src
auto v16i32Ty = VectorType::get(16, i32Type);
auto i32ZeroAttr = IntegerAttr::get(i32Type, 0);
Value undef = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(v16i32Ty, i32ZeroAttr));
Value undef = rewriter.create<mlir::spirv::UndefOp>(loc, v16i32Ty);
Value src0 = undef;
if (op.getValue()) {
src0 = op.getValue();
Expand All @@ -1222,7 +1222,6 @@ class AtomicPattern : public OpConversionPattern<AtomicRMWOp> {
auto retType = newType;
auto newOp = createFuncCall(rewriter, loc, funcName, TypeRange{retType},
args, false);

auto *converter = this->getTypeConverter();
auto castTy = converter->convertType(op.getType());
auto cast =
Expand Down
5 changes: 4 additions & 1 deletion lib/Conversion/XeGPUToVC/XeGPUToVC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include <mlir/Dialect/SPIRV/IR/SPIRVDialect.h>
#include <mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h>

namespace imex {
#define GEN_PASS_DEF_CONVERTXEGPUTOVC
Expand Down Expand Up @@ -765,7 +767,8 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase<XeGPUToVCPass> {
configureArithToVCConversionLegality(target);

target.addLegalDialect<func::FuncDialect, arith::ArithDialect,
memref::MemRefDialect, vector::VectorDialect>();
memref::MemRefDialect, vector::VectorDialect,
spirv::SPIRVDialect>();
target.addIllegalDialect<xegpu::XeGPUDialect>();

target.addDynamicallyLegalDialect<scf::SCFDialect>(
Expand Down
33 changes: 31 additions & 2 deletions lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,35 @@ class ScatterOpPattern
}
};

class AtomicRMWOpPattern
: public mlir::OpConversionPattern<xetile::AtomicRMWOp> {
public:
using mlir::OpConversionPattern<xetile::AtomicRMWOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(xetile::AtomicRMWOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto type = op.getValue().getType();
auto elemTy = type.getElementType();
auto value = adaptor.getValue();
auto valTy = mlir::VectorType::get(type.getNumElements(), elemTy);
auto maskTy = mlir::VectorType::get(type.getNumElements(),
rewriter.getIntegerType(1));
llvm::SmallVector<bool> maskValues(type.getNumElements(), true);
auto maskAttr = mlir::DenseElementsAttr::get(maskTy, maskValues);
mlir::Value mask =
rewriter.create<mlir::arith::ConstantOp>(loc, maskTy, maskAttr);
value =
rewriter.create<mlir::vector::ShapeCastOp>(op.getLoc(), valTy, value);
auto atomicrmwOp = rewriter.create<mlir::xegpu::AtomicRMWOp>(
loc, valTy, op.getKind(), adaptor.getTile(), mask, value);
auto v = rewriter.create<mlir::vector::ShapeCastOp>(loc, op.getType(),
atomicrmwOp);
rewriter.replaceOp(op, v);
return mlir::success();
}
};

// convert xetile.mma to xegpu::DpasOp.
class MMAOpPattern : public mlir::OpConversionPattern<xetile::TileMMAOp> {
public:
Expand Down Expand Up @@ -691,8 +720,8 @@ void populateXeTileToXeGPUConversionPatterns(
patterns.add<InitOpPattern, UpdateOpPattern, PrefetchOpPattern, LoadOpPattern,
StoreOpPattern, GatherOpPattern, ScatterOpPattern, MMAOpPattern,
BroadcastOpPattern, ReduceOpPattern, TransposeOpPattern,
SCFForOpPattern, SCFYieldOpPattern, MemRefViewOpPattern>(
converter, patterns.getContext());
SCFForOpPattern, SCFYieldOpPattern, MemRefViewOpPattern,
AtomicRMWOpPattern>(converter, patterns.getContext());
}

/// Create a pass that convert XeTile to XeGPU
Expand Down
44 changes: 43 additions & 1 deletion lib/Dialect/XeTile/Transforms/Blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,48 @@ class RewriteStoreScatterOp
}
};

class RewriteAtomicRMWOp
: public RewriteXeTileOp<xetile::AtomicRMWOp, BlockingAnalysis> {
public:
using RewriteXeTileOp<xetile::AtomicRMWOp, BlockingAnalysis>::RewriteXeTileOp;

mlir::LogicalResult
matchAndRewrite(xetile::AtomicRMWOp op,
OpPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto value = op.getValue();
auto valTy = value.getType();
auto tile = op.getTile();
auto tileTy = tile.getType();
auto shape = tileTy.getShape();
auto blockSize = analysis.getUseBlockSize(tile, op->getOpOperand(1));

if (!blockSize || shape == blockSize.asArrayRef())
return failure();

auto convertedValTypes = convertTypes(valTy, blockSize.asArrayRef());
auto convertedValues = addPackOp(value, convertedValTypes,
blockSize.asArrayRef(), loc, rewriter);
auto convertedTileTypes = convertTypes(tileTy, blockSize.asArrayRef());
auto convertedTiles = addPackOp(tile, convertedTileTypes,
blockSize.asArrayRef(), loc, rewriter);

llvm::SmallVector<mlir::Value> newOps;
for (auto [v, t] : llvm::zip(convertedValues, convertedTiles)) {
auto valTy = mlir::dyn_cast<mlir::VectorType>(v.getType());
auto vecTy =
::mlir::VectorType::get(valTy.getShape(), valTy.getElementType());
auto newOp =
rewriter.create<xetile::AtomicRMWOp>(loc, vecTy, op.getKind(), v, t);
newOps.push_back(newOp);
}
auto castOp = addUnpackOp(newOps, op.getType(), blockSize.asArrayRef(), loc,
rewriter);
rewriter.replaceOp(op, castOp);
return mlir::success();
}
};

// rewrite a update_tile_offset op on big tile size into multiple
// update_tile_offset ops on smaller tile size.
class RewriteUpdateTileOffsetOp
Expand Down Expand Up @@ -1563,7 +1605,7 @@ void populateXeTileBlockingPatterns(mlir::RewritePatternSet &patterns,
Blocking::RewriteTileBroadcastOp, Blocking::RewriteTileTransposeOp,
Blocking::RewriteVectorizableOp, Blocking::RewriteSCFForOp,
Blocking::RewriteSCFYieldOp, Blocking::RewriteCreateMaskOp,
Blocking::RewriteCreateMaskOp>(patterns.getContext(), analysis);
Blocking::RewriteAtomicRMWOp>(patterns.getContext(), analysis);
}

// Lowers XeTile to blocked layout with high-dim vector
Expand Down
24 changes: 24 additions & 0 deletions lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ class BlockingAnalysisImpl
mlir::ArrayRef<BlockingLattice *> operands,
mlir::ArrayRef<const BlockingLattice *> results);

void visitAtomicRMWOp(xetile::AtomicRMWOp op,
mlir::ArrayRef<BlockingLattice *> operands,
mlir::ArrayRef<const BlockingLattice *> results);

void visitUpdateTileOp(xetile::UpdateTileOffsetOp op,
mlir::ArrayRef<BlockingLattice *> operands,
mlir::ArrayRef<const BlockingLattice *> results);
Expand Down Expand Up @@ -359,6 +363,9 @@ mlir::LogicalResult BlockingAnalysisImpl::visitOperation(
if (auto scatterOp = mlir::dyn_cast<xetile::StoreScatterOp>(op))
visitStoreScatterOp(scatterOp, operands, results);

if (auto atomicrmwOp = mlir::dyn_cast<xetile::AtomicRMWOp>(op))
visitAtomicRMWOp(atomicrmwOp, operands, results);

if (auto tileMMAOp = mlir::dyn_cast<xetile::TileMMAOp>(op))
visitTileMMAOp(tileMMAOp, operands, results);

Expand Down Expand Up @@ -566,6 +573,23 @@ void BlockingAnalysisImpl::visitStoreScatterOp(
}
}

void BlockingAnalysisImpl::visitAtomicRMWOp(
xetile::AtomicRMWOp op, mlir::ArrayRef<BlockingLattice *> operands,
mlir::ArrayRef<const BlockingLattice *> results) {
auto tileTy = op.getTile().getType();
auto elemTy = tileTy.getElementType();
auto shape = tileTy.getShape();

auto size = getDefaultSize(elemTy, shape);
if (!size)
return;

for (auto &&[i, inputOpr] : llvm::enumerate(operands)) {
auto blockingRequest = BlockingRequests(size, op->getOpOperand(i));
propagateIfChanged(inputOpr, inputOpr->join(blockingRequest));
}
}

void BlockingAnalysisImpl::visitUpdateTileOp(
xetile::UpdateTileOffsetOp op, mlir::ArrayRef<BlockingLattice *> operands,
mlir::ArrayRef<const BlockingLattice *> results) {
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/XeGPUToVC/atomiclsc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ module @gemm attributes {gpu.container_module} {
//CHECK: %[[c1_i16:.*]] = arith.constant 1 : i16
//CHECK: %[[c0_i32:.*]] = arith.constant 0 : i32
//CHECK: %[[c3_i8:.*]] = arith.constant 3 : i8
//CHECK: %[[cst_2:.*]] = arith.constant dense<0> : vector<16xi32>
//CHECK: %[[undef:.*]] = spirv.Undef : vector<16xi32>
//CHECK: %[[r3:.*]] = vector.bitcast %[[cst_0]] : vector<16xf32> to vector<16xi32>
//CHECK: %[[r4:.*]] = func.call @llvm.genx.lsc.xatomic.stateless.v16i32.v16i1.v16i64(
//CHECK-SAME: %[[cst]], %[[c19_i8]], %[[c1_i8]], %[[c1_i8]], %[[c1_i16]], %[[c0_i32]], %[[c3_i8]],
//CHECK-SAME: %[[c1_i8]], %[[c1_i8]], %[[cst]], %[[r2]], %[[r3]], %[[cst_2]], %[[c0_i32]], %[[cst_2]])
//CHECK-SAME: %[[c1_i8]], %[[c1_i8]], %[[cst]], %[[r2]], %[[r3]], %[[undef]], %[[c0_i32]], %[[undef]])
//CHECK-SAME: (vector<16xi1>, i8, i8, i8, i16, i32, i8, i8, i8, vector<16xi1>, vector<16xi64>, vector<16xi32>, vector<16xi32>, i32, vector<16xi32>) -> vector<16xi32>
%offsets = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex>
%2 = xegpu.create_tdesc %arg0, %offsets : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
Expand Down
14 changes: 14 additions & 0 deletions test/Dialect/XeTile/Transforms/Blocking/atomic_rmw.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking %s -verify-diagnostics -o -| FileCheck %s

gpu.module @test_kernel {
gpu.func @sg_atomic_rmw(%value: vector<32x64xf32>, %arg2: memref<65536xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<true> : vector<32x64xi1>
%cst_0 = arith.constant dense<1> : vector<32x64xindex>
%tile = xetile.init_tile %arg2, %cst_0 : memref<65536xf32>, vector<32x64xindex> -> !xetile.tile<32x64xf32, #xetile.tile_attr<scattered = true>>
//CHECK-COUNT-128: {{.*}} = xetile.atomic_rmw addf {{.*}}, {{.*}} : vector<1x16xf32>, !xetile.tile<1x16xf32, #xetile.tile_attr<scattered = true>> -> vector<1x16xf32>
%rmw = xetile.atomic_rmw addf %value, %tile : vector<32x64xf32>, !xetile.tile<32x64xf32, #xetile.tile_attr<scattered = true>> -> vector<32x64xf32>
xetile.store %rmw, %tile, %cst : vector<32x64xf32>, !xetile.tile<32x64xf32, #xetile.tile_attr<scattered = true>>, vector<32x64xi1>
gpu.return
}
}
51 changes: 51 additions & 0 deletions test/Integration/Dialect/XeGPU/atomic_rmw.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \
// RUN: --runner imex-cpu-runner -e main --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \
// RUN: --runner imex-cpu-runner -e main --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck

#scatter = #xegpu.scatter_tdesc_attr<chunk_size = 1 : i64>

module @gemm attributes {gpu.container_module} {
func.func @test(%a: memref<16xf32>) -> memref<16xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%a_gpu = gpu.alloc host_shared () : memref<16xf32>
memref.copy %a, %a_gpu : memref<16xf32> to memref<16xf32>
%out = gpu.alloc host_shared () : memref<16xf32>
gpu.launch_func @test_kernel::@test_atomic_rmw blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%a_gpu: memref<16xf32>, %out : memref<16xf32>)
return %a_gpu : memref<16xf32>
}

gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_atomic_rmw(%input: memref<16xf32>, %mem: memref<16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%cst = arith.constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]> : vector<16xf32>
%mask = arith.constant dense<1> : vector<16xi1>
%offsets = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex>
%in_tdesc = xegpu.create_tdesc %input, %offsets : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #scatter>
%atomic_rmw = xegpu.atomic_rmw addf %in_tdesc, %mask, %cst : !xegpu.tensor_desc<16xf32, #scatter>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
%out_tdesc = xegpu.create_tdesc %mem, %offsets : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #scatter>
xegpu.store %atomic_rmw, %out_tdesc, %mask : vector<16xf32>, !xegpu.tensor_desc<16xf32, #scatter>, vector<16xi1>
gpu.return
}
}

func.func @main() attributes {llvm.emit_c_interface} {
%a = memref.alloc() : memref<16xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c1_f32 = arith.constant 1.0 : f32
scf.for %i = %c0 to %c16 step %c1 {
memref.store %c1_f32, %a[%i] : memref<16xf32>
}

%B = call @test(%a) : (memref<16xf32>) -> memref<16xf32>
%cast = memref.cast %B : memref<16xf32> to memref<*xf32>
//CHECK: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
call @printMemrefF32(%cast) : (memref<*xf32>) -> ()
return
}

func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
}
Loading
Loading