Skip to content

Commit

Permalink
[LLVMGPU] Add basic lowering pipeline without tiling and distribution (
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 authored Feb 22, 2024
1 parent 599a3d1 commit 7881ed9
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,26 @@ def CPU_DataTiling

def LLVMGPU_Default
: I32EnumAttrCase<"LLVMGPUDefault", 100>;
def LLVMGPU_BaseLowering
: I32EnumAttrCase<"LLVMGPUBaseLowering", 101>;
def LLVMGPU_SimpleDistribute
: I32EnumAttrCase<"LLVMGPUDistribute", 101>;
: I32EnumAttrCase<"LLVMGPUDistribute", 102>;
def LLVMGPU_Vectorize
: I32EnumAttrCase<"LLVMGPUVectorize", 102>;
: I32EnumAttrCase<"LLVMGPUVectorize", 103>;
def LLVMGPU_MatmulSimt
: I32EnumAttrCase<"LLVMGPUMatmulSimt", 103>;
: I32EnumAttrCase<"LLVMGPUMatmulSimt", 104>;
def LLVMGPU_MatmulTensorCore
: I32EnumAttrCase<"LLVMGPUMatmulTensorCore", 104>;
: I32EnumAttrCase<"LLVMGPUMatmulTensorCore", 105>;
def LLVMGPU_TransposeSharedMem
: I32EnumAttrCase<"LLVMGPUTransposeSharedMem", 105>;
: I32EnumAttrCase<"LLVMGPUTransposeSharedMem", 106>;
def LLVMGPU_WarpReduction
: I32EnumAttrCase<"LLVMGPUWarpReduction", 106>;
: I32EnumAttrCase<"LLVMGPUWarpReduction", 107>;
def LLVMGPU_PackUnPack
: I32EnumAttrCase<"LLVMGPUPackUnPack", 107>;
: I32EnumAttrCase<"LLVMGPUPackUnPack", 108>;
def LLVMGPU_MatmulTensorCoreMmaSync
: I32EnumAttrCase<"LLVMGPUMatmulTensorCoreMmaSync", 108>;
: I32EnumAttrCase<"LLVMGPUMatmulTensorCoreMmaSync", 109>;
def LLVMGPU_VectorDistribute
: I32EnumAttrCase<"LLVMGPUVectorDistribute", 109>;
: I32EnumAttrCase<"LLVMGPUVectorDistribute", 110>;

def SPIRV_BaseLowering
: I32EnumAttrCase<"SPIRVBaseLowering", 200>;
Expand Down Expand Up @@ -82,8 +84,8 @@ def DispatchLoweringPassPipelineEnum : I32EnumAttr<
CPU_DataTiling,

// LLVMGPU CodeGen pipelines
LLVMGPU_Default, LLVMGPU_SimpleDistribute, LLVMGPU_Vectorize,
LLVMGPU_MatmulSimt, LLVMGPU_MatmulTensorCore,
LLVMGPU_Default, LLVMGPU_BaseLowering, LLVMGPU_SimpleDistribute,
LLVMGPU_Vectorize, LLVMGPU_MatmulSimt, LLVMGPU_MatmulTensorCore,
LLVMGPU_TransposeSharedMem, LLVMGPU_WarpReduction, LLVMGPU_PackUnPack,
LLVMGPU_MatmulTensorCoreMmaSync, LLVMGPU_VectorDistribute,

Expand Down
26 changes: 26 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,10 @@ static void propagateLoweringConfig(Operation *rootOperation,
}
}

//===----------------------------------------------------------------------===//
// Entry Point
//===----------------------------------------------------------------------===//

LogicalResult initGPULaunchConfig(ModuleOp moduleOp) {
llvm::StringMap<IREE::HAL::ExecutableExportOp> exportOps =
getAllEntryPoints(moduleOp);
Expand All @@ -1505,6 +1509,28 @@ LogicalResult initGPULaunchConfig(ModuleOp moduleOp) {
auto exportOp = exportOps.lookup(funcOp.getName());
if (!exportOp)
continue;

// First check whether we already have workgroup count set--it's a
// "contract" to indicate that we should bypass all tiling and
// distribution to go down just the most basic lowering flow.
if (Block *body = exportOp.getWorkgroupCountBody()) {
auto retOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
// For scalar dispatch cases--using just one thread of one workgroup.
auto isOne = [](Value value) { return matchPattern(value, m_One()); };
if (llvm::all_of(retOp.getOperands(), isOne)) {
std::array<int64_t, 3> workgroupSize = {1, 1, 1};
if (failed(setDispatchConfig(funcOp, workgroupSize, std::nullopt)))
return failure();
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
funcOp.getContext(),
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUBaseLowering);
if (failed(setTranslationInfo(funcOp, translationInfo))) {
return failure();
}
continue;
}
}

SmallVector<Operation *> computeOps = getComputeOps(funcOp);
if (getTranslationInfo(exportOp)) {
// Currently LLVMGPU requires propagation of user lowering configs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ void LLVMGPULowerExecutableTargetPass::runOnOperation() {
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDefault:
addGPUDefaultPassPipeline(pipeline, enableMicrokernels);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUBaseLowering:
addGPUBaseLoweringPassPipeline(pipeline);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDistribute:
addGPUSimpleDistributePassPipeline(pipeline);
break;
Expand Down
23 changes: 23 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,29 @@ void addGPUDefaultPassPipeline(OpPassManager &pm, bool enableMicrokernels) {
createRemoveSingleIterationLoopPass());
}

void addGPUBaseLoweringPassPipeline(OpPassManager &pm) {
auto &nestedModulePM = pm.nest<ModuleOp>();

nestedModulePM.addNestedPass<func::FuncOp>(
createConvertToDestinationPassingStylePass(
/*useWARForCooperativeMatrixCodegen=*/false));
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());

addBufferizePasses(nestedModulePM);
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());

nestedModulePM.addNestedPass<func::FuncOp>(
IREE::LinalgExt::createLinalgExtToLoopsPass());
nestedModulePM.addNestedPass<func::FuncOp>(createMemrefCopyToLinalgPass());
nestedModulePM.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createRemoveSingleIterationLoopPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
}

// Add passes to make the address computation more explicit and optimize them.
//
// The idea here is to be less dependent on what the LLVM backend is able to do,
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ void addGPUWarpReductionPassPipeline(OpPassManager &pm);
/// Default pass pipeline on GPU, currently used only for the ukernel path.
void addGPUDefaultPassPipeline(OpPassManager &pm, bool enableMicrokernels);

/// Pass pipeline to lower IREE HAL executables without tiling and distribution.
void addGPUBaseLoweringPassPipeline(OpPassManager &pm);

/// Populates passes needed to preprocess and select the translation strategy.
void buildLLVMGPUCodegenConfigurationPassPipeline(OpPassManager &pm);

Expand Down
21 changes: 21 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,27 @@ LogicalResult initROCDLLaunchConfig(ModuleOp moduleOp) {
if (!exportOp)
continue;

// First check whether we already have workgroup count set--it's a
// "contract" to indicate that we should bypass all tiling and
// distribution to go down just the most basic lowering flow.
if (Block *body = exportOp.getWorkgroupCountBody()) {
auto retOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
// For scalar dispatch cases--using just one thread of one workgroup.
auto isOne = [](Value value) { return matchPattern(value, m_One()); };
if (llvm::all_of(retOp.getOperands(), isOne)) {
std::array<int64_t, 3> workgroupSize = {1, 1, 1};
if (failed(setDispatchConfig(funcOp, workgroupSize, std::nullopt)))
return failure();
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
funcOp.getContext(),
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUBaseLowering);
if (failed(setTranslationInfo(funcOp, translationInfo))) {
return failure();
}
continue;
}
}

SmallVector<Operation *> computeOps = getComputeOps(funcOp);
if (getTranslationInfo(exportOp)) {
// Currently ROCDL requires propagation of user lowering configs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class ROCDLLowerExecutableTargetPass
OpPassManager pipeline(variantOp.getOperationName());

switch (translationInfo.value().getDispatchLoweringPassPipeline()) {
case CodeGenPipeline::LLVMGPUBaseLowering:
addGPUBaseLoweringPassPipeline(pipeline);
break;
case CodeGenPipeline::LLVMGPUWarpReduction:
addGPUWarpReductionPassPipeline(pipeline);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/LLVMGPU/ROCDLPassDetail.h"
#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"

Expand All @@ -21,7 +22,11 @@ class ROCDLSelectLoweringStrategyPass
: public ROCDLSelectLoweringStrategyBase<ROCDLSelectLoweringStrategyPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Codegen::IREECodegenDialect>();
// clang-format off
registry
.insert<IREE::Codegen::IREECodegenDialect,
bufferization::BufferizationDialect>();
// clang-format on
}

void runOnOperation() override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ iree_lit_test_suite(
srcs = enforce_glob(
[
"config_vector_distribute.mlir",
"lowering_scalar_dispatch.mlir",
"pipeline_vector_distribute.mlir",
"pipeline_warp_reduction.mlir",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_lit_test_suite(
lit
SRCS
"config_vector_distribute.mlir"
"lowering_scalar_dispatch.mlir"
"pipeline_vector_distribute.mlir"
"pipeline_warp_reduction.mlir"
TOOLS
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-rocdl-select-lowering-strategy, iree-rocdl-lower-executable-target)))' -mlir-print-local-scope %s | FileCheck %s

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx90a", ukernels = "none"}>

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>

hal.executable @scalar_dispatch {
hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
hal.executable.export public @scalar_dispatch ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%c1 = arith.constant 1 : index
hal.return %c1, %c1, %c1 : index, index, index
}
builtin.module {
func.func @scalar_dispatch() {
%c0 = arith.constant 0 : index
%c6364136223846793005_i64 = arith.constant 6364136223846793005 : i64
%c1442695040888963407_i64 = arith.constant 1442695040888963407 : i64
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<i64>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<i64>>
%2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:tensor<i64>> -> tensor<i64>
%extracted = tensor.extract %2[] : tensor<i64>
%3 = arith.muli %extracted, %c6364136223846793005_i64 : i64
%4 = arith.addi %3, %c1442695040888963407_i64 : i64
%inserted = tensor.insert %4 into %2[] : tensor<i64>
flow.dispatch.tensor.store %inserted, %1, offsets = [], sizes = [], strides = [] : tensor<i64> -> !flow.dispatch.tensor<writeonly:tensor<i64>>
return
}
}
}
}

// CHECK-LABEL: hal.executable.export public @scalar_dispatch
// CHECK-SAME: translation_info = #iree_codegen.translation_info<LLVMGPUBaseLowering>
// CHECK-SAME: workgroup_size = [1 : index, 1 : index, 1 : index]

// CHECK: func.func @scalar_dispatch()
// CHECK: %[[SPAN0:.+]] = hal.interface.binding.subspan set(0) binding(0)
// CHECK: %[[SPAN1:.+]] = hal.interface.binding.subspan set(0) binding(1)
// CHECK: memref.load %[[SPAN0]][] : memref<i64, #hal.descriptor_type<storage_buffer>>
// CHECK: arith.muli {{.+}} : i64
// CHECK: arith.addi {{.+}} : i64
// CHECK: memref.store %{{.+}}, %[[SPAN1]][] : memref<i64, #hal.descriptor_type<storage_buffer>>

0 comments on commit 7881ed9

Please sign in to comment.