Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Aug 26, 2024
1 parent f1c44e3 commit 4d3471b
Show file tree
Hide file tree
Showing 12 changed files with 123 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ struct GPUVectorAllocPass final

SmallVector<IREE::VectorExt::ToLayoutOp> opsToPromote;
funcOp.walk([&](IREE::VectorExt::ToLayoutOp op) {
if (op->hasAttr("shared_memory_conversion")) {
if (op.getSharedMemoryConversion()) {
opsToPromote.push_back(op);
}
});
Expand Down Expand Up @@ -139,7 +139,7 @@ struct GPUVectorAllocPass final

// Remove the shared_memory_conversion attribute from the to_layout
// operation.
op->removeAttr("shared_memory_conversion");
op.setSharedMemoryConversion(false);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ ChangeResult DistributionLayout::resolveWithPossibleConflict(
Value input = opOperand.get();
// Create a resolution operation. This conflict should be handeled later by
// someone else, not this analysis.
Operation *resolveOp = builder.create<IREE::VectorExt::ToLayoutOp>(
input.getLoc(), input.getType(), input, rhs);
Operation *resolveOp =
builder.create<IREE::VectorExt::ToLayoutOp>(input.getLoc(), input, rhs);
Value resolvedValue = resolveOp->getResult(0);
opOperand.set(resolvedValue);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ NestedLayoutAttr createNestedLayout(MLIRContext *context, int64_t rank,
FailureOr<std::tuple<VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface>>
MMAScheduleAttr::getContractionLayout(linalg::GenericOp contractOp) const {
MMAScheduleAttr::getContractionLayout(linalg::LinalgOp contractOp) const {
auto maybeOpInfo = VectorContractOpInfo::inferFromIndexingMaps(
contractOp.getIndexingMapsArray());
if (failed(maybeOpInfo)) {
Expand All @@ -998,6 +998,11 @@ MMAScheduleAttr::getContractionLayout(linalg::GenericOp contractOp) const {
MLIRContext *context = getContext();

SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges();
if (llvm::any_of(bounds,
[](int64_t x) { return x == ShapedType::kDynamic; })) {
return failure();
}

int64_t batchCount = opInfo.getBatchCount();
if (batchCount == 1 && bounds[0] != 1) {
LLVM_DEBUG({ llvm::errs() << "non-unit batch dimension\n"; });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def IREEGPU_MmaScheduleAttr : AttrDef<IREEGPU_Dialect, "MMASchedule"> {
::mlir::FailureOr<::std::tuple<VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface>>
getContractionLayout(::mlir::linalg::GenericOp contractOp) const;
getContractionLayout(::mlir::linalg::LinalgOp contractOp) const;
}];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,5 +298,4 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
let genVerifyDecl = 1;
}


#endif // IREE_DIALECT_VECTOREXT_ATTRS
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,29 @@ def IREEVectorExt_ToLayoutOp : IREEVectorExt_PureOp<"to_layout", [
let description = [{
The layout conversion operator takes a shaped value and a layout and
transforms the value to have that layout.

If the "shared_memory_conversion" attribute is set, then this layout
change has to be materialized through shared memory.
}];
let arguments = (ins
AnyShaped:$input,
VectorLayoutInterface:$layout
VectorLayoutInterface:$layout,
DefaultValuedAttr<UnitAttr, "false">:$shared_memory_conversion
);
let results = (outs
AnyShaped:$output
);
let builders = [
OpBuilder<(ins "Value":$input,
"VectorLayoutInterface":$layout,
CArg<"bool", "false">:$shared_memory_conversion), [{
if (shared_memory_conversion) {
build($_builder, $_state, input.getType(), input, layout, UnitAttr::get(input.getContext()));
} else{
build($_builder, $_state, input.getType(), input, layout);
}
}]>
];
let extraClassDeclaration = [{
bool hasTensorSemantics() {
return isa<RankedTensorType>(getOutput().getType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ struct VectorizeToLayoutOpPattern final

// Create the toLayout operation but with vector types instead.
auto newLayoutOp = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, newInput.getType(), newInput, toLayoutOp.getLayout());
// Set attributes.
newLayoutOp->setAttrs(toLayoutOp->getAttrs());
loc, newInput, toLayoutOp.getLayout(),
toLayoutOp.getSharedMemoryConversion());

// Create the write back to a tensor.
int64_t rank = inputTy.getRank();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace {

LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
RewriterBase &rewriter,
linalg::GenericOp contract) {
linalg::LinalgOp contract) {
// TODO: Add SIMT fallback.
if (!schedule) {
return contract->emitError("missing mma schedule for contraction");
Expand All @@ -43,25 +43,25 @@ LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
auto [aLayout, bLayout, cLayout] = *layouts;
Location loc = contract.getLoc();

Value lhs = contract.getOperand(0);
Value rhs = contract.getOperand(1);
Value acc = contract.getOperand(2);
Value lhs = contract->getOperand(0);
Value rhs = contract->getOperand(1);
Value acc = contract->getOperand(2);

// Set layouts for lhs, rhs and acc.
rewriter.setInsertionPoint(contract);
auto layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, lhs.getType(), lhs, aLayout);
auto layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, rhs.getType(), rhs, bLayout);
auto layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, acc.getType(), acc, cLayout);
auto layoutedLhs =
rewriter.create<IREE::VectorExt::ToLayoutOp>(loc, lhs, aLayout);
auto layoutedRhs =
rewriter.create<IREE::VectorExt::ToLayoutOp>(loc, rhs, bLayout);
auto layoutedAcc =
rewriter.create<IREE::VectorExt::ToLayoutOp>(loc, acc, cLayout);

// Promote matmul lhs and rhs.
// TODO: We should read this from the lowering_config on the operation.
// TODO: This is a hack until layout analysis is improved. The layout analysis
// should decide where to put these shared memory conversions.
layoutedLhs->setAttr("shared_memory_conversion", rewriter.getUnitAttr());
layoutedRhs->setAttr("shared_memory_conversion", rewriter.getUnitAttr());
layoutedLhs.setSharedMemoryConversion(true);
layoutedRhs.setSharedMemoryConversion(true);

contract->setOperand(0, layoutedLhs.getResult());
contract->setOperand(1, layoutedRhs.getResult());
Expand All @@ -70,8 +70,8 @@ LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
// Set layout for result.
rewriter.setInsertionPointAfter(contract);
auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, contract.getResult(0).getType(), contract.getResult(0), cLayout);
rewriter.replaceAllUsesExcept(contract.getResult(0), toLayout.getResult(),
loc, contract->getResult(0), cLayout);
rewriter.replaceAllUsesExcept(contract->getResult(0), toLayout.getResult(),
toLayout);

return success();
Expand All @@ -89,51 +89,39 @@ struct LLVMGPUConfigureTensorLayoutsPass final
auto func = getOperation();

std::array<int64_t, 3> workgroupSize;
if (func->hasAttr("workgroup_size")) {
auto tmpSizes =
llvm::cast<ArrayAttr>(func->getAttr("workgroup_size")).getValue();
for (auto [i, size] : llvm::enumerate(tmpSizes)) {
workgroupSize[i] = llvm::cast<IntegerAttr>(size).getInt();
}
} else {
std::optional<SmallVector<int64_t>> maybeWorkgroupSize =
getWorkgroupSize(func);
if (!maybeWorkgroupSize) {
func->emitOpError()
<< "unable to query workgroup_size information from entry point";
return signalPassFailure();
}
for (auto [index, value] : llvm::enumerate(maybeWorkgroupSize.value())) {
workgroupSize[index] = value;
}
for (auto index : llvm::seq<size_t>(maybeWorkgroupSize->size(), 3)) {
workgroupSize[index] = 1;
}
std::optional<SmallVector<int64_t>> maybeWorkgroupSize =
getWorkgroupSize(func);
if (!maybeWorkgroupSize) {
func->emitOpError()
<< "unable to query workgroup_size information from entry point";
return signalPassFailure();
}
for (auto [index, value] : llvm::enumerate(maybeWorkgroupSize.value())) {
workgroupSize[index] = value;
}
for (auto index : llvm::seq<size_t>(maybeWorkgroupSize->size(), 3)) {
workgroupSize[index] = 1;
}

llvm::StringLiteral scheduleAttrName =
IREE::GPU::MMAScheduleAttr::getMnemonic();
auto scheduleAttr =
func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
if (!scheduleAttr) {
DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
configDict.get(scheduleAttrName));
}
DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
auto scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
configDict.get(scheduleAttrName));

// Vector layout option setter aimed at contractions. For now, layout
// setting for other problems like reductions is TODO.
SmallVector<linalg::GenericOp> contracts;
SmallVector<linalg::LinalgOp> contracts;

func->walk([&](linalg::GenericOp linalgOp) {
func->walk([&](linalg::LinalgOp linalgOp) {
if (linalg::isaContractionOpInterface(linalgOp)) {
contracts.push_back(linalgOp);
}
});

IRRewriter rewriter(func);

for (linalg::GenericOp contract : contracts) {
for (linalg::LinalgOp contract : contracts) {
if (failed(setContractionAnchor(scheduleAttr, rewriter, contract))) {
return signalPassFailure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ LogicalResult setTransferReadAnchor(ArrayRef<int64_t> workgroupSize,
Location loc = transfer.getLoc();
rewriter.setInsertionPointAfter(transfer);
auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, transfer.getResult().getType(), transfer.getResult(), layout);
loc, transfer.getResult(), layout);
rewriter.replaceAllUsesExcept(transfer, toLayout.getResult(), toLayout);

return success();
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,6 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
// this point.
funcPassManager.addPass(createLinalgGeneralizeNamedOpsPass());
if (!usePadToModelSharedMemcpy) {
// Folding unit dims gets confused with padding.
LinalgFoldUnitExtentDimsPassOptions options;
options.useRankReducingSlices = true;
funcPassManager.addPass(mlir::createLinalgFoldUnitExtentDimsPass(options));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1670,20 +1670,20 @@ transform_dialect::SetContractionLayoutAttributes::apply(
Operation *parentOp = operand.getDefiningOp();
if (!parentOp || (parentOp->getNumResults() != 1))
continue;
Value resolvedOperand = rewriter.create<VectorExt::ToLayoutOp>(
loc, operand.getType(), operand, readLayout);
Value resolvedOperand =
rewriter.create<VectorExt::ToLayoutOp>(loc, operand, readLayout);
contract.setOperand(operandIndices[i], resolvedOperand);
}
}

// Set layout anchors.
rewriter.setInsertionPoint(contract);
Value newLhs = rewriter.create<VectorExt::ToLayoutOp>(
loc, contract.getLhsType(), contract.getLhs(), aLayout);
Value newRhs = rewriter.create<VectorExt::ToLayoutOp>(
loc, contract.getRhsType(), contract.getRhs(), bLayout);
Value newAcc = rewriter.create<VectorExt::ToLayoutOp>(
loc, contract.getAccType(), contract.getAcc(), cLayout);
Value newLhs =
rewriter.create<VectorExt::ToLayoutOp>(loc, contract.getLhs(), aLayout);
Value newRhs =
rewriter.create<VectorExt::ToLayoutOp>(loc, contract.getRhs(), bLayout);
Value newAcc =
rewriter.create<VectorExt::ToLayoutOp>(loc, contract.getAcc(), cLayout);
contract.setOperand(0, newLhs);
contract.setOperand(1, newRhs);
contract.setOperand(2, newAcc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
iterator_types = ["parallel", "parallel", "reduction"]
}

func.func @matmul_96x64x16(%lhs: tensor<96x16xf16>,
func.func @matmul_96x64x16_mfma(%lhs: tensor<96x16xf16>,
%rhs: tensor<64x16xf16>,
%init: tensor<96x64xf32>)
-> tensor<96x64xf32>
Expand All @@ -40,7 +40,58 @@ func.func @matmul_96x64x16(%lhs: tensor<96x16xf16>,
// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], subgroup_strides = [0, 0], thread_strides = [1, 32]>
// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1], subgroup_strides = [0, 0], thread_strides = [32, 1]>

// CHECK-LABEL: func.func @matmul_96x64x16
// CHECK-LABEL: func.func @matmul_96x64x16_mfma

// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]] {shared_memory_conversion}
// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]] {shared_memory_conversion}
// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED2]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
// CHECK-SAME: outs(%[[ACC]]

// -----

#translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [64, 1, 1]
subgroup_size = 64,
{mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>,
subgroup_m_count = 1,
subgroup_n_count = 1>}>

#maps = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (n, k)>,
affine_map<(m, n, k) -> (m, n)>
]

#traits = {
indexing_maps = #maps,
iterator_types = ["parallel", "parallel", "reduction"]
}

func.func @matmul_96x64x16_wmma(%lhs: tensor<96x16xf16>,
%rhs: tensor<64x16xf16>,
%init: tensor<96x64xf32>)
-> tensor<96x64xf32>
attributes { translation_info = #translation } {
%out = linalg.generic #traits
ins(%lhs, %rhs: tensor<96x16xf16>, tensor<64x16xf16>)
outs(%init: tensor<96x64xf32>) {
^bb0(%in: f16, %in_1: f16, %out: f32):
%ex = arith.extf %in : f16 to f32
%ex_1 = arith.extf %in_1 : f16 to f32
%mul = arith.mulf %ex, %ex_1 : f32
%sum = arith.addf %out, %mul : f32
linalg.yield %sum : f32
} -> tensor<96x64xf32>
return %out : tensor<96x64xf32>
}

// CHECK-DAG: #[[$NESTED:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [6, 1], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16], subgroup_strides = [0, 0], thread_strides = [1, 0]>
// CHECK-DAG: #[[$NESTED1:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [4, 1], outers_per_batch = [1, 1], threads_per_outer = [16, 1], elements_per_thread = [1, 16], subgroup_strides = [0, 0], thread_strides = [1, 0]>
// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroups_per_workgroup = [1, 1], batches_per_subgroup = [6, 4], outers_per_batch = [8, 1], threads_per_outer = [2, 16], elements_per_thread = [1, 1], subgroup_strides = [0, 0], thread_strides = [16, 1]>

// CHECK-LABEL: func.func @matmul_96x64x16_wmma

// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED]] {shared_memory_conversion}
// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to #[[$NESTED1]] {shared_memory_conversion}
Expand Down

0 comments on commit 4d3471b

Please sign in to comment.