Skip to content

Commit

Permalink
Backport patching functionality from 24.06 (#41)
Browse files Browse the repository at this point in the history
This PR backports #27, #37, and #39 to 24.04

---------

Signed-off-by: Vyas Ramasubramani <[email protected]>
Co-authored-by: Richard (Rick) Zamora <[email protected]>
Co-authored-by: Bradley Dice <[email protected]>
  • Loading branch information
3 people authored Apr 5, 2024
1 parent 3d6efa0 commit 99d37eb
Show file tree
Hide file tree
Showing 22 changed files with 803 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ build/
wheels/
*.egg-info/
*.egg
*.whl
24 changes: 24 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

# Always exclude vendored files from linting
exclude: ".*__rdd_patch_.*"

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
hooks:
- id: codespell
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
hooks:
- id: ruff
args: ["--fix"]
- id: ruff-format

default_language_version:
python: python3
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The `rapids-dask-dependency` package encodes both `dask` and `distributed` requi
# Versioning the Metapackage Itself

This package is versioned just like the rest of RAPIDS: using CalVer, with alpha tags (trailing a\*) for nightlies.
Nightlies of the metapackage should be consumed just like nightlies of any other RAPIDS package:
Nightlies of the metapackage should be consumed just like nightlies of any other RAPIDS package:
- conda packages should pin up to the minor version with a trailing `.*`, i.e. `==23.10.*`. Conda will allow nightlies to match, so no further intervention is needed.
- pip packages should have the same pin, but wheel building scripts must add an alpha spec `>=0.0.0a0` when building nightlies to allow rapids-dask-dependency nightlies. This is the same strategy used to have RAPIDS repositories pull nightly versions of other RAPIDS dependencies (e.g. `cudf` requires `rmm` nightlies).

Expand All @@ -31,3 +31,11 @@ At release, these dependencies will be pinned to the desired versions.
Note that encoding direct URLs as above is technically prohibited by the [Python packaging specifications](https://packaging.python.org/en/latest/specifications/version-specifiers/#direct-references).
However, while PyPI enforces this, the RAPIDS nightly index does not.
Therefore, use of this versioning strategy currently prohibits rapids-dask-dependency nightlies from being uploaded to PyPI, and they must be hosted on the RAPIDS nightly pip index.

# Patching

In addition to functioning as a metapackage, `rapids-dask-dependency` also includes code for patching dask itself.
This package is never intended to be manually imported by the user.
Instead, upon installation it installs a `.pth` file (see the [site module documentation](https://docs.python.org/3.11/library/site.html) for how these work) that will be run whenever the Python interpreter starts.
This file installs a custom [meta path loader](https://docs.python.org/3/reference/import.html#the-meta-path) that intercepts all calls to import dask modules.
This loader is set up to apply RAPIDS-specific patches to the modules, ensuring that regardless of import order issues dask modules will always be patched for RAPIDS-compatibility in environments where RAPIDS packages are installed.
1 change: 1 addition & 0 deletions _rapids_dask_dependency.pth
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import rapids_dask_dependency
15 changes: 8 additions & 7 deletions ci/build_wheel.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
#!/bin/bash
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.

set -euo pipefail

source rapids-configure-sccache
source rapids-date-string

package_name=rapids-dask-dependency
package_dir="pip/${package_name}"
version=$(rapids-generate-version)

sed -i "s/^version = .*/version = \"${version}\"/g" "${package_dir}/pyproject.toml"
sed -i "s/^version = .*/version = \"${version}\"/g" "pyproject.toml"

cd "${package_dir}"
python -m pip wheel . -w dist -vvv --no-deps --disable-pip-version-check
python -m pip wheel . -w dist -vv --no-deps --disable-pip-version-check

RAPIDS_PY_WHEEL_NAME="${package_name}" RAPIDS_PY_WHEEL_PURE="1" rapids-upload-wheels-to-s3 dist
RAPIDS_PY_WHEEL_NAME="rapids-dask-dependency" RAPIDS_PY_WHEEL_PURE="1" rapids-upload-wheels-to-s3 dist

# Run tests
python -m pip install $(ls dist/*.whl)[test]
python -m pytest -v tests/
19 changes: 16 additions & 3 deletions conda/recipes/rapids-dask-dependency/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,22 @@ source:

build:
number: 0
noarch: generic
noarch: python
script: python -m pip install . -vv --no-deps

test:
requires:
- pip
- pytest
source_files:
- tests/

requirements:
host:
- pip
- python >=3.9
- setuptools
- conda-verify
run:
- dask ==2024.1.1
- dask-core ==2024.1.1
Expand All @@ -29,7 +42,7 @@ about:
This metapackage encodes the standard Dask version pinning used for a
particular release of RAPIDS. The metapackage adds an extra release segment
to the RAPIDS CalVer to allow pinnings in this metapackage to be updated
for a given release and automatically propagate to its dependents.
for a given release and automatically propagate to its dependents. It also
includes any patches to dask required for RAPIDS to function correctly.
doc_url: https://docs.rapids.ai/
dev_url: https://github.com/rapidsai/rapids_dask_dependency

4 changes: 4 additions & 0 deletions conda/recipes/rapids-dask-dependency/run_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
# Copyright (c) 2024, NVIDIA CORPORATION.

python -m pytest -v tests/
1 change: 1 addition & 0 deletions conda/recipes/rapids-dask-dependency/tests
1 change: 0 additions & 1 deletion pip/rapids-dask-dependency/LICENSE

This file was deleted.

1 change: 0 additions & 1 deletion pip/rapids-dask-dependency/README.md

This file was deleted.

14 changes: 13 additions & 1 deletion pip/rapids-dask-dependency/pyproject.toml → pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.

[build-system]
build-backend = "setuptools.build_meta"
Expand All @@ -19,5 +19,17 @@ dependencies = [
license = { text = "Apache 2.0" }
readme = { file = "README.md", content-type = "text/markdown" }

[project.optional-dependencies]
test = [
"pytest",
]

[tool.setuptools]
license-files = ["LICENSE"]

[tool.setuptools.packages.find]
include = ["rapids_dask_dependency*"]

[tool.ruff]
lint.select = ["E", "F", "W", "I", "N", "UP"]
lint.fixable = ["ALL"]
5 changes: 5 additions & 0 deletions rapids_dask_dependency/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from .dask_loader import DaskLoader

DaskLoader.install()
82 changes: 82 additions & 0 deletions rapids_dask_dependency/dask_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import importlib
import importlib.abc
import importlib.machinery
import importlib.util
import sys
from contextlib import contextmanager

from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec


class DaskLoader(importlib.abc.MetaPathFinder, importlib.abc.Loader):
def __init__(self):
self._blocklist = set()

def create_module(self, spec):
if spec.name.startswith("dask") or spec.name.startswith("distributed"):
with self.disable(spec.name):
try:
# Absolute import is important here to avoid shadowing the real dask
# and distributed modules in sys.modules. Bad things will happen if
# we use relative imports here.
proxy = importlib.import_module(
f"rapids_dask_dependency.patches.{spec.name}"
)
if hasattr(proxy, "load_module"):
return proxy.load_module(spec)
except ModuleNotFoundError:
pass

# Three extra stack frames: 1) DaskLoader.create_module,
# 2) importlib.import_module, and 3) the patched warnings function (not
# including the internal frames, which warnings ignores).
with patch_warning_stacklevel(3):
mod = importlib.import_module(spec.name)

update_spec(spec, mod.__spec__)
return mod

def exec_module(self, _):
pass

@contextmanager
def disable(self, name):
# This is a context manager that prevents this finder from intercepting calls to
# import a specific name. We must do this to avoid infinite recursion when
# calling import_module in create_module. However, we cannot blanket disable the
# finder because that causes it to be bypassed when transitive imports occur
# within import_module.
try:
self._blocklist.add(name)
yield
finally:
self._blocklist.remove(name)

def find_spec(self, fullname: str, _, __=None):
if fullname in self._blocklist:
return None
if (
fullname in ("dask", "distributed")
or fullname.startswith("dask.")
or fullname.startswith("distributed.")
):
return importlib.machinery.ModuleSpec(
name=fullname,
loader=self,
# Set these parameters dynamically in create_module
origin=None,
loader_state=None,
is_package=True,
)
return None

@classmethod
def install(cls):
try:
(self,) = (obj for obj in sys.meta_path if isinstance(obj, cls))
except ValueError:
self = cls()
sys.meta_path.insert(0, self)
return self
55 changes: 55 additions & 0 deletions rapids_dask_dependency/importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import importlib
import importlib.util
from abc import abstractmethod

from rapids_dask_dependency.utils import patch_warning_stacklevel, update_spec


class BaseImporter:
@abstractmethod
def load_module(self, spec):
pass


class MonkeyPatchImporter(BaseImporter):
"""The base importer for modules that are monkey-patched."""

def __init__(self, name, patch_func):
self.name = name.replace("rapids_dask_dependency.patches.", "")
self.patch_func = patch_func

def load_module(self, spec):
# Four extra stack frames: 1) DaskLoader.create_module, 2)
# MonkeyPatchImporter.load_module, 3) importlib.import_module, and 4) the
# patched warnings function (not including the internal frames, which warnings
# ignores).
with patch_warning_stacklevel(4):
mod = importlib.import_module(self.name)
self.patch_func(mod)
update_spec(spec, mod.__spec__)
mod._rapids_patched = True
return mod


class VendoredImporter(BaseImporter):
"""The base importer for vendored modules."""

# Vendored files use a standard prefix to avoid name collisions.
default_prefix = "__rdd_patch_"

def __init__(self, module):
self.real_module_name = module.replace("rapids_dask_dependency.patches.", "")
module_parts = module.split(".")
module_parts[-1] = self.default_prefix + module_parts[-1]
self.vendored_module_name = ".".join(module_parts)

def load_module(self, spec):
vendored_module = importlib.import_module(self.vendored_module_name)
# At this stage the module loader must have been disabled for this module, so we
# can access the original module. We don't want to actually import it, we just
# want enough information on it to update the spec.
original_spec = importlib.util.find_spec(self.real_module_name)
update_spec(spec, original_spec)
return vendored_module
1 change: 1 addition & 0 deletions rapids_dask_dependency/patches/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
6 changes: 6 additions & 0 deletions rapids_dask_dependency/patches/dask/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from rapids_dask_dependency.importer import MonkeyPatchImporter

_importer = MonkeyPatchImporter(__name__, lambda _: None)
load_module = _importer.load_module
Loading

0 comments on commit 99d37eb

Please sign in to comment.