-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][transform] Check for invalidated iterators on payload IR mappings #66369
[mlir][transform] Check for invalidated iterators on payload IR mappings #66369
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core ChangesAdd extra error checking (in debug mode) to detect cases where an iterator on "direct" payload IR mappings is invalidated (due to elements being removed). Such errors are hard to debug: they are often non-deterministic; sometimes the program crashes, sometimes it produces wrong results. Even when it crashes, the stack trace often points to completely unrelated code locations.Store a timestamp with each "direct" mapping. The timestamp is increased whenever an operation is performed that invaldiates an iterator on that mapping. A debug iterator is added that checks the timestamp before dereferencing or incrementing.Full diff: https://github.com/llvm/llvm-project/pull/66369.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index efd8d573936c332..cec31c63ed8167b 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -18,6 +18,40 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#ifndef NDEBUG +namespace { +/// An iterator adaptor that checks an assertion before every increment and +/// dereference. +template <typename WrappedIteratorT, typename AssertFn> +class AssertingIterator : public llvm::iterator_adaptor_base< + AssertingIterator<WrappedIteratorT, AssertFn>, + WrappedIteratorT, std::input_iterator_tag> { + using BaseT = typename AssertingIterator::iterator_adaptor_base; + using PointerT = typename std::iterator_traits<WrappedIteratorT>::pointer; + + /// The assertion function. + AssertFn assertFn; + +public: + AssertingIterator(WrappedIteratorT I, AssertFn assertFn) + : BaseT(I), assertFn(assertFn) {} + + using BaseT::operator*; + decltype(*std::declval<WrappedIteratorT>()) operator*() { + assertFn(); + return *(this->I); + } + + using BaseT::operator++; + AssertingIterator &operator++() { + assertFn(); + this->I++; + return *this; + } +}; +} // namespace +#endif // NDEBUG + namespace mlir { namespace transform { @@ -170,6 +204,12 @@ class TransformState { /// should be emitted when the value is used. using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>; +#ifndef NDEBUG + /// Debug only: A timestamp is associated with each transform IR value, so + /// that invalid iterator usage can be detected more reliably. + using TransformIRTimestampMapping = DenseMap<Value, int64_t>; +#endif // NDEBUG + /// The bidirectional mappings between transform IR values and payload IR /// operations, and the mapping between transform IR values and parameters. struct Mappings { @@ -178,6 +218,11 @@ class TransformState { ParamMapping params; ValueMapping values; ValueMapping reverseValues; + +#ifndef NDEBUG + TransformIRTimestampMapping timestamps; + void incrementTimestamp(Value value) { ++timestamps[value]; } +#endif // NDEBUG }; friend LogicalResult applyTransforms(Operation *, TransformOpInterface, @@ -207,9 +252,28 @@ class TransformState { /// not enumerated. This function is helpful for transformations that apply to /// a particular handle. auto getPayloadOps(Value value) const { + ArrayRef<Operation *> view = getPayloadOpsView(value); + +#ifndef NDEBUG + // Memorize the current timestamp and make sure that it has not changed + // when incrementing or dereferencing the iterator returned by this + // function. The timestamp is incremented when the "direct" mapping is + // resized; this would invalidate the iterator returned by this function. + int64_t currentTimestamp = getMapping(value).timestamps.lookup(value); + auto assertFn = [=] { + bool sameTimestamp = + currentTimestamp == this->getMapping(value).timestamps.lookup(value); + assert(sameTimestamp && "iterator was invalidated during iteration"); + }; + auto it = llvm::make_range(AssertingIterator(std::begin(view), assertFn), + AssertingIterator(std::end(view), assertFn)); +#else + auto it = llvm::make_range(view.begin(), view.end()); +#endif // NDEBUG + // When ops are replaced/erased, they are replaced with nullptr (until // the data structure is compacted). Do not enumerate these ops. - return llvm::make_filter_range(getPayloadOpsView(value), + return llvm::make_filter_range(it, [](Operation *op) { return op != nullptr; }); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 00450a1ff8f36cf..a091047c440de35 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -310,6 +310,11 @@ void transform::TransformState::forgetMapping(Value opHandle, for (Operation *op : mappings.direct[opHandle]) dropMappingEntry(mappings.reverse, op, opHandle); mappings.direct.erase(opHandle); +#ifndef NDEBUG + // Payload IR is removed from the mapping. This invalidates the respective + // iterators. + mappings.incrementTimestamp(opHandle); +#endif // NDEBUG for (Value opResult : origOpFlatResults) { SmallVector<Value> resultHandles; @@ -336,6 +341,12 @@ void transform::TransformState::forgetValueMapping( Mappings &localMappings = getMapping(opHandle); dropMappingEntry(localMappings.direct, opHandle, payloadOp); dropMappingEntry(localMappings.reverse, payloadOp, opHandle); + +#ifndef NDEBUG + // Payload IR is removed from the mapping. This invalidates the respective + // iterators. + localMappings.incrementTimestamp(opHandle); +#endif // NDEBUG } } } @@ -774,6 +785,13 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload, void transform::TransformState::compactOpHandles() { for (Value handle : opHandlesToCompact) { Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); +#ifndef NDEBUG + if (llvm::find(mappings.direct[handle], nullptr) != + mappings.direct[handle].end()) + // Payload IR is removed from the mapping. This invalidates the respective + // iterators. + mappings.incrementTimestamp(handle); +#endif // NDEBUG llvm::erase_value(mappings.direct[handle], nullptr); } opHandlesToCompact.clear(); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 21f9ff5999a5ed5..3e5f1baac684d42 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -360,8 +360,9 @@ DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { - auto payloadOps = state.getPayloadOps(getTarget()); - auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); + auto it = state.getPayloadOps(getTarget()); + SmallVector<Operation *> reversedOps(it.begin(), it.end()); + llvm::reverse(reversedOps); results.set(llvm::cast<OpResult>(getResult()), reversedOps); return DiagnosedSilenceableFailure::success(); } |
608c6d5
to
49d368f
Compare
@@ -360,8 +360,9 @@ DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( | |||
DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply( | |||
transform::TransformRewriter &rewriter, | |||
transform::TransformResults &results, transform::TransformState &state) { | |||
auto payloadOps = state.getPayloadOps(getTarget()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with the correct macro predicate
@@ -178,6 +184,11 @@ class TransformState { | |||
ParamMapping params; | |||
ValueMapping values; | |||
ValueMapping reverseValues; | |||
|
|||
#ifndef NDEBUG |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code that may change the size of the object should be guarded by LLVM_ENABLE_ABI_BREAKING_CHECKS
.
49d368f
to
0011bbe
Compare
Add extra error checking (in debug mode) to detect cases where an iterator on "direct" payload IR mappings is invalidated (due to elements being removed). Such errors are hard to debug: they are often non-deterministic; sometimes the program crashes, sometimes it produces wrong results. Even when it crashes, the stack trace often points to completely unrelated code locations. Store a timestamp with each "direct" mapping. The timestamp is increased whenever an operation is performed that invaldiates an iterator on that mapping. A debug iterator is added that checks the timestamp before dereferencing or incrementing.
0011bbe
to
4eb5b8e
Compare
…ngs (llvm#66369) Add extra error checking (in debug mode) to detect cases where an iterator on "direct" payload IR mappings is invalidated (due to elements being removed). Such errors are hard to debug: they are often non-deterministic; sometimes the program crashes, sometimes it produces wrong results. Even when it crashes, the stack trace often points to completely unrelated code locations. Store a timestamp with each "direct" mapping. The timestamp is increased whenever an operation is performed that invaldiates an iterator on that mapping. A debug iterator is added that checks the timestamp as payload IR is enumerated.
Same as llvm#66369 but for payload values. (llvm#66369 added checks only for payload operations.) It was necessary to change the signature of `getPayloadValues` to return an iterator. This is now similar to payload operations. Fixes an issue in llvm#66369 where the `LLVM_ENABLE_ABI_BREAKING_CHECKS` check was inverted.
…ngs (llvm#66369) Add extra error checking (in debug mode) to detect cases where an iterator on "direct" payload IR mappings is invalidated (due to elements being removed). Such errors are hard to debug: they are often non-deterministic; sometimes the program crashes, sometimes it produces wrong results. Even when it crashes, the stack trace often points to completely unrelated code locations. Store a timestamp with each "direct" mapping. The timestamp is increased whenever an operation is performed that invaldiates an iterator on that mapping. A debug iterator is added that checks the timestamp as payload IR is enumerated.
…66472) Same as #66369 but for payload values. (#66369 added checks only for payload operations.) It was necessary to change the signature of `getPayloadValues` to return an iterator. This is now similar to payload operations. Fixes an issue in #66369 where the `LLVM_ENABLE_ABI_BREAKING_CHECKS` check was inverted.
Add extra error checking (in debug mode) to detect cases where an iterator on "direct" payload IR mappings is invalidated (due to elements being removed). Such errors are hard to debug: they are often non-deterministic; sometimes the program crashes, sometimes it produces wrong results. Even when it crashes, the stack trace often points to completely unrelated code locations.
Store a timestamp with each "direct" mapping. The timestamp is increased whenever an operation is performed that invaldiates an iterator on that mapping. A debug iterator is added that checks the timestamp as payload IR is enumerated.