Skip to content

Commit

Permalink
[MLIR][XeGPU] Refine XeGPU definitions (#100763)
Browse files Browse the repository at this point in the history
This PR has following changes/fixes to XeGPU definition: 
- Fix type print format for atomic_rmw
- removed 2D support for MaskType
- Update LoadNd definition
   - Add 1D TensorDesc support 
   - Replaced vnni_axis attribute with packed attribute 
- Update DPAS op definition, limiting A to 2D vector, and B to either 2D/3D vector.
  • Loading branch information
chencha3 authored Aug 2, 2024
1 parent e96687a commit 6c783e1
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 90 deletions.
80 changes: 43 additions & 37 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,47 +53,56 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
let summary = "Create nd-tensor descriptor operation";
let description = [{
The "create_nd_tdesc" operation creates a TensorDescType which represents
a sub-view of a 2D memory region (It can be extended to support n-D memory
region if needed in future). Elements in the subview continuous in each
dimension. It encodes the following important information for supporting
Intel hardware features:

* source: an object representing (starting address/pointer of) a 2D memory region.
It can be either a 2D memref object, or simply a pointer represented by uint64_t type.
for the later case, the shape and layout information of the 2D memory region should
be explicitly passed via `shape` and `strides` parameters.
* offsets: two index values represents offsets from the "source" at the each dimension
at which the subview of the target memory will be created. It is encoded via two
variables, including "offsets" and "const_offsets", such that it can
accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).
* shape: the shape information of the memory region pointed by the "source". It is
typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>.
a sub-view of a 1D/2D memory region inside the one or two innermost dimensions
of the source. (It can be extended to support n-D memory region if needed in
future). Elements in the subview continuous in each dimension. It encodes the
following important information for supporting Intel hardware features:

* source: an object representing (starting address/pointer of) a memory region.
It can be either a memref object, or simply a pointer represented by uint64_t type.
For the case of dynamic memrefs or pointer, the shape and layout information of the
memory region should be explicitly passed via `shape` and `strides` parameters.

* offsets: index values represents offsets from the "source" at the each dimension
at which the subview of the target memory will be created. It is encoded via
"offsets" and "const_offsets", such that it can accept various forms, such as,
operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).

* shape: the shape information of the memory region pointed by the "source". It is
typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>.
But if "source" is simply a pointer represented as uint64_t type, or a memref
type without shape information e.g., memref<?x?xf16>, the shape information has
to be explicitly passed via the "shape" and "const_shape" arguments.

* strides: the strides of the memory region pointed by the "source". Similar to shape,
it is typically encoded via the MemRefType of the source too. But if "source" is
simply a pointer represented as uint64_t type, or a memref type without shape
information e.g., memref<?x?xf16>, the strides information has to be explicitly
passed via the "strides" and "const_strides" argument.

Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
```mlir
%0 = memref.alloc() : memref<1024x1024xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0]: memref<1024x1024xf32> -> TensorDesc<8x16xf32>
```

Example 2 (suppose the tensor shape inferred by the compiler is 8x16):
```mlir
%0 = memref.alloc(%h, %w) : memref<?x?xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: memref<?x?xf32> -> TensorDesc<8x16xf32>
```

Example 3 (suppose the tensor shape inferred by the compiler is 8x16):
```mlir
%0 = ... : ui64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>
```
}];

let arguments = (ins
Expand Down Expand Up @@ -219,7 +228,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
memory regions to each level of the cache based on their cache policy.

Example:
```
```mlir
xegpu.prefetch_nd %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
Expand All @@ -245,8 +254,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
}


def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "TensorDesc"]>,
AllElementCountsMatch<["value", "TensorDesc"]>]> {
def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
let summary = "loads a n-D block from memory (represented by TensorDesc)"
"to registers (represented by vector)";
let description = [{
Expand All @@ -263,7 +271,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
same time.

Example:
```
```mlir
xegpu.load_nd %1 {transpose = [1, 0],
l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
Expand All @@ -275,7 +283,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
}];

let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
OptionalAttr<I64Attr>: $vnni_axis,
OptionalAttr<UnitAttr>: $packed,
OptionalAttr<DenseI64ArrayAttr>: $transpose,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
Expand Down Expand Up @@ -309,7 +317,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc
Corresponding cache hint attribute will be masked.

Example:
```
```mlir
xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}
Expand Down Expand Up @@ -407,21 +415,21 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
elements accessed for each offset, default is 1.

Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
```
```mlir
%a = memref.alloc() : memref<1024xf32>
%1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32>
```

Example 2. It assumes subgroup size is 4, and each workitem access 8 elements.
It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
```
```mlir
%0 = memref.alloc() : memref<1024xf32>
%1 = xegpu.create_tdesc %0[0, 16, 32, 64] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
```

Example 3. It is similar to Example 2, but there is some overlaps among workitems.
It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
```
```mlir
%0 = memref.alloc() : memref<1024xf32>
%1 = xegpu.create_tdesc %0[0, 4, 8, 12] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
```
Expand Down Expand Up @@ -480,7 +488,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
it works on scattered TensorDesc instead.

Example:
```
```mlir
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
Expand Down Expand Up @@ -520,7 +528,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.

Example:
```
```mlir
%2 = xegpu.load %1, %0 {transpose = [1, 0],
l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
Expand Down Expand Up @@ -572,7 +580,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllShapesMatch<["value", "TensorDe
It has similar semantic to `load_gather`.

Example:
```
```mlir
%3 = xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}
Expand Down Expand Up @@ -621,7 +629,7 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
shifts for each work-item.

Example:
```
```mlir
%2 = xegpu.update_offset %1, [32, 32, 32, 32]
: !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
```
Expand Down Expand Up @@ -668,14 +676,12 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`,
and `C/D: vector<8x16xf32>`. Besides the matrix size requirements, DPAS
also requires A and B to be loaded with the required data layout. Specially,
VNNI layout is required for B operand. It is achieved via setting `vnni_axis = 0`
of the corresponding `load_nd` operator. To keep both operands as 3D vector,
operand A is loaded via setting `vnni_axis = 1` without impacting the
physical layouts change in register. Due to the VNNI transformation, A and B operands
are represented as 3D vector, with the last dimension representing the VNNI factor,
which is computed as `32/bit_width_of_elem_type`. Therefore, `A: vector<8x16xf16>`
is represented as `A: vector<8x8x2xf16>`, and `B: vector<16x16xf16>` is
represented as `B: vector<8x16x2xf16>`.

VNNI layout is required for B operand. It is achieved via adding `packed`
attribute to the `load_nd` operator. Due to the VNNI transformation, B operands
can be represented as a 3D vector, with the last dimension representing the VNNI
factor, which is computed as `32/bit_width_of_elem_type`. Thus, `B: vector<16x16xf16>`
can be represented as `B: vector<8x16x2xf16>`.

Note: on PVC, the hardware can perform load with VNNI transformation when data
element type is 16-bit or lower precision, taking 2 or 4 elements from
Expand Down Expand Up @@ -739,7 +745,7 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", [Pure,

let assemblyFormat = [{
$kind $tensorDesc `,` $mask `,` $value attr-dict `:`
type($tensorDesc) `,` type($mask) `,` type($value) `->` type($result)
qualified(type($tensorDesc)) `,` type($mask) `,` type($value) `->` type($result)
}];
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ include "mlir/IR/BuiltinTypes.td"
def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_BaseAddrType: AnyTypeOf<[MemRefRankOf<[XeGPU_ScalarType], [1, 2]>, UI64, UI32, I64, I32]>;
def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
def XeGPU_DpasOpType: VectorOfRankAndType<[2, 3], [XeGPU_ScalarType]>;
def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1,2], [I1]>, I1]>;
def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>;

Expand Down
64 changes: 38 additions & 26 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,

LogicalResult CreateNdDescOp::verify() {
auto rank = (int64_t)getMixedOffsets().size();
bool invalidRank = (rank != 2);
bool invalidRank = false;
bool invalidElemTy = false;

// check source type matches the rank if it is a memref.
Expand All @@ -133,17 +133,21 @@ LogicalResult CreateNdDescOp::verify() {
invalidElemTy |= memrefTy.getElementType() != getElementType();
}

// check result type matches the rank
invalidRank = (getType().getRank() != rank);

// mismatches among shape, strides, and offsets are
// already handeled by OffsetSizeAndStrideOpInterface.
// So they are not check here.
if (invalidRank)
return emitOpError(
"Expecting the rank of shape, strides, offsets, "
"source memref type (if source is a memref) and TensorDesc "
"should match with each other. They currenlty are 2D.");
"Expecting the rank of shape, strides, offsets, and source (if source "
"is a memref) should match with each other.");

// check result TensorDesc rank
invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);

if (invalidRank)
return emitOpError(
"Expecting the TensorDesc rank is up to 2 and not greater than the "
"ranks of shape, strides, offsets or the memref source.");

if (invalidElemTy)
return emitOpError("TensorDesc should have the same element "
Expand Down Expand Up @@ -182,8 +186,8 @@ LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();

if (tdescTy.getRank() != 2)
return emitOpError("Expecting a 2D TensorDesc.\n");
if (tdescTy.getRank() > 2)
return emitOpError("Expecting a 1D/2D TensorDesc.\n");

if (tdescTy.getScattered())
return emitOpError("Expects a non-scattered TensorDesc.\n");
Expand All @@ -206,17 +210,28 @@ LogicalResult LoadNdOp::verify() {

if (getTranspose()) {
auto trans = getTranspose().value();
if (tdescShape.size() >= trans.size())

// Make sure the transpose value is valid.
bool valid = std::all_of(trans.begin(), trans.end(), [&](int t) {
return t >= 0 && t < tdescTy.getRank();
});

if (valid)
transpose(trans, tdescShape);
else
emitWarning("Invalid transpose attr. It is ignored.");
}

if (getVnniAxis()) {
auto axis = getVnniAxis().value();
auto vnni_factor = valueShape.back();
tdescShape[axis] /= vnni_factor;
tdescShape.push_back(vnni_factor);
if (getPacked()) {
if (tdescTy.getRank() == 2) {
const int axis = 0;
auto vnni_factor = valueShape.back();
tdescShape[axis] /= vnni_factor;
tdescShape.push_back(vnni_factor);
} else {
return emitWarning("Invalid Packed Attr. It is ignored (available for 2D "
"TensorDesc only).");
}
}

if (array_len > 1) {
Expand All @@ -239,8 +254,8 @@ LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector

if (dstTy.getRank() != 2)
return emitOpError("Expecting a 2D TensorDesc.\n");
if (dstTy.getRank() > 2)
return emitOpError("Expecting a 1D/2D TensorDesc.\n");

if (dstTy.getScattered())
return emitOpError("Expects a non-scattered TensorDesc.\n");
Expand Down Expand Up @@ -413,18 +428,15 @@ LogicalResult DpasOp::verify() {
int64_t lhsRank = getLhsType().getRank();
int64_t rhsRank = getRhsType().getRank();

if (lhsRank != rhsRank || lhsRank != 3)
return emitOpError(
"lhs and rhs rank does not match for dpas op, or their rank is not 3.");

if (getAcc() && getAccType() != getResultType())
return emitOpError("Accumulator and Result for dpas op should have the "
"same type (both shape and element type).");
if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
return emitOpError("expecting lhs to be a 2D vector, and rhs to be either "
"2D or 3D (packed) vector.");

auto lhsShape = getLhsType().getShape();
auto rhsShape = getRhsType().getShape();
if (lhsShape[1] != rhsShape[0] || lhsShape[2] != rhsShape[2])
return emitOpError("K-dimension or vnni-factor mismatch.");
auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
if (bK != lhsShape[1])
return emitOpError("K-dimension mismatch.");

return success();
}
Expand Down
Loading

0 comments on commit 6c783e1

Please sign in to comment.