Skip to content

Commit

Permalink
Reland "[LLVMGPU] Add basic lowering pipeline without tiling and dist…
Browse files Browse the repository at this point in the history
…ribution" (#16566)

This commit cherry-picks 7881ed92
and adds additional fixes for CUDA transform
dialect tests. The issue was that transform dialect
run and set a certain pipeline already before normal
kernel configuration deduction. We need to check
that before re-setting the pipeline again.

ci-extra: test_a100

---------

Co-authored-by: jinchen <[email protected]>
  • Loading branch information
antiagainst and jinchen62 authored Feb 24, 2024
1 parent 862a031 commit 6596531
Show file tree
Hide file tree
Showing 12 changed files with 162 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,26 @@ def CPU_DataTiling

def LLVMGPU_Default
: I32EnumAttrCase<"LLVMGPUDefault", 100>;
def LLVMGPU_BaseLowering
: I32EnumAttrCase<"LLVMGPUBaseLowering", 101>;
def LLVMGPU_SimpleDistribute
: I32EnumAttrCase<"LLVMGPUDistribute", 101>;
: I32EnumAttrCase<"LLVMGPUDistribute", 102>;
def LLVMGPU_Vectorize
: I32EnumAttrCase<"LLVMGPUVectorize", 102>;
: I32EnumAttrCase<"LLVMGPUVectorize", 103>;
def LLVMGPU_MatmulSimt
: I32EnumAttrCase<"LLVMGPUMatmulSimt", 103>;
: I32EnumAttrCase<"LLVMGPUMatmulSimt", 104>;
def LLVMGPU_MatmulTensorCore
: I32EnumAttrCase<"LLVMGPUMatmulTensorCore", 104>;
: I32EnumAttrCase<"LLVMGPUMatmulTensorCore", 105>;
def LLVMGPU_TransposeSharedMem
: I32EnumAttrCase<"LLVMGPUTransposeSharedMem", 105>;
: I32EnumAttrCase<"LLVMGPUTransposeSharedMem", 106>;
def LLVMGPU_WarpReduction
: I32EnumAttrCase<"LLVMGPUWarpReduction", 106>;
: I32EnumAttrCase<"LLVMGPUWarpReduction", 107>;
def LLVMGPU_PackUnPack
: I32EnumAttrCase<"LLVMGPUPackUnPack", 107>;
: I32EnumAttrCase<"LLVMGPUPackUnPack", 108>;
def LLVMGPU_MatmulTensorCoreMmaSync
: I32EnumAttrCase<"LLVMGPUMatmulTensorCoreMmaSync", 108>;
: I32EnumAttrCase<"LLVMGPUMatmulTensorCoreMmaSync", 109>;
def LLVMGPU_VectorDistribute
: I32EnumAttrCase<"LLVMGPUVectorDistribute", 109>;
: I32EnumAttrCase<"LLVMGPUVectorDistribute", 110>;

def SPIRV_BaseLowering
: I32EnumAttrCase<"SPIRVBaseLowering", 200>;
Expand Down Expand Up @@ -82,8 +84,8 @@ def DispatchLoweringPassPipelineEnum : I32EnumAttr<
CPU_DataTiling,

// LLVMGPU CodeGen pipelines
LLVMGPU_Default, LLVMGPU_SimpleDistribute, LLVMGPU_Vectorize,
LLVMGPU_MatmulSimt, LLVMGPU_MatmulTensorCore,
LLVMGPU_Default, LLVMGPU_BaseLowering, LLVMGPU_SimpleDistribute,
LLVMGPU_Vectorize, LLVMGPU_MatmulSimt, LLVMGPU_MatmulTensorCore,
LLVMGPU_TransposeSharedMem, LLVMGPU_WarpReduction, LLVMGPU_PackUnPack,
LLVMGPU_MatmulTensorCoreMmaSync, LLVMGPU_VectorDistribute,

Expand Down
29 changes: 29 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,10 @@ static void propagateLoweringConfig(Operation *rootOperation,
}
}

//===----------------------------------------------------------------------===//
// Entry Point
//===----------------------------------------------------------------------===//

LogicalResult initGPULaunchConfig(ModuleOp moduleOp) {
llvm::StringMap<IREE::HAL::ExecutableExportOp> exportOps =
getAllEntryPoints(moduleOp);
Expand All @@ -1511,6 +1515,31 @@ LogicalResult initGPULaunchConfig(ModuleOp moduleOp) {
auto exportOp = exportOps.lookup(funcOp.getName());
if (!exportOp)
continue;

if (!getTranslationInfo(funcOp)) {
// If no translation info set, first check whether we already have
// workgroup count set--it's a "contract" to indicate that we should
// bypass all tiling and distribution to go down just the most basic
// lowering flow.
if (Block *body = exportOp.getWorkgroupCountBody()) {
auto retOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
// For scalar dispatch cases--using just one thread of one workgroup.
auto isOne = [](Value value) { return matchPattern(value, m_One()); };
if (llvm::all_of(retOp.getOperands(), isOne)) {
std::array<int64_t, 3> workgroupSize = {1, 1, 1};
if (failed(setDispatchConfig(funcOp, workgroupSize, std::nullopt)))
return failure();
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
funcOp.getContext(),
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUBaseLowering);
if (failed(setTranslationInfo(funcOp, translationInfo))) {
return failure();
}
continue;
}
}
}

SmallVector<Operation *> computeOps = getComputeOps(funcOp);
if (getTranslationInfo(exportOp)) {
// Currently LLVMGPU requires propagation of user lowering configs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ void LLVMGPULowerExecutableTargetPass::runOnOperation() {
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDefault:
addGPUDefaultPassPipeline(pipeline, enableMicrokernels);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUBaseLowering:
addGPUBaseLoweringPassPipeline(pipeline);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDistribute:
addGPUSimpleDistributePassPipeline(pipeline);
break;
Expand Down
23 changes: 23 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,29 @@ void addGPUDefaultPassPipeline(OpPassManager &pm, bool enableMicrokernels) {
createRemoveSingleIterationLoopPass());
}

void addGPUBaseLoweringPassPipeline(OpPassManager &pm) {
auto &nestedModulePM = pm.nest<ModuleOp>();

nestedModulePM.addNestedPass<func::FuncOp>(
createConvertToDestinationPassingStylePass(
/*useWARForCooperativeMatrixCodegen=*/false));
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());

addBufferizePasses(nestedModulePM);
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());

nestedModulePM.addNestedPass<func::FuncOp>(
IREE::LinalgExt::createLinalgExtToLoopsPass());
nestedModulePM.addNestedPass<func::FuncOp>(createMemrefCopyToLinalgPass());
nestedModulePM.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createRemoveSingleIterationLoopPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
}

// Add passes to make the address computation more explicit and optimize them.
//
// The idea here is to be less dependent on what the LLVM backend is able to do,
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ void addGPUWarpReductionPassPipeline(OpPassManager &pm);
/// Default pass pipeline on GPU, currently used only for the ukernel path.
void addGPUDefaultPassPipeline(OpPassManager &pm, bool enableMicrokernels);

/// Pass pipeline to lower IREE HAL executables without tiling and distribution.
void addGPUBaseLoweringPassPipeline(OpPassManager &pm);

/// Populates passes needed to preprocess and select the translation strategy.
void buildLLVMGPUCodegenConfigurationPassPipeline(OpPassManager &pm);

Expand Down
21 changes: 21 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,27 @@ LogicalResult initROCDLLaunchConfig(ModuleOp moduleOp) {
if (!exportOp)
continue;

// First check whether we already have workgroup count set--it's a
// "contract" to indicate that we should bypass all tiling and
// distribution to go down just the most basic lowering flow.
if (Block *body = exportOp.getWorkgroupCountBody()) {
auto retOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
// For scalar dispatch cases--using just one thread of one workgroup.
auto isOne = [](Value value) { return matchPattern(value, m_One()); };
if (llvm::all_of(retOp.getOperands(), isOne)) {
std::array<int64_t, 3> workgroupSize = {1, 1, 1};
if (failed(setDispatchConfig(funcOp, workgroupSize, std::nullopt)))
return failure();
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
funcOp.getContext(),
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUBaseLowering);
if (failed(setTranslationInfo(funcOp, translationInfo))) {
return failure();
}
continue;
}
}

SmallVector<Operation *> computeOps = getComputeOps(funcOp);
if (getTranslationInfo(exportOp)) {
// Currently ROCDL requires propagation of user lowering configs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class ROCDLLowerExecutableTargetPass
OpPassManager pipeline(variantOp.getOperationName());

switch (translationInfo.value().getDispatchLoweringPassPipeline()) {
case CodeGenPipeline::LLVMGPUBaseLowering:
addGPUBaseLoweringPassPipeline(pipeline);
break;
case CodeGenPipeline::LLVMGPUWarpReduction:
addGPUWarpReductionPassPipeline(pipeline);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/LLVMGPU/ROCDLPassDetail.h"
#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"

Expand All @@ -21,7 +22,11 @@ class ROCDLSelectLoweringStrategyPass
: public ROCDLSelectLoweringStrategyBase<ROCDLSelectLoweringStrategyPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Codegen::IREECodegenDialect>();
// clang-format off
registry
.insert<IREE::Codegen::IREECodegenDialect,
bufferization::BufferizationDialect>();
// clang-format on
}

void runOnOperation() override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ iree_lit_test_suite(
srcs = enforce_glob(
[
"config_vector_distribute.mlir",
"lowering_scalar_dispatch.mlir",
"pipeline_vector_distribute.mlir",
"pipeline_warp_reduction.mlir",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_lit_test_suite(
lit
SRCS
"config_vector_distribute.mlir"
"lowering_scalar_dispatch.mlir"
"pipeline_vector_distribute.mlir"
"pipeline_warp_reduction.mlir"
TOOLS
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-rocdl-select-lowering-strategy, iree-rocdl-lower-executable-target)))' -mlir-print-local-scope %s | FileCheck %s

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx90a", ukernels = "none"}>

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>

hal.executable @scalar_dispatch {
hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
hal.executable.export public @scalar_dispatch ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%c1 = arith.constant 1 : index
hal.return %c1, %c1, %c1 : index, index, index
}
builtin.module {
func.func @scalar_dispatch() {
%c0 = arith.constant 0 : index
%c6364136223846793005_i64 = arith.constant 6364136223846793005 : i64
%c1442695040888963407_i64 = arith.constant 1442695040888963407 : i64
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<i64>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<i64>>
%2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:tensor<i64>> -> tensor<i64>
%extracted = tensor.extract %2[] : tensor<i64>
%3 = arith.muli %extracted, %c6364136223846793005_i64 : i64
%4 = arith.addi %3, %c1442695040888963407_i64 : i64
%inserted = tensor.insert %4 into %2[] : tensor<i64>
flow.dispatch.tensor.store %inserted, %1, offsets = [], sizes = [], strides = [] : tensor<i64> -> !flow.dispatch.tensor<writeonly:tensor<i64>>
return
}
}
}
}

// CHECK-LABEL: hal.executable.export public @scalar_dispatch
// CHECK-SAME: translation_info = #iree_codegen.translation_info<LLVMGPUBaseLowering>
// CHECK-SAME: workgroup_size = [1 : index, 1 : index, 1 : index]

// CHECK: func.func @scalar_dispatch()
// CHECK: %[[SPAN0:.+]] = hal.interface.binding.subspan set(0) binding(0)
// CHECK: %[[SPAN1:.+]] = hal.interface.binding.subspan set(0) binding(1)
// CHECK: memref.load %[[SPAN0]][] : memref<i64, #hal.descriptor_type<storage_buffer>>
// CHECK: arith.muli {{.+}} : i64
// CHECK: arith.addi {{.+}} : i64
// CHECK: memref.store %{{.+}}, %[[SPAN1]][] : memref<i64, #hal.descriptor_type<storage_buffer>>
30 changes: 16 additions & 14 deletions compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1689,20 +1689,22 @@ static LogicalResult setSPIRVOpConfig(const spirv::TargetEnv &targetEnv,
static LogicalResult setConfigForKernel(const spirv::TargetEnv &targetEnv,
IREE::HAL::ExecutableExportOp exportOp,
mlir::FunctionOpInterface funcOp) {
// First check whether we already have workgroup count set--it's a "contract"
// to indicate that we should bypass all tiling and distribution to go down
// just the most basic lowering flow.
if (Block *body = exportOp.getWorkgroupCountBody()) {
auto retOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
// For scalar dispatch cases--using just one thread of one workgroup.
auto isOne = [](Value value) { return matchPattern(value, m_One()); };
if (llvm::all_of(retOp.getOperands(), isOne)) {
std::array<int64_t, 3> workgroupSize = {1, 1, 1};
if (failed(setDispatchConfig(funcOp, workgroupSize, std::nullopt)))
return failure();
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
funcOp.getContext(), CodeGenPipeline::SPIRVBaseLowering);
return setTranslationInfo(funcOp, translationInfo);
if (!getTranslationInfo(funcOp)) {
// If no translation info set, first check whether we already have workgroup
// count set--it's a "contract" to indicate that we should bypass all tiling
// and distribution to go down just the most basic lowering flow.
if (Block *body = exportOp.getWorkgroupCountBody()) {
auto retOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
// For scalar dispatch cases--using just one thread of one workgroup.
auto isOne = [](Value value) { return matchPattern(value, m_One()); };
if (llvm::all_of(retOp.getOperands(), isOne)) {
std::array<int64_t, 3> workgroupSize = {1, 1, 1};
if (failed(setDispatchConfig(funcOp, workgroupSize, std::nullopt)))
return failure();
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
funcOp.getContext(), CodeGenPipeline::SPIRVBaseLowering);
return setTranslationInfo(funcOp, translationInfo);
}
}
}

Expand Down

0 comments on commit 6596531

Please sign in to comment.