From bb804078488dc3edf4c7230085cb602b28c205e1 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 16 Mar 2023 09:14:16 +0000 Subject: [PATCH 1/2] LowerToBackendContract: Explicitly error out on unimplemented operator --- .../Torch/Transforms/LowerToBackendContract.cpp | 11 +++++++++++ .../verify-backend-contract-unimplemented-op.mlir | 10 ++++++++++ 2 files changed, 21 insertions(+) create mode 100644 test/Dialect/Torch/verify-backend-contract-unimplemented-op.mlir diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 1f21a36568ef..0289611b433b 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -197,6 +197,17 @@ static bool satisfiesBackendContract(ModuleOp module, if (walkResult0.wasInterrupted()) return false; + // Check for unimplemented operators first to give more direct diagnostics. + walkResult0 = module.walk([&](Torch::OperatorOp op) { + if (actuallyEmitDiagnostics) { + op->emitError("unsupported by backend contract: Unimplemented operator '" + + op.getName() + "'"); + } + return WalkResult::interrupt(); + }); + if (walkResult0.wasInterrupted()) + return false; + // Check all the types of all Value's in the program and the legality of all // the ops. // diff --git a/test/Dialect/Torch/verify-backend-contract-unimplemented-op.mlir b/test/Dialect/Torch/verify-backend-contract-unimplemented-op.mlir new file mode 100644 index 000000000000..9c8a3575494a --- /dev/null +++ b/test/Dialect/Torch/verify-backend-contract-unimplemented-op.mlir @@ -0,0 +1,10 @@ +// RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s +func.func @forward(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor { + %none = torch.constant.none + %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[3,5],f32> to !torch.vtensor<*,f32> + %1 = torch.copy.to_tensor %0 : !torch.tensor<*,f32> + // expected-error @+1 {{unsupported by backend contract: Unimplemented operator 'an.unimplemented.op'}} + %2 = torch.operator "an.unimplemented.op"(%1, %1, %none) : (!torch.tensor<*,f32>, !torch.tensor<*,f32>, !torch.none) -> !torch.tensor + %3 = torch.copy.to_vtensor %2 : !torch.vtensor + return %3 : !torch.vtensor +} From b57881440920c80804edae4a1c90911948bdd19a Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 20 Mar 2023 12:42:15 +0000 Subject: [PATCH 2/2] LowerToBackendContract: Only reject torch.operator when results are invalid Otherwise it might be a custom op that the backend supports. --- lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 0289611b433b..f07137613d9c 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -199,6 +199,13 @@ static bool satisfiesBackendContract(ModuleOp module, // Check for unimplemented operators first to give more direct diagnostics. walkResult0 = module.walk([&](Torch::OperatorOp op) { + if (llvm::all_of(op.getResults(), [&op](auto res) { + return succeeded( + checkType(op.getOperation(), res.getType(), /*actuallyEmitDiagnostics=*/false)); + })) { + return WalkResult::advance(); + } + if (actuallyEmitDiagnostics) { op->emitError("unsupported by backend contract: Unimplemented operator '" + op.getName() + "'");