diff --git a/scripts/abi_check.py b/scripts/abi_check.py index c2288432ce14..ac1d60ffd008 100755 --- a/scripts/abi_check.py +++ b/scripts/abi_check.py @@ -113,6 +113,8 @@ import xml.etree.ElementTree as ET +from mbedtls_dev import build_tree + class AbiChecker: """API and ABI checker.""" @@ -150,11 +152,6 @@ def __init__(self, old_version, new_version, configuration): self.git_command = "git" self.make_command = "make" - @staticmethod - def check_repo_path(): - if not all(os.path.isdir(d) for d in ["include", "library", "tests"]): - raise Exception("Must be run from Mbed TLS root") - def _setup_logger(self): self.log = logging.getLogger() if self.verbose: @@ -540,7 +537,7 @@ def get_abi_compatibility_report(self): def check_for_abi_changes(self): """Generate a report of ABI differences between self.old_rev and self.new_rev.""" - self.check_repo_path() + build_tree.check_repo_path() if self.check_api or self.check_abi: self.check_abi_tools_are_installed() self._get_abi_dump_for_ref(self.old_version) diff --git a/scripts/mbedtls_dev/build_tree.py b/scripts/mbedtls_dev/build_tree.py index 3920d0ed6c02..f52b785d95a0 100644 --- a/scripts/mbedtls_dev/build_tree.py +++ b/scripts/mbedtls_dev/build_tree.py @@ -25,6 +25,13 @@ def looks_like_mbedtls_root(path: str) -> bool: return all(os.path.isdir(os.path.join(path, subdir)) for subdir in ['include', 'library', 'programs', 'tests']) +def check_repo_path(): + """ + Check that the current working directory is the project root, and throw + an exception if not. + """ + if not all(os.path.isdir(d) for d in ["include", "library", "tests"]): + raise Exception("This script must be run from Mbed TLS root") def chdir_to_root() -> None: """Detect the root of the Mbed TLS source tree and change to it. diff --git a/tests/scripts/check_files.py b/tests/scripts/check_files.py index a0f5e1f53845..5c18702defb5 100755 --- a/tests/scripts/check_files.py +++ b/tests/scripts/check_files.py @@ -34,6 +34,9 @@ except ImportError: pass +import scripts_path # pylint: disable=unused-import +from mbedtls_dev import build_tree + class FileIssueTracker: """Base class for file-wide issue tracking. @@ -338,7 +341,7 @@ def __init__(self, log_file): """Instantiate the sanity checker. Check files under the current directory. Write a report of issues to log_file.""" - self.check_repo_path() + build_tree.check_repo_path() self.logger = None self.setup_logger(log_file) self.issues_to_check = [ @@ -353,11 +356,6 @@ def __init__(self, log_file): MergeArtifactIssueTracker(), ] - @staticmethod - def check_repo_path(): - if not all(os.path.isdir(d) for d in ["include", "library", "tests"]): - raise Exception("Must be run from Mbed TLS root") - def setup_logger(self, log_file, level=logging.INFO): self.logger = logging.getLogger() self.logger.setLevel(level) diff --git a/tests/scripts/check_names.py b/tests/scripts/check_names.py index 875d0b0f5bbb..d1e87b5c52b7 100755 --- a/tests/scripts/check_names.py +++ b/tests/scripts/check_names.py @@ -56,6 +56,10 @@ import subprocess import logging +import scripts_path # pylint: disable=unused-import +from mbedtls_dev import build_tree + + # Naming patterns to check against. These are defined outside the NameCheck # class for ease of modification. MACRO_PATTERN = r"^(MBEDTLS|PSA)_[0-9A-Z_]*[0-9A-Z]$" @@ -218,7 +222,7 @@ class CodeParser(): """ def __init__(self, log): self.log = log - self.check_repo_path() + build_tree.check_repo_path() # Memo for storing "glob expression": set(filepaths) self.files = {} @@ -227,15 +231,6 @@ def __init__(self, log): # Note that "*" can match directory separators in exclude lists. self.excluded_files = ["*/bn_mul", "*/compat-1.3.h"] - @staticmethod - def check_repo_path(): - """ - Check that the current working directory is the project root, and throw - an exception if not. - """ - if not all(os.path.isdir(d) for d in ["include", "library", "tests"]): - raise Exception("This script must be run from Mbed TLS root") - def comprehensive_parse(self): """ Comprehensive ("default") function to call each parsing function and