diff --git a/.gitignore b/.gitignore index 23b7d09..cb7efbf 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ build/ wheels/ *.egg-info/ *.egg +*.whl diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..d666ad1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,24 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +# Always exclude vendored files from linting +exclude: ".*__rdd_patch_.*" + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - repo: https://github.com/codespell-project/codespell + rev: v2.2.6 + hooks: + - id: codespell + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.2.2 + hooks: + - id: ruff + args: ["--fix"] + - id: ruff-format + +default_language_version: + python: python3 diff --git a/README.md b/README.md index dc0ebc5..ed8bb75 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ The `rapids-dask-dependency` package encodes both `dask` and `distributed` requi # Versioning the Metapackage Itself This package is versioned just like the rest of RAPIDS: using CalVer, with alpha tags (trailing a\*) for nightlies. -Nightlies of the metapackage should be consumed just like nightlies of any other RAPIDS package: +Nightlies of the metapackage should be consumed just like nightlies of any other RAPIDS package: - conda packages should pin up to the minor version with a trailing `.*`, i.e. `==23.10.*`. Conda will allow nightlies to match, so no further intervention is needed. - pip packages should have the same pin, but wheel building scripts must add an alpha spec `>=0.0.0a0` when building nightlies to allow rapids-dask-dependency nightlies. This is the same strategy used to have RAPIDS repositories pull nightly versions of other RAPIDS dependencies (e.g. `cudf` requires `rmm` nightlies). @@ -31,3 +31,11 @@ At release, these dependencies will be pinned to the desired versions. Note that encoding direct URLs as above is technically prohibited by the [Python packaging specifications](https://packaging.python.org/en/latest/specifications/version-specifiers/#direct-references). However, while PyPI enforces this, the RAPIDS nightly index does not. Therefore, use of this versioning strategy currently prohibits rapids-dask-dependency nightlies from being uploaded to PyPI, and they must be hosted on the RAPIDS nightly pip index. + +# Patching + +In addition to functioning as a metapackage, `rapids-dask-dependency` also includes code for patching dask itself. +This package is never intended to be manually imported by the user. +Instead, upon installation it installs a `.pth` file (see the [site module documentation](https://docs.python.org/3.11/library/site.html) for how these work) that will be run whenever the Python interpreter starts. +This file installs a custom [meta path loader](https://docs.python.org/3/reference/import.html#the-meta-path) that intercepts all calls to import dask modules. +This loader is set up to apply RAPIDS-specific patches to the modules, ensuring that regardless of import order issues dask modules will always be patched for RAPIDS-compatibility in environments where RAPIDS packages are installed. diff --git a/_rapids_dask_dependency.pth b/_rapids_dask_dependency.pth new file mode 100644 index 0000000..33349fb --- /dev/null +++ b/_rapids_dask_dependency.pth @@ -0,0 +1 @@ +import rapids_dask_dependency diff --git a/ci/build_wheel.sh b/ci/build_wheel.sh index 9c17c01..0fc36fb 100755 --- a/ci/build_wheel.sh +++ b/ci/build_wheel.sh @@ -1,18 +1,19 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. set -euo pipefail source rapids-configure-sccache source rapids-date-string -package_name=rapids-dask-dependency -package_dir="pip/${package_name}" version=$(rapids-generate-version) -sed -i "s/^version = .*/version = \"${version}\"/g" "${package_dir}/pyproject.toml" +sed -i "s/^version = .*/version = \"${version}\"/g" "pyproject.toml" -cd "${package_dir}" -python -m pip wheel . -w dist -vvv --no-deps --disable-pip-version-check +python -m pip wheel . -w dist -vv --no-deps --disable-pip-version-check -RAPIDS_PY_WHEEL_NAME="${package_name}" RAPIDS_PY_WHEEL_PURE="1" rapids-upload-wheels-to-s3 dist +RAPIDS_PY_WHEEL_NAME="rapids-dask-dependency" RAPIDS_PY_WHEEL_PURE="1" rapids-upload-wheels-to-s3 dist + +# Run tests +python -m pip install $(ls dist/*.whl)[test] +python -m pytest -v tests/ diff --git a/conda/recipes/rapids-dask-dependency/meta.yaml b/conda/recipes/rapids-dask-dependency/meta.yaml index 68c56e2..3da9f2c 100644 --- a/conda/recipes/rapids-dask-dependency/meta.yaml +++ b/conda/recipes/rapids-dask-dependency/meta.yaml @@ -11,9 +11,22 @@ source: build: number: 0 - noarch: generic + noarch: python + script: python -m pip install . -vv --no-deps + +test: + requires: + - pip + - pytest + source_files: + - tests/ requirements: + host: + - pip + - python >=3.9 + - setuptools + - conda-verify run: - dask ==2024.1.1 - dask-core ==2024.1.1 @@ -29,7 +42,7 @@ about: This metapackage encodes the standard Dask version pinning used for a particular release of RAPIDS. The metapackage adds an extra release segment to the RAPIDS CalVer to allow pinnings in this metapackage to be updated - for a given release and automatically propagate to its dependents. + for a given release and automatically propagate to its dependents. It also + includes any patches to dask required for RAPIDS to function correctly. doc_url: https://docs.rapids.ai/ dev_url: https://github.com/rapidsai/rapids_dask_dependency - diff --git a/conda/recipes/rapids-dask-dependency/run_test.sh b/conda/recipes/rapids-dask-dependency/run_test.sh new file mode 100755 index 0000000..7541a82 --- /dev/null +++ b/conda/recipes/rapids-dask-dependency/run_test.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. + +python -m pytest -v tests/ diff --git a/conda/recipes/rapids-dask-dependency/tests b/conda/recipes/rapids-dask-dependency/tests new file mode 120000 index 0000000..d41566a --- /dev/null +++ b/conda/recipes/rapids-dask-dependency/tests @@ -0,0 +1 @@ +../../../tests/ \ No newline at end of file diff --git a/pip/rapids-dask-dependency/LICENSE b/pip/rapids-dask-dependency/LICENSE deleted file mode 120000 index 30cff74..0000000 --- a/pip/rapids-dask-dependency/LICENSE +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE \ No newline at end of file diff --git a/pip/rapids-dask-dependency/README.md b/pip/rapids-dask-dependency/README.md deleted file mode 120000 index fe84005..0000000 --- a/pip/rapids-dask-dependency/README.md +++ /dev/null @@ -1 +0,0 @@ -../../README.md \ No newline at end of file diff --git a/pip/rapids-dask-dependency/pyproject.toml b/pyproject.toml similarity index 64% rename from pip/rapids-dask-dependency/pyproject.toml rename to pyproject.toml index 076d31e..a9510ee 100644 --- a/pip/rapids-dask-dependency/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. [build-system] build-backend = "setuptools.build_meta" @@ -19,5 +19,17 @@ dependencies = [ license = { text = "Apache 2.0" } readme = { file = "README.md", content-type = "text/markdown" } +[project.optional-dependencies] +test = [ + "pytest", +] + [tool.setuptools] license-files = ["LICENSE"] + +[tool.setuptools.packages.find] +include = ["rapids_dask_dependency*"] + +[tool.ruff] +lint.select = ["E", "F", "W", "I", "N", "UP"] +lint.fixable = ["ALL"] diff --git a/rapids_dask_dependency/__init__.py b/rapids_dask_dependency/__init__.py new file mode 100644 index 0000000..07ae32d --- /dev/null +++ b/rapids_dask_dependency/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from .dask_loader import DaskLoader + +DaskLoader.install() diff --git a/rapids_dask_dependency/dask_loader.py b/rapids_dask_dependency/dask_loader.py new file mode 100644 index 0000000..0a4aba4 --- /dev/null +++ b/rapids_dask_dependency/dask_loader.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +import importlib +import importlib.abc +import importlib.machinery +import importlib.util +import sys +from contextlib import contextmanager + +from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec + + +class DaskLoader(importlib.abc.MetaPathFinder, importlib.abc.Loader): + def __init__(self): + self._blocklist = set() + + def create_module(self, spec): + if spec.name.startswith("dask") or spec.name.startswith("distributed"): + with self.disable(spec.name): + try: + # Absolute import is important here to avoid shadowing the real dask + # and distributed modules in sys.modules. Bad things will happen if + # we use relative imports here. + proxy = importlib.import_module( + f"rapids_dask_dependency.patches.{spec.name}" + ) + if hasattr(proxy, "load_module"): + return proxy.load_module(spec) + except ModuleNotFoundError: + pass + + # Three extra stack frames: 1) DaskLoader.create_module, + # 2) importlib.import_module, and 3) the patched warnings function (not + # including the internal frames, which warnings ignores). + with patch_warning_stacklevel(3): + mod = importlib.import_module(spec.name) + + update_spec(spec, mod.__spec__) + return mod + + def exec_module(self, _): + pass + + @contextmanager + def disable(self, name): + # This is a context manager that prevents this finder from intercepting calls to + # import a specific name. We must do this to avoid infinite recursion when + # calling import_module in create_module. However, we cannot blanket disable the + # finder because that causes it to be bypassed when transitive imports occur + # within import_module. + try: + self._blocklist.add(name) + yield + finally: + self._blocklist.remove(name) + + def find_spec(self, fullname: str, _, __=None): + if fullname in self._blocklist: + return None + if ( + fullname in ("dask", "distributed") + or fullname.startswith("dask.") + or fullname.startswith("distributed.") + ): + return importlib.machinery.ModuleSpec( + name=fullname, + loader=self, + # Set these parameters dynamically in create_module + origin=None, + loader_state=None, + is_package=True, + ) + return None + + @classmethod + def install(cls): + try: + (self,) = (obj for obj in sys.meta_path if isinstance(obj, cls)) + except ValueError: + self = cls() + sys.meta_path.insert(0, self) + return self diff --git a/rapids_dask_dependency/importer.py b/rapids_dask_dependency/importer.py new file mode 100644 index 0000000..613a808 --- /dev/null +++ b/rapids_dask_dependency/importer.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +import importlib +import importlib.util +from abc import abstractmethod + +from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec + + +class BaseImporter: + @abstractmethod + def load_module(self, spec): + pass + + +class MonkeyPatchImporter(BaseImporter): + """The base importer for modules that are monkey-patched.""" + + def __init__(self, name, patch_func): + self.name = name.replace("rapids_dask_dependency.patches.", "") + self.patch_func = patch_func + + def load_module(self, spec): + # Four extra stack frames: 1) DaskLoader.create_module, 2) + # MonkeyPatchImporter.load_module, 3) importlib.import_module, and 4) the + # patched warnings function (not including the internal frames, which warnings + # ignores). + with patch_warning_stacklevel(4): + mod = importlib.import_module(self.name) + self.patch_func(mod) + update_spec(spec, mod.__spec__) + mod._rapids_patched = True + return mod + + +class VendoredImporter(BaseImporter): + """The base importer for vendored modules.""" + + # Vendored files use a standard prefix to avoid name collisions. + default_prefix = "__rdd_patch_" + + def __init__(self, module): + self.real_module_name = module.replace("rapids_dask_dependency.patches.", "") + module_parts = module.split(".") + module_parts[-1] = self.default_prefix + module_parts[-1] + self.vendored_module_name = ".".join(module_parts) + + def load_module(self, spec): + vendored_module = importlib.import_module(self.vendored_module_name) + # At this stage the module loader must have been disabled for this module, so we + # can access the original module. We don't want to actually import it, we just + # want enough information on it to update the spec. + original_spec = importlib.util.find_spec(self.real_module_name) + update_spec(spec, original_spec) + return vendored_module diff --git a/rapids_dask_dependency/patches/__init__.py b/rapids_dask_dependency/patches/__init__.py new file mode 100644 index 0000000..3c827d4 --- /dev/null +++ b/rapids_dask_dependency/patches/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. diff --git a/rapids_dask_dependency/patches/dask/__init__.py b/rapids_dask_dependency/patches/dask/__init__.py new file mode 100644 index 0000000..5f7dc38 --- /dev/null +++ b/rapids_dask_dependency/patches/dask/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from rapids_dask_dependency.importer import MonkeyPatchImporter + +_importer = MonkeyPatchImporter(__name__, lambda _: None) +load_module = _importer.load_module diff --git a/rapids_dask_dependency/patches/dask/dataframe/__rdd_patch_accessor.py b/rapids_dask_dependency/patches/dask/dataframe/__rdd_patch_accessor.py new file mode 100644 index 0000000..54f72cd --- /dev/null +++ b/rapids_dask_dependency/patches/dask/dataframe/__rdd_patch_accessor.py @@ -0,0 +1,450 @@ +from __future__ import annotations + +import functools +import warnings + +import numpy as np +import pandas as pd + +from dask.dataframe._compat import check_to_pydatetime_deprecation +from dask.utils import derived_from + + +def _bind_method(cls, pd_cls, attr, min_version=None): + def func(self, *args, **kwargs): + return self._function_map(attr, *args, **kwargs) + + func.__name__ = attr + func.__qualname__ = f"{cls.__name__}.{attr}" + try: + func.__wrapped__ = getattr(pd_cls, attr) + except Exception: + pass + setattr(cls, attr, derived_from(pd_cls, version=min_version)(func)) + + +def _bind_property(cls, pd_cls, attr, min_version=None): + def func(self): + return self._property_map(attr) + + func.__name__ = attr + func.__qualname__ = f"{cls.__name__}.{attr}" + original_prop = getattr(pd_cls, attr) + if isinstance(original_prop, property): + method = original_prop.fget + elif isinstance(original_prop, functools.cached_property): + method = original_prop.func + else: + method = original_prop + try: + func.__wrapped__ = method + except Exception: + pass + setattr(cls, attr, property(derived_from(pd_cls, version=min_version)(func))) + + +def maybe_wrap_pandas(obj, x): + if isinstance(x, np.ndarray): + if isinstance(obj, pd.Series): + return pd.Series(x, index=obj.index, dtype=x.dtype) + return pd.Index(x) + return x + + +class Accessor: + """ + Base class for pandas Accessor objects cat, dt, and str. + + Notes + ----- + Subclasses should define ``_accessor_name``, ``_accessor_methods``, and + ``_accessor_properties``. + """ + + def __init__(self, series): + from dask.dataframe.core import Series + + if not isinstance(series, Series): + raise ValueError("Accessor cannot be initialized") + + series_meta = series._meta + if hasattr(series_meta, "to_series"): # is index-like + series_meta = series_meta.to_series() + meta = getattr(series_meta, self._accessor_name) + + self._meta = meta + self._series = series + + def __init_subclass__(cls, **kwargs): + """Bind all auto-generated methods & properties""" + super().__init_subclass__(**kwargs) + pd_cls = getattr(pd.Series, cls._accessor_name) + for item in cls._accessor_methods: + attr, min_version = item if isinstance(item, tuple) else (item, None) + if not hasattr(cls, attr): + _bind_method(cls, pd_cls, attr, min_version) + for item in cls._accessor_properties: + attr, min_version = item if isinstance(item, tuple) else (item, None) + if not hasattr(cls, attr): + _bind_property(cls, pd_cls, attr, min_version) + + @staticmethod + def _delegate_property(obj, accessor, attr): + out = getattr(getattr(obj, accessor, obj), attr) + return maybe_wrap_pandas(obj, out) + + @staticmethod + def _delegate_method( + obj, accessor, attr, args, kwargs, catch_deprecation_warnings: bool = False + ): + with check_to_pydatetime_deprecation(catch_deprecation_warnings): + with warnings.catch_warnings(): + # Falling back on a non-pyarrow code path which may decrease performance + warnings.simplefilter("ignore", pd.errors.PerformanceWarning) + out = getattr(getattr(obj, accessor, obj), attr)(*args, **kwargs) + return maybe_wrap_pandas(obj, out) + + def _property_map(self, attr): + meta = self._delegate_property(self._series._meta, self._accessor_name, attr) + token = f"{self._accessor_name}-{attr}" + return self._series.map_partitions( + self._delegate_property, self._accessor_name, attr, token=token, meta=meta + ) + + def _function_map(self, attr, *args, **kwargs): + if "meta" in kwargs: + meta = kwargs.pop("meta") + else: + meta = self._delegate_method( + self._series._meta_nonempty, self._accessor_name, attr, args, kwargs + ) + token = f"{self._accessor_name}-{attr}" + return self._series.map_partitions( + self._delegate_method, + self._accessor_name, + attr, + args, + kwargs, + catch_deprecation_warnings=True, + meta=meta, + token=token, + ) + + +class DatetimeAccessor(Accessor): + """Accessor object for datetimelike properties of the Series values. + + Examples + -------- + + >>> s.dt.microsecond # doctest: +SKIP + """ + + _accessor_name = "dt" + + _accessor_methods = ( + "asfreq", + "ceil", + "day_name", + "floor", + "month_name", + "normalize", + "round", + "strftime", + "to_period", + "to_pydatetime", + "to_pytimedelta", + "to_timestamp", + "total_seconds", + "tz_convert", + "tz_localize", + ) + + _accessor_properties = ( + "components", + "date", + "day", + "day_of_week", + "day_of_year", + "dayofweek", + "dayofyear", + "days", + "days_in_month", + "daysinmonth", + "end_time", + "freq", + "hour", + "is_leap_year", + "is_month_end", + "is_month_start", + "is_quarter_end", + "is_quarter_start", + "is_year_end", + "is_year_start", + "microsecond", + "microseconds", + "minute", + "month", + "nanosecond", + "nanoseconds", + "quarter", + "qyear", + "second", + "seconds", + "start_time", + "time", + "timetz", + "tz", + "week", + "weekday", + "weekofyear", + "year", + ) + + @derived_from(pd.Series.dt) + def isocalendar(self): + # Sphinx can't solve types with dask-expr available so define explicitly, see + # https://github.com/sphinx-doc/sphinx/issues/4961 + return self._function_map("isocalendar") + + +class StringAccessor(Accessor): + """Accessor object for string properties of the Series values. + + Examples + -------- + + >>> s.str.lower() # doctest: +SKIP + """ + + _accessor_name = "str" + + _accessor_methods = ( + "capitalize", + "casefold", + "center", + "contains", + "count", + "decode", + "encode", + "find", + "findall", + "fullmatch", + "get", + "index", + "isalnum", + "isalpha", + "isdecimal", + "isdigit", + "islower", + "isnumeric", + "isspace", + "istitle", + "isupper", + "join", + "len", + "ljust", + "lower", + "lstrip", + "match", + "normalize", + "pad", + "partition", + ("removeprefix", "1.4"), + ("removesuffix", "1.4"), + "repeat", + "replace", + "rfind", + "rindex", + "rjust", + "rpartition", + "rstrip", + "slice", + "slice_replace", + "strip", + "swapcase", + "title", + "translate", + "upper", + "wrap", + "zfill", + ) + _accessor_properties = () + + def _split(self, method, pat=None, n=-1, expand=False): + if expand: + if n == -1: + raise NotImplementedError( + "To use the expand parameter you must specify the number of " + "expected splits with the n= parameter. Usually n splits " + "result in n+1 output columns." + ) + else: + delimiter = " " if pat is None else pat + meta = self._series._meta._constructor( + [delimiter.join(["a"] * (n + 1))], + index=self._series._meta_nonempty.iloc[:1].index, + ) + meta = getattr(meta.str, method)(n=n, expand=expand, pat=pat) + else: + meta = (self._series.name, object) + return self._function_map(method, pat=pat, n=n, expand=expand, meta=meta) + + @derived_from( + pd.Series.str, + inconsistencies="``expand=True`` with unknown ``n`` will raise a ``NotImplementedError``", + ) + def split(self, pat=None, n=-1, expand=False): + """Known inconsistencies: ``expand=True`` with unknown ``n`` will raise a ``NotImplementedError``.""" + return self._split("split", pat=pat, n=n, expand=expand) + + @derived_from(pd.Series.str) + def rsplit(self, pat=None, n=-1, expand=False): + return self._split("rsplit", pat=pat, n=n, expand=expand) + + @derived_from(pd.Series.str) + def cat(self, others=None, sep=None, na_rep=None): + from dask.dataframe.core import Index, Series + + if others is None: + + def str_cat_none(x): + if isinstance(x, (Series, Index)): + x = x.compute() + + return x.str.cat(sep=sep, na_rep=na_rep) + + return self._series.reduction(chunk=str_cat_none, aggregate=str_cat_none) + + valid_types = (Series, Index, pd.Series, pd.Index) + if isinstance(others, valid_types): + others = [others] + elif not all(isinstance(a, valid_types) for a in others): + raise TypeError("others must be Series/Index") + + return self._series.map_partitions( + str_cat, *others, sep=sep, na_rep=na_rep, meta=self._series._meta + ) + + @derived_from(pd.Series.str) + def extractall(self, pat, flags=0): + return self._series.map_partitions( + str_extractall, pat, flags, token="str-extractall" + ) + + def __getitem__(self, index): + return self._series.map_partitions(str_get, index, meta=self._series._meta) + + @derived_from(pd.Series.str) + def extract(self, *args, **kwargs): + # Sphinx can't solve types with dask-expr available so define explicitly, see + # https://github.com/sphinx-doc/sphinx/issues/4961 + return self._function_map("extract", *args, **kwargs) + + @derived_from(pd.Series.str) + def startswith(self, *args, **kwargs): + # Sphinx can't solve types with dask-expr available so define explicitly, see + # https://github.com/sphinx-doc/sphinx/issues/4961 + return self._function_map("startswith", *args, **kwargs) + + @derived_from(pd.Series.str) + def endswith(self, *args, **kwargs): + # Sphinx can't solve types with dask-expr available so define explicitly, see + # https://github.com/sphinx-doc/sphinx/issues/4961 + return self._function_map("endswith", *args, **kwargs) + + +def str_extractall(series, pat, flags): + return series.str.extractall(pat, flags=flags) + + +def str_get(series, index): + """Implements series.str[index]""" + return series.str[index] + + +def str_cat(self, *others, **kwargs): + return self.str.cat(others=others, **kwargs) + + +# Ported from pandas +# https://github.com/pandas-dev/pandas/blob/master/pandas/core/accessor.py +class CachedAccessor: + """ + Custom property-like object (descriptor) for caching accessors. + + Parameters + ---------- + name : str + The namespace this will be accessed under, e.g. ``df.foo`` + accessor : cls + The class with the extension methods. The class' __init__ method + should expect one of a ``Series``, ``DataFrame`` or ``Index`` as + the single argument ``data`` + """ + + def __init__(self, name, accessor): + self._name = name + self._accessor = accessor + + def __get__(self, obj, cls): + if obj is None: + # we're accessing the attribute of the class, i.e., Dataset.geo + return self._accessor + accessor_obj = self._accessor(obj) + # Replace the property with the accessor object. Inspired by: + # http://www.pydanny.com/cached-property.html + # We need to use object.__setattr__ because we overwrite __setattr__ on + # NDFrame + object.__setattr__(obj, self._name, accessor_obj) + return accessor_obj + + +def _register_accessor(name, cls): + def decorator(accessor): + if hasattr(cls, name): + warnings.warn( + "registration of accessor {!r} under name {!r} for type " + "{!r} is overriding a preexisting attribute with the same " + "name.".format(accessor, name, cls), + UserWarning, + stacklevel=2, + ) + setattr(cls, name, CachedAccessor(name, accessor)) + cls._accessors.add(name) + return accessor + + return decorator + + +def register_dataframe_accessor(name): + """ + Register a custom accessor on :class:`dask.dataframe.DataFrame`. + + See :func:`pandas.api.extensions.register_dataframe_accessor` for more. + """ + from dask.dataframe import DataFrame + + return _register_accessor(name, DataFrame) + + +def register_series_accessor(name): + """ + Register a custom accessor on :class:`dask.dataframe.Series`. + + See :func:`pandas.api.extensions.register_series_accessor` for more. + """ + from dask.dataframe import Series + + return _register_accessor(name, Series) + + +def register_index_accessor(name): + """ + Register a custom accessor on :class:`dask.dataframe.Index`. + + See :func:`pandas.api.extensions.register_index_accessor` for more. + """ + from dask.dataframe import Index + + return _register_accessor(name, Index) diff --git a/rapids_dask_dependency/patches/dask/dataframe/accessor.py b/rapids_dask_dependency/patches/dask/dataframe/accessor.py new file mode 100644 index 0000000..1968f97 --- /dev/null +++ b/rapids_dask_dependency/patches/dask/dataframe/accessor.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +import sys + +if sys.version_info >= (3, 11, 9): + from dask import __version__ + from packaging.version import Version + + if Version(__version__) < Version("2024.4.1"): + from rapids_dask_dependency.importer import VendoredImporter + + # Currently vendoring this module due to https://github.com/dask/dask/pull/11035 + _importer = VendoredImporter(__name__) + load_module = _importer.load_module diff --git a/rapids_dask_dependency/patches/distributed/__init__.py b/rapids_dask_dependency/patches/distributed/__init__.py new file mode 100644 index 0000000..5f7dc38 --- /dev/null +++ b/rapids_dask_dependency/patches/distributed/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from rapids_dask_dependency.importer import MonkeyPatchImporter + +_importer = MonkeyPatchImporter(__name__, lambda _: None) +load_module = _importer.load_module diff --git a/rapids_dask_dependency/utils.py b/rapids_dask_dependency/utils.py new file mode 100644 index 0000000..36d35bb --- /dev/null +++ b/rapids_dask_dependency/utils.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +import warnings +from contextlib import contextmanager +from functools import lru_cache + +original_warn = warnings.warn + + +@lru_cache +def _make_warning_func(level): + def _warning_with_increased_stacklevel( + message, category=None, stacklevel=1, source=None, **kwargs + ): + # Patch warnings to have the right stacklevel + original_warn(message, category, stacklevel + level, source, **kwargs) + + return _warning_with_increased_stacklevel + + +@contextmanager +def patch_warning_stacklevel(level): + previous_warn = warnings.warn + warnings.warn = _make_warning_func(level) + yield + warnings.warn = previous_warn + + +# Note: The Python documentation does not make it clear whether we're guaranteed that +# spec is not a copy of the original spec, but that is the case for now. We need to +# assign this because the spec is used to update module attributes after it is +# initialized by create_module. +def update_spec(spec, original_spec): + spec.origin = original_spec.origin + spec.submodule_search_locations = original_spec.submodule_search_locations + return spec diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..87119dc --- /dev/null +++ b/setup.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +import os + +from setuptools import setup +from setuptools.command.build_py import build_py + + +# Adapted from https://stackoverflow.com/a/71137790 +class build_py_with_pth_file(build_py): # noqa: N801 + """Include the .pth file in the generated wheel.""" + + def run(self): + super().run() + + fn = "_rapids_dask_dependency.pth" + + outfile = os.path.join(self.build_lib, fn) + self.copy_file(fn, outfile, preserve_mode=0) + + +setup(cmdclass={"build_py": build_py_with_pth_file}) diff --git a/tests/test_patch.py b/tests/test_patch.py new file mode 100644 index 0000000..4c611fd --- /dev/null +++ b/tests/test_patch.py @@ -0,0 +1,51 @@ +import contextlib +import tempfile +from functools import wraps +from multiprocessing import Process + + +def run_test_in_subprocess(func): + def redirect_stdout_stderr(func, stdout, stderr, *args, **kwargs): + with open(stdout, "w") as stdout_file, open(stderr, "w") as stderr_file: + with contextlib.redirect_stdout(stdout_file), contextlib.redirect_stderr( + stderr_file + ): + func(*args, **kwargs) + + @wraps(func) + def wrapper(*args, **kwargs): + with tempfile.NamedTemporaryFile( + mode="w+" + ) as stdout, tempfile.NamedTemporaryFile(mode="w+") as stderr: + p = Process( + target=redirect_stdout_stderr, + args=(func, stdout.name, stderr.name, *args), + kwargs=kwargs, + ) + p.start() + p.join() + stdout_log = stdout.file.read() + stderr_log = stderr.file.read() + if p.exitcode != 0: + msg = f"Process exited {p.exitcode}." + if stdout_log: + msg += f"\nstdout:\n{stdout_log}" + if stderr_log: + msg += f"\nstderr:\n{stderr_log}" + raise RuntimeError(msg) + + return wrapper + + +@run_test_in_subprocess +def test_dask(): + import dask + + assert hasattr(dask, "_rapids_patched") + + +@run_test_in_subprocess +def test_distributed(): + import distributed + + assert hasattr(distributed, "_rapids_patched")