Skip to content

Commit

Permalink
Add --main-branch argument to verify-copyright (#20)
Browse files Browse the repository at this point in the history
* Add --main-branch argument to verify-copyright

* Review feedback

* Clarify both --main-branch and --target-branch may be specified

* Fix typo.

---------

Co-authored-by: Bradley Dice <[email protected]>
  • Loading branch information
KyleFromNVIDIA and bdice authored Mar 13, 2024
1 parent 4d14f4c commit 6bc928c
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 42 deletions.
39 changes: 26 additions & 13 deletions src/rapids_pre_commit_hooks/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def apply_copyright_check(linter, old_content):
linter.add_warning((0, 0), "no copyright notice found")


def get_target_branch(repo, target_branch_arg=None):
def get_target_branch(repo, args):
"""Determine which branch is the "target" branch.
The target branch is determined in the following order:
Expand All @@ -123,16 +123,18 @@ def get_target_branch(repo, target_branch_arg=None):
This allows GitHub Actions to easily use this tool.
* If the ``$RAPIDS_BASE_BRANCH`` environment variable is defined, that branch is
used. This allows GitHub Actions inside ``copy-pr-bot`` to easily use this tool.
* If the configuration option ``rapidsai.baseBranch`` is defined, that branch is
* If the Git configuration option ``rapidsai.baseBranch`` is defined, that branch is
used. This allows users to locally set a base branch on a long-term basis.
* If the ``--main-branch`` argument is passed, that branch is used. This allows
projects to use a branching strategy other than ``branch-<major>.<minor>``.
* If a ``branch-<major>.<minor>`` branch exists, that branch is used. If more than
one such branch exists, the one with the latest version is used. This supports the
expected default.
* Otherwise, None is returned and a warning is issued.
"""
# Try command line
if target_branch_arg:
return target_branch_arg
# Try --target-branch
if args.target_branch:
return args.target_branch

# Try environment
if target_branch_name := os.getenv("TARGET_BRANCH"):
Expand All @@ -148,6 +150,10 @@ def get_target_branch(repo, target_branch_arg=None):
if target_branch_name:
return target_branch_name

# Try --main-branch
if args.main_branch:
return args.main_branch

# Try newest branch-xx.yy
try:
return max(
Expand All @@ -170,8 +176,8 @@ def get_target_branch(repo, target_branch_arg=None):
return None


def get_target_branch_upstream_commit(repo, target_branch_arg=None):
target_branch_name = get_target_branch(repo, target_branch_arg)
def get_target_branch_upstream_commit(repo, args):
target_branch_name = get_target_branch(repo, args)
if target_branch_name is None:
try:
return repo.head.commit
Expand Down Expand Up @@ -209,7 +215,7 @@ def try_get_ref(remote):
return None


def get_changed_files(target_branch_arg):
def get_changed_files(args):
try:
repo = git.Repo()
except git.InvalidGitRepositoryError:
Expand All @@ -220,9 +226,7 @@ def get_changed_files(target_branch_arg):
}

changed_files = {f: None for f in repo.untracked_files}
target_branch_upstream_commit = get_target_branch_upstream_commit(
repo, target_branch_arg
)
target_branch_upstream_commit = get_target_branch_upstream_commit(repo, args)
if target_branch_upstream_commit is None:
changed_files.update({blob.path: None for _, blob in repo.index.iter_blobs()})
return changed_files
Expand Down Expand Up @@ -271,7 +275,7 @@ def find_blob(tree, filename):


def check_copyright(args):
changed_files = get_changed_files(args.target_branch)
changed_files = get_changed_files(args)

def the_check(linter, args):
if not (git_filename := normalize_git_filename(linter.filename)):
Expand Down Expand Up @@ -303,7 +307,16 @@ def main():
"Verify that all files have had their copyright notices updated. Each file "
"will be compared against the target branch (determined automatically or with "
"the --target-branch argument) to decide whether or not they need a copyright "
"update."
"update.\n\n"
"--main-branch and --target-branch effectively control the same thing, but "
"--target-branch has higher precedence and is meant only for a user-local "
"override, while --main-branch is a project-wide setting. Both --main-branch "
"and --target-branch may be specified."
)
m.argparser.add_argument(
"--main-branch",
metavar="<main branch>",
help="main branch to use instead of branch-<major>.<minor>",
)
m.argparser.add_argument(
"--target-branch",
Expand Down
75 changes: 46 additions & 29 deletions test/rapids_pre_commit_hooks/test_copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def git_repo():

def test_get_target_branch(git_repo):
with patch.dict("os.environ", {}, clear=True):
args = Mock(main_branch=None, target_branch=None)

with open(os.path.join(git_repo.working_tree_dir, "file.txt"), "w") as f:
f.write("File\n")
git_repo.index.add(["file.txt"])
Expand All @@ -183,43 +185,49 @@ def test_get_target_branch(git_repo):
r"TARGET_BRANCH environment variable, or setting the rapidsai.baseBranch "
r"configuration option[.]$",
):
assert copyright.get_target_branch(git_repo) is None
assert copyright.get_target_branch(git_repo, args) is None

git_repo.create_head("branch-24.02")
assert copyright.get_target_branch(git_repo) == "branch-24.02"
assert copyright.get_target_branch(git_repo, args) == "branch-24.02"

args.main_branch = ""
args.target_branch = ""

git_repo.create_head("branch-24.04")
git_repo.create_head("branch-24.03")
assert copyright.get_target_branch(git_repo) == "branch-24.04"
assert copyright.get_target_branch(git_repo, args) == "branch-24.04"

git_repo.create_head("branch-25.01")
assert copyright.get_target_branch(git_repo) == "branch-25.01"
assert copyright.get_target_branch(git_repo, args) == "branch-25.01"

args.main_branch = "main"
assert copyright.get_target_branch(git_repo, args) == "main"

with git_repo.config_writer() as w:
w.set_value("rapidsai", "baseBranch", "nonexistent")
assert copyright.get_target_branch(git_repo) == "nonexistent"
assert copyright.get_target_branch(git_repo, args) == "nonexistent"

with git_repo.config_writer() as w:
w.set_value("rapidsai", "baseBranch", "branch-24.03")
assert copyright.get_target_branch(git_repo) == "branch-24.03"
assert copyright.get_target_branch(git_repo, args) == "branch-24.03"

with patch.dict("os.environ", {"RAPIDS_BASE_BRANCH": "nonexistent"}):
assert copyright.get_target_branch(git_repo) == "nonexistent"
assert copyright.get_target_branch(git_repo, args) == "nonexistent"

with patch.dict("os.environ", {"RAPIDS_BASE_BRANCH": "master"}):
assert copyright.get_target_branch(git_repo) == "master"
assert copyright.get_target_branch(git_repo, args) == "master"

with patch.dict(
"os.environ",
{"GITHUB_BASE_REF": "nonexistent", "RAPIDS_BASE_BRANCH": "master"},
):
assert copyright.get_target_branch(git_repo) == "nonexistent"
assert copyright.get_target_branch(git_repo, args) == "nonexistent"

with patch.dict(
"os.environ",
{"GITHUB_BASE_REF": "branch-24.02", "RAPIDS_BASE_BRANCH": "master"},
):
assert copyright.get_target_branch(git_repo) == "branch-24.02"
assert copyright.get_target_branch(git_repo, args) == "branch-24.02"

with patch.dict(
"os.environ",
Expand All @@ -229,7 +237,7 @@ def test_get_target_branch(git_repo):
"TARGET_BRANCH": "nonexistent",
},
):
assert copyright.get_target_branch(git_repo) == "nonexistent"
assert copyright.get_target_branch(git_repo, args) == "nonexistent"

with patch.dict(
"os.environ",
Expand All @@ -239,9 +247,11 @@ def test_get_target_branch(git_repo):
"TARGET_BRANCH": "branch-24.04",
},
):
assert copyright.get_target_branch(git_repo) == "branch-24.04"
assert copyright.get_target_branch(git_repo, "nonexistent") == "nonexistent"
assert copyright.get_target_branch(git_repo, "master") == "master"
assert copyright.get_target_branch(git_repo, args) == "branch-24.04"
args.target_branch = "nonexistent"
assert copyright.get_target_branch(git_repo, args) == "nonexistent"
args.target_branch = "master"
assert copyright.get_target_branch(git_repo, args) == "master"


def test_get_target_branch_upstream_commit(git_repo):
Expand Down Expand Up @@ -365,10 +375,10 @@ def mock_target_branch(branch):
remote_repo_2.index.commit("Update file5.txt")

with mock_target_branch(None):
assert copyright.get_target_branch_upstream_commit(git_repo) is None
assert copyright.get_target_branch_upstream_commit(git_repo, None) is None

with mock_target_branch("branch-1"):
assert copyright.get_target_branch_upstream_commit(git_repo) is None
assert copyright.get_target_branch_upstream_commit(git_repo, None) is None

remote_1 = git_repo.create_remote("unconventional/remote/name/1", remote_dir_1)
remote_1.fetch([
Expand Down Expand Up @@ -403,44 +413,51 @@ def mock_target_branch(branch):

with mock_target_branch("branch-1"):
assert (
copyright.get_target_branch_upstream_commit(git_repo)
copyright.get_target_branch_upstream_commit(git_repo, None)
== remote_1.refs["branch-1-renamed"].commit
)

with mock_target_branch("branch-2"):
assert (
copyright.get_target_branch_upstream_commit(git_repo)
copyright.get_target_branch_upstream_commit(git_repo, None)
== remote_1.refs["branch-2"].commit
)

with mock_target_branch("branch-3"):
assert (
copyright.get_target_branch_upstream_commit(git_repo)
copyright.get_target_branch_upstream_commit(git_repo, None)
== remote_1.refs["branch-3"].commit
)

with mock_target_branch("branch-4"):
assert (
copyright.get_target_branch_upstream_commit(git_repo)
copyright.get_target_branch_upstream_commit(git_repo, None)
== remote_2.refs["branch-4"].commit
)

with mock_target_branch("branch-5"):
assert (
copyright.get_target_branch_upstream_commit(git_repo)
copyright.get_target_branch_upstream_commit(git_repo, None)
== remote_2.refs["branch-5"].commit
)

with mock_target_branch("branch-6"):
assert (
copyright.get_target_branch_upstream_commit(git_repo) == branch_6.commit
copyright.get_target_branch_upstream_commit(git_repo, None)
== branch_6.commit
)

with mock_target_branch("branch-7"):
assert copyright.get_target_branch_upstream_commit(git_repo) == main.commit
assert (
copyright.get_target_branch_upstream_commit(git_repo, None)
== main.commit
)

with mock_target_branch(None):
assert copyright.get_target_branch_upstream_commit(git_repo) == main.commit
assert (
copyright.get_target_branch_upstream_commit(git_repo, None)
== main.commit
)


def test_get_changed_files(git_repo):
Expand Down Expand Up @@ -470,7 +487,7 @@ def mock_os_walk(top):
os.mkdir(os.path.join(non_git_dir, "subdir1/subdir2"))
with open(os.path.join(non_git_dir, "subdir1", "subdir2", "sub.txt"), "w") as f:
f.write("Subdir file\n")
assert copyright.get_changed_files(Mock(target_branch=None)) == {
assert copyright.get_changed_files(None) == {
"top.txt": None,
"subdir1/subdir2/sub.txt": None,
}
Expand Down Expand Up @@ -515,7 +532,7 @@ def file_contents(verbed):
"rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit",
Mock(return_value=None),
):
assert copyright.get_changed_files(Mock(target_branch=None)) == {
assert copyright.get_changed_files(None) == {
"untouched.txt": None,
"copied.txt": None,
"modified_and_copied.txt": None,
Expand Down Expand Up @@ -609,7 +626,7 @@ def file_contents(verbed):
"rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit",
Mock(return_value=target_branch.commit),
):
changed_files = copyright.get_changed_files(Mock(target_branch=None))
changed_files = copyright.get_changed_files(None)
assert {
path: old_blob.path if old_blob else None
for path, old_blob in changed_files.items()
Expand Down Expand Up @@ -730,8 +747,8 @@ def mock_repo_cwd():
return patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir))

def mock_target_branch_upstream_commit(target_branch):
def func(repo, target_branch_arg):
assert target_branch == target_branch_arg
def func(repo, args):
assert target_branch == args.target_branch
return repo.heads[target_branch].commit

return patch(
Expand Down

0 comments on commit 6bc928c

Please sign in to comment.