Skip to content

Commit

Permalink
[mlir][spirv] Add a generic convert-to-spirv pass (llvm#95942)
Browse files Browse the repository at this point in the history
This PR implements a MVP version of an MLIR lowering pipeline to SPIR-V.
The goal of adding this pipeline is to have a better test coverage of
SPIR-V compilation upstream, and enable writing simple kernels by hand.
The dialects supported in this version include `arith`, `vector` (only
1-D vectors with size 2,3,4,8 or 16), `scf`, `ub`, `index`, `func` and
`math`. New test cases for the pass are also included in this PR.

**Relevant links**

- [Open MLIR Meeting - YouTube
Video](https://www.youtube.com/watch?v=csWPOQfgLMo)
- [Discussion on LLVM
Forum](https://discourse.llvm.org/t/open-mlir-meeting-12-14-2023-discussion-on-improving-handling-of-unit-dimensions-in-the-vector-dialect/75683)

**Future plans**

- Add conversion patterns for other dialects, e.g. `gpu`, `tensor`, etc.
- Include vector transformation to unroll vectors to 1-D, and handle
those with unsupported sizes.
- Implement multiple-return. SPIR-V does not support multiple return
values since a `spirv.func` can only return zero or one values. It might
be possible to wrap the return values in a `spirv.struct`.
- Add a conversion for `scf.parallel`.
  • Loading branch information
angelz913 committed Jun 21, 2024
1 parent 39048b6 commit dc0bdbf
Show file tree
Hide file tree
Showing 13 changed files with 977 additions and 0 deletions.
22 changes: 22 additions & 0 deletions mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- ConvertToSPIRVPass.h - Conversion to SPIR-V pass ---*- C++ -*-=========//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
#define MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H

#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"

} // namespace mlir

#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
];
}

//===----------------------------------------------------------------------===//
// ToSPIRV
//===----------------------------------------------------------------------===//

def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
let summary = "Convert to SPIR-V";
let description = [{
This is a generic pass to convert to SPIR-V.
}];
let dependentDialects = ["spirv::SPIRVDialect"];
}

//===----------------------------------------------------------------------===//
// AffineToStandard
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_subdirectory(ControlFlowToLLVM)
add_subdirectory(ControlFlowToSCF)
add_subdirectory(ControlFlowToSPIRV)
add_subdirectory(ConvertToLLVM)
add_subdirectory(ConvertToSPIRV)
add_subdirectory(FuncToEmitC)
add_subdirectory(FuncToLLVM)
add_subdirectory(FuncToSPIRV)
Expand Down
32 changes: 32 additions & 0 deletions mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
set(LLVM_OPTIONAL_SOURCES
ConvertToSPIRVPass.cpp
)

add_mlir_conversion_library(MLIRConvertToSPIRVPass
ConvertToSPIRVPass.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ConvertToSPIRV

DEPENDS
MLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRArithToSPIRV
MLIRArithTransforms
MLIRFuncToSPIRV
MLIRIndexToSPIRV
MLIRIR
MLIRPass
MLIRRewrite
MLIRSCFToSPIRV
MLIRSPIRVConversion
MLIRSPIRVDialect
MLIRSPIRVTransforms
MLIRSupport
MLIRTransforms
MLIRTransformUtils
MLIRUBToSPIRV
MLIRVectorToSPIRV
MLIRVectorTransforms
)
71 changes: 71 additions & 0 deletions mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>

#define DEBUG_TYPE "convert-to-spirv"

namespace mlir {
#define GEN_PASS_DEF_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {

/// A pass to perform the SPIR-V conversion.
struct ConvertToSPIRVPass final
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {

void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();

spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
SPIRVTypeConverter typeConverter(targetAttr);

RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;

// Populate patterns.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
populateVectorToSPIRVPatterns(typeConverter, patterns);
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);

std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);

if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace
218 changes: 218 additions & 0 deletions mlir/test/Conversion/ConvertToSPIRV/arith.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s

//===----------------------------------------------------------------------===//
// arithmetic ops
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @int32_scalar
func.func @int32_scalar(%lhs: i32, %rhs: i32) {
// CHECK: spirv.IAdd %{{.*}}, %{{.*}}: i32
%0 = arith.addi %lhs, %rhs: i32
// CHECK: spirv.ISub %{{.*}}, %{{.*}}: i32
%1 = arith.subi %lhs, %rhs: i32
// CHECK: spirv.IMul %{{.*}}, %{{.*}}: i32
%2 = arith.muli %lhs, %rhs: i32
// CHECK: spirv.SDiv %{{.*}}, %{{.*}}: i32
%3 = arith.divsi %lhs, %rhs: i32
// CHECK: spirv.UDiv %{{.*}}, %{{.*}}: i32
%4 = arith.divui %lhs, %rhs: i32
// CHECK: spirv.UMod %{{.*}}, %{{.*}}: i32
%5 = arith.remui %lhs, %rhs: i32
return
}

// CHECK-LABEL: @int32_scalar_srem
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
// CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
// CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
// CHECK: %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32
// CHECK: %[[POS:.+]] = spirv.IEqual %[[LHS]], %[[LABS]] : i32
// CHECK: %[[NEG:.+]] = spirv.SNegate %[[ABS]] : i32
// CHECK: %{{.+}} = spirv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
%0 = arith.remsi %lhs, %rhs: i32
return
}

// -----

//===----------------------------------------------------------------------===//
// arith bit ops
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @bitwise_scalar
func.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) {
// CHECK: spirv.BitwiseAnd
%0 = arith.andi %arg0, %arg1 : i32
// CHECK: spirv.BitwiseOr
%1 = arith.ori %arg0, %arg1 : i32
// CHECK: spirv.BitwiseXor
%2 = arith.xori %arg0, %arg1 : i32
return
}

// CHECK-LABEL: @bitwise_vector
func.func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
// CHECK: spirv.BitwiseAnd
%0 = arith.andi %arg0, %arg1 : vector<4xi32>
// CHECK: spirv.BitwiseOr
%1 = arith.ori %arg0, %arg1 : vector<4xi32>
// CHECK: spirv.BitwiseXor
%2 = arith.xori %arg0, %arg1 : vector<4xi32>
return
}

// CHECK-LABEL: @logical_scalar
func.func @logical_scalar(%arg0 : i1, %arg1 : i1) {
// CHECK: spirv.LogicalAnd
%0 = arith.andi %arg0, %arg1 : i1
// CHECK: spirv.LogicalOr
%1 = arith.ori %arg0, %arg1 : i1
// CHECK: spirv.LogicalNotEqual
%2 = arith.xori %arg0, %arg1 : i1
return
}

// CHECK-LABEL: @logical_vector
func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
// CHECK: spirv.LogicalAnd
%0 = arith.andi %arg0, %arg1 : vector<4xi1>
// CHECK: spirv.LogicalOr
%1 = arith.ori %arg0, %arg1 : vector<4xi1>
// CHECK: spirv.LogicalNotEqual
%2 = arith.xori %arg0, %arg1 : vector<4xi1>
return
}

// CHECK-LABEL: @shift_scalar
func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
// CHECK: spirv.ShiftLeftLogical
%0 = arith.shli %arg0, %arg1 : i32
// CHECK: spirv.ShiftRightArithmetic
%1 = arith.shrsi %arg0, %arg1 : i32
// CHECK: spirv.ShiftRightLogical
%2 = arith.shrui %arg0, %arg1 : i32
return
}

// CHECK-LABEL: @shift_vector
func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
// CHECK: spirv.ShiftLeftLogical
%0 = arith.shli %arg0, %arg1 : vector<4xi32>
// CHECK: spirv.ShiftRightArithmetic
%1 = arith.shrsi %arg0, %arg1 : vector<4xi32>
// CHECK: spirv.ShiftRightLogical
%2 = arith.shrui %arg0, %arg1 : vector<4xi32>
return
}

// -----

//===----------------------------------------------------------------------===//
// arith.cmpf
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @cmpf
func.func @cmpf(%arg0 : f32, %arg1 : f32) {
// CHECK: spirv.FOrdEqual
%1 = arith.cmpf oeq, %arg0, %arg1 : f32
return
}

// CHECK-LABEL: @vec1cmpf
func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) {
// CHECK: spirv.FOrdGreaterThan
%0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32>
// CHECK: spirv.FUnordLessThan
%1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32>
return
}

// -----

//===----------------------------------------------------------------------===//
// arith.cmpi
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @cmpi
func.func @cmpi(%arg0 : i32, %arg1 : i32) {
// CHECK: spirv.IEqual
%0 = arith.cmpi eq, %arg0, %arg1 : i32
return
}

// CHECK-LABEL: @indexcmpi
func.func @indexcmpi(%arg0 : index, %arg1 : index) {
// CHECK: spirv.IEqual
%0 = arith.cmpi eq, %arg0, %arg1 : index
return
}

// CHECK-LABEL: @vec1cmpi
func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
// CHECK: spirv.ULessThan
%0 = arith.cmpi ult, %arg0, %arg1 : vector<1xi32>
// CHECK: spirv.SGreaterThan
%1 = arith.cmpi sgt, %arg0, %arg1 : vector<1xi32>
return
}

// CHECK-LABEL: @boolcmpi_equality
func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) {
// CHECK: spirv.LogicalEqual
%0 = arith.cmpi eq, %arg0, %arg1 : i1
// CHECK: spirv.LogicalNotEqual
%1 = arith.cmpi ne, %arg0, %arg1 : i1
return
}

// CHECK-LABEL: @boolcmpi_unsigned
func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) {
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.UGreaterThanEqual
%0 = arith.cmpi uge, %arg0, %arg1 : i1
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.ULessThan
%1 = arith.cmpi ult, %arg0, %arg1 : i1
return
}

// CHECK-LABEL: @vec1boolcmpi_equality
func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
// CHECK: spirv.LogicalEqual
%0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1>
// CHECK: spirv.LogicalNotEqual
%1 = arith.cmpi ne, %arg0, %arg1 : vector<1xi1>
return
}

// CHECK-LABEL: @vec1boolcmpi_unsigned
func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.UGreaterThanEqual
%0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1>
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.ULessThan
%1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1>
return
}

// CHECK-LABEL: @vecboolcmpi_equality
func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
// CHECK: spirv.LogicalEqual
%0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1>
// CHECK: spirv.LogicalNotEqual
%1 = arith.cmpi ne, %arg0, %arg1 : vector<4xi1>
return
}

// CHECK-LABEL: @vecboolcmpi_unsigned
func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) {
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.UGreaterThanEqual
%0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1>
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.ULessThan
%1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1>
return
}
Loading

0 comments on commit dc0bdbf

Please sign in to comment.