Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable patching #27

Merged
merged 37 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
fdb9cbc
Move pyproject to top-level
vyasr Feb 26, 2024
966d2be
Make it a Python package
vyasr Feb 26, 2024
7fc5fd5
Add initial version of loader
vyasr Feb 27, 2024
3e08e51
Rename package so that it's importable
vyasr Feb 27, 2024
e0606cc
Move installation into the package for testing
vyasr Feb 27, 2024
3cce186
Implement basic patching
vyasr Feb 27, 2024
1e7c735
Merge the finder and loader and import the true packages under the hood
vyasr Feb 27, 2024
e2fac54
Generalize to work for an arbitrary number of patches
vyasr Feb 27, 2024
611832f
Add comment
vyasr Feb 27, 2024
8afa0fc
Add pth file and install it so that the package automatically configu…
vyasr Feb 27, 2024
4803505
Also ignore wheel files
vyasr Feb 27, 2024
c48b28a
Update wheel-building CI script
vyasr Feb 27, 2024
896b3b8
Add pre-commit config
vyasr Feb 27, 2024
36fceb2
Apply linter
vyasr Feb 27, 2024
d86b788
Move patching to create_module
vyasr Feb 27, 2024
4e286ea
Also load submodules
vyasr Feb 27, 2024
09e84f8
Add test module
vyasr Feb 27, 2024
c0fabe3
Fix formatting
vyasr Feb 27, 2024
ef6a2ab
Patch warning levels
vyasr Mar 6, 2024
74aed8f
Change module name and set real attribute
vyasr Mar 6, 2024
a4681a0
Add documentation of patching
vyasr Mar 6, 2024
99d4fe5
Install package for conda build
vyasr Mar 6, 2024
ada0aa9
Fix tests
vyasr Mar 6, 2024
de03ad3
Add test runner for conda
vyasr Mar 6, 2024
b5c0c37
Some recipe fixes
vyasr Mar 6, 2024
b146685
Try adding a host section for the build
vyasr Mar 6, 2024
f863f8b
Add conda-verify
vyasr Mar 6, 2024
c56f961
Remove tests section and rely on run_tests.sh for testing
vyasr Mar 6, 2024
0c8d6a4
Rename file to match conda conventions
vyasr Mar 6, 2024
80ca4c1
Add test requires to conda testing env
vyasr Mar 6, 2024
d7629c1
Add pyproject extra for testing and run tests in wheel
vyasr Mar 6, 2024
88b9603
Try to fix test paths
vyasr Mar 6, 2024
6a26fe2
Test path
vyasr Mar 6, 2024
7d2b9fb
More debugging
vyasr Mar 6, 2024
bb83337
Make conda add test files to test tree
vyasr Mar 6, 2024
04e4763
Merge remote-tracking branch 'upstream/branch-24.06' into feat/patching
vyasr Mar 18, 2024
c97a7bc
Apply suggestions from code review
vyasr Mar 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ build/
wheels/
*.egg-info/
*.egg
*.whl
21 changes: 21 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
hooks:
- id: codespell
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
hooks:
- id: ruff
args: ["--fix"]
- id: ruff-format

default_language_version:
python: python3
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The `rapids-dask-dependency` package encodes both `dask` and `distributed` requi
# Versioning the Metapackage Itself

This package is versioned just like the rest of RAPIDS: using CalVer, with alpha tags (trailing a\*) for nightlies.
Nightlies of the metapackage should be consumed just like nightlies of any other RAPIDS package:
Nightlies of the metapackage should be consumed just like nightlies of any other RAPIDS package:
- conda packages should pin up to the minor version with a trailing `.*`, i.e. `==23.10.*`. Conda will allow nightlies to match, so no further intervention is needed.
- pip packages should have the same pin, but wheel building scripts must add an alpha spec `>=0.0.0a0` when building nightlies to allow rapids-dask-dependency nightlies. This is the same strategy used to have RAPIDS repositories pull nightly versions of other RAPIDS dependencies (e.g. `cudf` requires `rmm` nightlies).

Expand All @@ -31,3 +31,11 @@ At release, these dependencies will be pinned to the desired versions.
Note that encoding direct URLs as above is technically prohibited by the [Python packaging specifications](https://packaging.python.org/en/latest/specifications/version-specifiers/#direct-references).
However, while PyPI enforces this, the RAPIDS nightly index does not.
Therefore, use of this versioning strategy currently prohibits rapids-dask-dependency nightlies from being uploaded to PyPI, and they must be hosted on the RAPIDS nightly pip index.

# Patching

In addition to functioning as a metapackage, `rapids-dask-dependency` also includes code for patching dask itself.
This package is never intended to be manually imported by the user.
Instead, upon installation it installs a `.pth` file (see the [site module documentation](https://docs.python.org/3.11/library/site.html) for how these work) that will be run whenever the Python interpreter starts.
This file installs a custom [meta path loader](https://docs.python.org/3/reference/import.html#the-meta-path) that intercepts all calls to import dask modules.
This loader is set up to apply RAPIDS-specific patches to the modules, ensuring that regardless of import order issues dask modules will always be patched for RAPIDS-compatibility in environments where RAPIDS packages are installed.
1 change: 1 addition & 0 deletions _rapids_dask_dependency.pth
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import rapids_dask_dependency
15 changes: 8 additions & 7 deletions ci/build_wheel.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
#!/bin/bash
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.

set -euo pipefail

source rapids-configure-sccache
source rapids-date-string

package_name=rapids-dask-dependency
package_dir="pip/${package_name}"
version=$(rapids-generate-version)

sed -i "s/^version = .*/version = \"${version}\"/g" "${package_dir}/pyproject.toml"
sed -i "s/^version = .*/version = \"${version}\"/g" "pyproject.toml"

cd "${package_dir}"
python -m pip wheel . -w dist -vvv --no-deps --disable-pip-version-check
python -m pip wheel . -w dist -vv --no-deps --disable-pip-version-check

RAPIDS_PY_WHEEL_NAME="${package_name}" RAPIDS_PY_WHEEL_PURE="1" rapids-upload-wheels-to-s3 dist
RAPIDS_PY_WHEEL_NAME="rapids-dask-dependency" RAPIDS_PY_WHEEL_PURE="1" rapids-upload-wheels-to-s3 dist

# Run tests
python -m pip install $(ls dist/*.whl)[test]
python -m pytest -v tests/
19 changes: 16 additions & 3 deletions conda/recipes/rapids-dask-dependency/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,22 @@ source:

build:
number: 0
noarch: generic
noarch: python
script: python -m pip install . -vv --no-deps

test:
requires:
- pip
- pytest
source_files:
- tests/

requirements:
host:
- pip
- python >=3.9
- setuptools
- conda-verify
run:
- dask ==2024.1.1
- dask-core ==2024.1.1
Expand All @@ -29,7 +42,7 @@ about:
This metapackage encodes the standard Dask version pinning used for a
particular release of RAPIDS. The metapackage adds an extra release segment
to the RAPIDS CalVer to allow pinnings in this metapackage to be updated
for a given release and automatically propagate to its dependents.
for a given release and automatically propagate to its dependents. It also
includes any patches to dask required for RAPIDS to function correctly.
doc_url: https://docs.rapids.ai/
dev_url: https://github.com/rapidsai/rapids_dask_dependency

4 changes: 4 additions & 0 deletions conda/recipes/rapids-dask-dependency/run_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
# Copyright (c) 2024, NVIDIA CORPORATION.

python -m pytest -v tests/
1 change: 1 addition & 0 deletions conda/recipes/rapids-dask-dependency/tests
1 change: 0 additions & 1 deletion pip/rapids-dask-dependency/LICENSE

This file was deleted.

1 change: 0 additions & 1 deletion pip/rapids-dask-dependency/README.md

This file was deleted.

14 changes: 13 additions & 1 deletion pip/rapids-dask-dependency/pyproject.toml → pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.

[build-system]
build-backend = "setuptools.build_meta"
Expand All @@ -19,5 +19,17 @@ dependencies = [
license = { text = "Apache 2.0" }
readme = { file = "README.md", content-type = "text/markdown" }

[project.optional-dependencies]
test = [
"pytest",
]

[tool.setuptools]
license-files = ["LICENSE"]

[tool.setuptools.packages.find]
include = ["rapids_dask_dependency*"]

[tool.ruff]
lint.select = ["E", "F", "W", "I", "N", "UP"]
lint.fixable = ["ALL"]
5 changes: 5 additions & 0 deletions rapids_dask_dependency/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from .dask_loader import DaskLoader

DaskLoader.install()
87 changes: 87 additions & 0 deletions rapids_dask_dependency/dask_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import importlib
import importlib.abc
import importlib.machinery
import sys
import warnings
from contextlib import contextmanager

from .patches.dask import patches as dask_patches
from .patches.distributed import patches as distributed_patches

original_warn = warnings.warn


def _warning_with_increased_stacklevel(
message, category=None, stacklevel=1, source=None, **kwargs
):
# Patch warnings to have the right stacklevel
# Add 3 to the stacklevel to account for the 3 extra frames added by the loader: one
# in this warnings function, one in the actual loader, and one in the importlib
# call (not including all internal frames).
original_warn(message, category, stacklevel + 3, source, **kwargs)


@contextmanager
def patch_warning_stacklevel():
warnings.warn = _warning_with_increased_stacklevel
yield
warnings.warn = original_warn


class DaskLoader(importlib.abc.MetaPathFinder, importlib.abc.Loader):
def create_module(self, spec):
if spec.name.startswith("dask") or spec.name.startswith("distributed"):
with self.disable(), patch_warning_stacklevel():
mod = importlib.import_module(spec.name)

# Note: The spec does not make it clear whether we're guaranteed that spec
# is not a copy of the original spec, but that is the case for now. We need
# to assign this because the spec is used to update module attributes after
# it is initialized by create_module.
spec.origin = mod.__spec__.origin
spec.submodule_search_locations = mod.__spec__.submodule_search_locations

# TODO: I assume we'll want to only apply patches to specific submodules,
# that'll be up to RAPIDS dask devs to decide.
patches = dask_patches if "dask" in spec.name else distributed_patches
for patch in patches:
patch(mod)
return mod

def exec_module(self, _):
pass

@contextmanager
def disable(self):
sys.meta_path.remove(self)
try:
yield
finally:
sys.meta_path.insert(0, self)

def find_spec(self, fullname: str, _, __=None):
if (
fullname in ("dask", "distributed")
or fullname.startswith("dask.")
or fullname.startswith("distributed.")
):
return importlib.machinery.ModuleSpec(
name=fullname,
loader=self,
# Set these parameters dynamically in create_module
origin=None,
loader_state=None,
is_package=True,
)
return None

@classmethod
def install(cls):
try:
(self,) = (obj for obj in sys.meta_path if isinstance(obj, cls))
except ValueError:
self = cls()
sys.meta_path.insert(0, self)
return self
1 change: 1 addition & 0 deletions rapids_dask_dependency/patches/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
5 changes: 5 additions & 0 deletions rapids_dask_dependency/patches/dask/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from .add_patch_attr import add_patch_attr

patches = [add_patch_attr]
5 changes: 5 additions & 0 deletions rapids_dask_dependency/patches/dask/add_patch_attr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION.


def add_patch_attr(mod):
mod._rapids_patched = True
5 changes: 5 additions & 0 deletions rapids_dask_dependency/patches/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from .add_patch_attr import add_patch_attr

patches = [add_patch_attr]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION.


def add_patch_attr(mod):
mod._rapids_patched = True
21 changes: 21 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
import os

from setuptools import setup
from setuptools.command.build_py import build_py


# Adapted from https://stackoverflow.com/a/71137790
class build_py_with_pth_file(build_py): # noqa: N801
"""Include the .pth file in the generated wheel."""

def run(self):
super().run()

fn = "_rapids_dask_dependency.pth"

outfile = os.path.join(self.build_lib, fn)
self.copy_file(fn, outfile, preserve_mode=0)


setup(cmdclass={"build_py": build_py_with_pth_file})
26 changes: 26 additions & 0 deletions tests/test_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from functools import wraps
from multiprocessing import Process


def run_test_in_subprocess(func):
@wraps(func)
def wrapper(*args, **kwargs):
p = Process(target=func, args=args, kwargs=kwargs)
p.start()
p.join()

return wrapper


@run_test_in_subprocess
def test_dask():
import dask

assert hasattr(dask, "_rapids_patched")


@run_test_in_subprocess
def test_distributed():
import distributed

assert hasattr(distributed, "_rapids_patched")
Loading