-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][index] Implement folders for CastSOp and CastUOp #66960
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-index ChangesFixes #66402 Full diff: https://github.com/llvm/llvm-project/pull/66960.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index 61cdf4ed0877a0f..c6079cb8a98c813 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -446,7 +446,7 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> {
// CastSOp
//===----------------------------------------------------------------------===//
-def Index_CastSOp : IndexOp<"casts", [Pure,
+def Index_CastSOp : IndexOp<"casts", [Pure,
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "index signed cast";
let description = [{
@@ -469,13 +469,14 @@ def Index_CastSOp : IndexOp<"casts", [Pure,
let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// CastUOp
//===----------------------------------------------------------------------===//
-def Index_CastUOp : IndexOp<"castu", [Pure,
+def Index_CastUOp : IndexOp<"castu", [Pure,
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "index unsigned cast";
let description = [{
@@ -498,6 +499,7 @@ def Index_CastUOp : IndexOp<"castu", [Pure,
let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index b6d802876c15ede..b506397742772a7 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -444,11 +444,63 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
// CastSOp
//===----------------------------------------------------------------------===//
+static OpFoldResult
+foldCastOp(Attribute input, Type type,
+ function_ref<APInt(const APInt &, unsigned)> extFn,
+ function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
+ auto attr = dyn_cast_if_present<IntegerAttr>(input);
+ if (!attr)
+ return {};
+ const APInt &value = attr.getValue();
+
+ if (isa<IndexType>(type)) {
+ // When casting to an index type, perform the cast assuming a 64-bit target.
+ // The result can be truncated to 32 bits as needed and always be correct.
+ // This is because `cast32(cast64(value)) == cast32(value)`.
+ APInt result = extOrTruncFn(value, 64);
+ return IntegerAttr::get(type, result);
+ }
+
+ // When casting from an index type, we must ensure the results respect
+ // `cast_t(value) == cast_t(trunc32(value))`.
+ auto intType = cast<IntegerType>(type);
+ unsigned width = intType.getWidth();
+
+ // If the result type is at most 32 bits, then the cast can always be folded
+ // because it is always a truncation.
+ if (width <= 32) {
+ APInt result = value.trunc(width);
+ return IntegerAttr::get(type, result);
+ }
+
+ // If the result type is at least 64 bits, then the cast is always a
+ // extension. The results will differ if `trunc32(value) != value)`.
+ if (width >= 64) {
+ if (extFn(value.trunc(32), 64) != value)
+ return {};
+ APInt result = extFn(value, width);
+ return IntegerAttr::get(type, result);
+ }
+
+ // Otherwise, we just have to check the property directly.
+ APInt result = value.trunc(width);
+ if (result != extFn(value.trunc(32), width))
+ return {};
+ return IntegerAttr::get(type, result);
+}
+
bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
return llvm::isa<IndexType>(lhsTypes.front()) !=
llvm::isa<IndexType>(rhsTypes.front());
}
+OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
+ return foldCastOp(
+ adaptor.getInput(), getType(),
+ [](const APInt &x, unsigned width) { return x.sext(width); },
+ [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
+}
+
//===----------------------------------------------------------------------===//
// CastUOp
//===----------------------------------------------------------------------===//
@@ -458,6 +510,13 @@ bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
llvm::isa<IndexType>(rhsTypes.front());
}
+OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
+ return foldCastOp(
+ adaptor.getInput(), getType(),
+ [](const APInt &x, unsigned width) { return x.zext(width); },
+ [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
+}
+
//===----------------------------------------------------------------------===//
// CmpOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 67308ffbe55ac6d..f3eae3605b3b64f 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -556,3 +556,11 @@ func.func @sub_identity(%arg0: index) -> index {
// CHECK-NEXT: return %arg0
return %0 : index
}
+
+// CHECK-LABEL: @castu_to_index
+func.func @castu_to_index() -> index {
+ // CHECK: index.constant 8000000000000
+ %0 = arith.constant 8000000000000 : i48
+ %1 = index.castu %0 : i48 to index
+ return %1 : index
+}
|
jpienaar
reviewed
Sep 21, 2023
jpienaar
approved these changes
Sep 21, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #66402