diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h index 1a6afc58fef2704..c8888f294f6ca1d 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h @@ -95,15 +95,15 @@ class SingleValueMatcherOpTrait TransformResults &results, TransformState &state) { Value operandHandle = cast(this->getOperation()).getOperandHandle(); - ValueRange payload = state.getPayloadValues(operandHandle); - if (payload.size() != 1) { + auto payload = state.getPayloadValues(operandHandle); + if (!llvm::hasSingleElement(payload)) { return emitDefiniteFailure(this->getOperation()->getLoc()) << "SingleValueMatchOpTrait requires the value handle to point to " "a single payload value"; } return cast(this->getOperation()) - .matchValue(payload[0], results, state); + .matchValue(*payload.begin(), results, state); } void getEffects(SmallVectorImpl &effects) { diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 2c1775b3b462cf8..0ac3c9a16e03a36 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -170,7 +170,7 @@ class TransformState { /// should be emitted when the value is used. using InvalidatedHandleMap = DenseMap>; -#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS /// Debug only: A timestamp is associated with each transform IR value, so /// that invalid iterator usage can be detected more reliably. using TransformIRTimestampMapping = DenseMap; @@ -185,7 +185,7 @@ class TransformState { ValueMapping values; ValueMapping reverseValues; -#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS TransformIRTimestampMapping timestamps; void incrementTimestamp(Value value) { ++timestamps[value]; } #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -220,7 +220,7 @@ class TransformState { auto getPayloadOps(Value value) const { ArrayRef view = getPayloadOpsView(value); -#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS // Memorize the current timestamp and make sure that it has not changed // when incrementing or dereferencing the iterator returned by this // function. The timestamp is incremented when the "direct" mapping is @@ -231,7 +231,7 @@ class TransformState { // When ops are replaced/erased, they are replaced with nullptr (until // the data structure is compacted). Do not enumerate these ops. return llvm::make_filter_range(view, [=](Operation *op) { -#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS bool sameTimestamp = currentTimestamp == this->getMapping(value).timestamps.lookup(value); assert(sameTimestamp && "iterator was invalidated during iteration"); @@ -244,9 +244,29 @@ class TransformState { /// corresponds to. ArrayRef getParams(Value value) const; - /// Returns the list of payload IR values that the given transform IR value - /// corresponds to. - ArrayRef getPayloadValues(Value handleValue) const; + /// Returns an iterator that enumerates all payload IR values that the given + /// transform IR value corresponds to. + auto getPayloadValues(Value handleValue) const { + ArrayRef view = getPayloadValuesView(handleValue); + +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + // Memorize the current timestamp and make sure that it has not changed + // when incrementing or dereferencing the iterator returned by this + // function. The timestamp is incremented when the "values" mapping is + // resized; this would invalidate the iterator returned by this function. + int64_t currentTimestamp = + getMapping(handleValue).timestamps.lookup(handleValue); + return llvm::make_filter_range(view, [=](Value v) { + bool sameTimestamp = + currentTimestamp == + this->getMapping(handleValue).timestamps.lookup(handleValue); + assert(sameTimestamp && "iterator was invalidated during iteration"); + return true; + }); +#else + return llvm::make_range(view.begin(), view.end()); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + } /// Populates `handles` with all handles pointing to the given Payload IR op. /// Returns success if such handles exist, failure otherwise. @@ -501,12 +521,15 @@ class TransformState { LogicalResult updateStateFromResults(const TransformResults &results, ResultRange opResults); - /// Returns a list of all ops that the given transform IR value corresponds to - /// at the time when this function is called. In case an op was erased, the - /// returned list contains nullptr. This function is helpful for - /// transformations that apply to a particular handle. + /// Returns a list of all ops that the given transform IR value corresponds + /// to. In case an op was erased, the returned list contains nullptr. This + /// function is helpful for transformations that apply to a particular handle. ArrayRef getPayloadOpsView(Value value) const; + /// Returns a list of payload IR values that the given transform IR value + /// corresponds to. + ArrayRef getPayloadValuesView(Value handleValue) const; + /// Sets the payload IR ops associated with the given transform IR value /// (handle). A payload op may be associated multiple handles as long as /// at most one of them gets consumed by further transformations. @@ -806,7 +829,27 @@ class TransformResults { /// set by the transformation exactly once in case of transformation /// succeeding. The value must have a type implementing /// TransformValueHandleTypeInterface. - void setValues(OpResult handle, ValueRange values); + template + void setValues(OpResult handle, Range &&values) { + int64_t position = handle.getResultNumber(); + assert(position < static_cast(this->values.size()) && + "setting values for a non-existent handle"); + assert(this->values[position].data() == nullptr && "values already set"); + assert(operations[position].data() == nullptr && + "another kind of results already set"); + assert(params[position].data() == nullptr && + "another kind of results already set"); + this->values.replace(position, std::forward(values)); + } + + /// Indicates that the result of the transform IR op at the given position + /// corresponds to the given range of payload IR values. Each result must be + /// set by the transformation exactly once in case of transformation + /// succeeding. The value must have a type implementing + /// TransformValueHandleTypeInterface. + void setValues(OpResult handle, std::initializer_list values) { + setValues(handle, ArrayRef(values)); + } /// Indicates that the result of the transform IR op at the given position /// corresponds to the given range of mapped values. All mapped values are diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp index 7b8bf6fc5d8f5a4..fb021ed76242e90 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -728,7 +728,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation( Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position)); if (isa(getResult().getType())) { - results.setValues(cast(getResult()), result); + results.setValues(cast(getResult()), {result}); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 099408399a996fd..483b0e7f7a4f998 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -75,7 +75,7 @@ ArrayRef transform::TransformState::getParams(Value value) const { } ArrayRef -transform::TransformState::getPayloadValues(Value handleValue) const { +transform::TransformState::getPayloadValuesView(Value handleValue) const { const ValueMapping &mapping = getMapping(handleValue).values; auto iter = mapping.find(handleValue); assert(iter != mapping.end() && "cannot find mapping for value handle " @@ -310,7 +310,7 @@ void transform::TransformState::forgetMapping(Value opHandle, for (Operation *op : mappings.direct[opHandle]) dropMappingEntry(mappings.reverse, op, opHandle); mappings.direct.erase(opHandle); -#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS // Payload IR is removed from the mapping. This invalidates the respective // iterators. mappings.incrementTimestamp(opHandle); @@ -322,6 +322,11 @@ void transform::TransformState::forgetMapping(Value opHandle, for (Value resultHandle : resultHandles) { Mappings &localMappings = getMapping(resultHandle); dropMappingEntry(localMappings.values, resultHandle, opResult); +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + // Payload IR is removed from the mapping. This invalidates the respective + // iterators. + mappings.incrementTimestamp(resultHandle); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS dropMappingEntry(localMappings.reverseValues, opResult, resultHandle); } } @@ -333,6 +338,11 @@ void transform::TransformState::forgetValueMapping( for (Value payloadValue : mappings.reverseValues[valueHandle]) dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle); mappings.values.erase(valueHandle); +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + // Payload IR is removed from the mapping. This invalidates the respective + // iterators. + mappings.incrementTimestamp(valueHandle); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS for (Operation *payloadOp : payloadOperations) { SmallVector opHandles; @@ -342,7 +352,7 @@ void transform::TransformState::forgetValueMapping( dropMappingEntry(localMappings.direct, opHandle, payloadOp); dropMappingEntry(localMappings.reverse, payloadOp, opHandle); -#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS // Payload IR is removed from the mapping. This invalidates the respective // iterators. localMappings.incrementTimestamp(opHandle); @@ -439,6 +449,11 @@ transform::TransformState::replacePayloadValue(Value value, Value replacement) { // between the handles and the IR objects if (!replacement) { dropMappingEntry(mappings.values, handle, value); +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS + // Payload IR is removed from the mapping. This invalidates the respective + // iterators. + mappings.incrementTimestamp(handle); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } else { auto it = mappings.values.find(handle); if (it == mappings.values.end()) @@ -647,7 +662,7 @@ void transform::TransformState::recordValueHandleInvalidation( OpOperand &valueHandle, transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { // Invalidate other handles to the same value. - for (Value payloadValue : getPayloadValues(valueHandle.get())) { + for (Value payloadValue : getPayloadValuesView(valueHandle.get())) { SmallVector otherValueHandles; (void)getHandlesForPayloadValue(payloadValue, otherValueHandles); for (Value otherHandle : otherValueHandles) { @@ -785,7 +800,7 @@ checkRepeatedConsumptionInOperand(ArrayRef payload, void transform::TransformState::compactOpHandles() { for (Value handle : opHandlesToCompact) { Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); -#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS +#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS if (llvm::find(mappings.direct[handle], nullptr) != mappings.direct[handle].end()) // Payload IR is removed from the mapping. This invalidates the respective @@ -846,7 +861,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n"); DiagnosedSilenceableFailure check = checkRepeatedConsumptionInOperand( - getPayloadValues(operand.get()), transform, + getPayloadValuesView(operand.get()), transform, operand.getOperandNumber()); if (!check.succeeded()) { FULL_LDBG("----FAILED\n"); @@ -912,7 +927,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { continue; } if (llvm::isa(operand.getType())) { - for (Value payloadValue : getPayloadValues(operand)) { + for (Value payloadValue : getPayloadValuesView(operand)) { if (llvm::isa(payloadValue)) { origAssociatedOps.push_back(payloadValue.getDefiningOp()); continue; @@ -1170,19 +1185,6 @@ void transform::TransformResults::setParams( this->params.replace(position, params); } -void transform::TransformResults::setValues(OpResult handle, - ValueRange values) { - int64_t position = handle.getResultNumber(); - assert(position < static_cast(this->values.size()) && - "setting values for a non-existent handle"); - assert(this->values[position].data() == nullptr && "values already set"); - assert(operations[position].data() == nullptr && - "another kind of results already set"); - assert(params[position].data() == nullptr && - "another kind of results already set"); - this->values.replace(position, values); -} - void transform::TransformResults::setMappedValues( OpResult handle, ArrayRef values) { DiagnosedSilenceableFailure diag = dispatchMappedValues( diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 9bcec7ce27365bb..44626260e2f9ef3 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1379,9 +1379,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector params; - ArrayRef values = state.getPayloadValues(getValue()); - params.reserve(values.size()); - for (Value value : values) { + for (Value value : state.getPayloadValues(getValue())) { Type type = value.getType(); if (getElemental()) { if (auto shaped = dyn_cast(type)) { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 59f045de3246f6b..e8c25aca237251a 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -136,7 +136,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToSelfOperand::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { - results.setValues(llvm::cast(getOut()), getIn()); + results.setValues(llvm::cast(getOut()), {getIn()}); return DiagnosedSilenceableFailure::success(); } @@ -265,8 +265,7 @@ void mlir::test::TestPrintRemarkAtOperandOp::getEffects( DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { - ArrayRef values = state.getPayloadValues(getIn()); - for (Value value : values) { + for (Value value : state.getPayloadValues(getIn())) { std::string note; llvm::raw_string_ostream os(note); if (auto arg = llvm::dyn_cast(value)) { @@ -712,7 +711,7 @@ void mlir::test::TestProduceNullValueOp::getEffects( DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { - results.setValues(llvm::cast(getOut()), Value()); + results.setValues(llvm::cast(getOut()), {Value()}); return DiagnosedSilenceableFailure::success(); }