Skip to content

Commit

Permalink
try to fix missing dependentDialects issue
Browse files Browse the repository at this point in the history
Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW committed Aug 7, 2024
1 parent 57eabb8 commit eb7368c
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
Expand Down Expand Up @@ -63,6 +64,8 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-llvmcpu-convert-to-llvm"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CONVERTTOLLVMPASS
Expand Down Expand Up @@ -916,13 +919,45 @@ class ExpandMulSIExtended : public OpRewritePattern<arith::MulSIExtendedOp> {
}
};

/// This DialectExtension can be attached to the context, which will invoke the
/// `apply()` method for every loaded dialect. If a dialect implements the
/// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
/// through the interface. This extension is loaded in the context before
/// starting a pass pipeline that involves dialect conversion to LLVM.
class LoadDependentDialectExtension : public DialectExtensionBase {
public:
LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}

void apply(MLIRContext *context,
MutableArrayRef<Dialect *> dialects) const final {
LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
for (Dialect *dialect : dialects) {
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface)
continue;
LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
<< dialect->getNamespace() << "\n");
iface->loadDependentDialects(context);
}
}

/// Return a copy of this extension.
virtual std::unique_ptr<DialectExtensionBase> clone() const final {
return std::make_unique<LoadDependentDialectExtension>(*this);
}
};

class ConvertToLLVMPass
: public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
public:
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
explicit ConvertToLLVMPass(bool reassociateFpReductions) {
this->reassociateFpReductions = reassociateFpReductions;
}
void getDependentDialects(DialectRegistry &registry) const final {
Base::getDependentDialects(registry);
registry.addExtensions<LoadDependentDialectExtension>();
}
void runOnOperation() override;
};

Expand Down

0 comments on commit eb7368c

Please sign in to comment.