Skip to content

Commit

Permalink
Fix build errors after LLVM bump
Browse files Browse the repository at this point in the history
  • Loading branch information
ramiro050 committed Dec 12, 2022
1 parent e54df07 commit b9f941d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion include/torch-mlir/Dialect/Torch/IR/TorchOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ struct torch_list_of_optional_constant_ints_op_binder {
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
if (!listConstruct)
return false;
for (Value value : listConstruct.elements()) {
for (Value value : listConstruct.getElements()) {
int64_t num;
if (matchPattern(value, m_TorchConstantInt(&num)))
bind_values.push_back(num);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,13 @@ FailureOr<Value> Torch::adjustFunctionArg(
auto isNone = b.create<Aten__Is__Op>(loc, operand, none);
auto primIf = b.create<PrimIfOp>(loc, desiredType, isNone);
{
Region &thenRegion = primIf.thenRegion();
Region &thenRegion = primIf.getThenRegion();
b.createBlock(&thenRegion, thenRegion.end());
auto derefineNone = b.create<DerefineOp>(loc, desiredType, none);
b.create<PrimIfYieldOp>(loc, ValueRange{derefineNone});
}
{
Region &elseRegion = primIf.elseRegion();
Region &elseRegion = primIf.getElseRegion();
b.createBlock(&elseRegion, elseRegion.end());
auto downcasted = b.create<PrimUncheckedCastOp>(
loc, operandOptionalType.getContainedType(), operand);
Expand Down Expand Up @@ -258,7 +258,7 @@ FailureOr<Value> Torch::adjustFunctionArg(
{
OpBuilder::InsertionGuard guard(b);
Block *body =
b.createBlock(&loop.region(), loop.region().begin(),
b.createBlock(&loop.getRegion(), loop.getRegion().begin(),
TypeRange({b.getType<Torch::IntType>()}), {loc});
Value iterationNumber = body->getArgument(0);
Value element = b.create<Aten__Getitem__TOp>(
Expand Down
12 changes: 6 additions & 6 deletions lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using namespace mlir::torch::Torch;
static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
int resultNum,
PatternRewriter &rewriter) {
auto yieldDtypes = op.calculation().front().getTerminator();
auto yieldDtypes = op.getCalculation().front().getTerminator();
auto dtype = yieldDtypes->getOperand(resultNum);
auto result = op->getResult(resultNum);

Expand Down Expand Up @@ -85,12 +85,12 @@ class DecomposePromoteDtypesOp : public OpRewritePattern<PromoteDtypesOp> {
PatternRewriter &rewriter) const override {
SmallVector<Optional<int64_t>> ranks;
SmallVector<int64_t> dtypes;
if (!matchPattern(op.ranks(), m_TorchListOfOptionalConstantInts(ranks))) {
if (!matchPattern(op.getRanks(), m_TorchListOfOptionalConstantInts(ranks))) {
return rewriter.notifyMatchFailure(
op, "Expected `ranks` to be a list of optional constant ints");
}

if (!matchPattern(op.dtypes(), m_TorchListOfConstantInts(dtypes))) {
if (!matchPattern(op.getDtypes(), m_TorchListOfConstantInts(dtypes))) {
return rewriter.notifyMatchFailure(
op, "Expected `dtypes` to be a list of constant ints");
}
Expand Down Expand Up @@ -161,19 +161,19 @@ class RefineNumToTensorScalarOpType
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimNumToTensorScalarOp op,
PatternRewriter &rewriter) const override {
auto originalResultType = op.result().getType().cast<BaseTensorType>();
auto originalResultType = op.getResult().getType().cast<BaseTensorType>();
if (originalResultType.hasDtype())
return rewriter.notifyMatchFailure(
op, "`PrimNumToTensorScalarOp` already has a dtype");

Type inputType = getBuiltInTypeForTorchScalar(op.a().getType());
Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType());
auto impliedTypeFromInputType =
originalResultType.cast<BaseTensorType>()
.getWithSizesAndDtype(originalResultType.getOptionalSizes(),
inputType)
.cast<BaseTensorType>();

op.result().setType(impliedTypeFromInputType);
op.getResult().setType(impliedTypeFromInputType);
return success();
}
};
Expand Down

0 comments on commit b9f941d

Please sign in to comment.