diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 889a29908de09..4d1176e4a469b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -829,7 +829,8 @@ def Torch_DerefineOp : Torch_Op<"derefine", [ } def Torch_OperatorOp : Torch_Op<"operator", [ - AllowsTypeRefinement + AllowsTypeRefinement, + HasValueSemantics ]> { let summary = "Opaque torch operator"; let description = [{ diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 73a9f2e5ccd11..acf3797957e29 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -52,14 +52,6 @@ static Type getContainerOrTensorTypeWithValueSemantics(Type type) { } } -static bool operatorOpHasValueSemantics(OperatorOp opOp) { - if (!opOp->hasAttr("has_value_semantics")) - return false; - auto hasValueSemantics = - opOp->getAttr("has_value_semantics").cast().getValue(); - return hasValueSemantics; -} - namespace { // Convert value semantic ops operating on mutable arrays to instead operate on // immutable tensors. @@ -69,13 +61,8 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (isa(op)) { - if (!operatorOpHasValueSemantics(cast(op))) { + if (!op->hasTrait()) return rewriter.notifyMatchFailure(op, "does not have value semantics"); - } - } else if (!op->hasTrait()) { - return rewriter.notifyMatchFailure(op, "does not have value semantics"); - } rewriter.startRootUpdate(op); // Convert all operands. @@ -267,9 +254,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase { target.addIllegalOp(); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *op) { - if (op->hasTrait() || - (isa(op) && - operatorOpHasValueSemantics(cast(op)))) { + if (op->hasTrait()) { auto hasValueSemantics = [](Type t) { // TODO: Make this an allowlist based on a closed torch dialect // type system.