diff --git a/docs/rfcs/XeTile.md b/docs/rfcs/XeTile.md index 1a6af158a..9ac574498 100644 --- a/docs/rfcs/XeTile.md +++ b/docs/rfcs/XeTile.md @@ -25,15 +25,12 @@ XeTile provides a middle-level abstraction for matmul operation and sits between |update_tile_offset | operation ::=xetile.update_tile_offset $tile, $delta0, $delta1: type($tile) | %tdesc_updated = xetile.update_nd_offset %tdesc, %offset_x, offset_y tensor_desc<32x64xbf16> | |prefetch_tile | operation ::=xetile.prefetch_tile $tile, attr-dict: type($tile) | xetile.prefetch_tile %coop_tile: tile<16x32xbf16> | |tile_mma | operation ::=xetile.tile_mma $matA, $matB, $matC attr_dict: type($matC), type($matA), type($matB)-> type($res) | %vector_c = xetile.tile_mma %vector_a, %vector_b, %vector_c : vector<64x32xbf16>, vector<32x128xbf16>, vector<64x128xfloat> into vector<64x128xfloat> | -|atomic_rmw_tile| operation ::=xetile.atomic_rmw_tile \<$kind\>, $vec, $tile: type($vec), type($tile) -> type($res) | %vector_a = xetile.atomic_rmw_tile \ %value, %tile: vector<8x16xbf16>, tile<8x16xbf16> to vector<8x16xbf16> | -|tile_transpose | operation ::=xetile.tile_transpose $vec $permuation_dims attr_dict: type($vec) -> type($res) | %vector_a = xetile.tile_transpose %vector_b [1, 0]: vector<64x32xfloat> into vector<32x64xfloat> | -|tile_reduce | operation ::=xetile.tile_reduce \<$kind\> $src $reduction_dims attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_reduce \ %vector_b [1]: vector<64x32xfloat> into vector<64x1xfloat> | -|tile_broadcast | operation ::=xetile.tile_broadcast $src $broadcast_dims attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_broadcast %vector_b[0]: vector<1x32xfloat> into vector<64x32xfloat> | +|atomic_rmw| operation ::=xetile.atomic_rmw \<$kind\>, $vec, $tile: type($vec), type($tile) -> type($res) | %vector_a = xetile.atomic_rmw \ %value, %tile: vector<8x16xbf16>, tile<8x16xbf16> to vector<8x16xbf16> | +|transpose | operation ::=xetile.transpose $vec attr_dict: type($vec) -> type($res) | %vector_a = xetile.transpose %vector_b [1, 0]: vector<64x32xfloat> into vector<32x64xfloat> | +|reduction | operation ::=xetile.reduction \<$kind\> $src attr_dict: type($value) -> type($res) | %vector_a = xetile.reduction \ %vector_b [1]: vector<64x32xfloat> into vector<64x1xfloat> | +|broadcast | operation ::=xetile.broadcast $src attr_dict: type($value) -> type($res) | %vector_a = xetile.broadcast %vector_b[0]: vector<1x32xfloat> into vector<64x32xfloat> | - -*Operations only used to support internal lowering. - -**OP name convention: init_tile, load_tile, prefetch_tile, store_tile, and update_offset operates on the tile type and involves memory access. tile_xxx operates on vector data type only. +**init_tile, load_tile, prefetch_tile, store_tile, atomic_rmw, and update_offset operates on the tile type and involves memory access. other operates on vector data type only. To create a 2D Tile memory descriptor, the user needs to set up a tile (init_tile) describing a 2D region within the global memory. Setting up a tile requires the shape of the parent tile and the underneath physical memory buffer size, known as the base matrix. The base matrix must be 2D and must be contiguous. The XeTile takes the base matrix address pointer, shape, and strides, and the tile’s offsets and shape. Offsets, strides, and shapes are for two dimensions and in the number of elements. base_stride[0] describes the number of elements between the two rows, describing the width of the underneath physical memory buffer, and *%base_strides[1] must be 1, as the innermost dimension of the base matrix must be contiguous. The current version only supports 2D memref with a row-major layout. @@ -132,7 +129,7 @@ A `tile_mma` variant without vector_c initialization. ``` -`atomic_rmw_tile` atomically reads, modifies, and writes back data to the memory specified by the tile. +`atomic_rmw` atomically reads, modifies, and writes back data to the memory specified by the tile. ```mlir %ret_value = xetile.atomic_rmw %value, %tile: @@ -141,17 +138,29 @@ A `tile_mma` variant without vector_c initialization. xetile.atomic_rmw reuses the arith dialect attribute, mlir::arith::AtomicRMWKindAttr. -`tile_transpose` transpose a 2D vector. It has the same semantics as the vector.transpose, but restricts the vector dimension to 2D. +`transpose` transpose a 2D vector. It has the same semantics as the vector.transpose, but restricts the vector dimension to 2D. +```mlir + %vector_a = xetile.transpose [1, 0] %vector_b: vector<64x32xfloat> into vector<32x64xfloat> +``` +`reduction` performs a reduction operation over a 2D vector. The result is a 2D vector with the reduction dimension reduced to 1. It has the same semantics as the vector.multi_dimesnion, but restricts the vector dimension to 2D. The reduce operation are the same as what is defined in vector dialects's multi_reduction: add/mul/minsi/minui/maxsi/maxui /and/or/xor for integers, and add/mul/minnumf/maxnumf/minimumf /maximumf for floats. ```mlir - %vector_a = xetile.tile_transpose [1, 0] %vector_b: vector<64x32xfloat> into vector<32x64xfloat> + %vector_a = xetile.reduction %vector_b [1]: vector<64x32xfloat> into vector<64x1xfloat> ``` -`tile_reduce` performs a reduction operation over a 2D vector. The result is a 2D vector with the size of reduced axis being 1. It has the same semantics as the vector.multi_dimesnion, but restricts the vector dimension to 2D. The reduce operation are the same as vector.multi_dimension:add/mul/minsi/minui/maxsi/maxui /and/or/xor for integers, and add/mul/minnumf/maxnumf/minimumf /maximumf for floats. +`reduction_size` attribute support reduction over a given size, which may be a divisor of dimension size being reduced. This allow user to partially reduce the tensor without reshaping the vector to higher dimension just for reduction. With `reduction_size`, the reduction is done over contiguous elements along the reduction dimension. ```mlir - %vector_a = xetile.tile_reduce %vector_b [1]: vector<64x32xfloat> into vector<64x1xfloat> + %vector_a = xetile.reduction %vector_b [0] {$reduction_size=32}: vector<64x64xfloat> into vector<2x64xfloat> + %vector_a = xetile.reduction %vector_b [1] {$reduction_size=32}: vector<64x64xfloat> into vector<64x2xfloat> ``` -`tile_broadcast` broadcast from 1D vector to a 2D vector. +`broadcast` broadcast from a 2D vector to a 2D vector. The source tensor's broadcast dimension must be 1. ```mlir - %vector_a = xetile.tile_broadcast %vector_b [0]: vector<1x32xfloat> into vector<64x32xfloat> + %vector_a = xetile.broadcast %vector_b [0]: vector<1x32xfloat> into vector<64x32xfloat> +``` + +`broadcast_size` allows the broadcast dimension being more than 1. The broadcast operation "stretches" the input vector to match output shape as defined by `broadcast_size`. Along the broadcast dimension, each element is being replicated and fill a contiguous block of `broadcast_size`, and then fit into the output vector. + +```mlir + %vector_a = xetile.broadcast %vector_b [0] {$broadcast_size=32}: vector<2x64xfloat> into vector<64x64xfloat> + %vector_a = xetile.broadcast %vector_b [1] {$broadcast_size=32}: vector<64x2float> into vector<64x64xfloat> ``` ## support for load_gather and store_scatter (experimental) @@ -203,10 +212,10 @@ The proposal is to attach the `xetile.wg_map` attribute to the vector based XeTi | Ops | Syntax | Example | | :--- | :---- | :--- | |tile_mma | operation ::= xetile.tile_mma $matA, $matB, $matC attr_dict: type($matA), type($matB), type($matC)-> type($res) | %vector_c = xetile.tile_mma %vector_a, %vector_b, %vector_c {#mp_c} : vector<64x32xbf16>, vector<32x128xbf16>, vector<64x128xfloat> into vector<64x128xfloat> | -|tile_transpose | operation ::= xetile.tile_transpose $permuation_dims attr_dict $vec : type($vec) -> type($res) | %vector_a = xetile.tile_transpose %vector_b {#mp_a}: vector<64x32xfloat> into vector<32x64xfloat> | -|tile_reduce | operation ::= xetile.tile_reduce $kind $src $reduction_dims attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_reduce %vector_b [1] {#mp_a}: vector<64x32xfloat> into vector<64x1xfloat> | -|tile_broadcast | operation ::= xetile.tile_broadcast $src $broadcast_dims attr_dict : type($value) -> type($res) | %vector_a = xetile.tile_broadcast %vector_b [0] {#mp_a}: vector<1x32xfloat> into vector<64x32xfloat> | -|tile_conv_layout | operation ::= xetile.conv_layout $src attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_conv_layout %vector_b {#mp_a} : vector<256x256xfloat> into vector<256x256xfloat> | +|transpose | operation ::= xetile.transpose attr_dict $vec : type($vec) -> type($res) | %vector_a = xetile.transpose %vector_b {#mp_a}: vector<64x32xfloat> into vector<32x64xfloat> | +|reduction | operation ::= xetile.reduction $kind $src attr_dict: type($value) -> type($res) | %vector_a = xetile.reduction %vector_b [1] {#mp_a}: vector<64x32xfloat> into vector<64x1xfloat> | +|broadcast | operation ::= xetile.broadcast $src attr_dict : type($value) -> type($res) | %vector_a = xetile.broadcast %vector_b [0] {#mp_a}: vector<1x32xfloat> into vector<64x32xfloat> | +|convert_layout | operation ::= xetile.conv_layout $src attr_dict: type($value) -> type($res) | %vector_a = xetile.convert_layout %vector_b {#mp_a} : vector<256x256xfloat> into vector<256x256xfloat> | With the `wg_map` attribute attached for the output vector, `tile_mma` does a matrix multiplication at a work group level vector. ```mlir @@ -223,61 +232,69 @@ The `wg_map` attribute of input vector operands can be derived from the wg_map_d #wg_map_c = #xetile.wg_map //wg_map for %vector_c ``` -`tile_reduce` with `wg_map` does the reduction over a workgroup level vector. +`reduction` with `wg_map` does the reduction over a workgroup level vector. ```mlir #wg_map_a = #xetile.wg_map - %vector_a = xetile.tile_reduce %vector_b [1] {#wg_map_a}: vector<256x128xfloat> into vector<256x1xfloat> + %vector_a = xetile.reduction %vector_b [1] {#wg_map_a}: vector<256x128xfloat> into vector<256x1xfloat> ``` +`reduction_size` attribute is used to support paritial reduction. +```mlir + #wg_map_a = #xetile.wg_map + #wg_map_b = #xetile.wg_map + %vector_a = math.exp %input {#wg_map_a} : vector<256x128xf32> + %vector_b = xetile.reduction %vector_a [0] {$reduction_size = [32]} {#wg_map_b}: vector<256x128xfloat> into vector<8x128xfloat> +``` + The `wg_map` attribute of the input vector can be derived from the wg_map_a. sg_layout must be same, sg_data for the dimension being reduced must be same as the input vector, and the other dimension must be same as the wg_map_a. The input vector's wg_map attribute may be retrieved from its producer op, and the retrieved attribute must be consistent with the derived one. Below is the derived wg_map for the input vector in the example above. ```mlir #wg_map_b = #xetile.wg_map //wg_map for %vector_b ``` -`tile_broadcast` with `wg_map` attribute broadcast at workgroup level. +`broadcast` with `wg_map` attribute broadcast at workgroup level. ```mlir #wg_map_a = #xetile.wg_map - %vector_a = xetile.tile_broadcast %vector_b [1] {#wg_map_a}: vector<256x1xfloat> into vector<256x256xfloat> + %vector_a = xetile.broadcast %vector_b [1] {#wg_map_a}: vector<256x1xfloat> into vector<256x256xfloat> ``` The `wg_map` attribute of the input vector can be derived from the wg_map_a. sg_layout must be same, sg_data for the dimension being broadcast must be "1", and the other dimension must be same as the wg_map_a. The input vector's wg_map attribute may be retrieved from its producer op, and the retrieved attribute must be consistent with the derived one. Below is the derived wg_map for the input vector in the example above. ```mlir #wg_map_b = #xetile.wg_map //wg_map for %vector_b ``` -`tile_transpose` with `wg_map` attribute transpose a workgroup level vector. +`transpose` with `wg_map` attribute transpose a workgroup level vector. ```mlir #wg_map_a = #xetile.wg_map - %vector_a = xetile.tile_transpose %vector_b {#wg_map_a}: vector<512x128xfloat> into vector<128x512xfloat> + %vector_a = xetile.transpose %vector_b {#wg_map_a}: vector<512x128xfloat> into vector<128x512xfloat> ``` The `wg_map` attribute of the input vector can be derived from the wg_map_a. The two dimension of sg_layout and sg_data must be swapped. The input vector's wg_map attribute may be retrieved from its producer op, and the retrieved attribute must be consistent with the derived one. Below is the derived wg_map for the input vector in the example above. ```mlir #wg_map_b = #xetile.wg_map //wg_map for %vector_b ``` -The tile_transpose can be implemented by saving and restoring from the shared local memory. It can be conceptually viewed as a composition of two operations: 1) store the vector to to shared memory with the #wg_map_b mapping assuming row_major and 2) use wg_map_a mapping to load the data from shared memory to vector assuming column_major. To support this, we relax the restriction of tile_load and tile_store so that they can load 2D from share local memory. +The transpose can be implemented by saving and restoring from the shared local memory. It can be conceptually viewed as a composition of two operations: 1) store the vector to to shared memory with the #wg_map_b mapping assuming row_major and 2) use wg_map_a mapping to load the data from shared memory to vector assuming column_major. To support this, we relax the restriction of tile_load and tile_store so that they can load 2D from share local memory. An optimization is to analyze the load op which produces %vector_b, carefully arrange its mapping so that each subgroup thread loads its corresponding subgroup tile, and then either combine transpose function to the load op or do an in-register transpose. -`tile_conv_layout` with `wg_map` attributes remaps the workgroup level vector to subgroup threads. The second `wg_map` attribute is optional and describes the input operand. The input vector's wg_map attribute may be retrieved from its producer op, and the retrieved attribute must be consistent with the second `wg_map` attribute if it is present. +`convert_layout` with `wg_map` attributes remaps the workgroup level vector to subgroup threads. The second `wg_map` attribute is optional and describes the input operand. The input vector's wg_map attribute may be retrieved from its producer op, and the retrieved attribute must be consistent with the second `wg_map` attribute if it is present. Example with the wg_map specified for both input and output operands. ```mlir #wg_map_b = #xetile.wg_map // used for cooperative load/prefetch #wg_map_a = #xetile.wg_map // used as mma's input matrix A - %vector_a = xetile.tile_conv_layout %vector_b {#wg_map_a #wg_map_b}: vector<256x256xfloat> into vector<256x256xfloat> + %vector_a = xetile.convert_layout %vector_b {#wg_map_a #wg_map_b}: vector<256x256xfloat> into vector<256x256xfloat> ``` Example without the wg_map specified for the input operand. ```mlir #wg_map_a = #xetile.wg_map // used as mma's input matrix A - %vector_a = xetile.tile_conv_layout %vector_b {#wg_map_a}: vector<256x256xfloat> into vector<256x256xfloat> + %vector_a = xetile.convert_layout %vector_b {#wg_map_a}: vector<256x256xfloat> into vector<256x256xfloat> ``` -The tile_conv_layout could be implemented by saving and restoring from the shared local memory. It can be conceptually viewed as a composition of two operations: 1) store the vector to to shared memory with the #wg_map_b mapping assuming row_major and 2) use wg_map_a mapping to load the data from shared memory to vector assuming same row_major. To support this, we relax the restriction of tile_load and tile_store so that they can load 2D from share local memory. +The convert_layout could be implemented by saving and restoring from the shared local memory. It can be conceptually viewed as a composition of two operations: 1) store the vector to to shared memory with the #wg_map_b mapping assuming row_major and 2) use wg_map_a mapping to load the data from shared memory to vector assuming same row_major. To support this, we relax the restriction of tile_load and tile_store so that they can load 2D from share local memory. ## Alternative design considerations The alternative design of tile data type is to reuse the memref data type. The memref data type needs to be enhanced to allow attributes. So the XeTile's tile data type can be expressed with memref associated with Tile attributes. xetile.wg_map and xetile.sg_map are examples of these attributes. -## Appendix 1 - use case for xetile.order attribute and tile_transpose +## Appendix 1 - use case for xetile.order attribute and transpose xetile.tile describes a 2D block in memory . The default layout of xetile.tile is raw-major contiguous. So tile[i][j] refers to the position i*stride_i + j in the associated memory. The stride_j must be 1 since it is contiguous. This maps well the underlying 2d block loader, which loads data in raw-major layout only and no stride in innermost dimension. Below is the example code for the most common use case of xetile.tile. @@ -309,7 +326,7 @@ This is a use case for the order attribute of xetile.tile. In this use case, the %vc = tile_mma %va, %vb : vector<64x32xbf16>, vector<32x64x bf16> into vector<64x64xbf16>; ``` -Alternatively, the user may just writes the program according to the given memory layout but apply a tile_transpose after the code being loaded. This is also an valid code. +Alternatively, the user may just writes the program according to the given memory layout but apply a transpose after the code being loaded. This is also an valid code. ``` BF16 A[M][K], BT[N, K], C[M][N]; // C = MM(A, BT) For i = 0, M-1, M_tile Do @@ -320,7 +337,7 @@ For i = 0, M-1, M_tile Do %c = init_tile &C, [i, j], [M, N], [N, 1] : tile<64x64xbf16>; // M_tile=64, N_tile=64 %va = load_tile %a : vector<64x32xbf16>; %vbt = load_tile %bt : vector<64x 32x bf16>; - %vb = tile_transpose %vbt: vector<64x32xbf16> into vector<32x64xbf16>; + %vb = transpose %vbt: vector<64x32xbf16> into vector<32x64xbf16>; %vc = tile_mma %va, %vbt : vector<64x32xbf16>, vector<32x64xbf16> into vector<64x64xbf16>; ``` @@ -462,7 +479,7 @@ func.func @test_gemm(%a : memref<4096x4096xf16>,       scf.for %k= %c0 to %c4096 step %c32 {      %4 = load_tile %1 : tile<256x32xf16 #mp_a > -> vector<256x32xf16> // sg_layout=[8,4], sg_data=[32,32]           %10 = load_tile %2  : tile<256x32xf16 #mp_bt> -> vector<256x32xf16> // sg_layout=[4,8], sg_data=[64,32] -           %5 = tile_transpose %10 {#mp_bt #mp_b}: vector<256x32xf16> -> vector<32x256xf16>  // sg_layout=[4,8] -> sg_layout=[8,4] +           %5 = transpose %10 {#mp_bt #mp_b}: vector<256x32xf16> -> vector<32x256xf16>  // sg_layout=[4,8] -> sg_layout=[8,4]      prefetch_tile %1 : tile<256x32xf16, #mp_a_pfh>               // sg_layout=[32,1]           prefetch_tile %2  : tile<256x32xf16, #mp_a_pfh>              // sg_layout=[32,1] @@ -474,10 +491,10 @@ func.func @test_gemm(%a : memref<4096x4096xf16>,         }          %12  = load_tile %7  : tile<1x256xf32, #mp_bcast> -> vector<1x256xf16>     // sg_layout=[8, 4], sg_data=[1,64] -   %13 = tile_broadcast {#mp_bcast #mp_c} %12 [0]: vector<1x256xf32> => vector<256x256xf32>     // sg_layout=[8, 4] +   %13 = broadcast {#mp_bcast #mp_c} %12 [0]: vector<1x256xf32> => vector<256x256xf32>     // sg_layout=[8, 4]         %14 = add %6, %13 : vector<256x256xf32> -    %15 = tile_conv_layout {#mp_c #mp_reduce2} %14 : vector<256x256xf32>   // sg_layout=[8, 4] -> sg_layout=[32, 1] -    %16 = tile_reduce {#mp_reduce2 #mp_reduce} %15 [1], vector<256x256xf32> => vector<256x1xf32>  // sg_layout=[32, 1] +    %15 = convert_layout {#mp_c #mp_reduce2} %14 : vector<256x256xf32>   // sg_layout=[8, 4] -> sg_layout=[32, 1] +    %16 = reduction {#mp_reduce2 #mp_reduce} %15 [1]: vector<256x256xf32> => vector<256x1xf32>  // sg_layout=[32, 1]   store_tile %3, %7: (tile<256x1xf32, #mp_reduce>, vector<256x1xf32>)           // sg_layout=[32, 1]    }  } @@ -489,15 +506,15 @@ The transpose in the program above can be optimized to use a slightly different #mp_b     = #wg_map #mp_bt    = #wg_map %10 = load_tile %2  : tile<256x32xf16 #mp_bt> -> vector<256x32xf16> // sg_layout=[4,8], sg_data=[64,32] -%5 = tile_transpose %10 {#mp_bt #mp_b}: vector<256x32xf16> -> vector<32x256xf16>  // sg_layout=[4,8] -> sg_layout=[8,4] +%5 = transpose %10 {#mp_bt #mp_b}: vector<256x32xf16> -> vector<32x256xf16>  // sg_layout=[4,8] -> sg_layout=[8,4] ``` -With the optimized mapping, the tile_transpose below could be implemented with in-register transpose. +With the optimized mapping, the transpose below could be implemented with in-register transpose. ```mlir #mp_b = #wg_map #mp_bt = #wg_map %10 = load_tile %2 : tile<256x32xf16 #mp_bt> -> vector<256x32xf16>// sg_layout=[32,1], sg_data=[64,32] -%5 = tile_transpose %10 {#mp_bt #mp_b}: vector<256x32xf16> -> vector<32x256xf16> // sg_layout=[32,1] ->sg_layout=[8,4] +%5 = transpose %10 {#mp_bt #mp_b}: vector<256x32xf16> -> vector<32x256xf16> // sg_layout=[32,1] ->sg_layout=[8,4] ``` ## Appendix 2.4 Gemm implementation using cooperative load through shared local memory