From 6e96cb5a8f03e27bd5ed9728ee1bb14fdf92d063 Mon Sep 17 00:00:00 2001 From: ganler Date: Mon, 24 Oct 2022 17:12:19 -0500 Subject: [PATCH] feat: catch exception at verification stage as ORT backend can return None --- nnsmith/backends/factory.py | 9 +++++++++ nnsmith/difftest.py | 15 +++++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/nnsmith/backends/factory.py b/nnsmith/backends/factory.py index d884a30a..6558c8f8 100644 --- a/nnsmith/backends/factory.py +++ b/nnsmith/backends/factory.py @@ -226,6 +226,15 @@ def verify_results( log=traceback.format_exc(), version=self.version, ) + except Exception: + return BugReport( + testcase=testcase, + system=self.system_name, + symptom=Symptom.EXCEPTION, + stage=Stage.VERIFICATION, + log=traceback.format_exc(), + version=self.version, + ) def verify_testcase( self, testcase: TestCase, equal_nan=True diff --git a/nnsmith/difftest.py b/nnsmith/difftest.py index ceec678f..6cba9b6b 100644 --- a/nnsmith/difftest.py +++ b/nnsmith/difftest.py @@ -19,9 +19,20 @@ def assert_allclose( raise KeyError(f"{actual_name}: {akeys} != {oracle_name}: {dkeys}") for key in akeys: + lhs = actual[key] + rhs = desired[key] + + # check if lhs is np.ndarray + if not isinstance(lhs, np.ndarray): + raise TypeError(f"{actual_name}[{key}] is not np.ndarray but {type(lhs)}") + + # check if rhs is np.ndarray + if not isinstance(rhs, np.ndarray): + raise TypeError(f"{oracle_name}[{key}] is not np.ndarray but {type(rhs)}") + testing.assert_allclose( - actual[key], - desired[key], + lhs, + rhs, equal_nan=equal_nan, rtol=rtol, atol=atol,