Skip to content

Commit

Permalink
select-ukernels
Browse files Browse the repository at this point in the history
Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob committed Dec 13, 2024
1 parent eae7bfb commit a6e5117
Show file tree
Hide file tree
Showing 19 changed files with 406 additions and 314 deletions.
3 changes: 2 additions & 1 deletion compiler/plugins/target/ROCM/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ package(
iree_lit_test_suite(
name = "lit",
srcs = [
"config_ukernel_argmax_gfx908.mlir",
"config_ukernel_argmax_gfx942.mlir",
"default_tuning_specs_amdgpu.mlir",
"gpu_lower_to_ukernels.mlir",
"lowering_strategy_from_tuning_spec.mlir",
"ukernel_pipeline_transform.mlir",
],
Expand Down
3 changes: 2 additions & 1 deletion compiler/plugins/target/ROCM/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ iree_lit_test_suite(
NAME
lit
SRCS
"config_ukernel_argmax_gfx908.mlir"
"config_ukernel_argmax_gfx942.mlir"
"default_tuning_specs_amdgpu.mlir"
"gpu_lower_to_ukernels.mlir"
"lowering_strategy_from_tuning_spec.mlir"
"ukernel_pipeline_transform.mlir"
TOOLS
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' %s | FileCheck %s

// gfx908 a.k.a. CDNA1 is used here as an example of a GPU target that we don't have ukernels for.
// No need to add many ukernels here, just a quick check that we correctly do not select a ukernel.

func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
} {
%c0_i64 = arith.constant 0 : i64
%cst = arith.constant 0xFF800000 : f32
%0 = tensor.empty() : tensor<1xi64>
%1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
%2 = tensor.empty() : tensor<1xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
%4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
^bb0(%in: f32, %out: f32, %out_0: i64):
%5 = linalg.index 1 : index
%6 = arith.index_cast %5 : index to i64
%7 = arith.maximumf %in, %out : f32
%8 = arith.cmpf ogt, %in, %out : f32
%9 = arith.select %8, %6, %out_0 : i64
linalg.yield %7, %9 : f32, i64
} -> (tensor<1xf32>, tensor<1xi64>)
return %4#1 : tensor<1xi64>
}

// CHECK-NOT: lowering_config<{{.*}}ukernel
// CHECK-LABEL: func @argmax_2d_f32i64(
// CHECK: linalg.generic
// CHECK-NOT: hal.executable.objects
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-lower-to-ukernels,cse,canonicalize))" %s | FileCheck %s
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-lower-to-ukernels,cse,canonicalize))" %s | FileCheck %s --check-prefix=CDNA1
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' %s | FileCheck %s

func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
Expand All @@ -22,15 +21,10 @@ func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes
return %4#1 : tensor<1xi64>
}

//CHECK-LABEL: func @argmax_2d_f32i64(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?xf32>
// CHECK-DAG: %[[C1_index:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C0_i64:.+]] = arith.constant 0
// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[C0_i64]]
// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic {hal.executable.objects = [{{.*}}]} "iree_uk_amdgpu_argmax_f32i64"
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: return %[[MICRO_KERNEL]]
// CHECK: #iree_codegen.lowering_config<{{.*}}ukernel = [<name = "iree_uk_amdgpu_argmax_f32i64", def_attrs = {vm.import.module = "rocm"}>]
// CHECK-LABEL: func @argmax_2d_f32i64(
// CHECK: linalg.generic
// CHECK-SAME: hal.executable.objects = [#hal.executable.object<{path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", data = dense_resource<iree_uk_amdgpu_argmax_f32i64.gfx942.bc> : vector<{{[0-9]+}}xi8>}>]

// -----

Expand All @@ -55,65 +49,10 @@ func.func @argmax_4d_unit_parallel_f32i64(%arg0 : tensor<1x1x1x?xf32>) -> tensor
return %4#1 : tensor<1x1x1xi64>
}

// CHECK-LABEL: func @argmax_4d_unit_parallel_f32i64(
// CHECK: iree_codegen.ukernel.generic
// CHECK-NOT: linalg.generic

// -----

func.func @argmax_2d_non_unit_parallel_f32i64(%arg0 : tensor<4x?xf32>) -> tensor<4xi64> attributes {
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
} {
%c0_i64 = arith.constant 0 : i64
%cst = arith.constant 0xFF800000 : f32
%0 = tensor.empty() : tensor<4xi64>
%1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<4xi64>) -> tensor<4xi64>
%2 = tensor.empty() : tensor<4xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<4xf32>) -> tensor<4xf32>
%4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<4x?xf32>) outs(%3, %1 : tensor<4xf32>, tensor<4xi64>) {
^bb0(%in: f32, %out: f32, %out_0: i64):
%5 = linalg.index 1 : index
%6 = arith.index_cast %5 : index to i64
%7 = arith.maximumf %in, %out : f32
%8 = arith.cmpf ogt, %in, %out : f32
%9 = arith.select %8, %6, %out_0 : i64
linalg.yield %7, %9 : f32, i64
} -> (tensor<4xf32>, tensor<4xi64>)
return %4#1 : tensor<4xi64>
}

// CHECK-LABEL: func @argmax_2d_non_unit_parallel_f32i64(
// CHECK-NOT: iree_codegen.ukernel.generic
// CHECK: linalg.generic

// -----

func.func @argmax_2d_dyn_parallel_f32i64(%arg0 : tensor<?x?xf32>) -> tensor<?xi64> attributes {
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
} {
%c0 = arith.constant 0 : index
%c0_i64 = arith.constant 0 : i64
%cst = arith.constant 0xFF800000 : f32
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%0 = tensor.empty(%dim) : tensor<?xi64>
%1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<?xi64>) -> tensor<?xi64>
%2 = tensor.empty(%dim) : tensor<?xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?xf32>) -> tensor<?xf32>
%4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%3, %1 : tensor<?xf32>, tensor<?xi64>) {
^bb0(%in: f32, %out: f32, %out_0: i64):
%5 = linalg.index 1 : index
%6 = arith.index_cast %5 : index to i64
%7 = arith.maximumf %in, %out : f32
%8 = arith.cmpf ogt, %in, %out : f32
%9 = arith.select %8, %6, %out_0 : i64
linalg.yield %7, %9 : f32, i64
} -> (tensor<?xf32>, tensor<?xi64>)
return %4#1 : tensor<?xi64>
}

// CHECK-LABEL: func @argmax_2d_dyn_parallel_f32i64(
// CHECK-NOT: iree_codegen.ukernel.generic
// CHECK: linalg.generic
// CHECK: #iree_codegen.lowering_config<{{.*}}ukernel = [<name = "iree_uk_amdgpu_argmax_f32i64", def_attrs = {vm.import.module = "rocm"}>]
// CHECK-LABEL: func @argmax_4d_unit_parallel_f32i64(
// CHECK: linalg.generic
// CHECK-SAME: hal.executable.objects = [#hal.executable.object<{path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", data = dense_resource<iree_uk_amdgpu_argmax_f32i64.gfx942.bc> : vector<{{[0-9]+}}xi8>}>]

// -----

Expand All @@ -138,9 +77,10 @@ func.func @argmax_none_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor<1xi64>
return %4#1 : tensor<1xi64>
}

// CHECK-LABEL: func @argmax_none_ukernel_enabled(
// CHECK-NOT: iree_codegen.ukernel.generic
// CHECK: linalg.generic
// CHECK-NOT: lowering_config<{{.*}}ukernel
// CHECK-LABEL: func @argmax_none_ukernel_enabled(
// CHECK: linalg.generic
// CHECK-NOT: hal.executable.objects

// -----

Expand All @@ -165,9 +105,10 @@ func.func @argmax_only_argmax_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor
return %4#1 : tensor<1xi64>
}

// CDNA2-LABEL: func @argmax_only_argmax_ukernel_enabled(
// CDNA2: iree_codegen.ukernel.generic
// CDNA2-NOT: linalg.generic
// CHECK: #iree_codegen.lowering_config<{{.*}}ukernel = [<name = "iree_uk_amdgpu_argmax_f32i64", def_attrs = {vm.import.module = "rocm"}>]
// CHECK-LABEL: func @argmax_only_argmax_ukernel_enabled(
// CHECK: linalg.generic
// CHECK-SAME: hal.executable.objects = [#hal.executable.object<{path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", data = dense_resource<iree_uk_amdgpu_argmax_f32i64.gfx942.bc> : vector<{{[0-9]+}}xi8>}>]

// -----

Expand All @@ -192,11 +133,10 @@ func.func @argmax_only_foo_argmax_bar_ukernel_enabled(%arg0 : tensor<1x?xf32>) -
return %4#1 : tensor<1xi64>
}

// CHECK-LABEL: func @argmax_only_foo_argmax_bar_ukernel_enabled(
// CHECK: iree_codegen.ukernel.generic
// CHECK-NOT: linalg.generic

// CDNA2-LABEL: func @argmax_only_foo_argmax_bar_ukernel_enabled(
// CHECK: #iree_codegen.lowering_config<{{.*}}ukernel = [<name = "iree_uk_amdgpu_argmax_f32i64", def_attrs = {vm.import.module = "rocm"}>]
// CHECK-LABEL: func @argmax_only_foo_argmax_bar_ukernel_enabled(
// CHECK: linalg.generic
// CHECK-SAME: hal.executable.objects = [#hal.executable.object<{path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", data = dense_resource<iree_uk_amdgpu_argmax_f32i64.gfx942.bc> : vector<{{[0-9]+}}xi8>}>]

// -----

Expand All @@ -221,9 +161,10 @@ func.func @argmax_only_foo_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor<1x
return %4#1 : tensor<1xi64>
}

// CHECK-LABEL: func @argmax_only_foo_ukernel_enabled(
// CHECK-NOT: iree_codegen.ukernel.generic
// CHECK: linalg.generic
// CHECK-NOT: lowering_config<{{.*}}ukernel
// CHECK-LABEL: func @argmax_only_foo_ukernel_enabled(
// CHECK: linalg.generic
// CHECK-NOT: hal.executable.objects

// -----

Expand All @@ -249,46 +190,16 @@ func.func @argmax_2d_f32i64_not_neg_inf_init(%arg0 : tensor<1x?xf32>) -> tensor<
return %4#1 : tensor<1xi64>
}

// CHECK-LABEL: func @argmax_2d_f32i64_not_neg_inf_init(
// CHECK-NOT: iree_codegen.ukernel.generic
// CHECK: linalg.generic

// -----

// TODO: No technical reason this architecture is not supported.
// Currently just picking out popular chips to support,
// to minimize compile time and space.

func.func @argmax_ukernel_unsupported_arch(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
} {
%c0_i64 = arith.constant 0 : i64
%cst = arith.constant 0xFF800000 : f32
%0 = tensor.empty() : tensor<1xi64>
%1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64>
%2 = tensor.empty() : tensor<1xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32>
%4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) {
^bb0(%in: f32, %out: f32, %out_0: i64):
%5 = linalg.index 1 : index
%6 = arith.index_cast %5 : index to i64
%7 = arith.maximumf %in, %out : f32
%8 = arith.cmpf ogt, %in, %out : f32
%9 = arith.select %8, %6, %out_0 : i64
linalg.yield %7, %9 : f32, i64
} -> (tensor<1xf32>, tensor<1xi64>)
return %4#1 : tensor<1xi64>
}

// CDNA1-LABEL: func @argmax_ukernel_unsupported_arch(
// CDNA1-NOT: iree_codegen.ukernel.generic
// CDNA1: linalg.generic
// CHECK-NOT: lowering_config<{{.*}}ukernel
// CHECK-LABEL: func @argmax_2d_f32i64_not_neg_inf_init(
// CHECK: linalg.generic
// CHECK-NOT: hal.executable.objects

// -----

// Test user-provided bitcode in the source IR.

func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
func.func @argmax_2d_f32i64_custom_bitcode(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>,
// Dummy bitcode with an unusual length of 12. The first 4 bytes are the .bc file format signature.
hal.executable.objects = [
Expand Down Expand Up @@ -316,18 +227,12 @@ func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes
return %4#1 : tensor<1xi64>
}

//CHECK-LABEL: func @argmax_2d_f32i64(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?xf32>
// CHECK-DAG: %[[C1_index:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C0_i64:.+]] = arith.constant 0
// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[C0_i64]]
// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic {
// CHECK-SAME: hal.executable.objects = [
// CHECK-SAME: #hal.executable.object<{
// CHECK-SAME: path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc",
// CHECK-SAME: data = dense<[66, 67, -64, -34, 1, 35, 69, 103, -119, -85, -51, -17]> : tensor<12xi8>
// CHECK-SAME: }>
// CHECK-SAME: ]} "iree_uk_amdgpu_argmax_f32i64"
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: return %[[MICRO_KERNEL]]
// CHECK: #iree_codegen.lowering_config<{{.*}}ukernel = [<name = "iree_uk_amdgpu_argmax_f32i64", def_attrs = {vm.import.module = "rocm"}>]
// CHECK-LABEL: func @argmax_2d_f32i64_custom_bitcode(
// CHECK: linalg.generic
// CHECK-SAME: hal.executable.objects = [
// CHECK-SAME: #hal.executable.object<{
// CHECK-SAME: path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc",
// CHECK-SAME: data = dense<[66, 67, -64, -34, 1, 35, 69, 103, -119, -85, -51, -17]> : tensor<12xi8>
// CHECK-SAME: }>
// CHECK-SAME: ]
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func.func @argmax_1d_f16i64() attributes {
// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUDefault workgroup_size = [32, 1, 1]>
// CHECK: func.func @argmax_1d_f16i64()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: iree_codegen.ukernel.generic {hal.executable.objects = [{{.*}}]} "iree_uk_amdgpu_argmax_f16i64"
// CHECK: iree_codegen.ukernel.generic "iree_uk_amdgpu_argmax_f16i64"

// -----

Expand Down Expand Up @@ -94,7 +94,7 @@ func.func @argmax_2d_f32i64() attributes {
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: %[[SUBVIEW:.*]] = memref.subview{{.*}} memref<16x?xf32
// CHECK-SAME: to memref<1x?xf32
// CHECK: iree_codegen.ukernel.generic {hal.executable.objects = [{{.*}}]} "iree_uk_amdgpu_argmax_f32i64" ins(%[[SUBVIEW]]
// CHECK: iree_codegen.ukernel.generic "iree_uk_amdgpu_argmax_f32i64" ins(%[[SUBVIEW]]

// -----

Expand Down
Loading

0 comments on commit a6e5117

Please sign in to comment.