diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index 79a0d70005..2440e61cb1 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -15,8 +15,20 @@ import datetime import re -import gitutils +import argparse +import io +import os +import sys +SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) + +# Add the scripts dir for gitutils +sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, + "../../cpp/scripts"))) + +# Now import gitutils. Ignore flake8 error here since there is no other way to +# set up imports +import gitutils # noqa: E402 FilesToCheck = [ re.compile(r"[.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$"), @@ -26,11 +38,25 @@ re.compile(r"[.]flake8[.]cython$"), re.compile(r"meta[.]yaml$") ] +ExemptFiles = [] + +# this will break starting at year 10000, which is probably OK :) +CheckSimple = re.compile( + r"Copyright *(?:\(c\))? *(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)") +CheckDouble = re.compile( + r"Copyright *(?:\(c\))? *(\d{4})-(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" # noqa: E501 +) def checkThisFile(f): - if gitutils.isFileEmpty(f): + # This check covers things like symlinks which point to files that DNE + if not (os.path.exists(f)): return False + if gitutils and gitutils.isFileEmpty(f): + return False + for exempt in ExemptFiles: + if exempt.search(f): + return False for checker in FilesToCheck: if checker.search(f): return True @@ -38,17 +64,25 @@ def checkThisFile(f): def getCopyrightYears(line): - res = re.search(r"Copyright \(c\) (\d{4}), NVIDIA CORPORATION", line) + res = CheckSimple.search(line) if res: return (int(res.group(1)), int(res.group(1))) - res = re.search(r"Copyright \(c\) (\d{4})-(\d{4}), NVIDIA CORPORATION", - line) + res = CheckDouble.search(line) if res: return (int(res.group(1)), int(res.group(2))) return (None, None) -def checkCopyright(f): +def replaceCurrentYear(line, start, end): + # first turn a simple regex into double (if applicable). then update years + res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line) + res = CheckDouble.sub( + r"Copyright (c) {:04d}-{:04d}, NVIDIA CORPORATION".format(start, end), + res) + return res + + +def checkCopyright(f, update_current_year): """ Checks for copyright headers and their years """ @@ -57,48 +91,152 @@ def checkCopyright(f): lineNum = 0 crFound = False yearMatched = False - fp = open(f, "r") - for line in fp.readlines(): + with io.open(f, "r", encoding="utf-8") as fp: + lines = fp.readlines() + for line in lines: lineNum += 1 start, end = getCopyrightYears(line) if start is None: continue crFound = True + if start > end: + e = [ + f, + lineNum, + "First year after second year in the copyright " + "header (manual fix required)", + None + ] + errs.append(e) if thisYear < start or thisYear > end: - errs.append((f, lineNum, - "Current year not included in the copyright header")) + e = [ + f, + lineNum, + "Current year not included in the " + "copyright header", + None + ] + if thisYear < start: + e[-1] = replaceCurrentYear(line, thisYear, end) + if thisYear > end: + e[-1] = replaceCurrentYear(line, start, thisYear) + errs.append(e) else: yearMatched = True fp.close() # copyright header itself not found if not crFound: - errs.append((f, 0, - "Copyright header missing or formatted incorrectly")) + e = [ + f, + 0, + "Copyright header missing or formatted incorrectly " + "(manual fix required)", + None + ] + errs.append(e) # even if the year matches a copyright header, make the check pass if yearMatched: errs = [] + + if update_current_year: + errs_update = [x for x in errs if x[-1] is not None] + if len(errs_update) > 0: + print("File: {}. Changing line(s) {}".format( + f, ', '.join(str(x[1]) for x in errs if x[-1] is not None))) + for _, lineNum, __, replacement in errs_update: + lines[lineNum - 1] = replacement + with io.open(f, "w", encoding="utf-8") as out_file: + for new_line in lines: + out_file.write(new_line) + errs = [x for x in errs if x[-1] is None] + return errs -def checkCopyrightForAll(): +def getAllFilesUnderDir(root, pathFilter=None): + retList = [] + for (dirpath, dirnames, filenames) in os.walk(root): + for fn in filenames: + filePath = os.path.join(dirpath, fn) + if pathFilter(filePath): + retList.append(filePath) + return retList + + +def checkCopyright_main(): """ Checks for copyright headers in all the modified files. In case of local repo, this script will just look for uncommitted files and in case of CI it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch" """ - files = gitutils.modifiedFiles(filter=checkThisFile) + retVal = 0 + global ExemptFiles + + argparser = argparse.ArgumentParser( + "Checks for a consistent copyright header in git's modified files") + argparser.add_argument("--update-current-year", + dest='update_current_year', + action="store_true", + required=False, + help="If set, " + "update the current year if a header " + "is already present and well formatted.") + argparser.add_argument("--git-modified-only", + dest='git_modified_only', + action="store_true", + required=False, + help="If set, " + "only files seen as modified by git will be " + "processed.") + argparser.add_argument("--exclude", + dest='exclude', + action="append", + required=False, + default=["python/cuml/_thirdparty/"], + help=("Exclude the paths specified (regexp). " + "Can be specified multiple times.")) + + (args, dirs) = argparser.parse_known_args() + try: + ExemptFiles = ExemptFiles + [pathName for pathName in args.exclude] + ExemptFiles = [re.compile(file) for file in ExemptFiles] + except re.error as reException: + print("Regular expression error:") + print(reException) + return 1 + + if args.git_modified_only: + files = gitutils.modifiedFiles(pathFilter=checkThisFile) + else: + files = [] + for d in [os.path.abspath(d) for d in dirs]: + if not (os.path.isdir(d)): + raise ValueError(f"{d} is not a directory.") + files += getAllFilesUnderDir(d, pathFilter=checkThisFile) + errors = [] for f in files: - errors += checkCopyright(f) + errors += checkCopyright(f, args.update_current_year) + if len(errors) > 0: print("Copyright headers incomplete in some of the files!") for e in errors: print(" %s:%d Issue: %s" % (e[0], e[1], e[2])) print("") - raise Exception("Copyright check failed! Check above to know more") + n_fixable = sum(1 for e in errors if e[-1] is not None) + path_parts = os.path.abspath(__file__).split(os.sep) + file_from_repo = os.sep.join(path_parts[path_parts.index("ci"):]) + if n_fixable > 0: + print(("You can run `python {} --git-modified-only " + "--update-current-year` to fix {} of these " + "errors.\n").format(file_from_repo, n_fixable)) + retVal = 1 else: print("Copyright check passed") + return retVal + if __name__ == "__main__": - checkCopyrightForAll() + import sys + sys.exit(checkCopyright_main()) diff --git a/ci/checks/style.sh b/ci/checks/style.sh index e928ccb186..2ce8b446b8 100644 --- a/ci/checks/style.sh +++ b/ci/checks/style.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. ##################### # RAFT Style Tester # ##################### @@ -26,7 +26,7 @@ else fi # Check for copyright headers in the files modified currently -COPYRIGHT=`env PYTHONPATH=cpp/scripts python ci/checks/copyright.py 2>&1` +COPYRIGHT=`python ci/checks/copyright.py --git-modified-only 2>&1` CR_RETVAL=$? if [ "$RETVAL" = "0" ]; then RETVAL=$CR_RETVAL diff --git a/cpp/scripts/gitutils.py b/cpp/scripts/gitutils.py index cde5571871..8d4af79129 100644 --- a/cpp/scripts/gitutils.py +++ b/cpp/scripts/gitutils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ def __git(*opts): """Runs a git command and returns its output""" cmd = "git " + " ".join(list(opts)) ret = subprocess.check_output(cmd, shell=True) - return ret.decode("UTF-8") + return ret.decode("UTF-8").rstrip("\n") def __gitdiff(*opts): @@ -41,6 +41,111 @@ def branch(): return name +def repo_version(): + """ + Determines the version of the repo by using `git describe` + + Returns + ------- + str + The full version of the repo in the format 'v#.#.#{a|b|rc}' + """ + return __git("describe", "--tags", "--abbrev=0") + + +def repo_version_major_minor(): + """ + Determines the version of the repo using `git describe` and returns only + the major and minor portion + + Returns + ------- + str + The partial version of the repo in the format '{major}.{minor}' + """ + + full_repo_version = repo_version() + + match = re.match(r"^v?(?P[0-9]+)(?:\.(?P[0-9]+))?", + full_repo_version) + + if (match is None): + print(" [DEBUG] Could not determine repo major minor version. " + f"Full repo version: {full_repo_version}.") + return None + + out_version = match.group("major") + + if (match.group("minor")): + out_version += "." + match.group("minor") + + return out_version + + +def determine_merge_commit(current_branch="HEAD"): + """ + When running outside of CI, this will estimate the target merge commit hash + of `current_branch` by finding a common ancester with the remote branch + 'branch-{major}.{minor}' where {major} and {minor} are determined from the + repo version. + + Parameters + ---------- + current_branch : str, optional + Which branch to consider as the current branch, by default "HEAD" + + Returns + ------- + str + The common commit hash ID + """ + + try: + # Try to determine the target branch from the most recent tag + head_branch = __git("describe", + "--all", + "--tags", + "--match='branch-*'", + "--abbrev=0") + except subprocess.CalledProcessError: + print(" [DEBUG] Could not determine target branch from most recent " + "tag. Falling back to 'branch-{major}.{minor}.") + head_branch = None + + if (head_branch is not None): + # Convert from head to branch name + head_branch = __git("name-rev", "--name-only", head_branch) + else: + # Try and guess the target branch as "branch-." + version = repo_version_major_minor() + + if (version is None): + return None + + head_branch = "branch-{}".format(version) + + try: + # Now get the remote tracking branch + remote_branch = __git("rev-parse", + "--abbrev-ref", + "--symbolic-full-name", + head_branch + "@{upstream}") + except subprocess.CalledProcessError: + print(" [DEBUG] Could not remote tracking reference for " + f"branch {head_branch}.") + remote_branch = None + + if (remote_branch is None): + return None + + print(f" [DEBUG] Determined TARGET_BRANCH as: '{remote_branch}'. " + "Finding common ancestor.") + + common_commit = __git("merge-base", remote_branch, current_branch) + + return common_commit + + def uncommittedFiles(): """ Returns a list of all changed files that are not yet committed. This @@ -59,14 +164,25 @@ def uncommittedFiles(): return ret -def changedFilesBetween(b1, b2): - """Returns a list of files changed between branches b1 and b2""" +def changedFilesBetween(baseName, branchName, commitHash): + """ + Returns a list of files changed between branches baseName and latest commit + of branchName. + """ current = branch() - __git("checkout", "--quiet", b1) - __git("checkout", "--quiet", b2) - files = __gitdiff("--name-only", "--ignore-submodules", "%s...%s" % - (b1, b2)) - __git("checkout", "--quiet", current) + # checkout "base" branch + __git("checkout", "--force", baseName) + # checkout branch for comparing + __git("checkout", "--force", branchName) + # checkout latest commit from branch + __git("checkout", "-fq", commitHash) + + files = __gitdiff("--name-only", + "--ignore-submodules", + f"{baseName}..{branchName}") + + # restore the original branch + __git("checkout", "--force", current) return files.splitlines() @@ -75,8 +191,13 @@ def changesInFileBetween(file, b1, b2, filter=None): current = branch() __git("checkout", "--quiet", b1) __git("checkout", "--quiet", b2) - diffs = __gitdiff("--ignore-submodules", "-w", "--minimal", "-U0", - "%s...%s" % (b1, b2), "--", file) + diffs = __gitdiff("--ignore-submodules", + "-w", + "--minimal", + "-U0", + "%s...%s" % (b1, b2), + "--", + file) __git("checkout", "--quiet", current) lines = [] for line in diffs.splitlines(): @@ -85,12 +206,14 @@ def changesInFileBetween(file, b1, b2, filter=None): return lines -def modifiedFiles(filter=None): +def modifiedFiles(pathFilter=None): """ - If inside a CI-env (ie. currentBranch=current-pr-branch and the env-var - PR_TARGET_BRANCH is defined), then lists out all files modified between - these 2 branches. Else, lists out all the uncommitted files in the current - branch. + If inside a CI-env (ie. TARGET_BRANCH and COMMIT_HASH are defined, and + current branch is "current-pr-branch"), then lists out all files modified + between these 2 branches. Locally, TARGET_BRANCH will try to be determined + from the current repo version and finding a coresponding branch named + 'branch-{major}.{minor}'. If this fails, this functino will list out all + the uncommitted files in the current branch. Such utility function is helpful while putting checker scripts as part of cmake, as well as CI process. This way, during development, only the files @@ -98,15 +221,41 @@ def modifiedFiles(filter=None): process ALL files modified by the dev, as submiited in the PR, will be checked. This happens, all the while using the same script. """ - if "PR_TARGET_BRANCH" in os.environ and branch() == "current-pr-branch": - allFiles = changedFilesBetween(os.environ["PR_TARGET_BRANCH"], - branch()) + targetBranch = os.environ.get("TARGET_BRANCH") + commitHash = os.environ.get("COMMIT_HASH") + currentBranch = branch() + print( + f" [DEBUG] TARGET_BRANCH={targetBranch}, COMMIT_HASH={commitHash}, " + f"currentBranch={currentBranch}") + + if targetBranch and commitHash and (currentBranch == "current-pr-branch"): + print(" [DEBUG] Assuming a CI environment.") + allFiles = changedFilesBetween(targetBranch, currentBranch, commitHash) else: - allFiles = uncommittedFiles() + print(" [DEBUG] Did not detect CI environment. " + "Determining TARGET_BRANCH locally.") + + common_commit = determine_merge_commit(currentBranch) + + if (common_commit is not None): + + # Now get the diff. Use --staged to get both diff between + # common_commit..HEAD and any locally staged files + allFiles = __gitdiff("--name-only", + "--ignore-submodules", + "--staged", + f"{common_commit}").splitlines() + else: + # Fallback to just uncommitted files + allFiles = uncommittedFiles() + files = [] for f in allFiles: - if filter is None or filter(f): + if pathFilter is None or pathFilter(f): files.append(f) + + filesToCheckString = "\n\t".join(files) if files else "" + print(f" [DEBUG] Found files to check:\n\t{filesToCheckString}\n") return files @@ -131,7 +280,7 @@ def listFilesToCheck(filesDirs, filter=None): allFiles.append(f) elif os.path.isdir(f): files = listAllFilesInDir(f) - for f in files: - if filter is None or filter(f): - allFiles.append(f) + for f_ in files: + if filter is None or filter(f_): + allFiles.append(f_) return allFiles