Skip to content

Commit

Permalink
Add a way to hook custom TD strategy for specific ops
Browse files Browse the repository at this point in the history
This patch adds a transform dialect interpreter pass that can be used to
annotate specific operations with specific strategies. This patch relies on
iree-org#14788 to actually "link" the strategy
within the related module.

The intended use case, as demonstrated in the added test cases, is to:
1. specify the matcher in a dedicated file (in the transform dialect format)
   that is passed to the compiler through
   `--iree-llvmcpu-transform-dialect-select-strategy`.
2. provide the strategy as a named sequence through the library option
   `--iree-codegen-transform-library-file-name`.

If the matcher applies in iree-org#1, then the transform dialect pipeline will pick
up the proper strategy for iree-org#2 and apply it to the annotated operations.
  • Loading branch information
qcolombet committed Sep 28, 2023
1 parent b05da0c commit c5a190a
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 0 deletions.
16 changes: 16 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ extern llvm::cl::opt<std::string> clCPUCodegenTransformDialectFileName;
extern llvm::cl::opt<std::string> clCPUCodegenTransformDialectDebugPayloadTag;
extern llvm::cl::opt<std::string> clCPUCodegenTransformDialectDebugTransformTag;

llvm::cl::opt<std::string> clCPUCodegenSelectTransformDialectConfigFileName(
"iree-llvmcpu-transform-dialect-select-strategy",
llvm::cl::desc(
"MLIR file with transform dialect script used to select a strategy for "
"the matched operations. The expectation is for this script to set the "
"proper `iree_codegen.compilation_info` attribute with the related "
"`TransformDialectCodegen codegen_spec` to select a specific strategy. "
"Strategies are expected to live in a library file provided via the "
"`--iree-codegen-transform-library-file-name` option."),
llvm::cl::init(""));

//===---------------------------------------------------------------------===//
// Default allocation functions for CPU backend
//===---------------------------------------------------------------------===//
Expand Down Expand Up @@ -775,6 +786,11 @@ void buildLLVMCPUCodegenPassPipeline(OpPassManager &passManager) {
modulePassManager.addPass(createEraseHALDescriptorTypeFromMemRefPass());
}

// Set config attributes on selected functions.
if (clCPUCodegenSelectTransformDialectConfigFileName != "")
passManager.addPass(
mlir::iree_compiler::createTransformDialectInterpreterPass(
clCPUCodegenSelectTransformDialectConfigFileName));
passManager.addPass(createLLVMCPULowerExecutableTargetPass());
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
addLowerToLLVMPasses(nestedModulePM);
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ iree_lit_test_suite(
"peel.mlir",
"peel_and_vectorize.mlir",
"pipeline_tests.mlir",
"set_transform_strategy_from_file.mlir",
"split_reduction.mlir",
"split_reduction_pipeline_tests.mlir",
"synchronize_symbol_visibility.mlir",
Expand All @@ -66,8 +67,14 @@ iree_lit_test_suite(
"verify_linalg_transform_legality.mlir",
],
include = ["*.mlir"],
exclude = [
"transform_dialect_dummy_spec.mlir",
],
),
cfg = "//compiler:lit.cfg.py",
data = [
"transform_dialect_dummy_spec.mlir",
],
tools = [
"//tools:iree-compile",
"//tools:iree-opt",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ iree_lit_test_suite(
"peel.mlir"
"peel_and_vectorize.mlir"
"pipeline_tests.mlir"
"set_transform_strategy_from_file.mlir"
"split_reduction.mlir"
"split_reduction_pipeline_tests.mlir"
"synchronize_symbol_visibility.mlir"
Expand All @@ -63,6 +64,8 @@ iree_lit_test_suite(
FileCheck
iree-compile
iree-opt
DATA
transform_dialect_dummy_spec.mlir
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// RUN: iree-opt --split-input-file %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target)))" --iree-codegen-llvmcpu-use-transform-dialect=%p/transform_dialect_dummy_spec.mlir | FileCheck %s
// RUN: iree-opt --split-input-file %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target)))" --iree-codegen-transform-library-file-name=%p/transform_dialect_dummy_spec.mlir | FileCheck %s --check-prefix=CONFIG

// If we set the config on the command line, it takes precedence.
// CHECK: IR printer: from_flag

// When we include the library, we should honor the config we set in the
// attribute.
// CONFIG: IR printer: from_config

#blank_config = #iree_codegen.lowering_config<tile_sizes = []>
#translation = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec=@print_config>
#config = #iree_codegen.compilation_info<lowering_config = #blank_config, translation_info = #translation, workgroup_size = []>

#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "broadwell", cpu_features = "+cmov,+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+fma,+bmi,+bmi2,+pclmul,+adx,+cx16,+cx8,+crc32,+f16c,+fsgsbase,+fxsr,+invpcid,+lzcnt,+movbe,+prfchw,+rdrnd,+rdseed,+sahf,+x87,+xsave,+xsaveopt", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 32 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = false}>
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>

#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}>
module attributes {hal.device.targets = [#device_target_llvm_cpu]} {
hal.executable private @matmul_4x2304x768_f32_dispatch_0 {
hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64_ attributes {transform.target_tag = "payload_root"} {
hal.executable.export public @matmul_4x2304x768_f32_dispatch_0_generic_4x2304x768_f32 ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}

builtin.module {
func.func @matmul_4x2304x768_f32_dispatch_0_generic_4x2304x768_f32() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4x768xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<768x2304xf32>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<4x2304xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4, 768], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4x768xf32>> -> tensor<4x768xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [768, 2304], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<768x2304xf32>> -> tensor<768x2304xf32>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [4, 2304], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<4x2304xf32>> -> tensor<4x2304xf32>
%6 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<4x768xf32>, tensor<768x2304xf32>) outs(%5 : tensor<4x2304xf32>)
attrs = {compilation_info = #config} {
^bb0(%in: f32, %in_0: f32, %out: f32):
%7 = arith.mulf %in, %in_0 fastmath<fast> : f32
%8 = arith.addf %out, %7 fastmath<fast> : f32
linalg.yield %8 : f32
} -> tensor<4x2304xf32>
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [4, 2304], strides = [1, 1] : tensor<4x2304xf32> -> !flow.dispatch.tensor<readwrite:tensor<4x2304xf32>>
return
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: iree-opt %s

module attributes { transform.with_named_sequence } {
transform.named_sequence @print_config(%variant_op: !transform.any_op {transform.consumed}) {
transform.print %variant_op {name = "from_config"} : !transform.any_op
transform.yield
}

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
print %arg0 {name = "from_flag"} : !transform.any_op
transform.yield
}
}
4 changes: 4 additions & 0 deletions tests/transform_dialect/cpu/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ iree_lit_test_suite(
"eltwise_reduction_eltwise.mlir",
"fold_tensor_slice_into_transfer.mlir",
"matmul.mlir",
"select_strategy_matvec4.mlir",
"select_strategy_matvec6.mlir"
],
cfg = "//tests:lit.cfg.py",
# transform dialect spec files are MLIR files that specify a transformation,
Expand All @@ -30,6 +32,8 @@ iree_lit_test_suite(
"attention_codegen_spec.mlir",
"matmul_codegen_custom_dispatch_formation_spec.mlir",
"matmul_codegen_default_spec.mlir",
"transform_dialect_dummy_select.mlir",
"transform_dialect_dummy_spec.mlir",
],
tags = [
"noasan",
Expand Down
4 changes: 4 additions & 0 deletions tests/transform_dialect/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ iree_lit_test_suite(
"eltwise_reduction_eltwise.mlir"
"fold_tensor_slice_into_transfer.mlir"
"matmul.mlir"
"select_strategy_matvec4.mlir"
"select_strategy_matvec6.mlir"
TOOLS
${IREE_LLD_TARGET}
FileCheck
Expand All @@ -31,6 +33,8 @@ iree_lit_test_suite(
attention_codegen_spec.mlir
matmul_codegen_custom_dispatch_formation_spec.mlir
matmul_codegen_default_spec.mlir
transform_dialect_dummy_select.mlir
transform_dialect_dummy_spec.mlir
LABELS
"noasan"
"nomsan"
Expand Down
50 changes: 50 additions & 0 deletions tests/transform_dialect/cpu/select_strategy_matvec4.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: not iree-compile \
// RUN: --iree-codegen-transform-library-file-name=%p/transform_dialect_dummy_spec.mlir \
// RUN: --iree-hal-target-backends=llvm-cpu \
// RUN: %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=DEFAULT

// RUN: not iree-compile \
// RUN: --iree-codegen-transform-library-file-name=%p/transform_dialect_dummy_spec.mlir \
// RUN: --iree-llvmcpu-transform-dialect-select-strategy=%p/transform_dialect_dummy_select.mlir \
// RUN: --iree-hal-target-backends=llvm-cpu \
// RUN: %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=MATCHER

// Check that we can select the transform dialect strategy used for the lowering
// by passing a "select file" to the compiler.
// We know the compilation will fail here because we use a dummy strategy, i.e.,
// the code is not actually lowered. Hence the `not` in the command line.

// Default we don't select a strategy and just use what is set in the attribute:
// print_config.
// DEFAULT: IR printer: from_config

// When using the matcher, check that we override what is already set in the
// attribute. I.e., use print_matvec4 instead of print_config.
// MATCHER: IR printer: from_selected4

!tlhs = tensor<4x768xf32>
!trhs = tensor<768x2304xf32>
!tres = tensor<4x2304xf32>

#blank_config = #iree_codegen.lowering_config<tile_sizes = []>
#translation = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec=@print_config>
#config = #iree_codegen.compilation_info<lowering_config = #blank_config, translation_info = #translation, workgroup_size = []>

func.func @matmul_4x2304x768_f32(
%a: !tlhs,
%b: !trhs,
%c: !tres) -> !tres attributes { llvm.emit_c_interface } {
%result = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]
} ins(%a, %b: !tlhs, !trhs) outs(%c: !tres)
attrs = {compilation_info = #config} {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
%0 = arith.mulf %arg0, %arg1 {fastmath = #arith.fastmath<fast>} : f32
%1 = arith.addf %arg2, %0 {fastmath = #arith.fastmath<fast>} : f32
linalg.yield %1 : f32
} -> !tres
return %result : !tres
}
40 changes: 40 additions & 0 deletions tests/transform_dialect/cpu/select_strategy_matvec6.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: not iree-compile \
// RUN: --iree-codegen-transform-library-file-name=%p/transform_dialect_dummy_spec.mlir \
// RUN: --iree-llvmcpu-transform-dialect-select-strategy=%p/transform_dialect_dummy_select.mlir \
// RUN: --iree-hal-target-backends=llvm-cpu \
// RUN: %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=MATCHER

// Check that we can select the transform dialect strategy used for the lowering
// by passing a "select file" to the compiler.
// We know the compilation will fail here because we use a dummy strategy, i.e.,
// the code is not actually lowered. Hence the `not` in the command line.

// When using the matcher, check that we can run the right transform dialect
// strategy, even if the matched instruction didn't have the
// "TransformDialectCodegen codegen_spec" attribute before.
// I.e., make sure the attribute gets added properly by observing that the
// expected transform gets called. In this case we want a print of use
// print_matvec6.
// MATCHER: IR printer: from_selected6

!tlhs = tensor<6x768xf32>
!trhs = tensor<768x2304xf32>
!tres = tensor<6x2304xf32>

func.func @matmul_6x2304x768_f32(
%a: !tlhs,
%b: !trhs,
%c: !tres) -> !tres attributes { llvm.emit_c_interface } {
%result = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]
} ins(%a, %b: !tlhs, !trhs) outs(%c: !tres) {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
%0 = arith.mulf %arg0, %arg1 {fastmath = #arith.fastmath<fast>} : f32
%1 = arith.addf %arg2, %0 {fastmath = #arith.fastmath<fast>} : f32
linalg.yield %1 : f32
} -> !tres
return %result : !tres
}
60 changes: 60 additions & 0 deletions tests/transform_dialect/cpu/transform_dialect_dummy_select.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#transform = #iree_codegen.translation_info<TransformDialectCodegen>
#blank_config = #iree_codegen.lowering_config<tile_sizes = []>
#translation4 = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec=@print_selected4>
#matvec4_config = #iree_codegen.compilation_info<lowering_config = #blank_config, translation_info = #translation4, workgroup_size = []>
#translation6 = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec=@print_selected6>
#matvec6_config = #iree_codegen.compilation_info<lowering_config = #blank_config, translation_info = #translation6, workgroup_size = []>

module attributes { transform.with_named_sequence } {


//===------------------------------------------------------===
// Matvec
//===------------------------------------------------------===
transform.named_sequence @match_matvec4(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
%0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
^bb1(%arg1: !transform.any_op):
%c4 = transform.param.constant 4 : i64 -> !transform.param<i64>

%dim = transform.match.structured.dim %arg1[0] : (!transform.any_op) -> !transform.param<i64>
transform.match.param.cmpi eq %c4, %dim : !transform.param<i64>
transform.match.structured.yield %arg1 : !transform.any_op
}

%config = transform.param.constant #matvec4_config -> !transform.any_param
transform.yield %0, %config : !transform.any_op, !transform.any_param
}

transform.named_sequence @match_matvec6(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
%0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
^bb1(%arg1: !transform.any_op):
%c6 = transform.param.constant 6 : i64 -> !transform.param<i64>

%dim = transform.match.structured.dim %arg1[0] : (!transform.any_op) -> !transform.param<i64>
transform.match.param.cmpi eq %c6, %dim : !transform.param<i64>
transform.match.structured.yield %arg1 : !transform.any_op
}

%config = transform.param.constant #matvec6_config -> !transform.any_param
transform.yield %0, %config : !transform.any_op, !transform.any_param
}

//===------------------------------------------------------===
// Annotation and Application
//===------------------------------------------------------===

transform.named_sequence @annotate_op(%target: !transform.any_op {transform.readonly}, %config: !transform.any_param {transform.readonly}) {
transform.annotate %target "compilation_info" = %config : !transform.any_op, !transform.any_param
transform.yield
}


transform.sequence failures(propagate) {
^bb0(%dispatch: !transform.any_op):
%dispatch_func = transform.structured.match ops{["func.func"]} in %dispatch : (!transform.any_op) -> !transform.any_op
transform.foreach_match in %dispatch_func
@match_matvec4 -> @annotate_op,
@match_matvec6 -> @annotate_op
: (!transform.any_op) -> (!transform.any_op)
}
}
24 changes: 24 additions & 0 deletions tests/transform_dialect/cpu/transform_dialect_dummy_spec.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: iree-opt %s

module attributes { transform.with_named_sequence } {
transform.named_sequence @print_config(%variant_op: !transform.any_op {transform.consumed}) {
transform.print %variant_op {name = "from_config"} : !transform.any_op
transform.yield
}

transform.named_sequence @print_selected4(%variant_op: !transform.any_op {transform.consumed}) {
transform.print %variant_op {name = "from_selected4"} : !transform.any_op
transform.yield
}

transform.named_sequence @print_selected6(%variant_op: !transform.any_op {transform.consumed}) {
transform.print %variant_op {name = "from_selected6"} : !transform.any_op
transform.yield
}

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
print %arg0 {name = "from_flag"} : !transform.any_op
transform.yield
}
}

0 comments on commit c5a190a

Please sign in to comment.