diff --git a/.pre-commit-hooks.yaml b/.pre-commit-hooks.yaml index 9cfc074..3084b92 100644 --- a/.pre-commit-hooks.yaml +++ b/.pre-commit-hooks.yaml @@ -18,3 +18,14 @@ language: python types: [shell] args: [--fix] +- id: verify-rapids-metadata + name: RAPIDS metadata + description: make sure RAPIDS metadata has template file or isn't hard-coded + entry: verify-rapids-metadata + language: python + types: [text] + exclude: | + (?x) + ^VERSION$| + [.]rapids_metadata_template$ + args: [--fix] diff --git a/pyproject.toml b/pyproject.toml index 487bf24..3a5964d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ classifiers = [ requires-python = ">=3.9" dependencies = [ "bashlex", + "packaging", ] [project.optional-dependencies] @@ -42,6 +43,7 @@ dev = [ [project.scripts] verify-conda-yes = "rapids_pre_commit_hooks.shell.verify_conda_yes:main" +verify-rapids-metadata = "rapids_pre_commit_hooks.rapids_metadata:main" [tool.setuptools] packages = { "find" = { where = ["src"] } } diff --git a/src/rapids_pre_commit_hooks/lint.py b/src/rapids_pre_commit_hooks/lint.py index c688a79..4ef1459 100644 --- a/src/rapids_pre_commit_hooks/lint.py +++ b/src/rapids_pre_commit_hooks/lint.py @@ -64,7 +64,7 @@ def __repr__(self): return ( "LintWarning(" + f"pos={self.pos}, " - + f"msg={self.msg}, " + + f"msg={repr(self.msg)}, " + f"replacements={self.replacements})" ) diff --git a/src/rapids_pre_commit_hooks/rapids_metadata.py b/src/rapids_pre_commit_hooks/rapids_metadata.py new file mode 100644 index 0000000..3591f76 --- /dev/null +++ b/src/rapids_pre_commit_hooks/rapids_metadata.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from packaging.version import Version + +from .lint import LintMain + + +def check_rapids_version(linter, rapids_version_re): + for match in rapids_version_re.finditer(linter.content): + linter.add_warning( + match.span(), + "do not hard-code RAPIDS version; dynamically read from VERSION file or " + f'write a "{linter.filename}.rapids_metadata_template" file', + ) + + +def check_rapids_metadata(): + with open("VERSION") as f: + version = Version(f.read()) + + rapids_version_re = re.compile( + rf"{version.major:02}\.{version.minor:02}(\.{version.micro:02})?" + ) + + def the_check(linter, args): + try: + with open(f"{linter.filename}.rapids_metadata_template") as f: + template_content = f.read() + except FileNotFoundError: + template_content = None + + if template_content is None: + check_rapids_version(linter, rapids_version_re) + else: + template_replacement = template_content.format( + RAPIDS_VERSION_MAJOR=f"{version.major:02}", + RAPIDS_VERSION_MINOR=f"{version.minor:02}", + RAPIDS_VERSION_PATCH=f"{version.micro:02}", + RAPIDS_VERSION=( + f"{version.major:02}.{version.minor:02}.{version.micro:02}" + ), + RAPIDS_VERSION_MAJOR_MINOR=f"{version.major:02}.{version.minor:02}", + ) + + if linter.content != template_replacement: + linter.add_warning( + (0, len(linter.content)), + f'file does not match template replacement from "{linter.filename}' + '.rapids_metadata_template"', + ).add_replacement((0, len(linter.content)), template_replacement) + + return the_check + + +def main(): + m = LintMain() + with m.execute() as ctx: + ctx.add_check(check_rapids_metadata()) + + +if __name__ == "__main__": + main() diff --git a/test/rapids_pre_commit_hooks/test_rapids_metadata.py b/test/rapids_pre_commit_hooks/test_rapids_metadata.py new file mode 100644 index 0000000..212612b --- /dev/null +++ b/test/rapids_pre_commit_hooks/test_rapids_metadata.py @@ -0,0 +1,140 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +from unittest.mock import patch + +import pytest + +from rapids_pre_commit_hooks.lint import Linter +from rapids_pre_commit_hooks.rapids_metadata import check_rapids_metadata + + +class TestRAPIDSMetadata: + def mock_files(self, files): + def new_open(filename, mode="r"): + assert mode == "r" + + try: + content = files[filename] + except KeyError: + raise FileNotFoundError + + return io.StringIO(content) + + return patch("builtins.open", new_open) + + def test_mock_files(self): + with self.mock_files({"file.txt": "Hello"}): + with open("file.txt") as f: + assert f.read() == "Hello" + with pytest.raises(FileNotFoundError): + open("nonexistent.txt") + with pytest.raises(AssertionError): + open("file.txt", "rb") + with pytest.raises(AssertionError): + open("file.txt", "w") + + def test_template_file(self): + FILES = { + "VERSION": "24.04.00\n", + "file.txt.rapids_metadata_template": """This file contains RAPIDS metadata +Full version is {RAPIDS_VERSION} +Major-minor version is {RAPIDS_VERSION_MAJOR_MINOR} +Major version is {RAPIDS_VERSION_MAJOR} +Minor version is {RAPIDS_VERSION_MINOR} +Patch version is {RAPIDS_VERSION_PATCH} +Hard-coded version is 24.04.00 +Old hard-coded version is 24.02.00 +Brace literal is {{RAPIDS_VERSION}} +""", + } + + CORRECT_CONTENT = CONTENT = """This file contains RAPIDS metadata +Full version is 24.04.00 +Major-minor version is 24.04 +Major version is 24 +Minor version is 04 +Patch version is 00 +Hard-coded version is 24.04.00 +Old hard-coded version is 24.02.00 +Brace literal is {RAPIDS_VERSION} +""" + + linter = Linter("file.txt", CONTENT) + with self.mock_files(FILES): + checker = check_rapids_metadata() + checker(linter, None) + assert linter.warnings == [] + + CONTENT = """This file contains RAPIDS metadata +Full version is 24.02.00 +Major-minor version is 24.02 +Major version is 24 +Minor version is 02 +Patch version is 00 +Hard-coded version is 24.04.00 +Old hard-coded version is 24.02.00 +Brace literal is {RAPIDS_VERSION} +""" + expected_linter = Linter("file.txt", CONTENT) + expected_linter.add_warning( + (0, len(CONTENT)), + "file does not match template replacement from " + '"file.txt.rapids_metadata_template"', + ).add_replacement((0, len(CONTENT)), CORRECT_CONTENT) + + linter = Linter("file.txt", CONTENT) + with self.mock_files(FILES): + checker = check_rapids_metadata() + checker(linter, None) + assert linter.warnings == expected_linter.warnings + + def test_no_template_file(self): + FILES = { + "VERSION": "24.04.00\n", + } + + CONTENT = """This file contains RAPIDS metadata +Full version is 24.04.00 +Major-minor version is 24.04 +Major version is 24 +Minor version is 04 +Patch version is 00 +Hard-coded version is 24.04.00 +Old hard-coded version is 24.02.00 +Brace literal is {RAPIDS_VERSION} +""" + expected_linter = Linter("file.txt", CONTENT) + expected_linter.add_warning( + (51, 59), + "do not hard-code RAPIDS version; dynamically read from VERSION file or " + 'write a "file.txt.rapids_metadata_template" file', + ) + expected_linter.add_warning( + (83, 88), + "do not hard-code RAPIDS version; dynamically read from VERSION file or " + 'write a "file.txt.rapids_metadata_template" file', + ) + expected_linter.add_warning( + (171, 179), + "do not hard-code RAPIDS version; dynamically read from VERSION file or " + 'write a "file.txt.rapids_metadata_template" file', + ) + + linter = Linter("file.txt", CONTENT) + with self.mock_files(FILES): + checker = check_rapids_metadata() + checker(linter, None) + assert linter.warnings == expected_linter.warnings