diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f5c153d45..f7c9ce41f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: check-hooks-apply - id: check-useless-excludes - repo: https://github.com/pre-commit/pre-commit-hooks - rev: "v4.6.0" + rev: "v5.0.0" hooks: - id: check-case-conflict - id: check-merge-conflict @@ -32,6 +32,26 @@ repos: external/ ) - repo: https://github.com/Mateusz-Grzelinski/actionlint-py - rev: "v1.7.1.15" + rev: "v1.7.7.23" hooks: - id: actionlint + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.9.4" + hooks: + - id: ruff + - repo: https://github.com/pre-commit/mirrors-mypy + # This needs to match the last version that still supports the minimum + # required Python version + rev: "v1.5.1" + hooks: + - id: mypy + additional_dependencies: + - blessings + - testtools + - types-Pygments + # Work around config file setting exclusion rules being bypassed when + # filenames are explicitly passed on the command line. See + # https://github.com/python/mypy/pull/12373#issuecomment-1071662559 + # for the idea. + args: [.] + pass_filenames: false diff --git a/pyproject.toml b/pyproject.toml index a9b7290b0..015c75edc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,10 +2,10 @@ cache_dir = "build/.mypy_cache" explicit_package_bases = true mypy_path = "$MYPY_CONFIG_FILE_DIR/src/sst/core/testingframework" -# This should be 3.6 but is not supported with the newest versions of mypy. -python_version = "3.8" +python_version = "3.6" strict = true +warn_unused_ignores = true exclude = [ '^scripts/', @@ -20,3 +20,42 @@ module = [ "testtools.testsuite", ] ignore_missing_imports = true + +[tool.pyright] +pythonVersion = "3.6" + +[tool.ruff] +cache-dir = "build/.ruff_cache" +line-length = 100 +# This should be 3.6 but is not supported with the newest versions of ruff. +target-version = "py37" + +[tool.ruff.lint] +ignore = [ + "E401", + "E402", + "E701", + "E703", + "E713", + "E722", + "E731", + "F401", + "F403", + "F405", + "F523", + "F841", +] + +[tool.ruff.lint.isort] +known-first-party = ["sst"] +lines-after-imports = 2 +section-order = [ + "future", + "standard-library", + "first-party", + "third-party", + "local-folder", +] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" diff --git a/src/sst/core/testingframework/sst_unittest_support.py b/src/sst/core/testingframework/sst_unittest_support.py index 103521eb2..79839558b 100644 --- a/src/sst/core/testingframework/sst_unittest_support.py +++ b/src/sst/core/testingframework/sst_unittest_support.py @@ -1191,7 +1191,7 @@ def testing_stat_output_diff( ### Built in LineFilters for filtering diffs class LineFilter: - def __init__(self): + def __init__(self) -> None: self.apply_to_ref_file = True self.apply_to_out_file = True @@ -1883,7 +1883,7 @@ def os_extract_tar(tarfilepath: str, targetdir: str = ".") -> bool: try: this_tar = tarfile.open(tarfilepath) if sys.version_info.minor >= 12: - this_tar.extractall(targetdir, filter="data") + this_tar.extractall(targetdir, filter="data") # type: ignore [call-arg] else: this_tar.extractall(targetdir) this_tar.close() diff --git a/src/sst/core/testingframework/test_engine_unittest.py b/src/sst/core/testingframework/test_engine_unittest.py index c5c029477..d6186c8be 100644 --- a/src/sst/core/testingframework/test_engine_unittest.py +++ b/src/sst/core/testingframework/test_engine_unittest.py @@ -16,18 +16,25 @@ """ import sys -import unittest import traceback import threading import time import datetime -from typing import Callable, Dict, List, Optional, TextIO, Tuple, Any +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union, TYPE_CHECKING +from unittest import TestCase, TestResult, TestSuite, TextTestResult, TextTestRunner + +if TYPE_CHECKING: + from types import TracebackType + from unittest.runner import _WritelnDecorator # type: ignore [attr-defined] + + from sst_unittest import SSTTestCase + from test_engine import TestEngine if sys.version_info.minor >= 11: - def get_current_time(): - return datetime.datetime.now(datetime.UTC) + def get_current_time() -> datetime.datetime: + return datetime.datetime.now(datetime.UTC) # type: ignore [attr-defined] else: - def get_current_time(): + def get_current_time() -> datetime.datetime: return datetime.datetime.utcnow() ################################################################################ @@ -73,9 +80,9 @@ def check_module_conditional_import(module_name: str) -> bool: from testtools.testsuite import iterate_tests TestSuiteBaseClass = ConcurrentTestSuite else: - # If testtools not available, just trick the system to use unittest.TestSuite + # If testtools not available, just trick the system to use TestSuite # This allows us to continue, but not support concurrent testing - TestSuiteBaseClass = unittest.TestSuite + TestSuiteBaseClass = TestSuite import test_engine_globals from sst_unittest import * @@ -101,7 +108,7 @@ def verify_concurrent_test_engine_available() -> None: ################################################################################ -class SSTTextTestRunner(unittest.TextTestRunner): +class SSTTextTestRunner(TextTestRunner): """ A superclass to support SST required testing """ if blessings_loaded: @@ -117,11 +124,18 @@ class SSTTextTestRunner(unittest.TextTestRunner): None: str } - def __init__(self, stream=sys.stderr, descriptions=True, verbosity=1, - failfast=False, buffer=False, resultclass=None, - no_colour_output=False): - super(SSTTextTestRunner, self).__init__(stream, descriptions, verbosity, - failfast, buffer, resultclass) + def __init__( + self, + stream: Any = sys.stderr, + descriptions: bool = True, + verbosity: int = 1, + failfast: bool = False, + buffer: bool = False, + resultclass: Optional[Callable[[Any, bool, int], TextTestResult]] = None, + no_colour_output: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(stream, descriptions, verbosity, failfast, buffer, resultclass) if not blessings_loaded or not pygments_loaded: log_info(("Full colorized output can be obtained by running") + @@ -136,22 +150,22 @@ def __init__(self, stream=sys.stderr, descriptions=True, verbosity=1, ### - def run(self, test): + def run(self, test: Union[TestSuite, TestCase]) -> TestResult: """ Run the tests.""" testing_start_time = time.time() - runresults = super(SSTTextTestRunner, self).run(test) + runresults = super().run(test) testing_stop_time = time.time() total_testing_time = testing_stop_time - testing_start_time - self._get_and_display_test_results(runresults, total_testing_time) + self._get_and_display_test_results(runresults, total_testing_time) # type: ignore [arg-type] return runresults ### - def did_tests_pass(self, run_results): + def did_tests_pass(self, run_results: TestResult) -> bool: """ Figure out if testing passed. Args: - run_results - A unittest.TestResult object + run_results - A TestResult object Returns: True if all tests passing with no errors, false otherwise @@ -164,11 +178,15 @@ def did_tests_pass(self, run_results): ### - def _get_and_display_test_results(self, run_results, total_testing_time): + def _get_and_display_test_results( + self, + run_results: "SSTTextTestResult", + total_testing_time: float, + ) -> None: """ Figure out if testing passed, and display the test results. Args: - sst_tests_results - A unittest.TestResult object + sst_tests_results - A TestResult object Returns: True if all tests passing with no errors, false otherwise @@ -238,7 +256,7 @@ def _get_and_display_test_results(self, run_results, total_testing_time): ################################################################################ -class SSTTextTestResult(unittest.TestResult): +class SSTTextTestResult(TextTestResult): """ A superclass to support SST required testing, this is a modified version of unittestTextTestResult from python 2.7 modified for SST's needs. """ @@ -270,14 +288,20 @@ class SSTTextTestResult(unittest.TestResult): lexer = Lexer() - def __init__(self, stream, descriptions, verbosity, no_colour_output=False): - super(SSTTextTestResult, self).__init__(stream, descriptions, verbosity) + def __init__( + self, + stream: Any, + descriptions: bool, + verbosity: int, + no_colour_output: bool = False, + ) -> None: + super().__init__(stream, descriptions, verbosity) self.testsuitesresultsdict = SSTTestSuitesResultsDict() self._test_name = "undefined_testname" self._testcase_name = "undefined_testcasename" self._testsuite_name = "undefined_testsuitename" - self._junit_test_case = None - self.stream = stream + self._junit_test_case: JUnitTestCase = None # type: ignore [assignment] + self.stream: _WritelnDecorator = stream self.showAll = verbosity > 1 self.dots = verbosity == 1 self.descriptions = descriptions @@ -286,20 +310,20 @@ def __init__(self, stream, descriptions, verbosity, no_colour_output=False): else: self.no_colour_output = True - def getShortDescription(self, test): + def getShortDescription(self, test: TestCase) -> str: doc_first_line = test.shortDescription() if self.descriptions and doc_first_line: return '\n'.join((str(test), doc_first_line)) else: return str(test) - def getLongDescription(self, test): + def getLongDescription(self, test: TestCase) -> str: doc_first_line = test.shortDescription() if self.descriptions and doc_first_line: return '\n'.join((str(test), doc_first_line)) return str(test) - def getClassDescription(self, test): + def getClassDescription(self, test: TestCase) -> str: test_class = test.__class__ doc = test_class.__doc__ if self.descriptions and doc: @@ -308,9 +332,8 @@ def getClassDescription(self, test): ### - def startTest(self, test): - super(SSTTextTestResult, self).startTest(test) - #log_forced("DEBUG - startTest: Test = {0}\n".format(test)) + def startTest(self, test: TestCase) -> None: + super().startTest(test) if self.showAll: if not test_engine_globals.TESTENGINE_CONCURRENTMODE: if self._test_class != test.__class__: @@ -322,14 +345,12 @@ def startTest(self, test): self.stream.writeln(self.colours['title'](title)) self.stream.flush() - self._test_name = "undefined_testname" - _testname = getattr(test, 'testname', None) - if _testname is not None: - self._test_name = test.testname if self._is_test_of_type_ssttestcase(test): - self._testcase_name = test.get_testcase_name() - self._testsuite_name = test.get_testsuite_name() + self._test_name = test.testname # type: ignore [attr-defined] + self._testcase_name = test.get_testcase_name() # type: ignore [attr-defined] + self._testsuite_name = test.get_testsuite_name() # type: ignore [attr-defined] else: + self._test_name = "undefined_testname" self._testcase_name = "FailedTest" self._testsuite_name = "FailedTest" timestamp = get_current_time().strftime("%Y_%m%d_%H:%M:%S.%f utc") @@ -337,12 +358,11 @@ def startTest(self, test): self._testcase_name, timestamp=timestamp) - def stopTest(self, test): - super(SSTTextTestResult, self).stopTest(test) - #log_forced("DEBUG - stopTest: Test = {0}\n".format(test)) - testruntime = 0 + def stopTest(self, test: TestCase) -> None: + super().stopTest(test) + testruntime = 0.0 if self._is_test_of_type_ssttestcase(test): - testruntime = test.get_test_runtime_sec() + testruntime = test.get_test_runtime_sec() # type: ignore [attr-defined] self._junit_test_case.junit_add_elapsed_sec(testruntime) if not self._is_test_of_type_ssttestcase(test): @@ -357,13 +377,20 @@ def stopTest(self, test): ### - def get_testsuites_results_dict(self): + def get_testsuites_results_dict(self) -> "SSTTestSuitesResultsDict": """ Return the test suites results dict """ return self.testsuitesresultsdict ### - def printResult(self, test, short, extended, colour_key=None, showruntime=True): + def printResult( + self, + test: TestCase, + short: str, + extended: str, + colour_key: Optional[str] = None, + showruntime: bool = True, + ) -> None: if self.no_colour_output: colour = self.colours[None] else: @@ -373,13 +400,13 @@ def printResult(self, test, short, extended, colour_key=None, showruntime=True): self.stream.write(colour(extended)) self.stream.write(" -- ") self.stream.write(self.getShortDescription(test)) - testruntime = 0 + testruntime = 0.0 if self._is_test_of_type_ssttestcase(test): - testruntime = test.get_test_runtime_sec() + testruntime = test.get_test_runtime_sec() # type: ignore [attr-defined] if showruntime: self.stream.writeln(" [{0:.3f}s]".format(testruntime)) else: - self.stream.writeln(" ".format(testruntime)) + self.stream.writeln(" ") self.stream.flush() elif self.dots: self.stream.write(colour(short)) @@ -387,18 +414,16 @@ def printResult(self, test, short, extended, colour_key=None, showruntime=True): ### - def addSuccess(self, test): - super(SSTTextTestResult, self).addSuccess(test) - #log_forced("DEBUG - addSuccess: Test = {0}\n".format(test)) + def addSuccess(self, test: TestCase) -> None: + super().addSuccess(test) self.printResult(test, '.', 'PASS', 'success') if not self._is_test_of_type_ssttestcase(test): return self.testsuitesresultsdict.add_success(test) - def addError(self, test, err): - super(SSTTextTestResult, self).addError(test, err) - #log_forced("DEBUG - addError: Test = {0}, err = {1}\n".format(test, err)) + def addError(self, test: TestCase, err: "ErrorType") -> None: + super().addError(test, err) self.printResult(test, 'E', 'ERROR', 'error') if not self._is_test_of_type_ssttestcase(test): @@ -409,9 +434,8 @@ def addError(self, test, err): err_msg = self._get_err_info(err) _junit_test_case.junit_add_error_info(err_msg) - def addFailure(self, test, err): - super(SSTTextTestResult, self).addFailure(test, err) - #log_forced("DEBUG - addFailure: Test = {0}, err = {1}\n".format(test, err)) + def addFailure(self, test: TestCase, err: "ErrorType") -> None: + super().addFailure(test, err) self.printResult(test, 'F', 'FAIL', 'fail') if not self._is_test_of_type_ssttestcase(test): @@ -422,9 +446,8 @@ def addFailure(self, test, err): err_msg = self._get_err_info(err) _junit_test_case.junit_add_failure_info(err_msg) - def addSkip(self, test, reason): - super(SSTTextTestResult, self).addSkip(test, reason) - #log_forced("DEBUG - addSkip: Test = {0}, reason = {1}\n".format(test, reason)) + def addSkip(self, test: TestCase, reason: str) -> None: + super().addSkip(test, reason) if not test_engine_globals.TESTENGINE_IGNORESKIPS: self.printResult(test, 's', 'SKIPPED', 'skip', showruntime=False) @@ -435,20 +458,18 @@ def addSkip(self, test, reason): if _junit_test_case is not None: _junit_test_case.junit_add_skipped_info(reason) - def addExpectedFailure(self, test, err): + def addExpectedFailure(self, test: TestCase, err: "ErrorType") -> None: # NOTE: This is not a failure, but an identified pass # since we are expecting a failure - super(SSTTextTestResult, self).addExpectedFailure(test, err) - #log_forced("DEBUG - addExpectedFailure: Test = {0}, err = {1}\n".format(test, err)) + super().addExpectedFailure(test, err) self.printResult(test, 'x', 'EXPECTED FAILURE', 'expected') if not self._is_test_of_type_ssttestcase(test): return self.testsuitesresultsdict.add_expected_failure(test) - def addUnexpectedSuccess(self, test): + def addUnexpectedSuccess(self, test: TestCase) -> None: # NOTE: This is a failure, since we passed, but were expecting a failure - super(SSTTextTestResult, self).addUnexpectedSuccess(test) - #log_forced("DEBUG - addUnexpectedSuccess: Test = {0}\n".format(test)) + super().addUnexpectedSuccess(test) self.printResult(test, 'u', 'UNEXPECTED SUCCESS', 'unexpected') if not self._is_test_of_type_ssttestcase(test): @@ -471,7 +492,7 @@ def printErrors(self) -> None: self.printErrorList('ERROR', self.errors) self.printErrorList('FAIL', self.failures) - def printErrorList(self, flavour, errors): + def printErrorList(self, flavour: str, errors: "ErrorsType") -> None: if self.no_colour_output: colour = self.colours[None] else: @@ -487,7 +508,7 @@ def printErrorList(self, flavour, errors): else: self.stream.writeln(err) - def printSkipList(self, flavour, errors): + def printSkipList(self, flavour: str, errors: "ErrorsType") -> None: if self.no_colour_output: colour = self.colours[None] else: @@ -503,7 +524,7 @@ def printSkipList(self, flavour, errors): #### - def _get_err_info(self, err): + def _get_err_info(self, err: "ErrorType") -> str: """Converts a sys.exc_info() into a string.""" exctype, value, tback = err msg_lines = traceback.format_exception_only(exctype, value) @@ -512,20 +533,19 @@ def _get_err_info(self, err): #### - def _is_test_of_type_ssttestcase(self, test): - """ Determine if this is is within a valid SSTTestCase object by - checking if a unique SSTTestCase function exists - return: True if this is a test within a valid SSTTestCase object - """ - return getattr(test, 'get_testcase_name', None) is not None + def _is_test_of_type_ssttestcase(self, test: TestCase) -> bool: + """Determine if this is within a valid SSTTestCase object.""" + # This originally checked for a special method on the instance, + # but the type check is faster. + return isinstance(test, SSTTestCase) ################################################################################ # TestSuiteBaseClass will be either unitest.TestSuite or testtools.ConcurrentTestSuite # and is defined at the top of this file. -class SSTTestSuite(TestSuiteBaseClass): - """A TestSuite whose run() method can execute tests concurrently. - but also supports the python base unittest.TestSuite functionality. +class SSTTestSuite(TestSuiteBaseClass): # type: ignore [misc, valid-type] + """A TestSuite whose run() method can execute tests concurrently, + but also supports the Python base TestSuite functionality. This is a highly modified version of testtools.ConcurrentTestSuite class to support startUpModuleConcurrent() & tearDownModuleConcurrent() @@ -533,18 +553,25 @@ class to support startUpModuleConcurrent() & tearDownModuleConcurrent() This object will normally be derived from testtools.ConcurrentTestSuite class, however, if the import of testtools failed, it will be derived from - unittest.TestSuite. + TestSuite. If the user selected concurrent mode is false, then it will always make - calls to the unittest.TestSuite class EVEN IF it is derived from - testtools.ConcurrentTestSuite, which is itself derived from unittest.TestSuite. + calls to the TestSuite class EVEN IF it is derived from + testtools.ConcurrentTestSuite, which is itself derived from TestSuite. """ - def __init__(self, suite, make_tests, wrap_result=None): - """Create a ConcurrentTestSuite or unittest.TestSuite to execute the suite. + def __init__( + self, + suite: TestSuite, + make_tests: Callable[["SSTTestSuite"], List[Any]], + wrap_result: Optional[ + Callable[["testtools.ThreadsafeForwardingResults", int], TestResult] + ] = None, + ) -> None: + """Create a ConcurrentTestSuite or TestSuite to execute the suite. Note: If concurrent mode is false, then it will always make calls to the - unittest.TestSuite class EVEN IF the class is derived from + TestSuite class EVEN IF the class is derived from testtools.ConcurrentTestSuite. Args: @@ -553,23 +580,26 @@ def __init__(self, suite, make_tests, wrap_result=None): ConcurrentTestSuite into some number of concurrently executing sub-suites. make_tests must take a suite, and return an iterable of TestCase-like object, each of which must have a run(result) - method. NOT USED IN unittest.TestSuite. + method. NOT USED IN TestSuite. wrap_result: An optional function that takes a thread-safe result and a thread number and must return a ``TestResult`` object. If not provided, then ``ConcurrentTestSuite`` will just use a ``ThreadsafeForwardingResult`` wrapped around the result - passed to ``run()``. NOT USED IN unittest.TestSuite + passed to ``run()``. NOT USED IN TestSuite """ + # This separation is required for the case when `testtools` is not + # installed, regardless of whether or not testing is performed + # concurrently. if not test_engine_globals.TESTENGINE_CONCURRENTMODE: # Ignore make_tests and wrap_results - super(unittest.TestSuite, self).__init__(suite) + super().__init__(suite) else: - super(SSTTestSuite, self).__init__(suite, make_tests, wrap_result) + super().__init__(suite, make_tests, wrap_result) #### - def run(self, result): + def run(self, result: "SSTTextTestResult") -> Optional["SSTTextTestResult"]: """Run the tests (possibly concurrently). This calls out to the provided make_tests helper, and then serialises @@ -587,50 +617,40 @@ def run(self, result): support running a limited number of concurrent threads. If concurrent mode is false, then it will always make calls to the - unittest.TestSuite class EVEN IF it is derived from + TestSuite class EVEN IF it is derived from testtools.ConcurrentTestSuite. """ # Check to verify if we are NOT in concurrent mode, if so, then - # just call the run (this will be unittest.TestSuite's run()) + # just call the run (this will be TestSuite's run()) if not test_engine_globals.TESTENGINE_CONCURRENTMODE: - return super(unittest.TestSuite, self).run(result) + return super(TestSuite, self).run(result) # type: ignore [misc, no-any-return] # Perform the Concurrent Run tests = self.make_tests(self) thread_limit = test_engine_globals.TESTENGINE_THREADLIMIT test_index = -1 try: - threads = {} - testqueue = Queue() + threads: Dict[TestCase, Tuple[threading.Thread, Any]] = {} + testqueue: Queue[TestCase] = Queue() semaphore = threading.Semaphore(1) test_iter = iter(tests) - test = "startup_placeholder" + test: Optional[SSTTestCase] = "startup_placeholder" # type: ignore [assignment] tests_finished = False while not tests_finished: while len(threads) < thread_limit and test is not None: - #log_forced("DEBUG: CALLING FOR NEXT TEST; threads = {0}".format(len(threads))) test = next(test_iter, None) if result.shouldStop: tests_finished = True test_index += 1 - #log_forced("DEBUG: TEST = {0}; index = {1}".format(test, test_index)) if test is not None: - process_result = self._wrap_result(testtools.\ - ThreadsafeForwardingResult(result, semaphore), test_index) - reader_thread = threading.\ - Thread(target=self._run_test, args=(test, process_result, testqueue)) + process_result = self._wrap_result(testtools.ThreadsafeForwardingResult(result, semaphore), test_index) + reader_thread = threading.Thread(target=self._run_test, args=(test, process_result, testqueue)) threads[test] = reader_thread, process_result reader_thread.start() - #log_forced("DEBUG: ADDED TEST = {0}; threads = {1}".\ - #format(test, len(threads))) if threads: - #log_forced("DEBUG: IN THREADS PROESSING") finished_test = testqueue.get() - #log_forced("DEBUG: FINISHED TEST = {0}".format(finished_test)) threads[finished_test][0].join() del threads[finished_test] - #log_forced("DEBUG: FINISHED TEST NUM THREADS = {0}".format(len(threads))) - #log_forced("DEBUG: FINISHED TEST THREADS keys = {0}".format(threads.keys())) else: tests_finished = True test_engine_globals.TESTRUN_TESTRUNNINGFLAG = False @@ -640,9 +660,16 @@ def run(self, result): process_result.stop() raise + return None + ### - def _run_test(self, test, process_result, testqueue): + def _run_test( + self, + test: SSTTestCase, + process_result: "testtools.testresult.real.ThreadsafeForwardingResult", + testqueue: "Queue[TestCase]", + ) -> None: """Support running a single test concurrently NOTE: This is a slightly modified version of the @@ -658,8 +685,8 @@ def _run_test(self, test, process_result, testqueue): tearDownModuleConcurrent(test) except Exception: # The run logic itself failed. - case = testtools.ErrorHolder("broken-runner", error=sys.exc_info()) - case.run(process_result) + testcase = testtools.ErrorHolder("broken-runner", error=sys.exc_info()) + testcase.run(process_result) finally: testqueue.put(test) @@ -670,58 +697,58 @@ class SSTTestSuiteResultData: Results are stored as lists of test names """ def __init__(self) -> None: - self._tests_passing: List[SSTTestCase] = [] - self._tests_failing: List[SSTTestCase] = [] - self._tests_errored: List[SSTTestCase] = [] - self._tests_skipped: List[SSTTestCase] = [] - self._tests_expectedfailed: List[SSTTestCase] = [] - self._tests_unexpectedsuccess: List[SSTTestCase] = [] - - def add_success(self, test: SSTTestCase) -> None: + self._tests_passing: List[TestCase] = [] + self._tests_failing: List[TestCase] = [] + self._tests_errored: List[TestCase] = [] + self._tests_skipped: List[TestCase] = [] + self._tests_expectedfailed: List[TestCase] = [] + self._tests_unexpectedsuccess: List[TestCase] = [] + + def add_success(self, test: TestCase) -> None: """ Add a test to the success record""" self._tests_passing.append(test) - def add_failure(self, test: SSTTestCase) -> None: + def add_failure(self, test: TestCase) -> None: """ Add a test to the failure record""" self._tests_failing.append(test) - def add_error(self, test: SSTTestCase) -> None: + def add_error(self, test: TestCase) -> None: """ Add a test to the error record""" self._tests_errored.append(test) - def add_skip(self, test: SSTTestCase) -> None: + def add_skip(self, test: TestCase) -> None: """ Add a test to the skip record""" self._tests_skipped.append(test) - def add_expected_failure(self, test: SSTTestCase) -> None: + def add_expected_failure(self, test: TestCase) -> None: """ Add a test to the expected failure record""" self._tests_expectedfailed.append(test) - def add_unexpected_success(self, test: SSTTestCase) -> None: + def add_unexpected_success(self, test: TestCase) -> None: """ Add a test to the unexpected success record""" self._tests_unexpectedsuccess.append(test) - def get_passing(self) -> List[SSTTestCase]: + def get_passing(self) -> List[TestCase]: """ Return the tests passing list""" return self._tests_passing - def get_failed(self) -> List[SSTTestCase]: + def get_failed(self) -> List[TestCase]: """ Return the tests failed list""" return self._tests_failing - def get_errored(self) -> List[SSTTestCase]: + def get_errored(self) -> List[TestCase]: """ Return the tests errored list""" return self._tests_errored - def get_skipped(self) -> List[SSTTestCase]: + def get_skipped(self) -> List[TestCase]: """ Return the tests skipped list""" return self._tests_skipped - def get_expectedfailed(self) -> List[SSTTestCase]: + def get_expectedfailed(self) -> List[TestCase]: """ Return the expected failed list""" return self._tests_expectedfailed - def get_unexpectedsuccess(self) -> List[SSTTestCase]: + def get_unexpectedsuccess(self) -> List[TestCase]: """ Return the tests unexpected success list""" return self._tests_unexpectedsuccess @@ -733,27 +760,27 @@ class SSTTestSuitesResultsDict: def __init__(self) -> None: self.testsuitesresultsdict: Dict[str, SSTTestSuiteResultData] = {} - def add_success(self, test: SSTTestCase) -> None: + def add_success(self, test: TestCase) -> None: """ Add a testsuite and test to the success record""" self._get_testresult_from_testmodulecase(test).add_success(test) - def add_failure(self, test: SSTTestCase) -> None: + def add_failure(self, test: TestCase) -> None: """ Add a testsuite and test to the failure record""" self._get_testresult_from_testmodulecase(test).add_failure(test) - def add_error(self, test: SSTTestCase) -> None: + def add_error(self, test: TestCase) -> None: """ Add a testsuite and test to the error record""" self._get_testresult_from_testmodulecase(test).add_error(test) - def add_skip(self, test: SSTTestCase) -> None: + def add_skip(self, test: TestCase) -> None: """ Add a testsuite and test to the skip record""" self._get_testresult_from_testmodulecase(test).add_skip(test) - def add_expected_failure(self, test: SSTTestCase) -> None: + def add_expected_failure(self, test: TestCase) -> None: """ Add a testsuite and test to the expected failure record""" self._get_testresult_from_testmodulecase(test).add_expected_failure(test) - def add_unexpected_success(self, test: SSTTestCase) -> None: + def add_unexpected_success(self, test: TestCase) -> None: """ Add a testsuite and test to the unexpected success record""" self._get_testresult_from_testmodulecase(test).add_unexpected_success(test) @@ -796,18 +823,25 @@ def log_fail_error_skip_unexpeced_results(self) -> None: for testname in self.testsuitesresultsdict[tmtc_name].get_unexpectedsuccess(): log(" - UNEXPECTED SUCCESS : {0}".format(testname)) - def _get_testresult_from_testmodulecase(self, test: SSTTestCase) -> SSTTestSuiteResultData: + def _get_testresult_from_testmodulecase(self, test: TestCase) -> SSTTestSuiteResultData: tm_tc = self._get_test_module_test_case_name(test) if tm_tc not in self.testsuitesresultsdict.keys(): self.testsuitesresultsdict[tm_tc] = SSTTestSuiteResultData() return self.testsuitesresultsdict[tm_tc] - def _get_test_module_test_case_name(self, test: SSTTestCase) -> str: + def _get_test_module_test_case_name(self, test: TestCase) -> str: return "{0}.{1}".format(self._get_test_module_name(test), self._get_test_case_name(test)) - def _get_test_case_name(self, test: SSTTestCase) -> str: + def _get_test_case_name(self, test: TestCase) -> str: return strqual(test.__class__) - def _get_test_module_name(self, test: SSTTestCase) -> str: + def _get_test_module_name(self, test: TestCase) -> str: return strclass(test.__class__) + + +ErrorType = Union[ + Tuple[Type[BaseException], BaseException, "TracebackType"], + Tuple[None, None, None] +] +ErrorsType = Iterable[Tuple[TestCase, str]] diff --git a/tests/test_PortModule.py b/tests/test_PortModule.py index 15a5d855b..fd9be3e12 100644 --- a/tests/test_PortModule.py +++ b/tests/test_PortModule.py @@ -13,8 +13,9 @@ import sst import argparse +import sys -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="Run PortModule test") parser.add_argument('--send', action='store_true', help="Install PortModule on send") parser.add_argument('--recv', action='store_true', help="Install PortModule on receive")