Skip to content

Commit

Permalink
[VnniTransform][Util]Relocate vnni utils (#1000)
Browse files Browse the repository at this point in the history
  • Loading branch information
chencha3 authored Jan 14, 2025
1 parent 5dbeec7 commit 48a375a
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 67 deletions.
18 changes: 18 additions & 0 deletions include/imex/Utils/XeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ using namespace mlir::xegpu;

namespace imex {

// this method computes the vnni factor for the given element type.
// it returns 1 by default for types does not need vnni transformation.
int getVnniFactor(mlir::Type elemTy);

// a helper function to get the vector type after doing vnni transformation
// e.g., vector<4x4xf16> -> vector<2x4x2xf16>
mlir::VectorType getPackedType(mlir::VectorType vecTy);

// Apply VNNI transformation to the given value, using VectorShuffle
// and shapecast operations. Since it is to add some extra operations
// on the given value. Thus, the function also returns the first
// operation applied to the value for convenience, such that the
// user can replace all uses of current value, except the first
// appended operation.
std::pair<mlir::Value, mlir::Operation *>
applyVnniTransform(mlir::OpBuilder &builder,
mlir::TypedValue<mlir::VectorType> value);

// valid chunk sizes are 1, 2, 3, 4, 8 if simdLanes > 1.
// 16, 32, and 64 are only available if simdLanes == 1.
llvm::SmallVector<int> getSupportedChunkSizes(int simdlanes);
Expand Down
70 changes: 3 additions & 67 deletions lib/Transforms/VnniTransformation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <mlir/IR/BuiltinTypes.h>

#include "imex/Transforms/Passes.h"
#include "imex/Utils/XeCommon.h"

#include <optional>

Expand All @@ -28,6 +29,8 @@ namespace imex {
#include "imex/Transforms/Passes.h.inc"
} // namespace imex

using namespace imex;

namespace {
// Struct describing current layout per mlir::Value.
// Have 3 possible states:
Expand Down Expand Up @@ -102,11 +105,6 @@ class LayoutLattice : public mlir::dataflow::Lattice<Layout> {
}
};

static int getVnniFactor(mlir::Type elemTy) {
assert(elemTy.isIntOrFloat() && "Only integer and float types supported");
return 32 / elemTy.getIntOrFloatBitWidth();
}

static bool isVNNIApplicable(mlir::Type type) {
auto vecTy = mlir::dyn_cast<mlir::VectorType>(type);

Expand Down Expand Up @@ -267,68 +265,6 @@ class LayoutAnalysis {
};
} // namespace

static mlir::VectorType getPackedType(mlir::VectorType vecTy) {
auto shape = vecTy.getShape().vec();
auto factor = getVnniFactor(vecTy.getElementType());
unsigned axis = shape.size() == 3 ? 1 : 0;

// Only 2D/3D vector supported and The vector size
// must be divisible by the factor
if ((shape.size() != 2 && shape.size() != 3) || !factor ||
shape[axis] % factor != 0)
return nullptr;

shape.emplace_back(factor);
shape[axis] /= factor;
return mlir::VectorType::get(shape, vecTy.getElementType());
}

static llvm::SmallVector<int64_t>
getVNNIShuffleIndices(mlir::VectorType srcType) {
auto numElements = srcType.getNumElements();
llvm::SmallVector<int64_t> ret(numElements, 0);
auto dstType = getPackedType(srcType);
auto dstShape = dstType.getShape();
// Convert from contiguous layout to VNNI packed, e.g. from
// `vector<16x16xf16>` to `vector<8x16x2xf16>`.
// To arrange the data in VNNI format, the shuffle indices must satisfy
// following mapping.
// [i, j, k] => i * dstShape[1] * dstShape[2] + j + k * dstShape[1]
int shuffleIndex = 0;
for (unsigned i = 0; i < dstShape[0]; ++i) {
for (unsigned j = 0; j < dstShape[1]; ++j) {
for (unsigned k = 0; k < dstShape[2]; ++k) {
ret[shuffleIndex++] =
i * dstShape[1] * dstShape[2] + j + k * dstShape[1];
}
}
}
return ret;
}

static std::pair<mlir::Value, mlir::Operation *>
applyVnniTransform(mlir::OpBuilder &builder,
mlir::TypedValue<mlir::VectorType> src) {
assert(src && "value must be non-null");
auto loc = src.getLoc();
auto srcTy = src.getType();
auto elems = srcTy.getNumElements();
auto elemTy = srcTy.getElementType();
auto linearVecTy = mlir::VectorType::get(elems, elemTy);
auto root = builder.create<mlir::vector::ShapeCastOp>(loc, linearVecTy, src);
auto mask = getVNNIShuffleIndices(srcTy);
auto shuffle = builder.create<mlir::vector::ShuffleOp>(loc, root, root, mask);
auto packedTy = getPackedType(srcTy);
auto cast = builder.create<mlir::vector::ShapeCastOp>(loc, packedTy, shuffle);
// for convenience of load+transpose optimization, add packed attribute
// to indicate these ops are used to do vnni transform.
root.getOperation()->setAttr("packed", builder.getUnitAttr());
shuffle.getOperation()->setAttr("packed", builder.getUnitAttr());
cast.getOperation()->setAttr("packed", builder.getUnitAttr());

return {cast, root};
}

static void applyVnniTransformOnResults(mlir::OpBuilder &builder,
mlir::Operation *op,
LayoutAnalysis &analysis) {
Expand Down
70 changes: 70 additions & 0 deletions lib/Utils/XeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,76 @@
#include "llvm/Support/FormatVariadic.h"

namespace imex {

static llvm::SmallVector<int64_t>
getVNNIShuffleIndices(mlir::VectorType srcType) {
auto numElements = srcType.getNumElements();
llvm::SmallVector<int64_t> ret(numElements, 0);
auto dstType = getPackedType(srcType);
auto dstShape = dstType.getShape();
// Convert from contiguous layout to VNNI packed, e.g. from
// `vector<16x16xf16>` to `vector<8x16x2xf16>`.
// To arrange the data in VNNI format, the shuffle indices must satisfy
// following mapping.
// [i, j, k] => i * dstShape[1] * dstShape[2] + j + k * dstShape[1]
int shuffleIndex = 0;
for (unsigned i = 0; i < dstShape[0]; ++i) {
for (unsigned j = 0; j < dstShape[1]; ++j) {
for (unsigned k = 0; k < dstShape[2]; ++k) {
ret[shuffleIndex++] =
i * dstShape[1] * dstShape[2] + j + k * dstShape[1];
}
}
}
return ret;
}

int getVnniFactor(mlir::Type elemTy) {
int vnni = 1;
if (elemTy.isIntOrFloat())
vnni = std::max<int>(32 / elemTy.getIntOrFloatBitWidth(), 1);
return vnni;
}

mlir::VectorType getPackedType(mlir::VectorType vecTy) {
auto shape = vecTy.getShape().vec();
auto factor = getVnniFactor(vecTy.getElementType());
unsigned axis = shape.size() == 3 ? 1 : 0;

// Only 2D/3D vector supported and The vector size
// must be divisible by the factor
if ((shape.size() != 2 && shape.size() != 3) || !factor ||
shape[axis] % factor != 0)
return nullptr;

shape.emplace_back(factor);
shape[axis] /= factor;
return mlir::VectorType::get(shape, vecTy.getElementType());
}

std::pair<mlir::Value, mlir::Operation *>
applyVnniTransform(mlir::OpBuilder &builder,
mlir::TypedValue<mlir::VectorType> src) {
assert(src && "value must be non-null");
auto loc = src.getLoc();
auto srcTy = src.getType();
auto elems = srcTy.getNumElements();
auto elemTy = srcTy.getElementType();
auto linearVecTy = mlir::VectorType::get(elems, elemTy);
auto root = builder.create<mlir::vector::ShapeCastOp>(loc, linearVecTy, src);
auto mask = getVNNIShuffleIndices(srcTy);
auto shuffle = builder.create<mlir::vector::ShuffleOp>(loc, root, root, mask);
auto packedTy = getPackedType(srcTy);
auto cast = builder.create<mlir::vector::ShapeCastOp>(loc, packedTy, shuffle);
// for convenience of load+transpose optimization, add packed attribute
// to indicate these ops are used to do vnni transform.
root.getOperation()->setAttr("packed", builder.getUnitAttr());
shuffle.getOperation()->setAttr("packed", builder.getUnitAttr());
cast.getOperation()->setAttr("packed", builder.getUnitAttr());

return {cast, root};
}

llvm::SmallVector<int> getSupportedChunkSizes(int simdlanes) {
if (simdlanes == 1)
return {64, 32, 16, 8, 4, 3, 2, 1};
Expand Down

0 comments on commit 48a375a

Please sign in to comment.