From 82f41009c7093fbc49ee9409e8638d94f2474287 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Thu, 18 Jan 2024 14:19:40 -0500 Subject: [PATCH] Add copyright check Fixes: https://github.com/rapidsai/pre-commit-hooks/issues/2 --- ci/build-test.sh | 4 +- pyproject.toml | 5 +- src/rapids_pre_commit_hooks/copyright.py | 139 ++++++ .../rapids_pre_commit_hooks/test_copyright.py | 431 ++++++++++++++++++ 4 files changed, 576 insertions(+), 3 deletions(-) create mode 100644 src/rapids_pre_commit_hooks/copyright.py create mode 100644 test/rapids_pre_commit_hooks/test_copyright.py diff --git a/ci/build-test.sh b/ci/build-test.sh index ce157d1..a784659 100755 --- a/ci/build-test.sh +++ b/ci/build-test.sh @@ -3,13 +3,13 @@ set -ue -pip install build pytest +pip install build python -m build . for PKG in dist/*; do echo "$PKG" pip uninstall -y rapids-pre-commit-hooks - pip install "$PKG" + pip install "$PKG[test]" pytest done diff --git a/pyproject.toml b/pyproject.toml index 487bf24..7ac84b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,15 +33,18 @@ classifiers = [ requires-python = ">=3.9" dependencies = [ "bashlex", + "gitpython", ] [project.optional-dependencies] -dev = [ +test = [ + "freezegun", "pytest", ] [project.scripts] verify-conda-yes = "rapids_pre_commit_hooks.shell.verify_conda_yes:main" +verify-copyright = "rapids_pre_commit_hooks.copyright:main" [tool.setuptools] packages = { "find" = { where = ["src"] } } diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py new file mode 100644 index 0000000..29749a1 --- /dev/null +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import functools +import re + +import git + +from .lint import LintMain + +COPYRIGHT_RE = re.compile( + r"Copyright *(?:\(c\))? *(?P(?P\d{4})(-(?P\d{4}))?),?" + r" *NVIDIA C(?:ORPORATION|orporation)" +) + + +class ConflictingFilesError(RuntimeError): + pass + + +def match_copyright(content): + return list(COPYRIGHT_RE.finditer(content)) + + +def strip_copyright(content, copyright_matches): + lines = [] + + def append_stripped(start, item): + lines.append(content[start : item.start()]) + return item.end() + + start = functools.reduce(append_stripped, copyright_matches, 0) + lines.append(content[start:]) + return lines + + +def apply_copyright_check(linter, old_content): + if linter.content != old_content: + current_year = datetime.datetime.now().year + new_copyright_matches = match_copyright(linter.content) + + if old_content is not None: + old_copyright_matches = match_copyright(old_content) + + if old_content is not None and strip_copyright( + old_content, old_copyright_matches + ) == strip_copyright(linter.content, new_copyright_matches): + for old_match, new_match in zip( + old_copyright_matches, new_copyright_matches + ): + if old_match.group() != new_match.group(): + if old_match.group("years") == new_match.group("years"): + warning_pos = new_match.span() + else: + warning_pos = new_match.span("years") + linter.add_warning( + warning_pos, + "copyright is not out of date and should not be updated", + ).add_replacement(new_match.span(), old_match.group()) + else: + if new_copyright_matches: + for match in new_copyright_matches: + if ( + int(match.group("last_year") or match.group("first_year")) + < current_year + ): + linter.add_warning( + match.span("years"), "copyright is out of date" + ).add_replacement( + match.span(), + f"Copyright (c) {match.group('first_year')}-{current_year}" + ", NVIDIA CORPORATION", + ) + else: + linter.add_warning((0, 0), "no copyright notice found") + + +def get_target_branch(repo): + # TODO + raise NotImplementedError + + +def get_changed_files(repo, target_branch): + changed_files = {} + + diffs = target_branch.commit.diff( + other=None, + merge_base=True, + find_copies=True, + find_copies_harder=True, + find_renames=True, + ) + for diff in diffs: + if diff.change_type == "A": + changed_files[diff.b_path] = None + elif diff.change_type != "D": + changed_files[diff.b_path] = diff.a_blob + + changed_files.update({f: None for f in repo.untracked_files}) + return changed_files + + +def check_copyright(): + repo = git.Repo() + target_branch = get_target_branch(repo) + changed_files = get_changed_files(repo, target_branch) + + def the_check(linter, args): + try: + changed_file = changed_files[linter.filename] + except KeyError: + return + + old_content = changed_file.data_stream.read().decode("utf-8") + apply_copyright_check(linter, old_content) + + return the_check + + +def main(): + m = LintMain() + with m.execute() as ctx: + ctx.add_check(check_copyright()) + + +if __name__ == "__main__": + main() diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py new file mode 100644 index 0000000..1214701 --- /dev/null +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -0,0 +1,431 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path +import tempfile +from unittest.mock import Mock, patch + +import git +import pytest +from freezegun import freeze_time + +from rapids_pre_commit_hooks import copyright +from rapids_pre_commit_hooks.lint import Linter + + +def test_match_copyright(): + CONTENT = r""" +Copyright (c) 2024 NVIDIA CORPORATION +Copyright (c) 2021-2024 NVIDIA CORPORATION +# Copyright 2021, NVIDIA Corporation and affiliates +""" + + re_matches = copyright.match_copyright(CONTENT) + matches = [ + { + "span": match.span(), + "years": match.span("years"), + "first_year": match.span("first_year"), + "last_year": match.span("last_year"), + } + for match in re_matches + ] + assert matches == [ + { + "span": (1, 38), + "years": (15, 19), + "first_year": (15, 19), + "last_year": (-1, -1), + }, + { + "span": (39, 81), + "years": (53, 62), + "first_year": (53, 57), + "last_year": (58, 62), + }, + { + "span": (84, 119), + "years": (94, 98), + "first_year": (94, 98), + "last_year": (-1, -1), + }, + ] + + +def test_strip_copyright(): + CONTENT = r""" +This is a line before the first copyright statement +Copyright (c) 2024 NVIDIA CORPORATION +This is a line between the first two copyright statements +Copyright (c) 2021-2024 NVIDIA CORPORATION +This is a line between the next two copyright statements +# Copyright 2021, NVIDIA Corporation and affiliates +This is a line after the last copyright statement +""" + matches = copyright.match_copyright(CONTENT) + stripped = copyright.strip_copyright(CONTENT, matches) + assert stripped == [ + "\nThis is a line before the first copyright statement\n", + "\nThis is a line between the first two copyright statements\n", + "\nThis is a line between the next two copyright statements\n# ", + " and affiliates\nThis is a line after the last copyright statement\n", + ] + + stripped = copyright.strip_copyright("No copyright here", []) + assert stripped == ["No copyright here"] + + +@freeze_time("2024-01-18") +def test_apply_copyright_check(): + def run_apply_copyright_check(old_content, new_content): + linter = Linter("file.txt", new_content) + copyright.apply_copyright_check(linter, old_content) + return linter + + expected_linter = Linter("file.txt", "No copyright notice") + expected_linter.add_warning((0, 0), "no copyright notice found") + + linter = run_apply_copyright_check(None, "No copyright notice") + assert linter.warnings == expected_linter.warnings + + linter = run_apply_copyright_check("No copyright notice", "No copyright notice") + assert linter.warnings == [] + + OLD_CONTENT = r""" +Copyright (c) 2021-2023 NVIDIA CORPORATION +Copyright (c) 2023 NVIDIA CORPORATION +Copyright (c) 2024 NVIDIA CORPORATION +Copyright (c) 2025 NVIDIA CORPORATION +This file has not been changed +""" + linter = run_apply_copyright_check(OLD_CONTENT, OLD_CONTENT) + assert linter.warnings == [] + + NEW_CONTENT = r""" +Copyright (c) 2021-2023 NVIDIA CORPORATION +Copyright (c) 2023 NVIDIA CORPORATION +Copyright (c) 2024 NVIDIA CORPORATION +Copyright (c) 2025 NVIDIA CORPORATION +This file has been changed +""" + expected_linter = Linter("file.txt", NEW_CONTENT) + expected_linter.add_warning((15, 24), "copyright is out of date").add_replacement( + (1, 43), "Copyright (c) 2021-2024, NVIDIA CORPORATION" + ) + expected_linter.add_warning((58, 62), "copyright is out of date").add_replacement( + (44, 81), "Copyright (c) 2023-2024, NVIDIA CORPORATION" + ) + + linter = run_apply_copyright_check(OLD_CONTENT, NEW_CONTENT) + assert linter.warnings == expected_linter.warnings + + expected_linter = Linter("file.txt", NEW_CONTENT) + expected_linter.add_warning((15, 24), "copyright is out of date").add_replacement( + (1, 43), "Copyright (c) 2021-2024, NVIDIA CORPORATION" + ) + expected_linter.add_warning((58, 62), "copyright is out of date").add_replacement( + (44, 81), "Copyright (c) 2023-2024, NVIDIA CORPORATION" + ) + + linter = run_apply_copyright_check(None, NEW_CONTENT) + assert linter.warnings == expected_linter.warnings + + NEW_CONTENT = r""" +Copyright (c) 2021-2024 NVIDIA CORPORATION +Copyright (c) 2023 NVIDIA CORPORATION +Copyright (c) 2024 NVIDIA CORPORATION +Copyright (c) 2025 NVIDIA Corporation +This file has not been changed +""" + expected_linter = Linter("file.txt", NEW_CONTENT) + expected_linter.add_warning( + (15, 24), "copyright is not out of date and should not be updated" + ).add_replacement((1, 43), "Copyright (c) 2021-2023 NVIDIA CORPORATION") + expected_linter.add_warning( + (120, 157), "copyright is not out of date and should not be updated" + ).add_replacement((120, 157), "Copyright (c) 2025 NVIDIA CORPORATION") + + linter = run_apply_copyright_check(OLD_CONTENT, NEW_CONTENT) + assert linter.warnings == expected_linter.warnings + + +@pytest.fixture +def git_repo(): + with tempfile.TemporaryDirectory() as d: + repo = git.Repo.init(d) + config_writer = repo.config_writer() + config_writer.set_value("user", "name", "RAPIDS Test Fixtures") + config_writer.set_value("user", "email", "testfixtures@rapids.ai") + yield repo + + +@pytest.mark.skip +def test_get_target_branch(git_repo): + # TODO + pass + + +def test_get_changed_files(git_repo): + def fn(filename): + return os.path.join(git_repo.working_tree_dir, filename) + + def write_file(filename, contents): + with open(fn(filename), "w") as f: + f.write(contents) + + def file_contents(verbed): + return f"This file will be {verbed}\n" * 100 + + write_file("untouched.txt", file_contents("untouched")) + write_file("copied.txt", file_contents("copied")) + write_file("modified_and_copied.txt", file_contents("modified and copied")) + write_file("copied_and_modified.txt", file_contents("copied and modified")) + write_file("deleted.txt", file_contents("deleted")) + write_file("renamed.txt", file_contents("renamed")) + write_file("modified_and_renamed.txt", file_contents("modified and renamed")) + write_file("modified.txt", file_contents("modified")) + write_file("chmodded.txt", file_contents("chmodded")) + git_repo.index.add( + [ + "untouched.txt", + "copied.txt", + "modified_and_copied.txt", + "copied_and_modified.txt", + "deleted.txt", + "renamed.txt", + "modified_and_renamed.txt", + "modified.txt", + "chmodded.txt", + ] + ) + git_repo.index.commit("Initial commit") + + # Ensure that diff is done against merge base, not branch tip + git_repo.index.remove(["modified.txt"], working_tree=True) + git_repo.index.commit("Remove modified.txt") + + pr_branch = git_repo.create_head("pr", "HEAD~") + git_repo.head.reference = pr_branch + git_repo.head.reset(index=True, working_tree=True) + + write_file("copied_2.txt", file_contents("copied")) + git_repo.index.remove( + ["deleted.txt", "modified_and_renamed.txt"], working_tree=True + ) + git_repo.index.move(["renamed.txt", "renamed_2.txt"]) + write_file( + "modified.txt", file_contents("modified") + "This file has been modified\n" + ) + os.chmod(fn("chmodded.txt"), 0o755) + write_file("untouched.txt", file_contents("untouched") + "Oops\n") + write_file("added.txt", file_contents("added")) + write_file("added_and_deleted.txt", file_contents("added and deleted")) + write_file( + "modified_and_copied.txt", + file_contents("modified and copied") + "This file has been modified\n", + ) + write_file("modified_and_copied_2.txt", file_contents("modified and copied")) + write_file( + "copied_and_modified_2.txt", + file_contents("copied and modified") + "This file has been modified\n", + ) + write_file( + "modified_and_renamed_2.txt", + file_contents("modified and renamed") + "This file has been modified\n", + ) + git_repo.index.add( + [ + "untouched.txt", + "added.txt", + "added_and_deleted.txt", + "modified_and_copied.txt", + "modified_and_copied_2.txt", + "copied_and_modified_2.txt", + "copied_2.txt", + "modified_and_renamed_2.txt", + "modified.txt", + "chmodded.txt", + ] + ) + write_file("untracked.txt", file_contents("untracked")) + write_file("untouched.txt", file_contents("untouched")) + os.unlink(fn("added_and_deleted.txt")) + + target_branch = next(ref for ref in git_repo.refs if ref.name == "master") + merge_base = git_repo.merge_base(target_branch, "HEAD")[0] + old_files = { + blob.path: blob + for blob in merge_base.tree.traverse(lambda b, _: isinstance(b, git.Blob)) + } + + # Truly need to be checked + changed = { + "added.txt": None, + "untracked.txt": None, + "modified_and_renamed_2.txt": "modified_and_renamed.txt", + "modified.txt": "modified.txt", + "copied_and_modified_2.txt": "copied_and_modified.txt", + "modified_and_copied.txt": "modified_and_copied.txt", + } + + # Superfluous, but harmless because the content is identical + superfluous = { + "chmodded.txt": "chmodded.txt", + "modified_and_copied_2.txt": "modified_and_copied.txt", + "copied_2.txt": "copied.txt", + "renamed_2.txt": "renamed.txt", + } + + changed_files = copyright.get_changed_files(git_repo, target_branch) + assert { + path: old_blob.path if old_blob else None + for path, old_blob in changed_files.items() + } == changed | superfluous + + for new, old in changed.items(): + if old: + with open(fn(new), "rb") as f: + new_contents = f.read() + old_contents = old_files[old].data_stream.read() + assert new_contents != old_contents + assert changed_files[new].data_stream.read() == old_contents + + for new, old in superfluous.items(): + if old: + with open(fn(new), "rb") as f: + new_contents = f.read() + old_contents = old_files[old].data_stream.read() + assert new_contents == old_contents + assert changed_files[new].data_stream.read() == old_contents + + +@freeze_time("2024-01-18") +def test_check_copyright(git_repo): + def fn(filename): + return os.path.join(git_repo.working_tree_dir, filename) + + def write_file(filename, contents): + with open(fn(filename), "w") as f: + f.write(contents) + + def file_contents(num): + return rf""" +Copyright (c) 2021-2023 NVIDIA CORPORATION +File {num} +""" + + def file_contents_modified(num): + return rf""" +Copyright (c) 2021-2023 NVIDIA CORPORATION +File {num} modified +""" + + write_file("file1.txt", file_contents(1)) + write_file("file2.txt", file_contents(2)) + write_file("file3.txt", file_contents(3)) + write_file("file4.txt", file_contents(4)) + git_repo.index.add(["file1.txt", "file2.txt", "file3.txt", "file4.txt"]) + git_repo.index.commit("Initial commit") + + branch_1 = git_repo.create_head("branch-1", "master") + git_repo.head.reference = branch_1 + git_repo.head.reset(index=True, working_tree=True) + write_file("file1.txt", file_contents_modified(1)) + git_repo.index.add(["file1.txt"]) + git_repo.index.commit("Update file1.txt") + + branch_2 = git_repo.create_head("branch-2", "master") + git_repo.head.reference = branch_2 + git_repo.head.reset(index=True, working_tree=True) + write_file("file2.txt", file_contents_modified(2)) + git_repo.index.add(["file2.txt"]) + git_repo.index.commit("Update file2.txt") + + pr = git_repo.create_head("pr", "branch-1") + git_repo.head.reference = pr + git_repo.head.reset(index=True, working_tree=True) + write_file("file3.txt", file_contents_modified(3)) + git_repo.index.add(["file3.txt"]) + git_repo.index.commit("Update file3.txt") + write_file("file4.txt", file_contents_modified(4)) + git_repo.index.add(["file4.txt"]) + git_repo.index.commit("Update file4.txt") + git_repo.index.move(["file2.txt", "file5.txt"]) + git_repo.index.commit("Rename file2.txt to file5.txt") + + def mock_repo_cwd(): + return patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)) + + def mock_target_branch(branch_name): + def func(repo): + return next(ref for ref in repo.refs if ref.name == branch_name) + + return patch("rapids_pre_commit_hooks.copyright.get_target_branch", func) + + def mock_apply_copyright_check(): + return patch("rapids_pre_commit_hooks.copyright.apply_copyright_check", Mock()) + + ############################# + # branch-1 is target branch + ############################# + + with mock_repo_cwd(), mock_target_branch("branch-1"): + copyright_checker = copyright.check_copyright() + + linter = Linter("file1.txt", file_contents_modified(1)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_not_called() + + linter = Linter("file5.txt", file_contents(2)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(2)) + + linter = Linter("file3.txt", file_contents_modified(3)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(3)) + + linter = Linter("file4.txt", file_contents_modified(4)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(4)) + + ############################# + # branch-2 is target branch + ############################# + + with mock_repo_cwd(), mock_target_branch("branch-2"): + copyright_checker = copyright.check_copyright() + + linter = Linter("file1.txt", file_contents_modified(1)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(1)) + + linter = Linter("file5.txt", file_contents(2)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(2)) + + linter = Linter("file3.txt", file_contents_modified(3)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(3)) + + linter = Linter("file4.txt", file_contents_modified(4)) + with mock_apply_copyright_check() as apply_copyright_check: + copyright_checker(linter, None) + apply_copyright_check.assert_called_once_with(linter, file_contents(4))