diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py new file mode 100644 index 00000000000..d72fd95fea3 --- /dev/null +++ b/ci/checks/copyright.py @@ -0,0 +1,233 @@ +# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import re +import argparse +import io +import os +import sys + +SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) + +# Add the scripts dir for gitutils +sys.path.append(os.path.normpath(SCRIPT_DIR)) + +# Now import gitutils. Ignore flake8 error here since there is no other way to +# set up imports +import gitutils # noqa: E402 + +FilesToCheck = [ + re.compile(r"[.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$"), + re.compile(r"CMakeLists[.]txt$"), + re.compile(r"CMakeLists_standalone[.]txt$"), + re.compile(r"setup[.]cfg$"), + re.compile(r"[.]flake8[.]cython$"), + re.compile(r"meta[.]yaml$") +] +ExemptFiles = [] + +# this will break starting at year 10000, which is probably OK :) +CheckSimple = re.compile( + r"Copyright *(?:\(c\))? *(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)") +CheckDouble = re.compile( + r"Copyright *(?:\(c\))? *(\d{4})-(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" # noqa: E501 +) + + +def checkThisFile(f): + # This check covers things like symlinks which point to files that DNE + if not (os.path.exists(f)): + return False + if gitutils and gitutils.isFileEmpty(f): + return False + for exempt in ExemptFiles: + if exempt.search(f): + return False + for checker in FilesToCheck: + if checker.search(f): + return True + return False + + +def getCopyrightYears(line): + res = CheckSimple.search(line) + if res: + return (int(res.group(1)), int(res.group(1))) + res = CheckDouble.search(line) + if res: + return (int(res.group(1)), int(res.group(2))) + return (None, None) + + +def replaceCurrentYear(line, start, end): + # first turn a simple regex into double (if applicable). then update years + res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line) + res = CheckDouble.sub( + r"Copyright (c) {:04d}-{:04d}, NVIDIA CORPORATION".format(start, end), + res) + return res + + +def checkCopyright(f, update_current_year): + """ + Checks for copyright headers and their years + """ + errs = [] + thisYear = datetime.datetime.now().year + lineNum = 0 + crFound = False + yearMatched = False + with io.open(f, "r", encoding="utf-8") as fp: + lines = fp.readlines() + for line in lines: + lineNum += 1 + start, end = getCopyrightYears(line) + if start is None: + continue + crFound = True + if start > end: + e = [ + f, + lineNum, + "First year after second year in the copyright " + "header (manual fix required)", + None + ] + errs.append(e) + if thisYear < start or thisYear > end: + e = [ + f, + lineNum, + "Current year not included in the " + "copyright header", + None + ] + if thisYear < start: + e[-1] = replaceCurrentYear(line, thisYear, end) + if thisYear > end: + e[-1] = replaceCurrentYear(line, start, thisYear) + errs.append(e) + else: + yearMatched = True + fp.close() + # copyright header itself not found + if not crFound: + e = [ + f, + 0, + "Copyright header missing or formatted incorrectly " + "(manual fix required)", + None + ] + errs.append(e) + # even if the year matches a copyright header, make the check pass + if yearMatched: + errs = [] + + if update_current_year: + errs_update = [x for x in errs if x[-1] is not None] + if len(errs_update) > 0: + print("File: {}. Changing line(s) {}".format( + f, ', '.join(str(x[1]) for x in errs if x[-1] is not None))) + for _, lineNum, __, replacement in errs_update: + lines[lineNum - 1] = replacement + with io.open(f, "w", encoding="utf-8") as out_file: + for new_line in lines: + out_file.write(new_line) + errs = [x for x in errs if x[-1] is None] + + return errs + + +def getAllFilesUnderDir(root, pathFilter=None): + retList = [] + for (dirpath, dirnames, filenames) in os.walk(root): + for fn in filenames: + filePath = os.path.join(dirpath, fn) + if pathFilter(filePath): + retList.append(filePath) + return retList + + +def checkCopyright_main(): + """ + Checks for copyright headers in all the modified files. In case of local + repo, this script will just look for uncommitted files and in case of CI + it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch" + """ + retVal = 0 + global ExemptFiles + + argparser = argparse.ArgumentParser( + "Checks for a consistent copyright header in git's modified files") + argparser.add_argument("--update-current-year", + dest='update_current_year', + action="store_true", + required=False, + help="If set, " + "update the current year if a header is already " + "present and well formatted.") + argparser.add_argument("--git-modified-only", + dest='git_modified_only', + action="store_true", + required=False, + help="If set, " + "only files seen as modified by git will be " + "processed.") + + (args, dirs) = argparser.parse_known_args() + try: + ExemptFiles = [re.compile(file) for file in ExemptFiles] + except re.error as reException: + print("Regular expression error:") + print(reException) + return 1 + + if args.git_modified_only: + files = gitutils.modifiedFiles(pathFilter=checkThisFile) + else: + files = [] + for d in [os.path.abspath(d) for d in dirs]: + if not (os.path.isdir(d)): + raise ValueError(f"{d} is not a directory.") + files += getAllFilesUnderDir(d, pathFilter=checkThisFile) + + errors = [] + for f in files: + errors += checkCopyright(f, args.update_current_year) + + if len(errors) > 0: + print("Copyright headers incomplete in some of the files!") + for e in errors: + print(" %s:%d Issue: %s" % (e[0], e[1], e[2])) + print("") + n_fixable = sum(1 for e in errors if e[-1] is not None) + path_parts = os.path.abspath(__file__).split(os.sep) + file_from_repo = os.sep.join(path_parts[path_parts.index("ci"):]) + if n_fixable > 0: + print(("You can run `python {} --git-modified-only " + "--update-current-year` to fix {} of these " + "errors.\n").format(file_from_repo, n_fixable)) + retVal = 1 + else: + print("Copyright check passed") + + return retVal + + +if __name__ == "__main__": + import sys + sys.exit(checkCopyright_main()) \ No newline at end of file diff --git a/ci/checks/gitutils.py b/ci/checks/gitutils.py new file mode 100644 index 00000000000..0aea1d660cb --- /dev/null +++ b/ci/checks/gitutils.py @@ -0,0 +1,286 @@ +# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import subprocess +import os +import re + + +def isFileEmpty(f): + return os.stat(f).st_size == 0 + + +def __git(*opts): + """Runs a git command and returns its output""" + cmd = "git " + " ".join(list(opts)) + ret = subprocess.check_output(cmd, shell=True) + return ret.decode("UTF-8").rstrip("\n") + + +def __gitdiff(*opts): + """Runs a git diff command with no pager set""" + return __git("--no-pager", "diff", *opts) + + +def branch(): + """Returns the name of the current branch""" + name = __git("rev-parse", "--abbrev-ref", "HEAD") + name = name.rstrip() + return name + + +def repo_version(): + """ + Determines the version of the repo by using `git describe` + + Returns + ------- + str + The full version of the repo in the format 'v#.#.#{a|b|rc}' + """ + return __git("describe", "--tags", "--abbrev=0") + + +def repo_version_major_minor(): + """ + Determines the version of the repo using `git describe` and returns only + the major and minor portion + + Returns + ------- + str + The partial version of the repo in the format '{major}.{minor}' + """ + + full_repo_version = repo_version() + + match = re.match(r"^v?(?P[0-9]+)(?:\.(?P[0-9]+))?", + full_repo_version) + + if (match is None): + print(" [DEBUG] Could not determine repo major minor version. " + f"Full repo version: {full_repo_version}.") + return None + + out_version = match.group("major") + + if (match.group("minor")): + out_version += "." + match.group("minor") + + return out_version + + +def determine_merge_commit(current_branch="HEAD"): + """ + When running outside of CI, this will estimate the target merge commit hash + of `current_branch` by finding a common ancester with the remote branch + 'branch-{major}.{minor}' where {major} and {minor} are determined from the + repo version. + + Parameters + ---------- + current_branch : str, optional + Which branch to consider as the current branch, by default "HEAD" + + Returns + ------- + str + The common commit hash ID + """ + + try: + # Try to determine the target branch from the most recent tag + head_branch = __git("describe", + "--all", + "--tags", + "--match='branch-*'", + "--abbrev=0") + except subprocess.CalledProcessError: + print(" [DEBUG] Could not determine target branch from most recent " + "tag. Falling back to 'branch-{major}.{minor}.") + head_branch = None + + if (head_branch is not None): + # Convert from head to branch name + head_branch = __git("name-rev", "--name-only", head_branch) + else: + # Try and guess the target branch as "branch-." + version = repo_version_major_minor() + + if (version is None): + return None + + head_branch = "branch-{}".format(version) + + try: + # Now get the remote tracking branch + remote_branch = __git("rev-parse", + "--abbrev-ref", + "--symbolic-full-name", + head_branch + "@{upstream}") + except subprocess.CalledProcessError: + print(" [DEBUG] Could not remote tracking reference for " + f"branch {head_branch}.") + remote_branch = None + + if (remote_branch is None): + return None + + print(f" [DEBUG] Determined TARGET_BRANCH as: '{remote_branch}'. " + "Finding common ancestor.") + + common_commit = __git("merge-base", remote_branch, current_branch) + + return common_commit + + +def uncommittedFiles(): + """ + Returns a list of all changed files that are not yet committed. This + means both untracked/unstaged as well as uncommitted files too. + """ + files = __git("status", "-u", "-s") + ret = [] + for f in files.splitlines(): + f = f.strip(" ") + f = re.sub("\s+", " ", f) # noqa: W605 + tmp = f.split(" ", 1) + # only consider staged files or uncommitted files + # in other words, ignore untracked files + if tmp[0] == "M" or tmp[0] == "A": + ret.append(tmp[1]) + return ret + + +def changedFilesBetween(baseName, branchName, commitHash): + """ + Returns a list of files changed between branches baseName and latest commit + of branchName. + """ + current = branch() + # checkout "base" branch + __git("checkout", "--force", baseName) + # checkout branch for comparing + __git("checkout", "--force", branchName) + # checkout latest commit from branch + __git("checkout", "-fq", commitHash) + + files = __gitdiff("--name-only", + "--ignore-submodules", + f"{baseName}..{branchName}") + + # restore the original branch + __git("checkout", "--force", current) + return files.splitlines() + + +def changesInFileBetween(file, b1, b2, filter=None): + """Filters the changed lines to a file between the branches b1 and b2""" + current = branch() + __git("checkout", "--quiet", b1) + __git("checkout", "--quiet", b2) + diffs = __gitdiff("--ignore-submodules", + "-w", + "--minimal", + "-U0", + "%s...%s" % (b1, b2), + "--", + file) + __git("checkout", "--quiet", current) + lines = [] + for line in diffs.splitlines(): + if filter is None or filter(line): + lines.append(line) + return lines + + +def modifiedFiles(pathFilter=None): + """ + If inside a CI-env (ie. TARGET_BRANCH and COMMIT_HASH are defined, and + current branch is "current-pr-branch"), then lists out all files modified + between these 2 branches. Locally, TARGET_BRANCH will try to be determined + from the current repo version and finding a coresponding branch named + 'branch-{major}.{minor}'. If this fails, this functino will list out all + the uncommitted files in the current branch. + + Such utility function is helpful while putting checker scripts as part of + cmake, as well as CI process. This way, during development, only the files + touched (but not yet committed) by devs can be checked. But, during the CI + process ALL files modified by the dev, as submiited in the PR, will be + checked. This happens, all the while using the same script. + """ + targetBranch = os.environ.get("TARGET_BRANCH") + commitHash = os.environ.get("COMMIT_HASH") + currentBranch = branch() + print( + f" [DEBUG] TARGET_BRANCH={targetBranch}, COMMIT_HASH={commitHash}, " + f"currentBranch={currentBranch}") + + if targetBranch and commitHash and (currentBranch == "current-pr-branch"): + print(" [DEBUG] Assuming a CI environment.") + allFiles = changedFilesBetween(targetBranch, currentBranch, commitHash) + else: + print(" [DEBUG] Did not detect CI environment. " + "Determining TARGET_BRANCH locally.") + + common_commit = determine_merge_commit(currentBranch) + + if (common_commit is not None): + + # Now get the diff. Use --staged to get both diff between + # common_commit..HEAD and any locally staged files + allFiles = __gitdiff("--name-only", + "--ignore-submodules", + "--staged", + f"{common_commit}").splitlines() + else: + # Fallback to just uncommitted files + allFiles = uncommittedFiles() + + files = [] + for f in allFiles: + if pathFilter is None or pathFilter(f): + files.append(f) + + filesToCheckString = "\n\t".join(files) if files else "" + print(f" [DEBUG] Found files to check:\n\t{filesToCheckString}\n") + return files + + +def listAllFilesInDir(folder): + """Utility function to list all files/subdirs in the input folder""" + allFiles = [] + for root, dirs, files in os.walk(folder): + for name in files: + allFiles.append(os.path.join(root, name)) + return allFiles + + +def listFilesToCheck(filesDirs, filter=None): + """ + Utility function to filter the input list of files/dirs based on the input + filter method and returns all the files that need to be checked + """ + allFiles = [] + for f in filesDirs: + if os.path.isfile(f): + if filter is None or filter(f): + allFiles.append(f) + elif os.path.isdir(f): + files = listAllFilesInDir(f) + for f_ in files: + if filter is None or filter(f_): + allFiles.append(f_) + return allFiles diff --git a/ci/checks/style.sh b/ci/checks/style.sh index 9fb86b0b3c5..a7ad260758d 100755 --- a/ci/checks/style.sh +++ b/ci/checks/style.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. ##################### # cuDF Style Tester # ##################### @@ -19,9 +19,24 @@ export RAPIDS_CMAKE_FORMAT_FILE=/tmp/rapids_cmake_ci/cmake-formats-rapids-cmake. mkdir -p $(dirname ${RAPIDS_CMAKE_FORMAT_FILE}) wget -O ${RAPIDS_CMAKE_FORMAT_FILE} ${FORMAT_FILE_URL} + pre-commit run --hook-stage manual --all-files PRE_COMMIT_RETVAL=$? +# Check for copyright headers in the files modified currently +COPYRIGHT=`python ci/checks/copyright.py --git-modified-only 2>&1` +CR_RETVAL=$? + +# Output results if failure otherwise show pass +if [ "$CR_RETVAL" != "0" ]; then + echo -e "\n\n>>>> FAILED: copyright check; begin output\n\n" + echo -e "$COPYRIGHT" + echo -e "\n\n>>>> FAILED: copyright check; end output\n\n" +else + echo -e "\n\n>>>> PASSED: copyright check\n\n" + echo -e "$COPYRIGHT" +fi + # Run clang-format and check for a consistent code format CLANG_FORMAT=`python cpp/scripts/run-clang-format.py 2>&1` CLANG_FORMAT_RETVAL=$? @@ -40,7 +55,7 @@ HEADER_META_RETVAL=$? echo -e "$HEADER_META" RETVALS=( - $PRE_COMMIT_RETVAL $CLANG_FORMAT_RETVAL + $CR_RETVAL $PRE_COMMIT_RETVAL $CLANG_FORMAT_RETVAL ) IFS=$'\n' RETVAL=`echo "${RETVALS[*]}" | sort -nr | head -n1` diff --git a/conda/environments/cudf_dev_cuda11.5.yml b/conda/environments/cudf_dev_cuda11.5.yml index b9577d937d9..b926a6cdc99 100644 --- a/conda/environments/cudf_dev_cuda11.5.yml +++ b/conda/environments/cudf_dev_cuda11.5.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. name: cudf_dev channels: @@ -17,7 +17,7 @@ dependencies: - numba>=0.54 - numpy - pandas>=1.0,<1.4.0dev0 - - pyarrow=5.0.0=*cuda + - pyarrow=6.0.1=*cuda - fastavro>=0.22.9 - python-snappy>=0.6.0 - notebook>=0.5.0 @@ -45,7 +45,7 @@ dependencies: - dask>=2021.11.1,<=2022.01.0 - distributed>=2021.11.1,<=2022.01.0 - streamz - - arrow-cpp=5.0.0 + - arrow-cpp=6.0.1 - dlpack>=0.5,<0.6.0a0 - arrow-cpp-proc * cuda - double-conversion diff --git a/conda/recipes/cudf/meta.yaml b/conda/recipes/cudf/meta.yaml index bd1412bc611..0145e2e4d01 100644 --- a/conda/recipes/cudf/meta.yaml +++ b/conda/recipes/cudf/meta.yaml @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. {% set version = environ.get('GIT_DESCRIBE_TAG', '0.0.0.dev').lstrip('v') + environ.get('VERSION_SUFFIX', '') %} {% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} @@ -31,7 +31,7 @@ requirements: - setuptools - numba >=0.54 - dlpack>=0.5,<0.6.0a0 - - pyarrow 5.0.0 *cuda + - pyarrow 6.0.1 *cuda - libcudf {{ version }} - rmm {{ minor_version }} - cudatoolkit {{ cuda_version }} diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index 70c020d4abd..57205f22e57 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. {% set version = environ.get('GIT_DESCRIBE_TAG', '0.0.0.dev').lstrip('v') + environ.get('VERSION_SUFFIX', '') %} {% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} @@ -40,7 +40,7 @@ requirements: host: - librmm {{ minor_version }}.* - cudatoolkit {{ cuda_version }}.* - - arrow-cpp 5.0.0 *cuda + - arrow-cpp 6.0.1 *cuda - arrow-cpp-proc * cuda - dlpack>=0.5,<0.6.0a0 run: diff --git a/cpp/cmake/thirdparty/get_arrow.cmake b/cpp/cmake/thirdparty/get_arrow.cmake index ae1448da502..83c5e4c3e8f 100644 --- a/cpp/cmake/thirdparty/get_arrow.cmake +++ b/cpp/cmake/thirdparty/get_arrow.cmake @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -308,7 +308,7 @@ function(find_and_configure_arrow VERSION BUILD_STATIC ENABLE_S3 ENABLE_ORC ENAB endfunction() -set(CUDF_VERSION_Arrow 5.0.0) +set(CUDF_VERSION_Arrow 6.0.1) find_and_configure_arrow( ${CUDF_VERSION_Arrow} ${CUDF_USE_ARROW_STATIC} ${CUDF_ENABLE_ARROW_S3} ${CUDF_ENABLE_ARROW_ORC} diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 2c3951c0e5e..e5669fd44aa 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -1167,6 +1167,10 @@ def corr(self, other: ColumnBase): def nans_to_nulls(self: T) -> T: return self + @property + def contains_na_entries(self) -> bool: + return self.null_count != 0 + def _process_for_reduction( self, skipna: bool = None, min_count: int = 0 ) -> Union[ColumnBase, ScalarLike]: @@ -2093,24 +2097,16 @@ def as_column( dtype = "bool" np_type = np.dtype(dtype).type pa_type = np_to_pa_dtype(np.dtype(dtype)) - # TODO: A warning is emitted from pyarrow 5.0.0's function - # pyarrow.lib._sequence_to_array: - # "DeprecationWarning: an integer is required (got type float). - # Implicit conversion to integers using __int__ is deprecated, - # and may be removed in a future version of Python." - # This warning does not appear in pyarrow 6.0.1 and will be - # resolved by https://github.com/rapidsai/cudf/pull/9686/. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - pa_array = pa.array( + data = as_column( + pa.array( arbitrary, type=pa_type, from_pandas=True if nan_as_null is None else nan_as_null, - ) - data = as_column( - pa_array, dtype=dtype, nan_as_null=nan_as_null, + ), + dtype=dtype, + nan_as_null=nan_as_null, ) except (pa.ArrowInvalid, pa.ArrowTypeError, TypeError): if is_categorical_dtype(dtype): diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index 9b54c4d9acd..0e32c530e69 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -325,6 +325,10 @@ def nan_count(self) -> int: self._nan_count = nan_col.sum() return self._nan_count + @property + def contains_na_entries(self) -> bool: + return (self.nan_count != 0) or (self.null_count != 0) + def _process_values_for_isin( self, values: Sequence ) -> Tuple[ColumnBase, ColumnBase]: diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 1cb2d00d07d..c3c79135c34 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -4665,10 +4665,17 @@ def to_arrow(self, preserve_index=True): a: int64 b: int64 index: int64 + ---- + a: [[1,2,3]] + b: [[4,5,6]] + index: [[1,2,3]] >>> df.to_arrow(preserve_index=False) pyarrow.Table a: int64 b: int64 + ---- + a: [[1,2,3]] + b: [[4,5,6]] """ data = self.copy(deep=False) diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 4cc722eda4b..2d073d46891 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from __future__ import annotations @@ -1266,7 +1266,7 @@ def fillna( value = value elif not isinstance(value, abc.Mapping): value = {name: copy.deepcopy(value) for name in self._data.names} - elif isinstance(value, abc.Mapping): + else: value = { key: value.reindex(self.index) if isinstance(value, cudf.Series) @@ -1274,18 +1274,26 @@ def fillna( for key, value in value.items() } - copy_data = self._data.copy(deep=True) - - for name in copy_data.keys(): + filled_data = {} + for col_name, col in self._data.items(): + if col_name in value and method is None: + replace_val = value[col_name] + else: + replace_val = None should_fill = ( - name in value - and not libcudf.scalar._is_null_host_scalar(value[name]) + col_name in value + and col.contains_na_entries + and not libcudf.scalar._is_null_host_scalar(replace_val) ) or method is not None if should_fill: - copy_data[name] = copy_data[name].fillna(value[name], method) - result = self._from_data(copy_data, self._index) + filled_data[col_name] = col.fillna(replace_val, method) + else: + filled_data[col_name] = col.copy(deep=True) - return self._mimic_inplace(result, inplace=inplace) + return self._mimic_inplace( + self._from_data(data=filled_data, index=self._index), + inplace=inplace, + ) @annotate("FRAME_DROPNA_COLUMNS", color="green", domain="cudf_python") def _drop_na_columns(self, how="any", subset=None, thresh=None): @@ -1950,6 +1958,10 @@ def to_arrow(self): a: int64 b: int64 index: int64 + ---- + a: [[1,2,3]] + b: [[4,5,6]] + index: [[1,2,3]] """ return pa.Table.from_pydict( {name: col.to_arrow() for name, col in self._data.items()}