Skip to content

Commit

Permalink
Simplify tests/e2e/tosa_ops. (#17850)
Browse files Browse the repository at this point in the history
Follow-up to #17843, forked from
#17766

Fixes #11828
  • Loading branch information
ScottTodd authored Jul 30, 2024
1 parent 890bdc9 commit 6145b65
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 245 deletions.
2 changes: 1 addition & 1 deletion tests/e2e/stablehlo_ops/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ iree_check_single_backend_test_suite(
)

iree_check_single_backend_test_suite(
name = "check_meta-spirv_metal",
name = "check_metal-spirv_metal",
srcs = enforce_glob(
# keep sorted
[
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/stablehlo_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ iree_check_single_backend_test_suite(

iree_check_single_backend_test_suite(
NAME
check_meta-spirv_metal
check_metal-spirv_metal
SRCS
"abs.mlir"
"add.mlir"
Expand Down
193 changes: 25 additions & 168 deletions tests/e2e/tosa_ops/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Tests of end-to-end IREE support for individual ops in the TOSA dialect.
# Each test file should have a name matching the corresponding TOSA op and test only the
# functionality of that op (though may make use of other ops where necessary). Tests should be
# written using the IREE Check framework.
# See https://iree.dev/developers/general/testing-guide/#iree-core-end-to-end-e2e-tests.

load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite")

Expand All @@ -18,7 +12,7 @@ package(
licenses = ["notice"], # Apache 2.0
)

LLVM_SRCS = enforce_glob(
ALL_SRCS = enforce_glob(
[
"abs.mlir",
"add.mlir",
Expand Down Expand Up @@ -67,119 +61,23 @@ LLVM_SRCS = enforce_glob(

iree_check_single_backend_test_suite(
name = "check_llvm-cpu_local-task",
srcs = LLVM_SRCS,
srcs = ALL_SRCS,
driver = "local-task",
input_type = "tosa",
target_backend = "llvm-cpu",
)

VMVX_SRCS = enforce_glob(
[
"abs.mlir",
"add.mlir",
"arithmetic_right_shift.mlir",
"bitwise_and.mlir",
"bitwise_or.mlir",
"bitwise_xor.mlir",
"ceil.mlir",
"clamp.mlir",
"clz.mlir",
"const.mlir",
"equal.mlir",
"exp.mlir",
"floor.mlir",
"fully_connected.mlir",
"gather.mlir",
"greater.mlir",
"greater_equal.mlir",
"if.mlir",
"log.mlir",
"logical_left_shift.mlir",
"logical_right_shift.mlir",
"logical_right_shift_16.mlir",
"matmul.mlir",
"max_pool.mlir",
"maximum.mlir",
"minimum.mlir",
"mul.mlir",
"mul_shift.mlir",
"negate.mlir",
"pad.mlir",
"reciprocal.mlir",
"reshape.mlir",
"rsqrt.mlir",
"select.mlir",
"sigmoid.mlir",
"sub.mlir",
"table.mlir",
"tanh.mlir",
"transpose.mlir",
"while.mlir",
],
include = ["*.mlir"],
exclude = [
"reduce.mlir", # Currently flakey https://github.com/iree-org/iree/issues/5885
],
)

iree_check_single_backend_test_suite(
name = "check_vmvx_local-task",
srcs = VMVX_SRCS,
srcs = ALL_SRCS,
driver = "local-task",
input_type = "tosa",
target_backend = "vmvx",
)

VMVX_MICROKERNELS_SRCS = enforce_glob(
[
"abs.mlir",
"add.mlir",
"arithmetic_right_shift.mlir",
"bitwise_and.mlir",
"bitwise_or.mlir",
"bitwise_xor.mlir",
"ceil.mlir",
"clamp.mlir",
"clz.mlir",
"const.mlir",
"equal.mlir",
"exp.mlir",
"floor.mlir",
"fully_connected.mlir",
"gather.mlir",
"greater.mlir",
"greater_equal.mlir",
"if.mlir",
"log.mlir",
"logical_left_shift.mlir",
"logical_right_shift.mlir",
"logical_right_shift_16.mlir",
"matmul.mlir",
"max_pool.mlir",
"maximum.mlir",
"minimum.mlir",
"mul.mlir",
"mul_shift.mlir",
"negate.mlir",
"pad.mlir",
"reciprocal.mlir",
"reduce.mlir",
"reshape.mlir",
"rsqrt.mlir",
"select.mlir",
"sigmoid.mlir",
"sub.mlir",
"table.mlir",
"tanh.mlir",
"transpose.mlir",
"while.mlir",
],
include = ["*.mlir"],
)

iree_check_single_backend_test_suite(
name = "check_vmvx_local-sync_microkernels",
srcs = VMVX_MICROKERNELS_SRCS,
srcs = ALL_SRCS,
compiler_flags = [
"--iree-vmvx-enable-microkernels",
],
Expand All @@ -189,64 +87,23 @@ iree_check_single_backend_test_suite(
target_backend = "vmvx",
)

VULKAN_SRCS = enforce_glob(
[
"abs.mlir",
"add.mlir",
"arithmetic_right_shift.mlir",
"bitwise_and.mlir",
"bitwise_or.mlir",
"bitwise_xor.mlir",
"ceil.mlir",
"clamp.mlir",
"clz.mlir",
"const.mlir",
"equal.mlir",
"exp.mlir",
"floor.mlir",
"fully_connected.mlir",
"gather.mlir",
"greater.mlir",
"greater_equal.mlir",
"if.mlir",
"log.mlir",
"logical_left_shift.mlir",
"logical_right_shift.mlir",
"matmul.mlir",
"max_pool.mlir",
"maximum.mlir",
"minimum.mlir",
"mul_shift.mlir",
"mul.mlir",
"negate.mlir",
"pad.mlir",
"reciprocal.mlir",
"reduce.mlir",
"reshape.mlir",
"rsqrt.mlir",
"select.mlir",
"sigmoid.mlir",
"sub.mlir",
"table.mlir",
"tanh.mlir",
"transpose.mlir",
"while.mlir",
],
include = ["*.mlir"],
exclude = [
"logical_right_shift_16.mlir", # TODO(#11828)
],
)

iree_check_single_backend_test_suite(
name = "check_vulkan-spirv_vulkan",
srcs = VULKAN_SRCS,
srcs = ALL_SRCS,
driver = "vulkan",
input_type = "tosa",
target_backend = "vulkan-spirv",
)

CUDA_SRCS = enforce_glob(
iree_check_single_backend_test_suite(
name = "check_metal-spirv_metal",
srcs = ALL_SRCS,
driver = "metal",
input_type = "tosa",
target_backend = "metal-spirv",
)

ROCM_AND_CUDA_SRCS = enforce_glob(
[
"abs.mlir",
"add.mlir",
Expand Down Expand Up @@ -291,13 +148,13 @@ CUDA_SRCS = enforce_glob(
],
include = ["*.mlir"],
exclude = [
"mul_shift.mlir",
"mul_shift.mlir", # error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: tosa.apply_scale
],
)

iree_check_single_backend_test_suite(
name = "check_cuda_graph",
srcs = CUDA_SRCS,
srcs = ROCM_AND_CUDA_SRCS,
compiler_flags = [
# TODO(#13984): memset emulation required for graphs.
"--iree-stream-emulate-memset",
Expand All @@ -318,7 +175,7 @@ iree_check_single_backend_test_suite(

iree_check_single_backend_test_suite(
name = "check_cuda_stream",
srcs = CUDA_SRCS,
srcs = ROCM_AND_CUDA_SRCS,
driver = "cuda",
input_type = "tosa",
runner_args = ["--cuda_use_streams=true"],
Expand All @@ -333,13 +190,13 @@ iree_check_single_backend_test_suite(
target_backend = "cuda",
)

test_suite(
name = "check",
tests = [
":check_cuda_graph",
":check_cuda_stream",
":check_llvm-cpu_local-task",
":check_vmvx_local-task",
":check_vulkan-spirv_vulkan",
iree_check_single_backend_test_suite(
name = "check_rocm_hip_stream",
srcs = ROCM_AND_CUDA_SRCS,
driver = "hip",
input_type = "tosa",
runner_args = [
"--hip_use_streams=true",
],
target_backend = "rocm",
)
Loading

0 comments on commit 6145b65

Please sign in to comment.