diff --git a/src/rapids_pre_commit_hooks/copyright.py b/src/rapids_pre_commit_hooks/copyright.py index 000ba14..232afe5 100644 --- a/src/rapids_pre_commit_hooks/copyright.py +++ b/src/rapids_pre_commit_hooks/copyright.py @@ -110,7 +110,7 @@ def apply_copyright_check(linter, old_content): linter.add_warning((0, 0), "no copyright notice found") -def get_target_branch(repo, target_branch_arg=None): +def get_target_branch(repo, args): """Determine which branch is the "target" branch. The target branch is determined in the following order: @@ -123,16 +123,18 @@ def get_target_branch(repo, target_branch_arg=None): This allows GitHub Actions to easily use this tool. * If the ``$RAPIDS_BASE_BRANCH`` environment variable is defined, that branch is used. This allows GitHub Actions inside ``copy-pr-bot`` to easily use this tool. - * If the configuration option ``rapidsai.baseBranch`` is defined, that branch is + * If the Git configuration option ``rapidsai.baseBranch`` is defined, that branch is used. This allows users to locally set a base branch on a long-term basis. + * If the ``--main-branch`` argument is passed, that branch is used. This allows + projects to use a branching strategy other than ``branch-.``. * If a ``branch-.`` branch exists, that branch is used. If more than one such branch exists, the one with the latest version is used. This supports the expected default. * Otherwise, None is returned and a warning is issued. """ - # Try command line - if target_branch_arg: - return target_branch_arg + # Try --target-branch + if args.target_branch: + return args.target_branch # Try environment if target_branch_name := os.getenv("TARGET_BRANCH"): @@ -148,6 +150,10 @@ def get_target_branch(repo, target_branch_arg=None): if target_branch_name: return target_branch_name + # Try --main-branch + if args.main_branch: + return args.main_branch + # Try newest branch-xx.yy try: return max( @@ -170,8 +176,8 @@ def get_target_branch(repo, target_branch_arg=None): return None -def get_target_branch_upstream_commit(repo, target_branch_arg=None): - target_branch_name = get_target_branch(repo, target_branch_arg) +def get_target_branch_upstream_commit(repo, args): + target_branch_name = get_target_branch(repo, args) if target_branch_name is None: try: return repo.head.commit @@ -209,7 +215,7 @@ def try_get_ref(remote): return None -def get_changed_files(target_branch_arg): +def get_changed_files(args): try: repo = git.Repo() except git.InvalidGitRepositoryError: @@ -220,9 +226,7 @@ def get_changed_files(target_branch_arg): } changed_files = {f: None for f in repo.untracked_files} - target_branch_upstream_commit = get_target_branch_upstream_commit( - repo, target_branch_arg - ) + target_branch_upstream_commit = get_target_branch_upstream_commit(repo, args) if target_branch_upstream_commit is None: changed_files.update({blob.path: None for _, blob in repo.index.iter_blobs()}) return changed_files @@ -271,7 +275,7 @@ def find_blob(tree, filename): def check_copyright(args): - changed_files = get_changed_files(args.target_branch) + changed_files = get_changed_files(args) def the_check(linter, args): if not (git_filename := normalize_git_filename(linter.filename)): @@ -303,7 +307,16 @@ def main(): "Verify that all files have had their copyright notices updated. Each file " "will be compared against the target branch (determined automatically or with " "the --target-branch argument) to decide whether or not they need a copyright " - "update." + "update.\n\n" + "--main-branch and --target-branch effectively control the same thing, but " + "--target-branch has higher precedence and is meant only for a user-local " + "override, while --main-branch is a project-wide setting. Both --main-branch " + "and --target-branch may be specified." + ) + m.argparser.add_argument( + "--main-branch", + metavar="
", + help="main branch to use instead of branch-.", ) m.argparser.add_argument( "--target-branch", diff --git a/test/rapids_pre_commit_hooks/test_copyright.py b/test/rapids_pre_commit_hooks/test_copyright.py index 9d2fdd3..166d169 100644 --- a/test/rapids_pre_commit_hooks/test_copyright.py +++ b/test/rapids_pre_commit_hooks/test_copyright.py @@ -173,6 +173,8 @@ def git_repo(): def test_get_target_branch(git_repo): with patch.dict("os.environ", {}, clear=True): + args = Mock(main_branch=None, target_branch=None) + with open(os.path.join(git_repo.working_tree_dir, "file.txt"), "w") as f: f.write("File\n") git_repo.index.add(["file.txt"]) @@ -183,43 +185,49 @@ def test_get_target_branch(git_repo): r"TARGET_BRANCH environment variable, or setting the rapidsai.baseBranch " r"configuration option[.]$", ): - assert copyright.get_target_branch(git_repo) is None + assert copyright.get_target_branch(git_repo, args) is None git_repo.create_head("branch-24.02") - assert copyright.get_target_branch(git_repo) == "branch-24.02" + assert copyright.get_target_branch(git_repo, args) == "branch-24.02" + + args.main_branch = "" + args.target_branch = "" git_repo.create_head("branch-24.04") git_repo.create_head("branch-24.03") - assert copyright.get_target_branch(git_repo) == "branch-24.04" + assert copyright.get_target_branch(git_repo, args) == "branch-24.04" git_repo.create_head("branch-25.01") - assert copyright.get_target_branch(git_repo) == "branch-25.01" + assert copyright.get_target_branch(git_repo, args) == "branch-25.01" + + args.main_branch = "main" + assert copyright.get_target_branch(git_repo, args) == "main" with git_repo.config_writer() as w: w.set_value("rapidsai", "baseBranch", "nonexistent") - assert copyright.get_target_branch(git_repo) == "nonexistent" + assert copyright.get_target_branch(git_repo, args) == "nonexistent" with git_repo.config_writer() as w: w.set_value("rapidsai", "baseBranch", "branch-24.03") - assert copyright.get_target_branch(git_repo) == "branch-24.03" + assert copyright.get_target_branch(git_repo, args) == "branch-24.03" with patch.dict("os.environ", {"RAPIDS_BASE_BRANCH": "nonexistent"}): - assert copyright.get_target_branch(git_repo) == "nonexistent" + assert copyright.get_target_branch(git_repo, args) == "nonexistent" with patch.dict("os.environ", {"RAPIDS_BASE_BRANCH": "master"}): - assert copyright.get_target_branch(git_repo) == "master" + assert copyright.get_target_branch(git_repo, args) == "master" with patch.dict( "os.environ", {"GITHUB_BASE_REF": "nonexistent", "RAPIDS_BASE_BRANCH": "master"}, ): - assert copyright.get_target_branch(git_repo) == "nonexistent" + assert copyright.get_target_branch(git_repo, args) == "nonexistent" with patch.dict( "os.environ", {"GITHUB_BASE_REF": "branch-24.02", "RAPIDS_BASE_BRANCH": "master"}, ): - assert copyright.get_target_branch(git_repo) == "branch-24.02" + assert copyright.get_target_branch(git_repo, args) == "branch-24.02" with patch.dict( "os.environ", @@ -229,7 +237,7 @@ def test_get_target_branch(git_repo): "TARGET_BRANCH": "nonexistent", }, ): - assert copyright.get_target_branch(git_repo) == "nonexistent" + assert copyright.get_target_branch(git_repo, args) == "nonexistent" with patch.dict( "os.environ", @@ -239,9 +247,11 @@ def test_get_target_branch(git_repo): "TARGET_BRANCH": "branch-24.04", }, ): - assert copyright.get_target_branch(git_repo) == "branch-24.04" - assert copyright.get_target_branch(git_repo, "nonexistent") == "nonexistent" - assert copyright.get_target_branch(git_repo, "master") == "master" + assert copyright.get_target_branch(git_repo, args) == "branch-24.04" + args.target_branch = "nonexistent" + assert copyright.get_target_branch(git_repo, args) == "nonexistent" + args.target_branch = "master" + assert copyright.get_target_branch(git_repo, args) == "master" def test_get_target_branch_upstream_commit(git_repo): @@ -365,10 +375,10 @@ def mock_target_branch(branch): remote_repo_2.index.commit("Update file5.txt") with mock_target_branch(None): - assert copyright.get_target_branch_upstream_commit(git_repo) is None + assert copyright.get_target_branch_upstream_commit(git_repo, None) is None with mock_target_branch("branch-1"): - assert copyright.get_target_branch_upstream_commit(git_repo) is None + assert copyright.get_target_branch_upstream_commit(git_repo, None) is None remote_1 = git_repo.create_remote("unconventional/remote/name/1", remote_dir_1) remote_1.fetch([ @@ -403,44 +413,51 @@ def mock_target_branch(branch): with mock_target_branch("branch-1"): assert ( - copyright.get_target_branch_upstream_commit(git_repo) + copyright.get_target_branch_upstream_commit(git_repo, None) == remote_1.refs["branch-1-renamed"].commit ) with mock_target_branch("branch-2"): assert ( - copyright.get_target_branch_upstream_commit(git_repo) + copyright.get_target_branch_upstream_commit(git_repo, None) == remote_1.refs["branch-2"].commit ) with mock_target_branch("branch-3"): assert ( - copyright.get_target_branch_upstream_commit(git_repo) + copyright.get_target_branch_upstream_commit(git_repo, None) == remote_1.refs["branch-3"].commit ) with mock_target_branch("branch-4"): assert ( - copyright.get_target_branch_upstream_commit(git_repo) + copyright.get_target_branch_upstream_commit(git_repo, None) == remote_2.refs["branch-4"].commit ) with mock_target_branch("branch-5"): assert ( - copyright.get_target_branch_upstream_commit(git_repo) + copyright.get_target_branch_upstream_commit(git_repo, None) == remote_2.refs["branch-5"].commit ) with mock_target_branch("branch-6"): assert ( - copyright.get_target_branch_upstream_commit(git_repo) == branch_6.commit + copyright.get_target_branch_upstream_commit(git_repo, None) + == branch_6.commit ) with mock_target_branch("branch-7"): - assert copyright.get_target_branch_upstream_commit(git_repo) == main.commit + assert ( + copyright.get_target_branch_upstream_commit(git_repo, None) + == main.commit + ) with mock_target_branch(None): - assert copyright.get_target_branch_upstream_commit(git_repo) == main.commit + assert ( + copyright.get_target_branch_upstream_commit(git_repo, None) + == main.commit + ) def test_get_changed_files(git_repo): @@ -470,7 +487,7 @@ def mock_os_walk(top): os.mkdir(os.path.join(non_git_dir, "subdir1/subdir2")) with open(os.path.join(non_git_dir, "subdir1", "subdir2", "sub.txt"), "w") as f: f.write("Subdir file\n") - assert copyright.get_changed_files(Mock(target_branch=None)) == { + assert copyright.get_changed_files(None) == { "top.txt": None, "subdir1/subdir2/sub.txt": None, } @@ -515,7 +532,7 @@ def file_contents(verbed): "rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit", Mock(return_value=None), ): - assert copyright.get_changed_files(Mock(target_branch=None)) == { + assert copyright.get_changed_files(None) == { "untouched.txt": None, "copied.txt": None, "modified_and_copied.txt": None, @@ -609,7 +626,7 @@ def file_contents(verbed): "rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit", Mock(return_value=target_branch.commit), ): - changed_files = copyright.get_changed_files(Mock(target_branch=None)) + changed_files = copyright.get_changed_files(None) assert { path: old_blob.path if old_blob else None for path, old_blob in changed_files.items() @@ -730,8 +747,8 @@ def mock_repo_cwd(): return patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir)) def mock_target_branch_upstream_commit(target_branch): - def func(repo, target_branch_arg): - assert target_branch == target_branch_arg + def func(repo, args): + assert target_branch == args.target_branch return repo.heads[target_branch].commit return patch(