Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][index] Implement folders for CastSOp and CastUOp #66960

Merged
merged 2 commits into from
Sep 21, 2023

Conversation

Mogball
Copy link
Contributor

@Mogball Mogball commented Sep 20, 2023

Fixes #66402

@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-index

Changes

Fixes #66402


Full diff: https://github.com/llvm/llvm-project/pull/66960.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Index/IR/IndexOps.td (+4-2)
  • (modified) mlir/lib/Dialect/Index/IR/IndexOps.cpp (+59)
  • (modified) mlir/test/Dialect/Index/index-canonicalize.mlir (+8)
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index 61cdf4ed0877a0f..c6079cb8a98c813 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -446,7 +446,7 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> {
 // CastSOp
 //===----------------------------------------------------------------------===//
 
-def Index_CastSOp : IndexOp<"casts", [Pure, 
+def Index_CastSOp : IndexOp<"casts", [Pure,
     DeclareOpInterfaceMethods<CastOpInterface>]> {
   let summary = "index signed cast";
   let description = [{
@@ -469,13 +469,14 @@ def Index_CastSOp : IndexOp<"casts", [Pure,
   let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
   let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
 // CastUOp
 //===----------------------------------------------------------------------===//
 
-def Index_CastUOp : IndexOp<"castu", [Pure, 
+def Index_CastUOp : IndexOp<"castu", [Pure,
     DeclareOpInterfaceMethods<CastOpInterface>]> {
   let summary = "index unsigned cast";
   let description = [{
@@ -498,6 +499,7 @@ def Index_CastUOp : IndexOp<"castu", [Pure,
   let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
   let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index b6d802876c15ede..b506397742772a7 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -444,11 +444,63 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
 // CastSOp
 //===----------------------------------------------------------------------===//
 
+static OpFoldResult
+foldCastOp(Attribute input, Type type,
+           function_ref<APInt(const APInt &, unsigned)> extFn,
+           function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
+  auto attr = dyn_cast_if_present<IntegerAttr>(input);
+  if (!attr)
+    return {};
+  const APInt &value = attr.getValue();
+
+  if (isa<IndexType>(type)) {
+    // When casting to an index type, perform the cast assuming a 64-bit target.
+    // The result can be truncated to 32 bits as needed and always be correct.
+    // This is because `cast32(cast64(value)) == cast32(value)`.
+    APInt result = extOrTruncFn(value, 64);
+    return IntegerAttr::get(type, result);
+  }
+
+  // When casting from an index type, we must ensure the results respect
+  // `cast_t(value) == cast_t(trunc32(value))`.
+  auto intType = cast<IntegerType>(type);
+  unsigned width = intType.getWidth();
+
+  // If the result type is at most 32 bits, then the cast can always be folded
+  // because it is always a truncation.
+  if (width <= 32) {
+    APInt result = value.trunc(width);
+    return IntegerAttr::get(type, result);
+  }
+
+  // If the result type is at least 64 bits, then the cast is always a
+  // extension. The results will differ if `trunc32(value) != value)`.
+  if (width >= 64) {
+    if (extFn(value.trunc(32), 64) != value)
+      return {};
+    APInt result = extFn(value, width);
+    return IntegerAttr::get(type, result);
+  }
+
+  // Otherwise, we just have to check the property directly.
+  APInt result = value.trunc(width);
+  if (result != extFn(value.trunc(32), width))
+    return {};
+  return IntegerAttr::get(type, result);
+}
+
 bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
   return llvm::isa<IndexType>(lhsTypes.front()) !=
          llvm::isa<IndexType>(rhsTypes.front());
 }
 
+OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
+  return foldCastOp(
+      adaptor.getInput(), getType(),
+      [](const APInt &x, unsigned width) { return x.sext(width); },
+      [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
+}
+
 //===----------------------------------------------------------------------===//
 // CastUOp
 //===----------------------------------------------------------------------===//
@@ -458,6 +510,13 @@ bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
          llvm::isa<IndexType>(rhsTypes.front());
 }
 
+OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
+  return foldCastOp(
+      adaptor.getInput(), getType(),
+      [](const APInt &x, unsigned width) { return x.zext(width); },
+      [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
+}
+
 //===----------------------------------------------------------------------===//
 // CmpOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 67308ffbe55ac6d..f3eae3605b3b64f 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -556,3 +556,11 @@ func.func @sub_identity(%arg0: index) -> index {
   // CHECK-NEXT: return %arg0
   return %0 : index
 }
+
+// CHECK-LABEL: @castu_to_index
+func.func @castu_to_index() -> index {
+  // CHECK: index.constant 8000000000000
+  %0 = arith.constant 8000000000000 : i48
+  %1 = index.castu %0 : i48 to index
+  return %1 : index
+}

@Mogball Mogball merged commit 9744d39 into llvm:main Sep 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][index] Implement folder for CastSOp and CastUOp
3 participants