From b8f5804651e4e3861d004cb50d6205eaf36de0c0 Mon Sep 17 00:00:00 2001 From: Mogball <jeff@modular.com> Date: Wed, 20 Sep 2023 16:30:40 -0700 Subject: [PATCH 1/2] [mlir][index] Implement folders for CastSOp and CastUOp --- .../include/mlir/Dialect/Index/IR/IndexOps.td | 6 +- mlir/lib/Dialect/Index/IR/IndexOps.cpp | 59 +++++++++++++++++++ .../Dialect/Index/index-canonicalize.mlir | 8 +++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td index 61cdf4ed0877a0..c6079cb8a98c81 100644 --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -446,7 +446,7 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> { // CastSOp //===----------------------------------------------------------------------===// -def Index_CastSOp : IndexOp<"casts", [Pure, +def Index_CastSOp : IndexOp<"casts", [Pure, DeclareOpInterfaceMethods<CastOpInterface>]> { let summary = "index signed cast"; let description = [{ @@ -469,13 +469,14 @@ def Index_CastSOp : IndexOp<"casts", [Pure, let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input); let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; + let hasFolder = 1; } //===----------------------------------------------------------------------===// // CastUOp //===----------------------------------------------------------------------===// -def Index_CastUOp : IndexOp<"castu", [Pure, +def Index_CastUOp : IndexOp<"castu", [Pure, DeclareOpInterfaceMethods<CastOpInterface>]> { let summary = "index unsigned cast"; let description = [{ @@ -498,6 +499,7 @@ def Index_CastUOp : IndexOp<"castu", [Pure, let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input); let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp index b6d802876c15ed..b506397742772a 100644 --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -444,11 +444,63 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) { // CastSOp //===----------------------------------------------------------------------===// +static OpFoldResult +foldCastOp(Attribute input, Type type, + function_ref<APInt(const APInt &, unsigned)> extFn, + function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) { + auto attr = dyn_cast_if_present<IntegerAttr>(input); + if (!attr) + return {}; + const APInt &value = attr.getValue(); + + if (isa<IndexType>(type)) { + // When casting to an index type, perform the cast assuming a 64-bit target. + // The result can be truncated to 32 bits as needed and always be correct. + // This is because `cast32(cast64(value)) == cast32(value)`. + APInt result = extOrTruncFn(value, 64); + return IntegerAttr::get(type, result); + } + + // When casting from an index type, we must ensure the results respect + // `cast_t(value) == cast_t(trunc32(value))`. + auto intType = cast<IntegerType>(type); + unsigned width = intType.getWidth(); + + // If the result type is at most 32 bits, then the cast can always be folded + // because it is always a truncation. + if (width <= 32) { + APInt result = value.trunc(width); + return IntegerAttr::get(type, result); + } + + // If the result type is at least 64 bits, then the cast is always a + // extension. The results will differ if `trunc32(value) != value)`. + if (width >= 64) { + if (extFn(value.trunc(32), 64) != value) + return {}; + APInt result = extFn(value, width); + return IntegerAttr::get(type, result); + } + + // Otherwise, we just have to check the property directly. + APInt result = value.trunc(width); + if (result != extFn(value.trunc(32), width)) + return {}; + return IntegerAttr::get(type, result); +} + bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { return llvm::isa<IndexType>(lhsTypes.front()) != llvm::isa<IndexType>(rhsTypes.front()); } +OpFoldResult CastSOp::fold(FoldAdaptor adaptor) { + return foldCastOp( + adaptor.getInput(), getType(), + [](const APInt &x, unsigned width) { return x.sext(width); }, + [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); }); +} + //===----------------------------------------------------------------------===// // CastUOp //===----------------------------------------------------------------------===// @@ -458,6 +510,13 @@ bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { llvm::isa<IndexType>(rhsTypes.front()); } +OpFoldResult CastUOp::fold(FoldAdaptor adaptor) { + return foldCastOp( + adaptor.getInput(), getType(), + [](const APInt &x, unsigned width) { return x.zext(width); }, + [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); }); +} + //===----------------------------------------------------------------------===// // CmpOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir index 67308ffbe55ac6..f3eae3605b3b64 100644 --- a/mlir/test/Dialect/Index/index-canonicalize.mlir +++ b/mlir/test/Dialect/Index/index-canonicalize.mlir @@ -556,3 +556,11 @@ func.func @sub_identity(%arg0: index) -> index { // CHECK-NEXT: return %arg0 return %0 : index } + +// CHECK-LABEL: @castu_to_index +func.func @castu_to_index() -> index { + // CHECK: index.constant 8000000000000 + %0 = arith.constant 8000000000000 : i48 + %1 = index.castu %0 : i48 to index + return %1 : index +} From 59c701ae7395e8e443b4e3a4ce0315fb15cd783e Mon Sep 17 00:00:00 2001 From: Mogball <jeff@modular.com> Date: Thu, 21 Sep 2023 10:59:35 -0700 Subject: [PATCH 2/2] add unit tests --- .../Dialect/Index/index-canonicalize.mlir | 8 ++ mlir/unittests/Dialect/CMakeLists.txt | 1 + mlir/unittests/Dialect/Index/CMakeLists.txt | 7 ++ .../Dialect/Index/IndexOpsFoldersTest.cpp | 104 ++++++++++++++++++ 4 files changed, 120 insertions(+) create mode 100644 mlir/unittests/Dialect/Index/CMakeLists.txt create mode 100644 mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir index f3eae3605b3b64..db03505350b77e 100644 --- a/mlir/test/Dialect/Index/index-canonicalize.mlir +++ b/mlir/test/Dialect/Index/index-canonicalize.mlir @@ -564,3 +564,11 @@ func.func @castu_to_index() -> index { %1 = index.castu %0 : i48 to index return %1 : index } + +// CHECK-LABEL: @casts_to_index +func.func @casts_to_index() -> index { + // CHECK: index.constant -1000 + %0 = arith.constant -1000 : i48 + %1 = index.casts %0 : i48 to index + return %1 : index +} diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 522aeca29146d1..2d2835c64b9844 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests MLIRIR MLIRDialect) +add_subdirectory(Index) add_subdirectory(LLVMIR) add_subdirectory(MemRef) add_subdirectory(SparseTensor) diff --git a/mlir/unittests/Dialect/Index/CMakeLists.txt b/mlir/unittests/Dialect/Index/CMakeLists.txt new file mode 100644 index 00000000000000..c4bac2371e52fb --- /dev/null +++ b/mlir/unittests/Dialect/Index/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRIndexOpsTests + IndexOpsFoldersTest.cpp +) +target_link_libraries(MLIRIndexOpsTests + PRIVATE + MLIRIndexDialect +) diff --git a/mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp b/mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp new file mode 100644 index 00000000000000..948033ddb5934a --- /dev/null +++ b/mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp @@ -0,0 +1,104 @@ +//===- IndexOpsFoldersTest.cpp - unit tests for index op folders ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/IR/OwningOpRef.h" +#include "gtest/gtest.h" + +using namespace mlir; + +namespace { +/// Test fixture for testing operation folders. +class IndexFolderTest : public testing::Test { +public: + IndexFolderTest() { ctx.getOrLoadDialect<index::IndexDialect>(); } + + /// Instantiate an operation, invoke its folder, and return the attribute + /// result. + template <typename OpT> + void foldOp(IntegerAttr &value, Type type, ArrayRef<Attribute> operands); + +protected: + /// The MLIR context to use. + MLIRContext ctx; + /// A builder to use. + OpBuilder b{&ctx}; +}; +} // namespace + +template <typename OpT> +void IndexFolderTest::foldOp(IntegerAttr &value, Type type, + ArrayRef<Attribute> operands) { + // This function returns null so that `ASSERT_*` works within it. + OperationState state(UnknownLoc::get(&ctx), OpT::getOperationName()); + state.addTypes(type); + OwningOpRef<OpT> op = cast<OpT>(b.create(state)); + SmallVector<OpFoldResult> results; + LogicalResult result = op->getOperation()->fold(operands, results); + // Propagate the failure to the test. + if (failed(result)) { + value = nullptr; + return; + } + ASSERT_EQ(results.size(), 1u); + value = dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(results.front())); + ASSERT_TRUE(value); +} + +TEST_F(IndexFolderTest, TestCastUOpFolder) { + IntegerAttr value; + auto fold = [&](Type type, Attribute input) { + foldOp<index::CastUOp>(value, type, input); + }; + + // Target width less than or equal to 32 bits. + fold(b.getIntegerType(16), b.getIndexAttr(8000000000)); + ASSERT_TRUE(value); + EXPECT_EQ(value.getInt(), 20480u); + + // Target width greater than or equal to 64 bits. + fold(b.getIntegerType(64), b.getIndexAttr(2000)); + ASSERT_TRUE(value); + EXPECT_EQ(value.getInt(), 2000u); + + // Fails to fold, because truncating to 32 bits and then extending creates a + // different value. + fold(b.getIntegerType(64), b.getIndexAttr(8000000000)); + EXPECT_FALSE(value); + + // Target width between 32 and 64 bits. + fold(b.getIntegerType(40), b.getIndexAttr(0x10000000010000)); + // Fold succeeds because the upper bits are truncated in the cast. + ASSERT_TRUE(value); + EXPECT_EQ(value.getInt(), 65536); + + // Fails to fold because the upper bits are not truncated. + fold(b.getIntegerType(60), b.getIndexAttr(0x10000000010000)); + EXPECT_FALSE(value); +} + +TEST_F(IndexFolderTest, TestCastSOpFolder) { + IntegerAttr value; + auto fold = [&](Type type, Attribute input) { + foldOp<index::CastSOp>(value, type, input); + }; + + // Just test the extension cases to ensure signs are being respected. + + // Target width greater than or equal to 64 bits. + fold(b.getIntegerType(64), b.getIndexAttr(-2000)); + ASSERT_TRUE(value); + EXPECT_EQ(value.getInt(), -2000); + + // Target width between 32 and 64 bits. + fold(b.getIntegerType(40), b.getIndexAttr(-0x10000000010000)); + // Fold succeeds because the upper bits are truncated in the cast. + ASSERT_TRUE(value); + EXPECT_EQ(value.getInt(), -65536); +}