Skip to content

Commit

Permalink
[mlir][transform] Check for invalidated iterators on payload values (#…
Browse files Browse the repository at this point in the history
…66472)

Same as #66369 but for payload values. (#66369 added checks only for
payload operations.)

It was necessary to change the signature of `getPayloadValues` to return
an iterator. This is now similar to payload operations.

Fixes an issue in #66369 where the `LLVM_ENABLE_ABI_BREAKING_CHECKS`
check was inverted.
  • Loading branch information
matthias-springer authored Sep 25, 2023
1 parent 702608f commit 085075a
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 43 deletions.
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ class SingleValueMatcherOpTrait
TransformResults &results,
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
ValueRange payload = state.getPayloadValues(operandHandle);
if (payload.size() != 1) {
auto payload = state.getPayloadValues(operandHandle);
if (!llvm::hasSingleElement(payload)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "SingleValueMatchOpTrait requires the value handle to point to "
"a single payload value";
}

return cast<OpTy>(this->getOperation())
.matchValue(payload[0], results, state);
.matchValue(*payload.begin(), results, state);
}

void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
Expand Down
67 changes: 55 additions & 12 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class TransformState {
/// should be emitted when the value is used.
using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;

#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
/// Debug only: A timestamp is associated with each transform IR value, so
/// that invalid iterator usage can be detected more reliably.
using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
Expand All @@ -185,7 +185,7 @@ class TransformState {
ValueMapping values;
ValueMapping reverseValues;

#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
TransformIRTimestampMapping timestamps;
void incrementTimestamp(Value value) { ++timestamps[value]; }
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
Expand Down Expand Up @@ -220,7 +220,7 @@ class TransformState {
auto getPayloadOps(Value value) const {
ArrayRef<Operation *> view = getPayloadOpsView(value);

#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Memorize the current timestamp and make sure that it has not changed
// when incrementing or dereferencing the iterator returned by this
// function. The timestamp is incremented when the "direct" mapping is
Expand All @@ -231,7 +231,7 @@ class TransformState {
// When ops are replaced/erased, they are replaced with nullptr (until
// the data structure is compacted). Do not enumerate these ops.
return llvm::make_filter_range(view, [=](Operation *op) {
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
bool sameTimestamp =
currentTimestamp == this->getMapping(value).timestamps.lookup(value);
assert(sameTimestamp && "iterator was invalidated during iteration");
Expand All @@ -244,9 +244,29 @@ class TransformState {
/// corresponds to.
ArrayRef<Attribute> getParams(Value value) const;

/// Returns the list of payload IR values that the given transform IR value
/// corresponds to.
ArrayRef<Value> getPayloadValues(Value handleValue) const;
/// Returns an iterator that enumerates all payload IR values that the given
/// transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const {
ArrayRef<Value> view = getPayloadValuesView(handleValue);

#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Memorize the current timestamp and make sure that it has not changed
// when incrementing or dereferencing the iterator returned by this
// function. The timestamp is incremented when the "values" mapping is
// resized; this would invalidate the iterator returned by this function.
int64_t currentTimestamp =
getMapping(handleValue).timestamps.lookup(handleValue);
return llvm::make_filter_range(view, [=](Value v) {
bool sameTimestamp =
currentTimestamp ==
this->getMapping(handleValue).timestamps.lookup(handleValue);
assert(sameTimestamp && "iterator was invalidated during iteration");
return true;
});
#else
return llvm::make_range(view.begin(), view.end());
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}

/// Populates `handles` with all handles pointing to the given Payload IR op.
/// Returns success if such handles exist, failure otherwise.
Expand Down Expand Up @@ -501,12 +521,15 @@ class TransformState {
LogicalResult updateStateFromResults(const TransformResults &results,
ResultRange opResults);

/// Returns a list of all ops that the given transform IR value corresponds to
/// at the time when this function is called. In case an op was erased, the
/// returned list contains nullptr. This function is helpful for
/// transformations that apply to a particular handle.
/// Returns a list of all ops that the given transform IR value corresponds
/// to. In case an op was erased, the returned list contains nullptr. This
/// function is helpful for transformations that apply to a particular handle.
ArrayRef<Operation *> getPayloadOpsView(Value value) const;

/// Returns a list of payload IR values that the given transform IR value
/// corresponds to.
ArrayRef<Value> getPayloadValuesView(Value handleValue) const;

/// Sets the payload IR ops associated with the given transform IR value
/// (handle). A payload op may be associated multiple handles as long as
/// at most one of them gets consumed by further transformations.
Expand Down Expand Up @@ -806,7 +829,27 @@ class TransformResults {
/// set by the transformation exactly once in case of transformation
/// succeeding. The value must have a type implementing
/// TransformValueHandleTypeInterface.
void setValues(OpResult handle, ValueRange values);
template <typename Range>
void setValues(OpResult handle, Range &&values) {
int64_t position = handle.getResultNumber();
assert(position < static_cast<int64_t>(this->values.size()) &&
"setting values for a non-existent handle");
assert(this->values[position].data() == nullptr && "values already set");
assert(operations[position].data() == nullptr &&
"another kind of results already set");
assert(params[position].data() == nullptr &&
"another kind of results already set");
this->values.replace(position, std::forward<Range>(values));
}

/// Indicates that the result of the transform IR op at the given position
/// corresponds to the given range of payload IR values. Each result must be
/// set by the transformation exactly once in case of transformation
/// succeeding. The value must have a type implementing
/// TransformValueHandleTypeInterface.
void setValues(OpResult handle, std::initializer_list<Value> values) {
setValues(handle, ArrayRef<Value>(values));
}

/// Indicates that the result of the transform IR op at the given position
/// corresponds to the given range of mapped values. All mapped values are
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(

Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
results.setValues(cast<OpResult>(getResult()), result);
results.setValues(cast<OpResult>(getResult()), {result});
return DiagnosedSilenceableFailure::success();
}

Expand Down
42 changes: 22 additions & 20 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
}

ArrayRef<Value>
transform::TransformState::getPayloadValues(Value handleValue) const {
transform::TransformState::getPayloadValuesView(Value handleValue) const {
const ValueMapping &mapping = getMapping(handleValue).values;
auto iter = mapping.find(handleValue);
assert(iter != mapping.end() && "cannot find mapping for value handle "
Expand Down Expand Up @@ -310,7 +310,7 @@ void transform::TransformState::forgetMapping(Value opHandle,
for (Operation *op : mappings.direct[opHandle])
dropMappingEntry(mappings.reverse, op, opHandle);
mappings.direct.erase(opHandle);
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(opHandle);
Expand All @@ -322,6 +322,11 @@ void transform::TransformState::forgetMapping(Value opHandle,
for (Value resultHandle : resultHandles) {
Mappings &localMappings = getMapping(resultHandle);
dropMappingEntry(localMappings.values, resultHandle, opResult);
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(resultHandle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
}
}
Expand All @@ -333,6 +338,11 @@ void transform::TransformState::forgetValueMapping(
for (Value payloadValue : mappings.reverseValues[valueHandle])
dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
mappings.values.erase(valueHandle);
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(valueHandle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS

for (Operation *payloadOp : payloadOperations) {
SmallVector<Value> opHandles;
Expand All @@ -342,7 +352,7 @@ void transform::TransformState::forgetValueMapping(
dropMappingEntry(localMappings.direct, opHandle, payloadOp);
dropMappingEntry(localMappings.reverse, payloadOp, opHandle);

#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
localMappings.incrementTimestamp(opHandle);
Expand Down Expand Up @@ -439,6 +449,11 @@ transform::TransformState::replacePayloadValue(Value value, Value replacement) {
// between the handles and the IR objects
if (!replacement) {
dropMappingEntry(mappings.values, handle, value);
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(handle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
} else {
auto it = mappings.values.find(handle);
if (it == mappings.values.end())
Expand Down Expand Up @@ -647,7 +662,7 @@ void transform::TransformState::recordValueHandleInvalidation(
OpOperand &valueHandle,
transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
// Invalidate other handles to the same value.
for (Value payloadValue : getPayloadValues(valueHandle.get())) {
for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
SmallVector<Value> otherValueHandles;
(void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
for (Value otherHandle : otherValueHandles) {
Expand Down Expand Up @@ -785,7 +800,7 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
void transform::TransformState::compactOpHandles() {
for (Value handle : opHandlesToCompact) {
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
if (llvm::find(mappings.direct[handle], nullptr) !=
mappings.direct[handle].end())
// Payload IR is removed from the mapping. This invalidates the respective
Expand Down Expand Up @@ -846,7 +861,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
DiagnosedSilenceableFailure check =
checkRepeatedConsumptionInOperand<Value>(
getPayloadValues(operand.get()), transform,
getPayloadValuesView(operand.get()), transform,
operand.getOperandNumber());
if (!check.succeeded()) {
FULL_LDBG("----FAILED\n");
Expand Down Expand Up @@ -912,7 +927,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
continue;
}
if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
for (Value payloadValue : getPayloadValues(operand)) {
for (Value payloadValue : getPayloadValuesView(operand)) {
if (llvm::isa<OpResult>(payloadValue)) {
origAssociatedOps.push_back(payloadValue.getDefiningOp());
continue;
Expand Down Expand Up @@ -1170,19 +1185,6 @@ void transform::TransformResults::setParams(
this->params.replace(position, params);
}

void transform::TransformResults::setValues(OpResult handle,
ValueRange values) {
int64_t position = handle.getResultNumber();
assert(position < static_cast<int64_t>(this->values.size()) &&
"setting values for a non-existent handle");
assert(this->values[position].data() == nullptr && "values already set");
assert(operations[position].data() == nullptr &&
"another kind of results already set");
assert(params[position].data() == nullptr &&
"another kind of results already set");
this->values.replace(position, values);
}

void transform::TransformResults::setMappedValues(
OpResult handle, ArrayRef<MappedValue> values) {
DiagnosedSilenceableFailure diag = dispatchMappedValues(
Expand Down
4 changes: 1 addition & 3 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1379,9 +1379,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Attribute> params;
ArrayRef<Value> values = state.getPayloadValues(getValue());
params.reserve(values.size());
for (Value value : values) {
for (Value value : state.getPayloadValues(getValue())) {
Type type = value.getType();
if (getElemental()) {
if (auto shaped = dyn_cast<ShapedType>(type)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ DiagnosedSilenceableFailure
mlir::test::TestProduceValueHandleToSelfOperand::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
results.setValues(llvm::cast<OpResult>(getOut()), getIn());
results.setValues(llvm::cast<OpResult>(getOut()), {getIn()});
return DiagnosedSilenceableFailure::success();
}

Expand Down Expand Up @@ -265,8 +265,7 @@ void mlir::test::TestPrintRemarkAtOperandOp::getEffects(
DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
ArrayRef<Value> values = state.getPayloadValues(getIn());
for (Value value : values) {
for (Value value : state.getPayloadValues(getIn())) {
std::string note;
llvm::raw_string_ostream os(note);
if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
Expand Down Expand Up @@ -712,7 +711,7 @@ void mlir::test::TestProduceNullValueOp::getEffects(
DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
results.setValues(llvm::cast<OpResult>(getOut()), Value());
results.setValues(llvm::cast<OpResult>(getOut()), {Value()});
return DiagnosedSilenceableFailure::success();
}

Expand Down

0 comments on commit 085075a

Please sign in to comment.