Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --main-branch argument to verify-copyright #20

Merged
merged 4 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions src/rapids_pre_commit_hooks/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def apply_copyright_check(linter, old_content):
linter.add_warning((0, 0), "no copyright notice found")


def get_target_branch(repo, target_branch_arg=None):
def get_target_branch(repo, args):
"""Determine which branch is the "target" branch.

The target branch is determined in the following order:
Expand All @@ -123,16 +123,18 @@ def get_target_branch(repo, target_branch_arg=None):
This allows GitHub Actions to easily use this tool.
* If the ``$RAPIDS_BASE_BRANCH`` environment variable is defined, that branch is
used. This allows GitHub Actions inside ``copy-pr-bot`` to easily use this tool.
* If the configuration option ``rapidsai.baseBranch`` is defined, that branch is
* If the Git configuration option ``rapidsai.baseBranch`` is defined, that branch is
used. This allows users to locally set a base branch on a long-term basis.
* If the ``--main-branch`` argument is passed, that branch is used. This allows
projects to use a branching strategy other than ``branch-<major>.<minor>``.
* If a ``branch-<major>.<minor>`` branch exists, that branch is used. If more than
one such branch exists, the one with the latest version is used. This supports the
expected default.
* Otherwise, None is returned and a warning is issued.
"""
# Try command line
if target_branch_arg:
return target_branch_arg
# Try --target-branch
if args.target_branch:
return args.target_branch

# Try environment
if target_branch_name := os.getenv("TARGET_BRANCH"):
Expand All @@ -148,6 +150,10 @@ def get_target_branch(repo, target_branch_arg=None):
if target_branch_name:
return target_branch_name

# Try --main-branch
if args.main_branch:
return args.main_branch

# Try newest branch-xx.yy
try:
return max(
Expand All @@ -170,8 +176,8 @@ def get_target_branch(repo, target_branch_arg=None):
return None


def get_target_branch_upstream_commit(repo, target_branch_arg=None):
target_branch_name = get_target_branch(repo, target_branch_arg)
def get_target_branch_upstream_commit(repo, args):
target_branch_name = get_target_branch(repo, args)
if target_branch_name is None:
try:
return repo.head.commit
Expand Down Expand Up @@ -209,7 +215,7 @@ def try_get_ref(remote):
return None


def get_changed_files(target_branch_arg):
def get_changed_files(args):
try:
repo = git.Repo()
except git.InvalidGitRepositoryError:
Expand All @@ -220,9 +226,7 @@ def get_changed_files(target_branch_arg):
}

changed_files = {f: None for f in repo.untracked_files}
target_branch_upstream_commit = get_target_branch_upstream_commit(
repo, target_branch_arg
)
target_branch_upstream_commit = get_target_branch_upstream_commit(repo, args)
if target_branch_upstream_commit is None:
changed_files.update({blob.path: None for _, blob in repo.index.iter_blobs()})
return changed_files
Expand Down Expand Up @@ -271,7 +275,7 @@ def find_blob(tree, filename):


def check_copyright(args):
changed_files = get_changed_files(args.target_branch)
changed_files = get_changed_files(args)

def the_check(linter, args):
if not (git_filename := normalize_git_filename(linter.filename)):
Expand Down Expand Up @@ -303,7 +307,16 @@ def main():
"Verify that all files have had their copyright notices updated. Each file "
"will be compared against the target branch (determined automatically or with "
"the --target-branch argument) to decide whether or not they need a copyright "
"update."
"update.\n\n"
"--main-branch and --target-branch effectively control the same thing, but "
"--target-branch has higher precedence and is meant only for a user-local "
"override, while --main-branch is a project-wide setting. Both --main-branch "
"and --target-branch may be specified."
)
m.argparser.add_argument(
"--main-branch",
metavar="<main branch>",
help="main branch to use instead of branch-<major>.<minor>",
KyleFromNVIDIA marked this conversation as resolved.
Show resolved Hide resolved
)
m.argparser.add_argument(
"--target-branch",
Expand Down
75 changes: 46 additions & 29 deletions test/rapids_pre_commit_hooks/test_copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def git_repo():

def test_get_target_branch(git_repo):
with patch.dict("os.environ", {}, clear=True):
args = Mock(main_branch=None, target_branch=None)

with open(os.path.join(git_repo.working_tree_dir, "file.txt"), "w") as f:
f.write("File\n")
git_repo.index.add(["file.txt"])
Expand All @@ -183,43 +185,49 @@ def test_get_target_branch(git_repo):
r"TARGET_BRANCH environment variable, or setting the rapidsai.baseBranch "
r"configuration option[.]$",
):
assert copyright.get_target_branch(git_repo) is None
assert copyright.get_target_branch(git_repo, args) is None

git_repo.create_head("branch-24.02")
assert copyright.get_target_branch(git_repo) == "branch-24.02"
assert copyright.get_target_branch(git_repo, args) == "branch-24.02"

args.main_branch = ""
args.target_branch = ""

git_repo.create_head("branch-24.04")
git_repo.create_head("branch-24.03")
assert copyright.get_target_branch(git_repo) == "branch-24.04"
assert copyright.get_target_branch(git_repo, args) == "branch-24.04"

git_repo.create_head("branch-25.01")
assert copyright.get_target_branch(git_repo) == "branch-25.01"
assert copyright.get_target_branch(git_repo, args) == "branch-25.01"

args.main_branch = "main"
assert copyright.get_target_branch(git_repo, args) == "main"

with git_repo.config_writer() as w:
w.set_value("rapidsai", "baseBranch", "nonexistent")
assert copyright.get_target_branch(git_repo) == "nonexistent"
assert copyright.get_target_branch(git_repo, args) == "nonexistent"

with git_repo.config_writer() as w:
w.set_value("rapidsai", "baseBranch", "branch-24.03")
assert copyright.get_target_branch(git_repo) == "branch-24.03"
assert copyright.get_target_branch(git_repo, args) == "branch-24.03"

with patch.dict("os.environ", {"RAPIDS_BASE_BRANCH": "nonexistent"}):
assert copyright.get_target_branch(git_repo) == "nonexistent"
assert copyright.get_target_branch(git_repo, args) == "nonexistent"

with patch.dict("os.environ", {"RAPIDS_BASE_BRANCH": "master"}):
assert copyright.get_target_branch(git_repo) == "master"
assert copyright.get_target_branch(git_repo, args) == "master"

with patch.dict(
"os.environ",
{"GITHUB_BASE_REF": "nonexistent", "RAPIDS_BASE_BRANCH": "master"},
):
assert copyright.get_target_branch(git_repo) == "nonexistent"
assert copyright.get_target_branch(git_repo, args) == "nonexistent"

with patch.dict(
"os.environ",
{"GITHUB_BASE_REF": "branch-24.02", "RAPIDS_BASE_BRANCH": "master"},
):
assert copyright.get_target_branch(git_repo) == "branch-24.02"
assert copyright.get_target_branch(git_repo, args) == "branch-24.02"

with patch.dict(
"os.environ",
Expand All @@ -229,7 +237,7 @@ def test_get_target_branch(git_repo):
"TARGET_BRANCH": "nonexistent",
},
):
assert copyright.get_target_branch(git_repo) == "nonexistent"
assert copyright.get_target_branch(git_repo, args) == "nonexistent"

with patch.dict(
"os.environ",
Expand All @@ -239,9 +247,11 @@ def test_get_target_branch(git_repo):
"TARGET_BRANCH": "branch-24.04",
},
):
assert copyright.get_target_branch(git_repo) == "branch-24.04"
assert copyright.get_target_branch(git_repo, "nonexistent") == "nonexistent"
assert copyright.get_target_branch(git_repo, "master") == "master"
assert copyright.get_target_branch(git_repo, args) == "branch-24.04"
args.target_branch = "nonexistent"
assert copyright.get_target_branch(git_repo, args) == "nonexistent"
args.target_branch = "master"
assert copyright.get_target_branch(git_repo, args) == "master"


def test_get_target_branch_upstream_commit(git_repo):
Expand Down Expand Up @@ -365,10 +375,10 @@ def mock_target_branch(branch):
remote_repo_2.index.commit("Update file5.txt")

with mock_target_branch(None):
assert copyright.get_target_branch_upstream_commit(git_repo) is None
assert copyright.get_target_branch_upstream_commit(git_repo, None) is None

with mock_target_branch("branch-1"):
assert copyright.get_target_branch_upstream_commit(git_repo) is None
assert copyright.get_target_branch_upstream_commit(git_repo, None) is None

remote_1 = git_repo.create_remote("unconventional/remote/name/1", remote_dir_1)
remote_1.fetch([
Expand Down Expand Up @@ -403,44 +413,51 @@ def mock_target_branch(branch):

with mock_target_branch("branch-1"):
assert (
copyright.get_target_branch_upstream_commit(git_repo)
copyright.get_target_branch_upstream_commit(git_repo, None)
== remote_1.refs["branch-1-renamed"].commit
)

with mock_target_branch("branch-2"):
assert (
copyright.get_target_branch_upstream_commit(git_repo)
copyright.get_target_branch_upstream_commit(git_repo, None)
== remote_1.refs["branch-2"].commit
)

with mock_target_branch("branch-3"):
assert (
copyright.get_target_branch_upstream_commit(git_repo)
copyright.get_target_branch_upstream_commit(git_repo, None)
== remote_1.refs["branch-3"].commit
)

with mock_target_branch("branch-4"):
assert (
copyright.get_target_branch_upstream_commit(git_repo)
copyright.get_target_branch_upstream_commit(git_repo, None)
== remote_2.refs["branch-4"].commit
)

with mock_target_branch("branch-5"):
assert (
copyright.get_target_branch_upstream_commit(git_repo)
copyright.get_target_branch_upstream_commit(git_repo, None)
== remote_2.refs["branch-5"].commit
)

with mock_target_branch("branch-6"):
assert (
copyright.get_target_branch_upstream_commit(git_repo) == branch_6.commit
copyright.get_target_branch_upstream_commit(git_repo, None)
== branch_6.commit
)

with mock_target_branch("branch-7"):
assert copyright.get_target_branch_upstream_commit(git_repo) == main.commit
assert (
copyright.get_target_branch_upstream_commit(git_repo, None)
== main.commit
)

with mock_target_branch(None):
assert copyright.get_target_branch_upstream_commit(git_repo) == main.commit
assert (
copyright.get_target_branch_upstream_commit(git_repo, None)
== main.commit
)


def test_get_changed_files(git_repo):
Expand Down Expand Up @@ -470,7 +487,7 @@ def mock_os_walk(top):
os.mkdir(os.path.join(non_git_dir, "subdir1/subdir2"))
with open(os.path.join(non_git_dir, "subdir1", "subdir2", "sub.txt"), "w") as f:
f.write("Subdir file\n")
assert copyright.get_changed_files(Mock(target_branch=None)) == {
assert copyright.get_changed_files(None) == {
"top.txt": None,
"subdir1/subdir2/sub.txt": None,
}
Expand Down Expand Up @@ -515,7 +532,7 @@ def file_contents(verbed):
"rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit",
Mock(return_value=None),
):
assert copyright.get_changed_files(Mock(target_branch=None)) == {
assert copyright.get_changed_files(None) == {
"untouched.txt": None,
"copied.txt": None,
"modified_and_copied.txt": None,
Expand Down Expand Up @@ -609,7 +626,7 @@ def file_contents(verbed):
"rapids_pre_commit_hooks.copyright.get_target_branch_upstream_commit",
Mock(return_value=target_branch.commit),
):
changed_files = copyright.get_changed_files(Mock(target_branch=None))
changed_files = copyright.get_changed_files(None)
assert {
path: old_blob.path if old_blob else None
for path, old_blob in changed_files.items()
Expand Down Expand Up @@ -730,8 +747,8 @@ def mock_repo_cwd():
return patch("os.getcwd", Mock(return_value=git_repo.working_tree_dir))

def mock_target_branch_upstream_commit(target_branch):
def func(repo, target_branch_arg):
assert target_branch == target_branch_arg
def func(repo, args):
assert target_branch == args.target_branch
return repo.heads[target_branch].commit

return patch(
Expand Down