From 81e0deb5685ae5698278e7b1b7d8523c9d6a09b9 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 21 Feb 2022 17:45:49 -0500 Subject: [PATCH 1/6] iUpdating some of the ci check scripts --- ci/checks/copyright.py | 174 +++++++++++++++++++--- ci/checks/style.sh | 2 +- cpp/scripts/gitutils.py | 197 +++++++++++++++++++++--- cpp/scripts/include_checker.py | 263 +++++++++++++++++++++++++++++---- 4 files changed, 565 insertions(+), 71 deletions(-) diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index 79a0d70005..b0e11014cf 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +15,20 @@ import datetime import re -import gitutils +import argparse +import io +import os +import sys +SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) + +# Add the scripts dir for gitutils +sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, + "../../cpp/scripts"))) + +# Now import gitutils. Ignore flake8 error here since there is no other way to +# set up imports +import gitutils # noqa: E402 FilesToCheck = [ re.compile(r"[.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$"), @@ -26,11 +38,25 @@ re.compile(r"[.]flake8[.]cython$"), re.compile(r"meta[.]yaml$") ] +ExemptFiles = [] + +# this will break starting at year 10000, which is probably OK :) +CheckSimple = re.compile( + r"Copyright *(?:\(c\))? *(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)") +CheckDouble = re.compile( + r"Copyright *(?:\(c\))? *(\d{4})-(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" # noqa: E501 +) def checkThisFile(f): - if gitutils.isFileEmpty(f): + # This check covers things like symlinks which point to files that DNE + if not (os.path.exists(f)): return False + if gitutils and gitutils.isFileEmpty(f): + return False + for exempt in ExemptFiles: + if exempt.search(f): + return False for checker in FilesToCheck: if checker.search(f): return True @@ -38,17 +64,25 @@ def checkThisFile(f): def getCopyrightYears(line): - res = re.search(r"Copyright \(c\) (\d{4}), NVIDIA CORPORATION", line) + res = CheckSimple.search(line) if res: return (int(res.group(1)), int(res.group(1))) - res = re.search(r"Copyright \(c\) (\d{4})-(\d{4}), NVIDIA CORPORATION", - line) + res = CheckDouble.search(line) if res: return (int(res.group(1)), int(res.group(2))) return (None, None) -def checkCopyright(f): +def replaceCurrentYear(line, start, end): + # first turn a simple regex into double (if applicable). then update years + res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line) + res = CheckDouble.sub( + r"Copyright (c) {:04d}-{:04d}, NVIDIA CORPORATION".format(start, end), + res) + return res + + +def checkCopyright(f, update_current_year): """ Checks for copyright headers and their years """ @@ -57,48 +91,152 @@ def checkCopyright(f): lineNum = 0 crFound = False yearMatched = False - fp = open(f, "r") - for line in fp.readlines(): + with io.open(f, "r", encoding="utf-8") as fp: + lines = fp.readlines() + for line in lines: lineNum += 1 start, end = getCopyrightYears(line) if start is None: continue crFound = True + if start > end: + e = [ + f, + lineNum, + "First year after second year in the copyright " + "header (manual fix required)", + None + ] + errs.append(e) if thisYear < start or thisYear > end: - errs.append((f, lineNum, - "Current year not included in the copyright header")) + e = [ + f, + lineNum, + "Current year not included in the " + "copyright header", + None + ] + if thisYear < start: + e[-1] = replaceCurrentYear(line, thisYear, end) + if thisYear > end: + e[-1] = replaceCurrentYear(line, start, thisYear) + errs.append(e) else: yearMatched = True fp.close() # copyright header itself not found if not crFound: - errs.append((f, 0, - "Copyright header missing or formatted incorrectly")) + e = [ + f, + 0, + "Copyright header missing or formatted incorrectly " + "(manual fix required)", + None + ] + errs.append(e) # even if the year matches a copyright header, make the check pass if yearMatched: errs = [] + + if update_current_year: + errs_update = [x for x in errs if x[-1] is not None] + if len(errs_update) > 0: + print("File: {}. Changing line(s) {}".format( + f, ', '.join(str(x[1]) for x in errs if x[-1] is not None))) + for _, lineNum, __, replacement in errs_update: + lines[lineNum - 1] = replacement + with io.open(f, "w", encoding="utf-8") as out_file: + for new_line in lines: + out_file.write(new_line) + errs = [x for x in errs if x[-1] is None] + return errs -def checkCopyrightForAll(): +def getAllFilesUnderDir(root, pathFilter=None): + retList = [] + for (dirpath, dirnames, filenames) in os.walk(root): + for fn in filenames: + filePath = os.path.join(dirpath, fn) + if pathFilter(filePath): + retList.append(filePath) + return retList + + +def checkCopyright_main(): """ Checks for copyright headers in all the modified files. In case of local repo, this script will just look for uncommitted files and in case of CI it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch" """ - files = gitutils.modifiedFiles(filter=checkThisFile) + retVal = 0 + global ExemptFiles + + argparser = argparse.ArgumentParser( + "Checks for a consistent copyright header in git's modified files") + argparser.add_argument("--update-current-year", + dest='update_current_year', + action="store_true", + required=False, + help="If set, " + "update the current year if a header is already " + "present and well formatted.") + argparser.add_argument("--git-modified-only", + dest='git_modified_only', + action="store_true", + required=False, + help="If set, " + "only files seen as modified by git will be " + "processed.") + argparser.add_argument("--exclude", + dest='exclude', + action="append", + required=False, + default=["python/cuml/_thirdparty/"], + help=("Exclude the paths specified (regexp). " + "Can be specified multiple times.")) + + (args, dirs) = argparser.parse_known_args() + try: + ExemptFiles = ExemptFiles + [pathName for pathName in args.exclude] + ExemptFiles = [re.compile(file) for file in ExemptFiles] + except re.error as reException: + print("Regular expression error:") + print(reException) + return 1 + + if args.git_modified_only: + files = gitutils.modifiedFiles(pathFilter=checkThisFile) + else: + files = [] + for d in [os.path.abspath(d) for d in dirs]: + if not (os.path.isdir(d)): + raise ValueError(f"{d} is not a directory.") + files += getAllFilesUnderDir(d, pathFilter=checkThisFile) + errors = [] for f in files: - errors += checkCopyright(f) + errors += checkCopyright(f, args.update_current_year) + if len(errors) > 0: print("Copyright headers incomplete in some of the files!") for e in errors: print(" %s:%d Issue: %s" % (e[0], e[1], e[2])) print("") - raise Exception("Copyright check failed! Check above to know more") + n_fixable = sum(1 for e in errors if e[-1] is not None) + path_parts = os.path.abspath(__file__).split(os.sep) + file_from_repo = os.sep.join(path_parts[path_parts.index("ci"):]) + if n_fixable > 0: + print(("You can run `python {} --git-modified-only " + "--update-current-year` to fix {} of these " + "errors.\n").format(file_from_repo, n_fixable)) + retVal = 1 else: print("Copyright check passed") + return retVal + if __name__ == "__main__": - checkCopyrightForAll() + import sys + sys.exit(checkCopyright_main()) diff --git a/ci/checks/style.sh b/ci/checks/style.sh index e928ccb186..b57b3f3557 100644 --- a/ci/checks/style.sh +++ b/ci/checks/style.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. ##################### # RAFT Style Tester # ##################### diff --git a/cpp/scripts/gitutils.py b/cpp/scripts/gitutils.py index cde5571871..8d4af79129 100644 --- a/cpp/scripts/gitutils.py +++ b/cpp/scripts/gitutils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ def __git(*opts): """Runs a git command and returns its output""" cmd = "git " + " ".join(list(opts)) ret = subprocess.check_output(cmd, shell=True) - return ret.decode("UTF-8") + return ret.decode("UTF-8").rstrip("\n") def __gitdiff(*opts): @@ -41,6 +41,111 @@ def branch(): return name +def repo_version(): + """ + Determines the version of the repo by using `git describe` + + Returns + ------- + str + The full version of the repo in the format 'v#.#.#{a|b|rc}' + """ + return __git("describe", "--tags", "--abbrev=0") + + +def repo_version_major_minor(): + """ + Determines the version of the repo using `git describe` and returns only + the major and minor portion + + Returns + ------- + str + The partial version of the repo in the format '{major}.{minor}' + """ + + full_repo_version = repo_version() + + match = re.match(r"^v?(?P[0-9]+)(?:\.(?P[0-9]+))?", + full_repo_version) + + if (match is None): + print(" [DEBUG] Could not determine repo major minor version. " + f"Full repo version: {full_repo_version}.") + return None + + out_version = match.group("major") + + if (match.group("minor")): + out_version += "." + match.group("minor") + + return out_version + + +def determine_merge_commit(current_branch="HEAD"): + """ + When running outside of CI, this will estimate the target merge commit hash + of `current_branch` by finding a common ancester with the remote branch + 'branch-{major}.{minor}' where {major} and {minor} are determined from the + repo version. + + Parameters + ---------- + current_branch : str, optional + Which branch to consider as the current branch, by default "HEAD" + + Returns + ------- + str + The common commit hash ID + """ + + try: + # Try to determine the target branch from the most recent tag + head_branch = __git("describe", + "--all", + "--tags", + "--match='branch-*'", + "--abbrev=0") + except subprocess.CalledProcessError: + print(" [DEBUG] Could not determine target branch from most recent " + "tag. Falling back to 'branch-{major}.{minor}.") + head_branch = None + + if (head_branch is not None): + # Convert from head to branch name + head_branch = __git("name-rev", "--name-only", head_branch) + else: + # Try and guess the target branch as "branch-." + version = repo_version_major_minor() + + if (version is None): + return None + + head_branch = "branch-{}".format(version) + + try: + # Now get the remote tracking branch + remote_branch = __git("rev-parse", + "--abbrev-ref", + "--symbolic-full-name", + head_branch + "@{upstream}") + except subprocess.CalledProcessError: + print(" [DEBUG] Could not remote tracking reference for " + f"branch {head_branch}.") + remote_branch = None + + if (remote_branch is None): + return None + + print(f" [DEBUG] Determined TARGET_BRANCH as: '{remote_branch}'. " + "Finding common ancestor.") + + common_commit = __git("merge-base", remote_branch, current_branch) + + return common_commit + + def uncommittedFiles(): """ Returns a list of all changed files that are not yet committed. This @@ -59,14 +164,25 @@ def uncommittedFiles(): return ret -def changedFilesBetween(b1, b2): - """Returns a list of files changed between branches b1 and b2""" +def changedFilesBetween(baseName, branchName, commitHash): + """ + Returns a list of files changed between branches baseName and latest commit + of branchName. + """ current = branch() - __git("checkout", "--quiet", b1) - __git("checkout", "--quiet", b2) - files = __gitdiff("--name-only", "--ignore-submodules", "%s...%s" % - (b1, b2)) - __git("checkout", "--quiet", current) + # checkout "base" branch + __git("checkout", "--force", baseName) + # checkout branch for comparing + __git("checkout", "--force", branchName) + # checkout latest commit from branch + __git("checkout", "-fq", commitHash) + + files = __gitdiff("--name-only", + "--ignore-submodules", + f"{baseName}..{branchName}") + + # restore the original branch + __git("checkout", "--force", current) return files.splitlines() @@ -75,8 +191,13 @@ def changesInFileBetween(file, b1, b2, filter=None): current = branch() __git("checkout", "--quiet", b1) __git("checkout", "--quiet", b2) - diffs = __gitdiff("--ignore-submodules", "-w", "--minimal", "-U0", - "%s...%s" % (b1, b2), "--", file) + diffs = __gitdiff("--ignore-submodules", + "-w", + "--minimal", + "-U0", + "%s...%s" % (b1, b2), + "--", + file) __git("checkout", "--quiet", current) lines = [] for line in diffs.splitlines(): @@ -85,12 +206,14 @@ def changesInFileBetween(file, b1, b2, filter=None): return lines -def modifiedFiles(filter=None): +def modifiedFiles(pathFilter=None): """ - If inside a CI-env (ie. currentBranch=current-pr-branch and the env-var - PR_TARGET_BRANCH is defined), then lists out all files modified between - these 2 branches. Else, lists out all the uncommitted files in the current - branch. + If inside a CI-env (ie. TARGET_BRANCH and COMMIT_HASH are defined, and + current branch is "current-pr-branch"), then lists out all files modified + between these 2 branches. Locally, TARGET_BRANCH will try to be determined + from the current repo version and finding a coresponding branch named + 'branch-{major}.{minor}'. If this fails, this functino will list out all + the uncommitted files in the current branch. Such utility function is helpful while putting checker scripts as part of cmake, as well as CI process. This way, during development, only the files @@ -98,15 +221,41 @@ def modifiedFiles(filter=None): process ALL files modified by the dev, as submiited in the PR, will be checked. This happens, all the while using the same script. """ - if "PR_TARGET_BRANCH" in os.environ and branch() == "current-pr-branch": - allFiles = changedFilesBetween(os.environ["PR_TARGET_BRANCH"], - branch()) + targetBranch = os.environ.get("TARGET_BRANCH") + commitHash = os.environ.get("COMMIT_HASH") + currentBranch = branch() + print( + f" [DEBUG] TARGET_BRANCH={targetBranch}, COMMIT_HASH={commitHash}, " + f"currentBranch={currentBranch}") + + if targetBranch and commitHash and (currentBranch == "current-pr-branch"): + print(" [DEBUG] Assuming a CI environment.") + allFiles = changedFilesBetween(targetBranch, currentBranch, commitHash) else: - allFiles = uncommittedFiles() + print(" [DEBUG] Did not detect CI environment. " + "Determining TARGET_BRANCH locally.") + + common_commit = determine_merge_commit(currentBranch) + + if (common_commit is not None): + + # Now get the diff. Use --staged to get both diff between + # common_commit..HEAD and any locally staged files + allFiles = __gitdiff("--name-only", + "--ignore-submodules", + "--staged", + f"{common_commit}").splitlines() + else: + # Fallback to just uncommitted files + allFiles = uncommittedFiles() + files = [] for f in allFiles: - if filter is None or filter(f): + if pathFilter is None or pathFilter(f): files.append(f) + + filesToCheckString = "\n\t".join(files) if files else "" + print(f" [DEBUG] Found files to check:\n\t{filesToCheckString}\n") return files @@ -131,7 +280,7 @@ def listFilesToCheck(filesDirs, filter=None): allFiles.append(f) elif os.path.isdir(f): files = listAllFilesInDir(f) - for f in files: - if filter is None or filter(f): - allFiles.append(f) + for f_ in files: + if filter is None or filter(f_): + allFiles.append(f_) return allFiles diff --git a/cpp/scripts/include_checker.py b/cpp/scripts/include_checker.py index 1ced05e743..2f8e6e4781 100644 --- a/cpp/scripts/include_checker.py +++ b/cpp/scripts/include_checker.py @@ -1,4 +1,5 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,27 +18,83 @@ import sys import re import os -import subprocess import argparse +import io +from functools import reduce +import operator +import dataclasses +import typing - -IncludeRegex = re.compile(r"\s*#include\s*(\S+)") -RemoveComments = re.compile(r"//.*") +# file names could (in theory) contain simple white-space +IncludeRegex = re.compile(r"(\s*#include\s*)([\"<])([\S ]+)([\">])") +PragmaRegex = re.compile(r"^ *\#pragma\s+once *$") def parse_args(): argparser = argparse.ArgumentParser( "Checks for a consistent '#include' syntax") - argparser.add_argument("--regex", type=str, + argparser.add_argument("--regex", + type=str, default=r"[.](cu|cuh|h|hpp|hxx|cpp)$", help="Regex string to filter in sources") - argparser.add_argument("dirs", type=str, nargs="*", + argparser.add_argument( + "--inplace", + action="store_true", + required=False, + help="If set, perform the required changes inplace.") + argparser.add_argument("--top_include_dirs", + required=False, + default='include,src,test', + help="comma-separated list of directories used as " + "search dirs on build and which should not be " + "crossed in relative includes") + argparser.add_argument("dirs", + type=str, + nargs="*", help="List of dirs where to find sources") args = argparser.parse_args() args.regex_compiled = re.compile(args.regex) return args +@dataclasses.dataclass() +class Issue: + is_error: bool + msg: str + file: str + line: int + fixed_str: str = None + was_fixed: bool = False + + def get_msg_str(self) -> str: + if (self.is_error and not self.was_fixed): + return make_error_msg( + self.file, + self.line, + self.msg + (". Fixed!" if self.was_fixed else "")) + else: + return make_warn_msg( + self.file, + self.line, + self.msg + (". Fixed!" if self.was_fixed else "")) + + +def make_msg(err_or_warn: str, file: str, line: int, msg: str): + """ + Formats the error message with a file and line number that can be used by + IDEs to quickly go to the exact line + """ + return "{}: {}:{}, {}".format(err_or_warn, file, line, msg) + + +def make_error_msg(file: str, line: int, msg: str): + return make_msg("ERROR", file, line, msg) + + +def make_warn_msg(file: str, line: int, msg: str): + return make_msg("WARN", file, line, msg) + + def list_all_source_file(file_regex, srcdirs): all_files = [] for srcdir in srcdirs: @@ -49,41 +106,191 @@ def list_all_source_file(file_regex, srcdirs): return all_files -def check_includes_in(src): - errs = [] +def rel_include_warnings(dir, src, line_num, inc_file, + top_inc_dirs) -> typing.List[Issue]: + warn: typing.List[Issue] = [] + inc_folders = inc_file.split(os.path.sep)[:-1] + inc_folders_alt = inc_file.split(os.path.altsep)[:-1] + + if len(inc_folders) != 0 and len(inc_folders_alt) != 0: + w = "using %s and %s as path separators" % (os.path.sep, + os.path.altsep) + warn.append(Issue(False, w, src, line_num)) + + if len(inc_folders) == 0: + inc_folders = inc_folders_alt + + abs_inc_folders = [ + os.path.abspath(os.path.join(dir, *inc_folders[:i + 1])) + for i in range(len(inc_folders)) + ] + + if os.path.curdir in inc_folders: + w = "rel include containing reference to current folder '{}'".format( + os.path.curdir) + warn.append(Issue(False, w, src, line_num)) + + if any( + any([os.path.basename(p) == f for f in top_inc_dirs]) + for p in abs_inc_folders): + + w = "rel include going over %s folders" % ("/".join( + "'" + f + "'" for f in top_inc_dirs)) + + warn.append(Issue(False, w, src, line_num)) + + if (len(inc_folders) >= 3 and os.path.pardir in inc_folders + and any(p != os.path.pardir for p in inc_folders)): + + w = ("rel include with more than " + "2 folders that aren't in a straight heritage line") + warn.append(Issue(False, w, src, line_num)) + + return warn + + +def check_includes_in(src, inplace, top_inc_dirs) -> typing.List[Issue]: + issues: typing.List[Issue] = [] dir = os.path.dirname(src) - for line_number, line in enumerate(open(src)): - line = RemoveComments.sub("", line) + found_pragma_once = False + include_count = 0 + + # Read all lines + with io.open(src, encoding="utf-8") as file_obj: + lines = list(enumerate(file_obj)) + + for line_number, line in lines: + line_num = line_number + 1 + match = IncludeRegex.search(line) if match is None: + # Check to see if its a pragma once + if not found_pragma_once: + pragma_match = PragmaRegex.search(line) + + if pragma_match is not None: + found_pragma_once = True + + if include_count > 0: + issues.append( + Issue( + True, + "`#pragma once` must be before any `#include`", + src, + line_num)) continue - val = match.group(1) - inc_file = val[1:-1] # strip out " or < + + include_count += 1 + + val_type = match.group(2) # " or < + inc_file = match.group(3) full_path = os.path.join(dir, inc_file) - line_num = line_number + 1 - if val[0] == "\"" and not os.path.exists(full_path): - errs.append("Line:%d use #include <...>" % line_num) - elif val[0] == "<" and os.path.exists(full_path): - errs.append("Line:%d use #include \"...\"" % line_num) - return errs + + if val_type == "\"" and not os.path.isfile(full_path): + new_line, n = IncludeRegex.subn(r"\1<\3>", line) + assert n == 1, "inplace only handles one include match per line" + + issues.append( + Issue(True, "use #include <...>", src, line_num, new_line)) + + elif val_type == "<" and os.path.isfile(full_path): + new_line, n = IncludeRegex.subn(r'\1"\3"', line) + assert n == 1, "inplace only handles one include match per line" + + issues.append( + Issue(True, "use #include \"...\"", src, line_num, new_line)) + + # output warnings for some cases + # 1. relative include containing current folder + # 2. relative include going over src / src_prims folders + # 3. relative include longer than 2 folders and containing + # both ".." and "non-.." + # 4. absolute include used but rel include possible without warning + if val_type == "\"": + issues += rel_include_warnings(dir, + src, + line_num, + inc_file, + top_inc_dirs) + if val_type == "<": + # try to make a relative import using the top folders + for top_folder in top_inc_dirs: + full_dir = os.path.abspath(dir) + fs = full_dir.split(os.path.sep) + fs_alt = full_dir.split(os.path.altsep) + if len(fs) <= 1: + fs = fs_alt + if top_folder not in fs: + continue + if fs[0] == "": # full dir was absolute + fs[0] = os.path.sep + full_top = os.path.join(*fs[:fs.index(top_folder) + 1]) + full_inc = os.path.join(full_top, inc_file) + if not os.path.isfile(full_inc): + continue + new_rel_inc = os.path.relpath(full_inc, full_dir) + warn = rel_include_warnings(dir, + src, + line_num, + new_rel_inc, + top_inc_dirs) + if len(warn) == 0: + issues.append( + Issue( + False, + "absolute include could be transformed to relative", + src, + line_num, + f"#include \"{new_rel_inc}\"\n")) + else: + issues += warn + + if inplace and len(issues) > 0: + had_fixes = False + + for issue in issues: + if (issue.fixed_str is not None): + lines[issue.line - 1] = (lines[issue.line - 1][0], + issue.fixed_str) + issue.was_fixed = True + had_fixes = True + + if (had_fixes): + with io.open(src, "w", encoding="utf-8") as out_file: + for _, new_line in lines: + out_file.write(new_line) + + return issues def main(): args = parse_args() + top_inc_dirs = args.top_include_dirs.split(',') all_files = list_all_source_file(args.regex_compiled, args.dirs) - all_errs = {} + all_issues: typing.List[Issue] = [] + errs: typing.List[Issue] = [] + for f in all_files: - errs = check_includes_in(f) - if len(errs) > 0: - all_errs[f] = errs - if len(all_errs) == 0: + issues = check_includes_in(f, args.inplace, top_inc_dirs) + + all_issues += issues + + for i in all_issues: + if (i.is_error and not i.was_fixed): + errs.append(i) + else: + print(i.get_msg_str()) + + if len(errs) == 0: print("include-check PASSED") else: print("include-check FAILED! See below for errors...") - for f, errs in all_errs.items(): - print("File: %s" % f) - for e in errs: - print(" %s" % e) + for err in errs: + print(err.get_msg_str()) + + path_parts = os.path.abspath(__file__).split(os.sep) + print("You can run '{} --inplace' to bulk fix these errors".format( + os.sep.join(path_parts[path_parts.index("cpp"):]))) sys.exit(-1) return From 96b66ca29a83e35ed4e2a5dc19756f3678e47faf Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 21 Feb 2022 17:53:52 -0500 Subject: [PATCH 2/6] Reverting include checker. Not sure we want to use the relavtive includes everywhere yet. --- ci/checks/copyright.py | 4 +- ci/checks/style.sh | 2 +- cpp/scripts/include_checker.py | 263 ++++----------------------------- 3 files changed, 31 insertions(+), 238 deletions(-) diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index b0e11014cf..b2acc3027a 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -179,8 +179,8 @@ def checkCopyright_main(): action="store_true", required=False, help="If set, " - "update the current year if a header is already " - "present and well formatted.") + "update the current year if a header " + "is already present and well formatted.") argparser.add_argument("--git-modified-only", dest='git_modified_only', action="store_true", diff --git a/ci/checks/style.sh b/ci/checks/style.sh index b57b3f3557..839ce3a83f 100644 --- a/ci/checks/style.sh +++ b/ci/checks/style.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2022, NVIDIA CORPORATION. ##################### # RAFT Style Tester # ##################### diff --git a/cpp/scripts/include_checker.py b/cpp/scripts/include_checker.py index 2f8e6e4781..1ced05e743 100644 --- a/cpp/scripts/include_checker.py +++ b/cpp/scripts/include_checker.py @@ -1,5 +1,4 @@ -# -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,83 +17,27 @@ import sys import re import os +import subprocess import argparse -import io -from functools import reduce -import operator -import dataclasses -import typing -# file names could (in theory) contain simple white-space -IncludeRegex = re.compile(r"(\s*#include\s*)([\"<])([\S ]+)([\">])") -PragmaRegex = re.compile(r"^ *\#pragma\s+once *$") + +IncludeRegex = re.compile(r"\s*#include\s*(\S+)") +RemoveComments = re.compile(r"//.*") def parse_args(): argparser = argparse.ArgumentParser( "Checks for a consistent '#include' syntax") - argparser.add_argument("--regex", - type=str, + argparser.add_argument("--regex", type=str, default=r"[.](cu|cuh|h|hpp|hxx|cpp)$", help="Regex string to filter in sources") - argparser.add_argument( - "--inplace", - action="store_true", - required=False, - help="If set, perform the required changes inplace.") - argparser.add_argument("--top_include_dirs", - required=False, - default='include,src,test', - help="comma-separated list of directories used as " - "search dirs on build and which should not be " - "crossed in relative includes") - argparser.add_argument("dirs", - type=str, - nargs="*", + argparser.add_argument("dirs", type=str, nargs="*", help="List of dirs where to find sources") args = argparser.parse_args() args.regex_compiled = re.compile(args.regex) return args -@dataclasses.dataclass() -class Issue: - is_error: bool - msg: str - file: str - line: int - fixed_str: str = None - was_fixed: bool = False - - def get_msg_str(self) -> str: - if (self.is_error and not self.was_fixed): - return make_error_msg( - self.file, - self.line, - self.msg + (". Fixed!" if self.was_fixed else "")) - else: - return make_warn_msg( - self.file, - self.line, - self.msg + (". Fixed!" if self.was_fixed else "")) - - -def make_msg(err_or_warn: str, file: str, line: int, msg: str): - """ - Formats the error message with a file and line number that can be used by - IDEs to quickly go to the exact line - """ - return "{}: {}:{}, {}".format(err_or_warn, file, line, msg) - - -def make_error_msg(file: str, line: int, msg: str): - return make_msg("ERROR", file, line, msg) - - -def make_warn_msg(file: str, line: int, msg: str): - return make_msg("WARN", file, line, msg) - - def list_all_source_file(file_regex, srcdirs): all_files = [] for srcdir in srcdirs: @@ -106,191 +49,41 @@ def list_all_source_file(file_regex, srcdirs): return all_files -def rel_include_warnings(dir, src, line_num, inc_file, - top_inc_dirs) -> typing.List[Issue]: - warn: typing.List[Issue] = [] - inc_folders = inc_file.split(os.path.sep)[:-1] - inc_folders_alt = inc_file.split(os.path.altsep)[:-1] - - if len(inc_folders) != 0 and len(inc_folders_alt) != 0: - w = "using %s and %s as path separators" % (os.path.sep, - os.path.altsep) - warn.append(Issue(False, w, src, line_num)) - - if len(inc_folders) == 0: - inc_folders = inc_folders_alt - - abs_inc_folders = [ - os.path.abspath(os.path.join(dir, *inc_folders[:i + 1])) - for i in range(len(inc_folders)) - ] - - if os.path.curdir in inc_folders: - w = "rel include containing reference to current folder '{}'".format( - os.path.curdir) - warn.append(Issue(False, w, src, line_num)) - - if any( - any([os.path.basename(p) == f for f in top_inc_dirs]) - for p in abs_inc_folders): - - w = "rel include going over %s folders" % ("/".join( - "'" + f + "'" for f in top_inc_dirs)) - - warn.append(Issue(False, w, src, line_num)) - - if (len(inc_folders) >= 3 and os.path.pardir in inc_folders - and any(p != os.path.pardir for p in inc_folders)): - - w = ("rel include with more than " - "2 folders that aren't in a straight heritage line") - warn.append(Issue(False, w, src, line_num)) - - return warn - - -def check_includes_in(src, inplace, top_inc_dirs) -> typing.List[Issue]: - issues: typing.List[Issue] = [] +def check_includes_in(src): + errs = [] dir = os.path.dirname(src) - found_pragma_once = False - include_count = 0 - - # Read all lines - with io.open(src, encoding="utf-8") as file_obj: - lines = list(enumerate(file_obj)) - - for line_number, line in lines: - line_num = line_number + 1 - + for line_number, line in enumerate(open(src)): + line = RemoveComments.sub("", line) match = IncludeRegex.search(line) if match is None: - # Check to see if its a pragma once - if not found_pragma_once: - pragma_match = PragmaRegex.search(line) - - if pragma_match is not None: - found_pragma_once = True - - if include_count > 0: - issues.append( - Issue( - True, - "`#pragma once` must be before any `#include`", - src, - line_num)) continue - - include_count += 1 - - val_type = match.group(2) # " or < - inc_file = match.group(3) + val = match.group(1) + inc_file = val[1:-1] # strip out " or < full_path = os.path.join(dir, inc_file) - - if val_type == "\"" and not os.path.isfile(full_path): - new_line, n = IncludeRegex.subn(r"\1<\3>", line) - assert n == 1, "inplace only handles one include match per line" - - issues.append( - Issue(True, "use #include <...>", src, line_num, new_line)) - - elif val_type == "<" and os.path.isfile(full_path): - new_line, n = IncludeRegex.subn(r'\1"\3"', line) - assert n == 1, "inplace only handles one include match per line" - - issues.append( - Issue(True, "use #include \"...\"", src, line_num, new_line)) - - # output warnings for some cases - # 1. relative include containing current folder - # 2. relative include going over src / src_prims folders - # 3. relative include longer than 2 folders and containing - # both ".." and "non-.." - # 4. absolute include used but rel include possible without warning - if val_type == "\"": - issues += rel_include_warnings(dir, - src, - line_num, - inc_file, - top_inc_dirs) - if val_type == "<": - # try to make a relative import using the top folders - for top_folder in top_inc_dirs: - full_dir = os.path.abspath(dir) - fs = full_dir.split(os.path.sep) - fs_alt = full_dir.split(os.path.altsep) - if len(fs) <= 1: - fs = fs_alt - if top_folder not in fs: - continue - if fs[0] == "": # full dir was absolute - fs[0] = os.path.sep - full_top = os.path.join(*fs[:fs.index(top_folder) + 1]) - full_inc = os.path.join(full_top, inc_file) - if not os.path.isfile(full_inc): - continue - new_rel_inc = os.path.relpath(full_inc, full_dir) - warn = rel_include_warnings(dir, - src, - line_num, - new_rel_inc, - top_inc_dirs) - if len(warn) == 0: - issues.append( - Issue( - False, - "absolute include could be transformed to relative", - src, - line_num, - f"#include \"{new_rel_inc}\"\n")) - else: - issues += warn - - if inplace and len(issues) > 0: - had_fixes = False - - for issue in issues: - if (issue.fixed_str is not None): - lines[issue.line - 1] = (lines[issue.line - 1][0], - issue.fixed_str) - issue.was_fixed = True - had_fixes = True - - if (had_fixes): - with io.open(src, "w", encoding="utf-8") as out_file: - for _, new_line in lines: - out_file.write(new_line) - - return issues + line_num = line_number + 1 + if val[0] == "\"" and not os.path.exists(full_path): + errs.append("Line:%d use #include <...>" % line_num) + elif val[0] == "<" and os.path.exists(full_path): + errs.append("Line:%d use #include \"...\"" % line_num) + return errs def main(): args = parse_args() - top_inc_dirs = args.top_include_dirs.split(',') all_files = list_all_source_file(args.regex_compiled, args.dirs) - all_issues: typing.List[Issue] = [] - errs: typing.List[Issue] = [] - + all_errs = {} for f in all_files: - issues = check_includes_in(f, args.inplace, top_inc_dirs) - - all_issues += issues - - for i in all_issues: - if (i.is_error and not i.was_fixed): - errs.append(i) - else: - print(i.get_msg_str()) - - if len(errs) == 0: + errs = check_includes_in(f) + if len(errs) > 0: + all_errs[f] = errs + if len(all_errs) == 0: print("include-check PASSED") else: print("include-check FAILED! See below for errors...") - for err in errs: - print(err.get_msg_str()) - - path_parts = os.path.abspath(__file__).split(os.sep) - print("You can run '{} --inplace' to bulk fix these errors".format( - os.sep.join(path_parts[path_parts.index("cpp"):]))) + for f, errs in all_errs.items(): + print("File: %s" % f) + for e in errs: + print(" %s" % e) sys.exit(-1) return From 2ac97128020cfa966ea06f6a587b51a5cd8f4bb3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 21 Feb 2022 17:57:28 -0500 Subject: [PATCH 3/6] Updating copyright checker --- ci/checks/style.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/checks/style.sh b/ci/checks/style.sh index 839ce3a83f..89381d6778 100644 --- a/ci/checks/style.sh +++ b/ci/checks/style.sh @@ -26,7 +26,7 @@ else fi # Check for copyright headers in the files modified currently -COPYRIGHT=`env PYTHONPATH=cpp/scripts python ci/checks/copyright.py 2>&1` +COPYRIGHT=`python ci/checks/copyright.py --git-modified-only 2>&1` CR_RETVAL=$? if [ "$RETVAL" = "0" ]; then RETVAL=$CR_RETVAL From cef9d43ab1dfdc58bbfde4b1c593431cd5ef82bc Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 21 Feb 2022 18:09:00 -0500 Subject: [PATCH 4/6] Testing copyright checker fails --- cpp/include/raft.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/raft.hpp b/cpp/include/raft.hpp index 08f836d3a8..52e084906a 100644 --- a/cpp/include/raft.hpp +++ b/cpp/include/raft.hpp @@ -15,6 +15,7 @@ */ #include +#include namespace raft { From 340927bf8b04a568c41d02a7225ea1b0edf20a2d Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 21 Feb 2022 18:10:47 -0500 Subject: [PATCH 5/6] Reverting copyright failed file --- cpp/include/raft.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft.hpp b/cpp/include/raft.hpp index 52e084906a..08f836d3a8 100644 --- a/cpp/include/raft.hpp +++ b/cpp/include/raft.hpp @@ -15,7 +15,6 @@ */ #include -#include namespace raft { From eef09f809d41fe2cd944a3f103e0a615a204f8a2 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 21 Feb 2022 18:11:54 -0500 Subject: [PATCH 6/6] Using proper copyright ranges for older files --- ci/checks/copyright.py | 2 +- ci/checks/style.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index b2acc3027a..2440e61cb1 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ci/checks/style.sh b/ci/checks/style.sh index 89381d6778..2ce8b446b8 100644 --- a/ci/checks/style.sh +++ b/ci/checks/style.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. ##################### # RAFT Style Tester # #####################