Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][transform] Check for invalidated iterators on payload IR mappings #66369

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Sep 14, 2023

Add extra error checking (in debug mode) to detect cases where an iterator on "direct" payload IR mappings is invalidated (due to elements being removed). Such errors are hard to debug: they are often non-deterministic; sometimes the program crashes, sometimes it produces wrong results. Even when it crashes, the stack trace often points to completely unrelated code locations.

Store a timestamp with each "direct" mapping. The timestamp is increased whenever an operation is performed that invaldiates an iterator on that mapping. A debug iterator is added that checks the timestamp as payload IR is enumerated.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Sep 14, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 14, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Changes Add extra error checking (in debug mode) to detect cases where an iterator on "direct" payload IR mappings is invalidated (due to elements being removed). Such errors are hard to debug: they are often non-deterministic; sometimes the program crashes, sometimes it produces wrong results. Even when it crashes, the stack trace often points to completely unrelated code locations.

Store a timestamp with each "direct" mapping. The timestamp is increased whenever an operation is performed that invaldiates an iterator on that mapping. A debug iterator is added that checks the timestamp before dereferencing or incrementing.

Full diff: https://github.com/llvm/llvm-project/pull/66369.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+65-1)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+18)
  • (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (+3-2)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index efd8d573936c332..cec31c63ed8167b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -18,6 +18,40 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
 
+#ifndef NDEBUG
+namespace {
+/// An iterator adaptor that checks an assertion before every increment and
+/// dereference.
+template <typename WrappedIteratorT, typename AssertFn>
+class AssertingIterator : public llvm::iterator_adaptor_base<
+                              AssertingIterator<WrappedIteratorT, AssertFn>,
+                              WrappedIteratorT, std::input_iterator_tag> {
+  using BaseT = typename AssertingIterator::iterator_adaptor_base;
+  using PointerT = typename std::iterator_traits<WrappedIteratorT>::pointer;
+
+  /// The assertion function.
+  AssertFn assertFn;
+
+public:
+  AssertingIterator(WrappedIteratorT I, AssertFn assertFn)
+      : BaseT(I), assertFn(assertFn) {}
+
+  using BaseT::operator*;
+  decltype(*std::declval<WrappedIteratorT>()) operator*() {
+    assertFn();
+    return *(this->I);
+  }
+
+  using BaseT::operator++;
+  AssertingIterator &operator++() {
+    assertFn();
+    this->I++;
+    return *this;
+  }
+};
+} // namespace
+#endif // NDEBUG
+
 namespace mlir {
 namespace transform {
 
@@ -170,6 +204,12 @@ class TransformState {
   /// should be emitted when the value is used.
   using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
 
+#ifndef NDEBUG
+  /// Debug only: A timestamp is associated with each transform IR value, so
+  /// that invalid iterator usage can be detected more reliably.
+  using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
+#endif // NDEBUG
+
   /// The bidirectional mappings between transform IR values and payload IR
   /// operations, and the mapping between transform IR values and parameters.
   struct Mappings {
@@ -178,6 +218,11 @@ class TransformState {
     ParamMapping params;
     ValueMapping values;
     ValueMapping reverseValues;
+
+#ifndef NDEBUG
+    TransformIRTimestampMapping timestamps;
+    void incrementTimestamp(Value value) { ++timestamps[value]; }
+#endif // NDEBUG
   };
 
   friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
@@ -207,9 +252,28 @@ class TransformState {
   /// not enumerated. This function is helpful for transformations that apply to
   /// a particular handle.
   auto getPayloadOps(Value value) const {
+    ArrayRef<Operation *> view = getPayloadOpsView(value);
+
+#ifndef NDEBUG
+    // Memorize the current timestamp and make sure that it has not changed
+    // when incrementing or dereferencing the iterator returned by this
+    // function. The timestamp is incremented when the "direct" mapping is
+    // resized; this would invalidate the iterator returned by this function.
+    int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
+    auto assertFn = [=] {
+      bool sameTimestamp =
+          currentTimestamp == this->getMapping(value).timestamps.lookup(value);
+      assert(sameTimestamp && "iterator was invalidated during iteration");
+    };
+    auto it = llvm::make_range(AssertingIterator(std::begin(view), assertFn),
+                               AssertingIterator(std::end(view), assertFn));
+#else
+    auto it = llvm::make_range(view.begin(), view.end());
+#endif // NDEBUG
+
     // When ops are replaced/erased, they are replaced with nullptr (until
     // the data structure is compacted). Do not enumerate these ops.
-    return llvm::make_filter_range(getPayloadOpsView(value),
+    return llvm::make_filter_range(it,
                                    [](Operation *op) { return op != nullptr; });
   }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 00450a1ff8f36cf..a091047c440de35 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -310,6 +310,11 @@ void transform::TransformState::forgetMapping(Value opHandle,
   for (Operation *op : mappings.direct[opHandle])
     dropMappingEntry(mappings.reverse, op, opHandle);
   mappings.direct.erase(opHandle);
+#ifndef NDEBUG
+  // Payload IR is removed from the mapping. This invalidates the respective
+  // iterators.
+  mappings.incrementTimestamp(opHandle);
+#endif // NDEBUG
 
   for (Value opResult : origOpFlatResults) {
     SmallVector<Value> resultHandles;
@@ -336,6 +341,12 @@ void transform::TransformState::forgetValueMapping(
       Mappings &localMappings = getMapping(opHandle);
       dropMappingEntry(localMappings.direct, opHandle, payloadOp);
       dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
+
+#ifndef NDEBUG
+      // Payload IR is removed from the mapping. This invalidates the respective
+      // iterators.
+      localMappings.incrementTimestamp(opHandle);
+#endif // NDEBUG
     }
   }
 }
@@ -774,6 +785,13 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
 void transform::TransformState::compactOpHandles() {
   for (Value handle : opHandlesToCompact) {
     Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
+#ifndef NDEBUG
+    if (llvm::find(mappings.direct[handle], nullptr) !=
+        mappings.direct[handle].end())
+      // Payload IR is removed from the mapping. This invalidates the respective
+      // iterators.
+      mappings.incrementTimestamp(handle);
+#endif // NDEBUG
     llvm::erase_value(mappings.direct[handle], nullptr);
   }
   opHandlesToCompact.clear();
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 21f9ff5999a5ed5..3e5f1baac684d42 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -360,8 +360,9 @@ DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
 DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(
     transform::TransformRewriter &rewriter,
     transform::TransformResults &results, transform::TransformState &state) {
-  auto payloadOps = state.getPayloadOps(getTarget());
-  auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
+  auto it = state.getPayloadOps(getTarget());
+  SmallVector<Operation *> reversedOps(it.begin(), it.end());
+  llvm::reverse(reversedOps);
   results.set(llvm::cast<OpResult>(getResult()), reversedOps);
   return DiagnosedSilenceableFailure::success();
 }

@matthias-springer matthias-springer force-pushed the transform_check_invalidation branch 2 times, most recently from 608c6d5 to 49d368f Compare September 14, 2023 13:11
@@ -360,8 +360,9 @@ DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto payloadOps = state.getPayloadOps(getTarget());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not needed.

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with the correct macro predicate

@@ -178,6 +184,11 @@ class TransformState {
ParamMapping params;
ValueMapping values;
ValueMapping reverseValues;

#ifndef NDEBUG
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code that may change the size of the object should be guarded by LLVM_ENABLE_ABI_BREAKING_CHECKS.

@matthias-springer matthias-springer force-pushed the transform_check_invalidation branch from 49d368f to 0011bbe Compare September 14, 2023 14:28
Add extra error checking (in debug mode) to detect cases where an iterator on "direct" payload IR mappings is invalidated (due to elements being removed). Such errors are hard to debug: they are often non-deterministic; sometimes the program crashes, sometimes it produces wrong results. Even when it crashes, the stack trace often points to completely unrelated code locations.

Store a timestamp with each "direct" mapping. The timestamp is increased whenever an operation is performed that invaldiates an iterator on that mapping. A debug iterator is added that checks the timestamp before dereferencing or incrementing.
@matthias-springer matthias-springer force-pushed the transform_check_invalidation branch from 0011bbe to 4eb5b8e Compare September 14, 2023 14:29
@matthias-springer matthias-springer merged commit aca9019 into llvm:main Sep 14, 2023
kstoimenov pushed a commit to kstoimenov/llvm-project that referenced this pull request Sep 14, 2023
…ngs (llvm#66369)

Add extra error checking (in debug mode) to detect cases where an
iterator on "direct" payload IR mappings is invalidated (due to elements
being removed). Such errors are hard to debug: they are often
non-deterministic; sometimes the program crashes, sometimes it produces
wrong results. Even when it crashes, the stack trace often points to
completely unrelated code locations.

Store a timestamp with each "direct" mapping. The timestamp is increased
whenever an operation is performed that invaldiates an iterator on that
mapping. A debug iterator is added that checks the timestamp as payload
IR is enumerated.
matthias-springer added a commit to matthias-springer/llvm-project that referenced this pull request Sep 15, 2023
Same as llvm#66369 but for payload values. (llvm#66369 added checks only for payload operations.)

It was necessary to change the signature of `getPayloadValues` to return an iterator. This is now similar to payload operations.

Fixes an issue in llvm#66369 where the `LLVM_ENABLE_ABI_BREAKING_CHECKS` check was inverted.
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
…ngs (llvm#66369)

Add extra error checking (in debug mode) to detect cases where an
iterator on "direct" payload IR mappings is invalidated (due to elements
being removed). Such errors are hard to debug: they are often
non-deterministic; sometimes the program crashes, sometimes it produces
wrong results. Even when it crashes, the stack trace often points to
completely unrelated code locations.

Store a timestamp with each "direct" mapping. The timestamp is increased
whenever an operation is performed that invaldiates an iterator on that
mapping. A debug iterator is added that checks the timestamp as payload
IR is enumerated.
matthias-springer added a commit that referenced this pull request Sep 25, 2023
…66472)

Same as #66369 but for payload values. (#66369 added checks only for
payload operations.)

It was necessary to change the signature of `getPayloadValues` to return
an iterator. This is now similar to payload operations.

Fixes an issue in #66369 where the `LLVM_ENABLE_ABI_BREAKING_CHECKS`
check was inverted.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants