[mlir][Transforms][NFC] Dialect conversion: Cache UnresolvedMaterializationRewrite#108359
Merged
matthias-springer merged 1 commit intomainfrom Sep 13, 2024
Merged
Conversation
Member
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThe dialect conversion maintains a set of unresolved materializations ( Also delete some dead code. Full diff: https://github.com/llvm/llvm-project/pull/108359.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b58a95c3baf70a..ed15b571f01883 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
- MaterializationKind kind = MaterializationKind::Target)
- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind) {}
+ MaterializationKind kind = MaterializationKind::Target);
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
});
}
-/// Find the single rewrite object of the specified type and block among the
-/// given rewrites. In debug mode, asserts that there is mo more than one such
-/// object. Return "nullptr" if no object was found.
-template <typename RewriteTy, typename R>
-static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
- RewriteTy *result = nullptr;
- for (auto &rewrite : rewrites) {
- auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
- if (rewriteTy && rewriteTy->getBlock() == block) {
-#ifndef NDEBUG
- assert(!result && "expected single matching rewrite");
- result = rewriteTy;
-#else
- return rewriteTy;
-#endif // NDEBUG
- }
- }
- return result;
-}
-
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
@@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasErased(void *ptr) const { return erased.contains(ptr); }
- bool wasErased(OperationRewrite *rewrite) const {
- return wasErased(rewrite->getOperation());
- }
-
void notifyOperationErased(Operation *op) override { erased.insert(op); }
void notifyBlockErased(Block *block) override { erased.insert(block); }
@@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;
- /// A set of all unresolved materializations.
- DenseSet<Operation *> unresolvedMaterializations;
+ /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
+ /// to the corresponding rewrite objects.
+ DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+ unresolvedMaterializations;
/// The current type converter, or nullptr if no type converter is currently
/// active.
@@ -1058,6 +1034,14 @@ void CreateOperationRewrite::rollback() {
op->erase();
}
+UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
+ const TypeConverter *converter, MaterializationKind kind)
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+ converterAndKind(converter, kind) {
+ rewriterImpl.unresolvedMaterializations[op] = this;
+}
+
void UnresolvedMaterializationRewrite::rollback() {
if (getMaterializationKind() == MaterializationKind::Target) {
for (Value input : op->getOperands())
@@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- unresolvedMaterializations.insert(convertOp);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
@@ -2499,15 +2482,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Gather all unresolved materializations.
SmallVector<UnrealizedConversionCastOp> allCastOps;
- DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
- for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) {
- auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
- if (!mat)
- continue;
- if (rewriterImpl.eraseRewriter.wasErased(mat))
+ const DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+ &materializations = rewriterImpl.unresolvedMaterializations;
+ for (auto it : materializations) {
+ if (rewriterImpl.eraseRewriter.wasErased(it.first))
continue;
- allCastOps.push_back(mat->getOperation());
- rewriteMap[mat->getOperation()] = mat;
+ allCastOps.push_back(cast<UnrealizedConversionCastOp>(it.first));
}
// Reconcile all UnrealizedConversionCastOps that were inserted by the
@@ -2520,8 +2500,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (config.buildMaterializations) {
IRRewriter rewriter(rewriterImpl.context, config.listener);
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
- auto it = rewriteMap.find(castOp.getOperation());
- assert(it != rewriteMap.end() && "inconsistent state");
+ auto it = materializations.find(castOp.getOperation());
+ assert(it != materializations.end() && "inconsistent state");
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
return failure();
}
|
joker-eph
reviewed
Sep 12, 2024
joker-eph
approved these changes
Sep 12, 2024
Base automatically changed from
users/matthias-springer/replace_op_source_mat
to
main
September 12, 2024 13:30
4cb4bcf to
066359e
Compare
…izationRewrite` The dialect conversion already maintains a set of unresolved materializations (`UnrealizedConversionCastOp`). Turn that set into a map that maps from ops to `UnresolvedMaterializationRewrite *`. This improves efficiency a bit, because an iteration over `ConversionPatternRewriterImpl::rewrites` can be avoided. Also delete some dead code.
066359e to
e724e44
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The dialect conversion maintains a set of unresolved materializations (
UnrealizedConversionCastOp). Turn that set into aDenseMapthat maps from ops toUnresolvedMaterializationRewrite *. This improves efficiency a bit, because an iteration overConversionPatternRewriterImpl::rewritescan be avoided.Also delete some dead code.