From dc0bdbffdcc2cb428969f7cb2a59fc1e956fb8a4 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Fri, 21 Jun 2024 12:31:16 -0400 Subject: [PATCH] [mlir][spirv] Add a generic `convert-to-spirv` pass (#95942) This PR implements a MVP version of an MLIR lowering pipeline to SPIR-V. The goal of adding this pipeline is to have a better test coverage of SPIR-V compilation upstream, and enable writing simple kernels by hand. The dialects supported in this version include `arith`, `vector` (only 1-D vectors with size 2,3,4,8 or 16), `scf`, `ub`, `index`, `func` and `math`. New test cases for the pass are also included in this PR. **Relevant links** - [Open MLIR Meeting - YouTube Video](https://www.youtube.com/watch?v=csWPOQfgLMo) - [Discussion on LLVM Forum](https://discourse.llvm.org/t/open-mlir-meeting-12-14-2023-discussion-on-improving-handling-of-unit-dimensions-in-the-vector-dialect/75683) **Future plans** - Add conversion patterns for other dialects, e.g. `gpu`, `tensor`, etc. - Include vector transformation to unroll vectors to 1-D, and handle those with unsupported sizes. - Implement multiple-return. SPIR-V does not support multiple return values since a `spirv.func` can only return zero or one values. It might be possible to wrap the return values in a `spirv.struct`. - Add a conversion for `scf.parallel`. --- .../ConvertToSPIRV/ConvertToSPIRVPass.h | 22 + mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 12 + mlir/lib/Conversion/CMakeLists.txt | 1 + .../Conversion/ConvertToSPIRV/CMakeLists.txt | 32 ++ .../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 71 +++ .../test/Conversion/ConvertToSPIRV/arith.mlir | 218 +++++++++ .../Conversion/ConvertToSPIRV/combined.mlir | 47 ++ .../test/Conversion/ConvertToSPIRV/index.mlir | 63 +++ mlir/test/Conversion/ConvertToSPIRV/scf.mlir | 47 ++ .../Conversion/ConvertToSPIRV/simple.mlir | 15 + mlir/test/Conversion/ConvertToSPIRV/ub.mlir | 9 + .../Conversion/ConvertToSPIRV/vector.mlir | 439 ++++++++++++++++++ 13 files changed, 977 insertions(+) create mode 100644 mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h create mode 100644 mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt create mode 100644 mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp create mode 100644 mlir/test/Conversion/ConvertToSPIRV/arith.mlir create mode 100644 mlir/test/Conversion/ConvertToSPIRV/combined.mlir create mode 100644 mlir/test/Conversion/ConvertToSPIRV/index.mlir create mode 100644 mlir/test/Conversion/ConvertToSPIRV/scf.mlir create mode 100644 mlir/test/Conversion/ConvertToSPIRV/simple.mlir create mode 100644 mlir/test/Conversion/ConvertToSPIRV/ub.mlir create mode 100644 mlir/test/Conversion/ConvertToSPIRV/vector.mlir diff --git a/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h new file mode 100644 index 00000000000000..38527822475272 --- /dev/null +++ b/mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h @@ -0,0 +1,22 @@ +//===- ConvertToSPIRVPass.h - Conversion to SPIR-V pass ---*- C++ -*-=========// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H +#define MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H + +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTTOSPIRVPASS +#include "mlir/Conversion/Passes.h.inc" + +} // namespace mlir + +#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 7700299b3a4f32..8c6f85d461aea3 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -30,6 +30,7 @@ #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" +#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h" #include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 2315686839c20b..560b088dbe5cd2 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -31,6 +31,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> { ]; } +//===----------------------------------------------------------------------===// +// ToSPIRV +//===----------------------------------------------------------------------===// + +def ConvertToSPIRVPass : Pass<"convert-to-spirv"> { + let summary = "Convert to SPIR-V"; + let description = [{ + This is a generic pass to convert to SPIR-V. + }]; + let dependentDialects = ["spirv::SPIRVDialect"]; +} + //===----------------------------------------------------------------------===// // AffineToStandard //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 0a03a2e133db18..e107738a4c50c0 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -19,6 +19,7 @@ add_subdirectory(ControlFlowToLLVM) add_subdirectory(ControlFlowToSCF) add_subdirectory(ControlFlowToSPIRV) add_subdirectory(ConvertToLLVM) +add_subdirectory(ConvertToSPIRV) add_subdirectory(FuncToEmitC) add_subdirectory(FuncToLLVM) add_subdirectory(FuncToSPIRV) diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt new file mode 100644 index 00000000000000..f7b090acf33af3 --- /dev/null +++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt @@ -0,0 +1,32 @@ +set(LLVM_OPTIONAL_SOURCES + ConvertToSPIRVPass.cpp +) + +add_mlir_conversion_library(MLIRConvertToSPIRVPass + ConvertToSPIRVPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ConvertToSPIRV + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithToSPIRV + MLIRArithTransforms + MLIRFuncToSPIRV + MLIRIndexToSPIRV + MLIRIR + MLIRPass + MLIRRewrite + MLIRSCFToSPIRV + MLIRSPIRVConversion + MLIRSPIRVDialect + MLIRSPIRVTransforms + MLIRSupport + MLIRTransforms + MLIRTransformUtils + MLIRUBToSPIRV + MLIRVectorToSPIRV + MLIRVectorTransforms + ) diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp new file mode 100644 index 00000000000000..b5be4654bcb255 --- /dev/null +++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp @@ -0,0 +1,71 @@ +//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h" +#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" +#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" +#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h" +#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" +#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" +#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + +#define DEBUG_TYPE "convert-to-spirv" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTTOSPIRVPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +/// A pass to perform the SPIR-V conversion. +struct ConvertToSPIRVPass final + : impl::ConvertToSPIRVPassBase { + + void runOnOperation() override { + MLIRContext *context = &getContext(); + Operation *op = getOperation(); + + spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); + SPIRVTypeConverter typeConverter(targetAttr); + + RewritePatternSet patterns(context); + ScfToSPIRVContext scfToSPIRVContext; + + // Populate patterns. + arith::populateCeilFloorDivExpandOpsPatterns(patterns); + arith::populateArithToSPIRVPatterns(typeConverter, patterns); + populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); + populateFuncToSPIRVPatterns(typeConverter, patterns); + index::populateIndexToSPIRVPatterns(typeConverter, patterns); + populateVectorToSPIRVPatterns(typeConverter, patterns); + populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns); + ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns); + + std::unique_ptr target = + SPIRVConversionTarget::get(targetAttr); + + if (failed(applyPartialConversion(op, *target, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir new file mode 100644 index 00000000000000..a2adc0ad9c7a5a --- /dev/null +++ b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir @@ -0,0 +1,218 @@ +// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// arithmetic ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @int32_scalar +func.func @int32_scalar(%lhs: i32, %rhs: i32) { + // CHECK: spirv.IAdd %{{.*}}, %{{.*}}: i32 + %0 = arith.addi %lhs, %rhs: i32 + // CHECK: spirv.ISub %{{.*}}, %{{.*}}: i32 + %1 = arith.subi %lhs, %rhs: i32 + // CHECK: spirv.IMul %{{.*}}, %{{.*}}: i32 + %2 = arith.muli %lhs, %rhs: i32 + // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: i32 + %3 = arith.divsi %lhs, %rhs: i32 + // CHECK: spirv.UDiv %{{.*}}, %{{.*}}: i32 + %4 = arith.divui %lhs, %rhs: i32 + // CHECK: spirv.UMod %{{.*}}, %{{.*}}: i32 + %5 = arith.remui %lhs, %rhs: i32 + return +} + +// CHECK-LABEL: @int32_scalar_srem +// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) +func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) { + // CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32 + // CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32 + // CHECK: %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32 + // CHECK: %[[POS:.+]] = spirv.IEqual %[[LHS]], %[[LABS]] : i32 + // CHECK: %[[NEG:.+]] = spirv.SNegate %[[ABS]] : i32 + // CHECK: %{{.+}} = spirv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32 + %0 = arith.remsi %lhs, %rhs: i32 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// arith bit ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @bitwise_scalar +func.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) { + // CHECK: spirv.BitwiseAnd + %0 = arith.andi %arg0, %arg1 : i32 + // CHECK: spirv.BitwiseOr + %1 = arith.ori %arg0, %arg1 : i32 + // CHECK: spirv.BitwiseXor + %2 = arith.xori %arg0, %arg1 : i32 + return +} + +// CHECK-LABEL: @bitwise_vector +func.func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + // CHECK: spirv.BitwiseAnd + %0 = arith.andi %arg0, %arg1 : vector<4xi32> + // CHECK: spirv.BitwiseOr + %1 = arith.ori %arg0, %arg1 : vector<4xi32> + // CHECK: spirv.BitwiseXor + %2 = arith.xori %arg0, %arg1 : vector<4xi32> + return +} + +// CHECK-LABEL: @logical_scalar +func.func @logical_scalar(%arg0 : i1, %arg1 : i1) { + // CHECK: spirv.LogicalAnd + %0 = arith.andi %arg0, %arg1 : i1 + // CHECK: spirv.LogicalOr + %1 = arith.ori %arg0, %arg1 : i1 + // CHECK: spirv.LogicalNotEqual + %2 = arith.xori %arg0, %arg1 : i1 + return +} + +// CHECK-LABEL: @logical_vector +func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { + // CHECK: spirv.LogicalAnd + %0 = arith.andi %arg0, %arg1 : vector<4xi1> + // CHECK: spirv.LogicalOr + %1 = arith.ori %arg0, %arg1 : vector<4xi1> + // CHECK: spirv.LogicalNotEqual + %2 = arith.xori %arg0, %arg1 : vector<4xi1> + return +} + +// CHECK-LABEL: @shift_scalar +func.func @shift_scalar(%arg0 : i32, %arg1 : i32) { + // CHECK: spirv.ShiftLeftLogical + %0 = arith.shli %arg0, %arg1 : i32 + // CHECK: spirv.ShiftRightArithmetic + %1 = arith.shrsi %arg0, %arg1 : i32 + // CHECK: spirv.ShiftRightLogical + %2 = arith.shrui %arg0, %arg1 : i32 + return +} + +// CHECK-LABEL: @shift_vector +func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + // CHECK: spirv.ShiftLeftLogical + %0 = arith.shli %arg0, %arg1 : vector<4xi32> + // CHECK: spirv.ShiftRightArithmetic + %1 = arith.shrsi %arg0, %arg1 : vector<4xi32> + // CHECK: spirv.ShiftRightLogical + %2 = arith.shrui %arg0, %arg1 : vector<4xi32> + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// arith.cmpf +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @cmpf +func.func @cmpf(%arg0 : f32, %arg1 : f32) { + // CHECK: spirv.FOrdEqual + %1 = arith.cmpf oeq, %arg0, %arg1 : f32 + return +} + +// CHECK-LABEL: @vec1cmpf +func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) { + // CHECK: spirv.FOrdGreaterThan + %0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32> + // CHECK: spirv.FUnordLessThan + %1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32> + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// arith.cmpi +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @cmpi +func.func @cmpi(%arg0 : i32, %arg1 : i32) { + // CHECK: spirv.IEqual + %0 = arith.cmpi eq, %arg0, %arg1 : i32 + return +} + +// CHECK-LABEL: @indexcmpi +func.func @indexcmpi(%arg0 : index, %arg1 : index) { + // CHECK: spirv.IEqual + %0 = arith.cmpi eq, %arg0, %arg1 : index + return +} + +// CHECK-LABEL: @vec1cmpi +func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) { + // CHECK: spirv.ULessThan + %0 = arith.cmpi ult, %arg0, %arg1 : vector<1xi32> + // CHECK: spirv.SGreaterThan + %1 = arith.cmpi sgt, %arg0, %arg1 : vector<1xi32> + return +} + +// CHECK-LABEL: @boolcmpi_equality +func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) { + // CHECK: spirv.LogicalEqual + %0 = arith.cmpi eq, %arg0, %arg1 : i1 + // CHECK: spirv.LogicalNotEqual + %1 = arith.cmpi ne, %arg0, %arg1 : i1 + return +} + +// CHECK-LABEL: @boolcmpi_unsigned +func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) { + // CHECK-COUNT-2: spirv.Select + // CHECK: spirv.UGreaterThanEqual + %0 = arith.cmpi uge, %arg0, %arg1 : i1 + // CHECK-COUNT-2: spirv.Select + // CHECK: spirv.ULessThan + %1 = arith.cmpi ult, %arg0, %arg1 : i1 + return +} + +// CHECK-LABEL: @vec1boolcmpi_equality +func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) { + // CHECK: spirv.LogicalEqual + %0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1> + // CHECK: spirv.LogicalNotEqual + %1 = arith.cmpi ne, %arg0, %arg1 : vector<1xi1> + return +} + +// CHECK-LABEL: @vec1boolcmpi_unsigned +func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) { + // CHECK-COUNT-2: spirv.Select + // CHECK: spirv.UGreaterThanEqual + %0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1> + // CHECK-COUNT-2: spirv.Select + // CHECK: spirv.ULessThan + %1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1> + return +} + +// CHECK-LABEL: @vecboolcmpi_equality +func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { + // CHECK: spirv.LogicalEqual + %0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1> + // CHECK: spirv.LogicalNotEqual + %1 = arith.cmpi ne, %arg0, %arg1 : vector<4xi1> + return +} + +// CHECK-LABEL: @vecboolcmpi_unsigned +func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) { + // CHECK-COUNT-2: spirv.Select + // CHECK: spirv.UGreaterThanEqual + %0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1> + // CHECK-COUNT-2: spirv.Select + // CHECK: spirv.ULessThan + %1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1> + return +} diff --git a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir new file mode 100644 index 00000000000000..9e908465cb142f --- /dev/null +++ b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s + +// CHECK-LABEL: @combined +// CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32 +// CHECK: %[[C1_F32:.*]] = spirv.Constant 1.000000e+00 : f32 +// CHECK: %[[C0_I32:.*]] = spirv.Constant 0 : i32 +// CHECK: %[[C4_I32:.*]] = spirv.Constant 4 : i32 +// CHECK: %[[C0_I32_0:.*]] = spirv.Constant 0 : i32 +// CHECK: %[[C4_I32_0:.*]] = spirv.Constant 4 : i32 +// CHECK: %[[C1_I32:.*]] = spirv.Constant 1 : i32 +// CHECK: %[[VEC:.*]] = spirv.Constant dense<1.000000e+00> : vector<4xf32> +// CHECK: %[[VARIABLE:.*]] = spirv.Variable : !spirv.ptr +// CHECK: spirv.mlir.loop { +// CHECK: spirv.Branch ^[[HEADER:.*]](%[[C0_I32_0]], %[[C0_F32]] : i32, f32) +// CHECK: ^[[HEADER]](%[[INDVAR_0:.*]]: i32, %[[INDVAR_1:.*]]: f32): +// CHECK: %[[SLESSTHAN:.*]] = spirv.SLessThan %[[INDVAR_0]], %[[C4_I32_0]] : i32 +// CHECK: spirv.BranchConditional %[[SLESSTHAN]], ^[[BODY:.*]], ^[[MERGE:.*]] +// CHECK: ^[[BODY]]: +// CHECK: %[[FADD:.*]] = spirv.FAdd %[[INDVAR_1]], %[[C1_F32]] : f32 +// CHECK: %[[INSERT:.*]] = spirv.CompositeInsert %[[FADD]], %[[VEC]][0 : i32] : f32 into vector<4xf32> +// CHECK: spirv.Store "Function" %[[VARIABLE]], %[[FADD]] : f32 +// CHECK: %[[IADD:.*]] = spirv.IAdd %[[INDVAR_0]], %[[C1_I32]] : i32 +// CHECK: spirv.Branch ^[[HEADER]](%[[IADD]], %[[FADD]] : i32, f32) +// CHECK: ^[[MERGE]]: +// CHECK: spirv.mlir.merge +// CHECK: } +// CHECK: %[[LOAD:.*]] = spirv.Load "Function" %[[VARIABLE]] : f32 +// CHECK: %[[UNDEF:.*]] = spirv.Undef : f32 +// CHECK: spirv.ReturnValue %[[UNDEF]] : f32 +func.func @combined() -> f32 { + %c0_f32 = arith.constant 0.0 : f32 + %c1_f32 = arith.constant 1.0 : f32 + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %lb = index.casts %c0_i32 : i32 to index + %ub = index.casts %c4_i32 : i32 to index + %step = arith.constant 1 : index + %buf = vector.broadcast %c1_f32 : f32 to vector<4xf32> + scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %c0_f32) -> f32 { + %t = vector.extract %buf[0] : f32 from vector<4xf32> + %sum_next = arith.addf %sum_iter, %t : f32 + vector.insert %sum_next, %buf[0] : f32 into vector<4xf32> + scf.yield %sum_next : f32 + } + %ret = ub.poison : f32 + return %ret : f32 +} diff --git a/mlir/test/Conversion/ConvertToSPIRV/index.mlir b/mlir/test/Conversion/ConvertToSPIRV/index.mlir new file mode 100644 index 00000000000000..db747625bc7b39 --- /dev/null +++ b/mlir/test/Conversion/ConvertToSPIRV/index.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt %s -convert-to-spirv | FileCheck %s + +// CHECK-LABEL: @basic +func.func @basic(%a: index, %b: index) { + // CHECK: spirv.IAdd + %0 = index.add %a, %b + // CHECK: spirv.ISub + %1 = index.sub %a, %b + // CHECK: spirv.IMul + %2 = index.mul %a, %b + // CHECK: spirv.SDiv + %3 = index.divs %a, %b + // CHECK: spirv.UDiv + %4 = index.divu %a, %b + // CHECK: spirv.SRem + %5 = index.rems %a, %b + // CHECK: spirv.UMod + %6 = index.remu %a, %b + // CHECK: spirv.GL.SMax + %7 = index.maxs %a, %b + // CHECK: spirv.GL.UMax + %8 = index.maxu %a, %b + // CHECK: spirv.GL.SMin + %9 = index.mins %a, %b + // CHECK: spirv.GL.UMin + %10 = index.minu %a, %b + // CHECK: spirv.ShiftLeftLogical + %11 = index.shl %a, %b + // CHECK: spirv.ShiftRightArithmetic + %12 = index.shrs %a, %b + // CHECK: spirv.ShiftRightLogical + %13 = index.shru %a, %b + // CHECK: spirv.BitwiseAnd + %14 = index.and %a, %b + // CHECK: spirv.BitwiseOr + %15 = index.or %a, %b + // CHECK: spirv.BitwiseXor + %16 = index.xor %a, %b + return +} + +// CHECK-LABEL: @cmp +func.func @cmp(%a : index, %b : index) { + // CHECK: spirv.IEqual + %0 = index.cmp eq(%a, %b) + return +} + +// CHECK-LABEL: @ceildivs +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +// CHECK: spirv.ReturnValue %{{.*}} : i32 +func.func @ceildivs(%n: index, %m: index) -> index { + %result = index.ceildivs %n, %m + return %result : index +} + +// CHECK-LABEL: @ceildivu +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +// CHECK: spirv.ReturnValue %{{.*}} : i32 +func.func @ceildivu(%n: index, %m: index) -> index { + %result = index.ceildivu %n, %m + return %result : index +} diff --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir new file mode 100644 index 00000000000000..f619ca5771824b --- /dev/null +++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s + +// CHECK-LABEL: @if_yield +// CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr +// CHECK: spirv.mlir.selection { +// CHECK-NEXT: spirv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] +// CHECK-NEXT: [[TRUE]]: +// CHECK: %[[C0TRUE:.*]] = spirv.Constant 0.000000e+00 : f32 +// CHECK: %[[RETTRUE:.*]] = spirv.Constant 0.000000e+00 : f32 +// CHECK-DAG: spirv.Store "Function" %[[VAR]], %[[RETTRUE]] : f32 +// CHECK: spirv.Branch ^[[MERGE:.*]] +// CHECK-NEXT: [[FALSE]]: +// CHECK: %[[C0FALSE:.*]] = spirv.Constant 1.000000e+00 : f32 +// CHECK: %[[RETFALSE:.*]] = spirv.Constant 2.71828175 : f32 +// CHECK-DAG: spirv.Store "Function" %[[VAR]], %[[RETFALSE]] : f32 +// CHECK: spirv.Branch ^[[MERGE]] +// CHECK-NEXT: ^[[MERGE]]: +// CHECK: spirv.mlir.merge +// CHECK-NEXT: } +// CHECK-DAG: %[[OUT:.*]] = spirv.Load "Function" %[[VAR]] : f32 +// CHECK: spirv.ReturnValue %[[OUT]] : f32 +func.func @if_yield(%arg0: i1) -> f32 { + %0 = scf.if %arg0 -> f32 { + %c0 = arith.constant 0.0 : f32 + %res = math.sqrt %c0 : f32 + scf.yield %res : f32 + } else { + %c0 = arith.constant 1.0 : f32 + %res = math.exp %c0 : f32 + scf.yield %res : f32 + } + return %0 : f32 +} + +// CHECK-LABEL: @while +func.func @while(%arg0: i32, %arg1: i32) -> i32 { + %c2_i32 = arith.constant 2 : i32 + %0 = scf.while (%arg3 = %arg0) : (i32) -> (i32) { + %1 = arith.cmpi slt, %arg3, %arg1 : i32 + scf.condition(%1) %arg3 : i32 + } do { + ^bb0(%arg5: i32): + %1 = arith.muli %arg5, %c2_i32 : i32 + scf.yield %1 : i32 + } + return %0 : i32 +} diff --git a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir new file mode 100644 index 00000000000000..20b2a42bc3975c --- /dev/null +++ b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s + +// CHECK-LABEL: @return_scalar +// CHECK-SAME: %[[ARG0:.*]]: i32 +// CHECK: spirv.ReturnValue %[[ARG0]] +func.func @return_scalar(%arg0 : i32) -> i32 { + return %arg0 : i32 +} + +// CHECK-LABEL: @return_vector +// CHECK-SAME: %[[ARG0:.*]]: vector<4xi32> +// CHECK: spirv.ReturnValue %[[ARG0]] +func.func @return_vector(%arg0 : vector<4xi32>) -> vector<4xi32> { + return %arg0 : vector<4xi32> +} diff --git a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir new file mode 100644 index 00000000000000..66528b68f58cf3 --- /dev/null +++ b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s + +// CHECK-LABEL: @ub +// CHECK: %[[UNDEF:.*]] = spirv.Undef : i32 +// CHECK: spirv.ReturnValue %[[UNDEF]] : i32 +func.func @ub() -> index { + %0 = ub.poison : index + return %0 : index +} diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir new file mode 100644 index 00000000000000..336f0fe10c27ef --- /dev/null +++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir @@ -0,0 +1,439 @@ +// RUN: mlir-opt -split-input-file -convert-to-spirv %s | FileCheck %s + +// CHECK-LABEL: @extract +// CHECK-SAME: %[[ARG:.+]]: vector<2xf32> +// CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32> +// CHECK: spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32> +func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) { + %0 = "vector.extract"(%arg0) <{static_position = array}> : (vector<2xf32>) -> vector<1xf32> + %1 = "vector.extract"(%arg0) <{static_position = array}> : (vector<2xf32>) -> f32 + return %0, %1: vector<1xf32>, f32 +} + +// ----- + +// CHECK-LABEL: @extract_size1_vector +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK: spirv.ReturnValue %[[ARG0]] : f32 +func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 { + %0 = vector.extract %arg0[0] : f32 from vector<1xf32> + return %0: f32 +} + +// ----- + +// CHECK-LABEL: @insert +// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32 +// CHECK: spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32> +func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> { + %1 = vector.insert %arg1, %arg0[2] : f32 into vector<4xf32> + return %1: vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @insert_index_vector +// CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32> +func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> { + %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex> + return %1: vector<4xindex> +} + +// ----- + +// CHECK-LABEL: @insert_size1_vector +// CHECK-SAME: %[[V:.*]]: f32, %[[S:.*]]: f32 +// CHECK: spirv.ReturnValue %[[S]] : f32 +func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> { + %1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32> + return %1 : vector<1xf32> +} + +// ----- + +// CHECK-LABEL: @extract_element +// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 +// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 +func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 { + %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32> + return %0: f32 +} + +// ----- + +// CHECK-LABEL: @extract_element_cst +// CHECK-SAME: %[[V:.*]]: vector<4xf32> +// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> +func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 { + %idx = arith.constant 1 : i32 + %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32> + return %0: f32 +} + +// ----- + +// CHECK-LABEL: @extract_element_index +func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 { + // CHECK: spirv.VectorExtractDynamic + %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> + return %0: f32 +} + +// ----- + +// CHECK-LABEL: @extract_element_size1_vector +// CHECK-SAME:(%[[S:.+]]: f32, +func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 { + %bcast = vector.broadcast %arg0 : f32 to vector<1xf32> + %0 = vector.extractelement %bcast[%i : index] : vector<1xf32> + // CHECK: spirv.ReturnValue %[[S]] + return %0: f32 +} + +// ----- + +// CHECK-LABEL: @extract_element_0d_vector +// CHECK-SAME: (%[[S:.+]]: f32) +func.func @extract_element_0d_vector(%arg0 : f32) -> f32 { + %bcast = vector.broadcast %arg0 : f32 to vector + %0 = vector.extractelement %bcast[] : vector + // CHECK: spirv.ReturnValue %[[S]] + return %0: f32 +} + +// ----- + +// CHECK-LABEL: @insert_element +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 +// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 +func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> { + %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @insert_element_cst +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> +// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> +func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { + %idx = arith.constant 2 : i32 + %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @insert_element_index +func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { + // CHECK: spirv.VectorInsertDynamic + %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @insert_element_size1_vector +// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 +func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> { + %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32> + // CHECK: spirv.ReturnValue %[[S]] + return %0: vector<1xf32> +} + +// ----- + +// CHECK-LABEL: @insert_element_0d_vector +// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 +func.func @insert_element_0d_vector(%scalar: f32, %vector : vector) -> vector { + %0 = vector.insertelement %scalar, %vector[] : vector + // CHECK: spirv.ReturnValue %[[S]] + return %0: vector +} + +// ----- + +// CHECK-LABEL: @insert_size1_vector +// CHECK-SAME: %[[SUB:.*]]: f32, %[[FULL:.*]]: vector<3xf32> +// CHECK: %[[RET:.*]] = spirv.CompositeInsert %[[SUB]], %[[FULL]][2 : i32] : f32 into vector<3xf32> +// CHECK: spirv.ReturnValue %[[RET]] : vector<3xf32> +func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> vector<3xf32> { + %1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32> + return %1 : vector<3xf32> +} + +// ----- + +// CHECK-LABEL: @fma +// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> +// CHECK: spirv.GL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32> +func.func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> { + %0 = vector.fma %a, %b, %c: vector<4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @fma_size1_vector +// CHECK: spirv.GL.Fma %{{.+}} : f32 +func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf32>) -> vector<1xf32> { + %0 = vector.fma %a, %b, %c: vector<1xf32> + return %0 : vector<1xf32> +} + +// ----- + +// CHECK-LABEL: func @splat +// CHECK-SAME: (%[[A:.+]]: f32) +// CHECK: %[[VAL:.+]] = spirv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] +// CHECK: spirv.ReturnValue %[[VAL]] : vector<4xf32> +func.func @splat(%f : f32) -> vector<4xf32> { + %splat = vector.splat %f : vector<4xf32> + return %splat : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @splat_size1_vector +// CHECK-SAME: (%[[A:.+]]: f32) +// CHECK: spirv.ReturnValue %[[A]] : f32 +func.func @splat_size1_vector(%f : f32) -> vector<1xf32> { + %splat = vector.splat %f : vector<1xf32> + return %splat : vector<1xf32> +} + +// ----- + +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32 +// CHECK: spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[ARG0]] : (f32, f32, f32, f32) -> vector<4xf32> +func.func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> { + %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xf32>, vector<1xf32> + return %shuffle : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[V0:.+]]: vector<3xf32>, %[[V1:.+]]: vector<3xf32> +// CHECK: spirv.VectorShuffle [3 : i32, 2 : i32, 5 : i32, 1 : i32] %[[V0]], %[[V1]] : vector<3xf32>, vector<3xf32> -> vector<4xf32> +func.func @shuffle(%v0 : vector<3xf32>, %v1: vector<3xf32>) -> vector<4xf32> { + %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xf32>, vector<3xf32> + return %shuffle : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: i32, %[[ARG1:.+]]: vector<3xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[ARG1]][2 : i32] : vector<3xi32> +// CHECK: %[[RES:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[S1]], %[[S2]] : (i32, i32, i32) -> vector<3xi32> +// CHECK: spirv.ReturnValue %[[RES]] +func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<3xi32>) -> vector<3xi32> { + %shuffle = vector.shuffle %v0, %v1 [0, 2, 3] : vector<1xi32>, vector<3xi32> + return %shuffle : vector<3xi32> +} + +// ----- + +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32 +// CHECK: %[[RES:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]] : (i32, i32) -> vector<2xi32> +// CHECK: spirv.ReturnValue %[[RES]] +func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> { + %shuffle = vector.shuffle %v0, %v1 [0, 1] : vector<1xi32>, vector<1xi32> + return %shuffle : vector<2xi32> +} + +// ----- + +// CHECK-LABEL: func @interleave +// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>) +// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32> +// CHECK: spirv.ReturnValue %[[SHUFFLE]] +func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> { + %0 = vector.interleave %a, %b : vector<2xf32> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @interleave_size1 +// CHECK-SAME: (%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32) +// CHECK: %[[RES:.*]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]] : (f32, f32) -> vector<2xf32> +// CHECK: spirv.ReturnValue %[[RES]] +func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf32> { + %0 = vector.interleave %a, %b : vector<1xf32> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// ----- + +// CHECK-LABEL: func @reduction_add +// CHECK-SAME: (%[[V:.+]]: vector<4xi32>) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<4xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<4xi32> +// CHECK: %[[S3:.+]] = spirv.CompositeExtract %[[V]][3 : i32] : vector<4xi32> +// CHECK: %[[ADD0:.+]] = spirv.IAdd %[[S0]], %[[S1]] +// CHECK: %[[ADD1:.+]] = spirv.IAdd %[[ADD0]], %[[S2]] +// CHECK: %[[ADD2:.+]] = spirv.IAdd %[[ADD1]], %[[S3]] +// CHECK: spirv.ReturnValue %[[ADD2]] +func.func @reduction_add(%v : vector<4xi32>) -> i32 { + %reduce = vector.reduction , %v : vector<4xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction_addf_one_elem +// CHECK-SAME: (%[[ARG0:.+]]: f32) +// CHECK: spirv.ReturnValue %[[ARG0]] : f32 +func.func @reduction_addf_one_elem(%arg0: vector<1xf32>) -> f32 { + %red = vector.reduction , %arg0 : vector<1xf32> into f32 + return %red : f32 +} + +// ----- + +// CHECK-LABEL: func @reduction_addf_one_elem_acc +// CHECK-SAME: (%[[ARG0:.+]]: f32, %[[ACC:.+]]: f32) +// CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[ARG0]] : f32 +// CHECK: spirv.ReturnValue %[[RES]] : f32 +func.func @reduction_addf_one_elem_acc(%arg0: vector<1xf32>, %acc: f32) -> f32 { + %red = vector.reduction , %arg0, %acc : vector<1xf32> into f32 + return %red : f32 +} + +// ----- + +// CHECK-LABEL: func @reduction_mul +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MUL0:.+]] = spirv.FMul %[[S0]], %[[S1]] +// CHECK: %[[MUL1:.+]] = spirv.FMul %[[MUL0]], %[[S2]] +// CHECK: %[[MUL2:.+]] = spirv.FMul %[[MUL1]], %[[S]] +// CHECK: spirv.ReturnValue %[[MUL2]] +func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// ----- + +// CHECK-LABEL: func @reduction_maximumf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[MAX1]], %[[S]] +// CHECK: spirv.ReturnValue %[[MAX2]] +func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// ----- + +// CHECK-LABEL: func @reduction_minimumf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]] +// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[MIN0]], %[[S2]] +// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[MIN1]], %[[S]] +// CHECK: spirv.ReturnValue %[[MIN2]] +func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// ----- + +// CHECK-LABEL: func @reduction_maxsi +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MAX0:.+]] = spirv.GL.SMax %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spirv.GL.SMax %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spirv.GL.SMax %[[MAX1]], %[[S]] +// CHECK: spirv.ReturnValue %[[MAX2]] +func.func @reduction_maxsi(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction_minsi +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MIN0:.+]] = spirv.GL.SMin %[[S0]], %[[S1]] +// CHECK: %[[MIN1:.+]] = spirv.GL.SMin %[[MIN0]], %[[S2]] +// CHECK: %[[MIN2:.+]] = spirv.GL.SMin %[[MIN1]], %[[S]] +// CHECK: spirv.ReturnValue %[[MIN2]] +func.func @reduction_minsi(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction_maxui +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MAX0:.+]] = spirv.GL.UMax %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spirv.GL.UMax %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spirv.GL.UMax %[[MAX1]], %[[S]] +// CHECK: spirv.ReturnValue %[[MAX2]] +func.func @reduction_maxui(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction_minui +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MIN0:.+]] = spirv.GL.UMin %[[S0]], %[[S1]] +// CHECK: %[[MIN1:.+]] = spirv.GL.UMin %[[MIN0]], %[[S2]] +// CHECK: %[[MIN2:.+]] = spirv.GL.UMin %[[MIN1]], %[[S]] +// CHECK: spirv.ReturnValue %[[MIN2]] +func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: @shape_cast_same_type +// CHECK-SAME: (%[[ARG0:.*]]: vector<2xf32>) +// CHECK: spirv.ReturnValue %[[ARG0]] +func.func @shape_cast_same_type(%arg0 : vector<2xf32>) -> vector<2xf32> { + %1 = vector.shape_cast %arg0 : vector<2xf32> to vector<2xf32> + return %arg0 : vector<2xf32> +} + +// ----- + +// CHECK-LABEL: @shape_cast_size1_vector +// CHECK-SAME: (%[[ARG0:.*]]: f32) +// CHECK: spirv.ReturnValue %[[ARG0]] +func.func @shape_cast_size1_vector(%arg0 : vector) -> vector<1xf32> { + %1 = vector.shape_cast %arg0 : vector to vector<1xf32> + return %1 : vector<1xf32> +}