diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 1f21a36568ef0..f07137613d9ca 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -197,6 +197,24 @@ static bool satisfiesBackendContract(ModuleOp module, if (walkResult0.wasInterrupted()) return false; + // Check for unimplemented operators first to give more direct diagnostics. + walkResult0 = module.walk([&](Torch::OperatorOp op) { + if (llvm::all_of(op.getResults(), [&op](auto res) { + return succeeded( + checkType(op.getOperation(), res.getType(), /*actuallyEmitDiagnostics=*/false)); + })) { + return WalkResult::advance(); + } + + if (actuallyEmitDiagnostics) { + op->emitError("unsupported by backend contract: Unimplemented operator '" + + op.getName() + "'"); + } + return WalkResult::interrupt(); + }); + if (walkResult0.wasInterrupted()) + return false; + // Check all the types of all Value's in the program and the legality of all // the ops. // diff --git a/test/Dialect/Torch/verify-backend-contract-unimplemented-op.mlir b/test/Dialect/Torch/verify-backend-contract-unimplemented-op.mlir new file mode 100644 index 0000000000000..9c8a3575494a5 --- /dev/null +++ b/test/Dialect/Torch/verify-backend-contract-unimplemented-op.mlir @@ -0,0 +1,10 @@ +// RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s +func.func @forward(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor { + %none = torch.constant.none + %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[3,5],f32> to !torch.vtensor<*,f32> + %1 = torch.copy.to_tensor %0 : !torch.tensor<*,f32> + // expected-error @+1 {{unsupported by backend contract: Unimplemented operator 'an.unimplemented.op'}} + %2 = torch.operator "an.unimplemented.op"(%1, %1, %none) : (!torch.tensor<*,f32>, !torch.tensor<*,f32>, !torch.none) -> !torch.tensor + %3 = torch.copy.to_vtensor %2 : !torch.vtensor + return %3 : !torch.vtensor +}