Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
qedawkins committed Aug 6, 2024
1 parent 0b9028d commit 85588aa
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ SmallVector<Value> getTileSizes(OpBuilder &b, Operation *op, unsigned level);
void setLoweringConfig(Operation *op, Attribute config);

/// Convenience function that sets the lowering configuration on the operation
/// and translation info on for a generic lowering config, lowering pipeline,
/// and translation info for a generic lowering config, lowering pipeline,
/// and optional workgroup/subgroup size.
inline LogicalResult setOpConfigAndEntryPointFnTranslation(
mlir::FunctionOpInterface entryPointFn, Operation *op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -46,8 +47,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
return failure();
}

if (contractionDims->k.size() < 1 || contractionDims->m.size() < 1 ||
contractionDims->n.size() < 1) {
if (contractionDims->k.empty() || contractionDims->m.empty() ||
contractionDims->n.empty()) {
return failure();
}

Expand All @@ -65,11 +66,6 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
return failure();
}

// Bail out on matvec-like cases.
if (bounds[mDim] == 1 || bounds[nDim] == 1) {
return failure();
}

Value lhs = linalgOp.getDpsInputOperand(0)->get();
Value rhs = linalgOp.getDpsInputOperand(1)->get();
Value init = linalgOp.getDpsInitOperand(0)->get();
Expand All @@ -82,8 +78,16 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
lhsElemType, rhsElemType, initElemType};

SmallVector<GPUMatmulShapeType> intrinsics;
intrinsics.reserve(target.getWgp().getMma().size());
SmallVector<IREE::GPU::MmaInterfaceAttr> supportedMmas;
for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
IREE::GPU::MMAIntrinsic type = mma.getIntrinsic().getValue();
// TODO: Drop this once all intrinsics are supported.
if (type != IREE::GPU::MMAIntrinsic::MFMA_F16_16x16x16_F32 &&
type != IREE::GPU::MMAIntrinsic::MFMA_I8_16x16x32_I32) {
continue;
}
supportedMmas.push_back(mma);

auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
if (mma.getSubgroupSize() != targetSubgroupSize)
Expand Down Expand Up @@ -186,8 +190,7 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
// Similarly the reduction tile size is just the post-packing tile count.
reductionTileSizes[kDim] = schedule->kTileCount;

IREE::GPU::MmaInterfaceAttr mmaKind =
target.getWgp().getMma()[schedule->index];
IREE::GPU::MmaInterfaceAttr mmaKind = supportedMmas[schedule->index];

// Attach the MMA schedule as an attribute to the entry point export function
// for later access in the pipeline.
Expand Down

0 comments on commit 85588aa

Please sign in to comment.